"""prompt_templates_demo.py — templates, few-shot, chain-of-thought
Run with venv active:  python prompt_templates_demo.py
"""
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.messages import HumanMessage

load_dotenv()

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

questions = [
    "What does the HTML <a> tag do?",
    "What does the <title> tag do?",
]

# ChatPromptTemplate — system + human roles
chat_prompt = ChatPromptTemplate.from_messages([
    ("system", "Keep each reply to one short sentence."),
    ("human", "{question}"),
])

print("=== ChatPromptTemplate ===")
for q in questions:
    messages = chat_prompt.format_messages(question=q)
    reply = llm.invoke(messages)
    print(f"Q: {q}")
    print(f"A: {reply.content}\n")

# PromptTemplate — one string, then HumanMessage
text_prompt = PromptTemplate.from_template(
    "Keep each reply to one short sentence.\n\nQuestion: {question}"
)

print("=== PromptTemplate ===")
for q in questions:
    text = text_prompt.format(question=q)
    reply = llm.invoke([HumanMessage(content=text)])
    print(f"Q: {q}")
    print(f"A: {reply.content}\n")

# Few-shot — solved examples before the real question
few_shot_prompt = ChatPromptTemplate.from_messages([
    ("system", "Answer in one short sentence."),
    ("human", "What does the <p> tag do?"),
    ("ai", "The <p> tag marks a paragraph of text."),
    ("human", "What does the <h1> tag do?"),
    ("ai", "The <h1> tag marks the main heading on a page."),
    ("human", "{question}"),
])

print("=== Few-shot ===")
for q in questions:
    messages = few_shot_prompt.format_messages(question=q)
    reply = llm.invoke(messages)
    print(f"Q: {q}")
    print(f"A: {reply.content}\n")

# Chain-of-thought — short steps, then a final Answer line
cot_prompt = ChatPromptTemplate.from_messages([
    ("system", (
        "For each HTML question: write Step 1 and Step 2, "
        "then print Answer: on its own line."
    )),
    ("human", "{question}"),
])

print("=== Chain-of-thought ===")
for q in questions:
    messages = cot_prompt.format_messages(question=q)
    reply = llm.invoke(messages)
    print(f"Q: {q}")
    print(reply.content)
    print()