thesizer / app.py
sakuexe
tweaked the code a bit to make answering faster
0b367ea
raw
history blame
6.51 kB
# the following code is collected from this hugging face tutorial
# https://huggingface.co/learn/cookbook/rag_zephyr_langchain
# langchain
from typing import TypedDict
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_huggingface import HuggingFacePipeline
# huggingface
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline
# pytorch
import torch
# gradio
import gradio as gr
# stdlib
from asyncio import sleep
# local
from vector_store import get_document_database
class ChatMessage(TypedDict):
role: str
metadata: dict
content: str
# MODEL_NAME = "meta-llama/Llama-3.2-3B"
# MODEL_NAME = "google/gemma-7b"
MODEL_NAME = "google/gemma-2-2b-it"
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_use_double_quant=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.bfloat16
# )
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
# quantization_config=bnb_config,
# device_map="cpu",
torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
text_generation_pipeline = pipeline(
model=model,
tokenizer=tokenizer,
task="text-generation",
temperature=0.2,
do_sample=True,
repetition_penalty=1.1,
return_full_text=True,
max_new_tokens=400,
)
llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
# generate a vector store
print("creating the document database")
db = get_document_database("learning_material/*/*/*")
print("Document database is ready")
def generate_prompt(message_history: list[ChatMessage], max_history=5):
# creating the prompt template in the shape of a chat prompt
# this is done so that it could be easily expanded
# https://www.mirascope.com/post/langchain-prompt-template
prompt_template = ChatPromptTemplate([
("system", """You are 'thesizer', a HAMK thesis assistant.
You will help the user with technicalities on writing a thesis
for hamk. If you can't find the answer from the context given to you,
you will tell the user that you cannot assist with the specific topic.
You speak both Finnish and English by following the user's language.
Continue the conversation with a single response from the AI."""),
("system", "{context}"),
])
# include the examples in the prompt if the conversation is starting
if len(message_history) < 4:
prompt_template.append(
("assistant", "Hei! Kuinka voin auttaa opinnäytetyösi kanssa?"),
)
prompt_template.append(
("assistant", "Hello! How can I help you with authoring your thesis?"),
)
# add chat messages here (only include up to the max amount)
for message in message_history[-max_history:]:
prompt_template.append(
(message["role"], message["content"])
)
# this is here so that the stupid thing wont start roleplaying as the user
# and therefore making up the conversation
prompt_template.append(
("assistant", "<RESPONSE>:")
)
return prompt_template
async def generate_answer(message_history: list[ChatMessage]):
# initialize the similarity search
n_of_best_results = 4
retriever = db.as_retriever(
search_type="similarity", search_kwargs={"k": n_of_best_results})
print("generating prompt")
prompt = generate_prompt(message_history, max_history=5)
print("prompt is ready")
# create the pipeline for generating a response
# RunnablePassthrough handles the invoke parameters
retrieval_chain = (
{"context": retriever, "user_input": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
# fetch the context using the latest message as the fetch string
user_input = message_history[-1]["content"]
print("invoking")
response = retrieval_chain.invoke(user_input)
print("response recieved from invoke")
# debugging
print("=====raw response=====")
print(response)
# get the next response from the AI
# first parse until the last user input and then get the first response
parsed_answer = response.split(
str(user_input)
).pop().split("<RESPONSE>:", 1).pop().strip()
print(repr(parsed_answer))
# replace newlines with br tags, since the gr.chatbot does not work
# well with newlines
return parsed_answer.replace("\n\n", "<br>")
def update_chat(user_message: str, history: list):
return "", history + [{"role": "user", "content": user_message}]
async def handle_conversation(
history: list[ChatMessage],
characters_per_second=80
):
bot_message = await generate_answer(history)
new_message: ChatMessage = {"role": "assistant",
"metadata": {"title": None},
"content": ""}
history.append(new_message)
for character in bot_message:
history[-1]['content'] += character
yield history
await sleep(1 / characters_per_second)
def create_interface():
with gr.Blocks() as interface:
gr.Markdown("# 📄 Thesizer: HAMK Thesis Assistant")
gr.Markdown("Ask for help with authoring the HAMK thesis!")
gr.Markdown("## Chat interface")
with gr.Column():
chatbot = gr.Chatbot(type="messages")
with gr.Row():
user_input = gr.Textbox(
label="You:",
placeholder="Type your message here...",
show_label=False
)
send_button = gr.Button("Send")
# handle the messages being sent
send_button.click(
fn=update_chat,
inputs=[user_input, chatbot],
outputs=[user_input, chatbot],
queue=False
).then(
fn=handle_conversation,
inputs=chatbot,
outputs=chatbot
)
# pressing enter instead of the button
user_input.submit(
fn=update_chat,
inputs=[user_input, chatbot],
outputs=[user_input, chatbot],
queue=False
).then(
fn=handle_conversation,
inputs=chatbot,
outputs=chatbot
)
return interface
if __name__ == "__main__":
create_interface().launch()