Spaces:
Sleeping
Sleeping
File size: 6,447 Bytes
ae85316 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
# the following code is collected from this hugging face tutorial
# https://huggingface.co/learn/cookbook/rag_zephyr_langchain
# langchain
from typing import TypedDict
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_huggingface import HuggingFacePipeline
# huggingface
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import pipeline
# pytorch
import torch
# gradio
import gradio as gr
# stdlib
from asyncio import sleep
# local
from vector_store import get_document_database
class ChatMessage(TypedDict):
role: str
metadata: dict
content: str
# MODEL_NAME = "meta-llama/Llama-3.2-3B"
# MODEL_NAME = "google/gemma-7b"
MODEL_NAME = "google/gemma-2-2b-it"
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_use_double_quant=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.bfloat16
# )
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
# quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
text_generation_pipeline = pipeline(
model=model,
tokenizer=tokenizer,
task="text-generation",
temperature=0.2,
do_sample=True,
repetition_penalty=1.1,
return_full_text=True,
max_new_tokens=400,
)
llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
def generate_prompt(message_history: list[ChatMessage], max_history=5):
# creating the prompt template in the shape of a chat prompt
# this is done so that it could be easily expanded
# https://www.mirascope.com/post/langchain-prompt-template
prompt_template = ChatPromptTemplate([
("system", """You are 'thesizer', a HAMK thesis assistant.
You will help the user with technicalities on writing a thesis
for hamk. If you can't find the answer from the context given to you,
you will tell the user that you cannot assist with the specific topic.
You speak both Finnish and English by following the user's language.
Continue the conversation with a single response from the AI."""),
("system", "{context}"),
])
# include the examples in the prompt if the conversation is starting
if len(message_history) < 4:
prompt_template.append(
("assistant", "Hei! Kuinka voin auttaa opinnäytetyösi kanssa?"),
)
prompt_template.append(
("assistant", "Hello! How can I help you with authoring your thesis?"),
)
# add chat messages here (only include up to the max amount)
for message in message_history[-max_history:]:
prompt_template.append(
(message["role"], message["content"])
)
# this is here so that the stupid thing wont start roleplaying as the user
# and therefore making up the conversation
prompt_template.append(
("assistant", "<RESPONSE>:")
)
return prompt_template
async def generate_answer(message_history: list[ChatMessage]):
# generate a vector store
db = await get_document_database("learning_material/*/*/*")
# initialize the similarity search
n_of_best_results = 4
retriever = db.as_retriever(
search_type="similarity", search_kwargs={"k": n_of_best_results})
prompt = generate_prompt(message_history, max_history=5)
# create the pipeline for generating a response
# RunnablePassthrough handles the invoke parameters
retrieval_chain = (
{"context": retriever, "user_input": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
# fetch the context using the latest message as the fetch string
user_input = message_history[-1]["content"]
response = retrieval_chain.invoke(user_input)
# # debugging
# print("=====raw response=====")
# print(response)
# get the next response from the AI
# first parse until the last user input and then get the first response
parsed_answer = response.split(
str(user_input)
).pop().split("<RESPONSE>:", 1).pop().strip()
print(repr(parsed_answer))
# replace newlines with br tags, since the gr.chatbot does not work
# well with newlines
return parsed_answer.replace("\n\n", "<br>")
def update_chat(user_message: str, history: list):
return "", history + [{"role": "user", "content": user_message}]
async def handle_conversation(
history: list[ChatMessage],
characters_per_second=80
):
bot_message = await generate_answer(history)
new_message: ChatMessage = {"role": "assistant",
"metadata": {"title": None},
"content": ""}
history.append(new_message)
for character in bot_message:
history[-1]['content'] += character
yield history
await sleep(1 / characters_per_second)
def create_interface():
with gr.Blocks() as interface:
gr.Markdown("# 📄 Thesizer: HAMK Thesis Assistant")
gr.Markdown("Ask for help with authoring the HAMK thesis!")
gr.Markdown("## Chat interface")
with gr.Column():
chatbot = gr.Chatbot(type="messages")
with gr.Row():
user_input = gr.Textbox(
label="You:",
placeholder="Type your message here...",
show_label=False
)
send_button = gr.Button("Send")
# handle the messages being sent
send_button.click(
fn=update_chat,
inputs=[user_input, chatbot],
outputs=[user_input, chatbot],
queue=False
).then(
fn=handle_conversation,
inputs=chatbot,
outputs=chatbot
)
# pressing enter instead of the button
user_input.submit(
fn=update_chat,
inputs=[user_input, chatbot],
outputs=[user_input, chatbot],
queue=False
).then(
fn=handle_conversation,
inputs=chatbot,
outputs=chatbot
)
return interface
if __name__ == "__main__":
create_interface().launch()
|