File size: 6,447 Bytes
ae85316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# the following code is collected from this hugging face tutorial
# https://huggingface.co/learn/cookbook/rag_zephyr_langchain
# langchain
from typing import TypedDict
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_huggingface import HuggingFacePipeline
# huggingface
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import pipeline
# pytorch
import torch
# gradio
import gradio as gr
# stdlib
from asyncio import sleep
# local
from vector_store import get_document_database


class ChatMessage(TypedDict):
    role: str
    metadata: dict
    content: str


# MODEL_NAME = "meta-llama/Llama-3.2-3B"
# MODEL_NAME = "google/gemma-7b"
MODEL_NAME = "google/gemma-2-2b-it"

# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.bfloat16
# )

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    # quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

text_generation_pipeline = pipeline(
    model=model,
    tokenizer=tokenizer,
    task="text-generation",
    temperature=0.2,
    do_sample=True,
    repetition_penalty=1.1,
    return_full_text=True,
    max_new_tokens=400,
)

llm = HuggingFacePipeline(pipeline=text_generation_pipeline)


def generate_prompt(message_history: list[ChatMessage], max_history=5):
    # creating the prompt template in the shape of a chat prompt
    # this is done so that it could be easily expanded
    # https://www.mirascope.com/post/langchain-prompt-template
    prompt_template = ChatPromptTemplate([
        ("system", """You are 'thesizer', a HAMK thesis assistant.
        You will help the user with technicalities on writing a thesis
        for hamk. If you can't find the answer from the context given to you,
        you will tell the user that you cannot assist with the specific topic.
        You speak both Finnish and English by following the user's language.
        Continue the conversation with a single response from the AI."""),
        ("system", "{context}"),
    ])

    # include the examples in the prompt if the conversation is starting
    if len(message_history) < 4:
        prompt_template.append(
            ("assistant", "Hei! Kuinka voin auttaa opinnäytetyösi kanssa?"),
        )
        prompt_template.append(
            ("assistant", "Hello! How can I help you with authoring your thesis?"),
        )

    # add chat messages here (only include up to the max amount)
    for message in message_history[-max_history:]:
        prompt_template.append(
            (message["role"], message["content"])
        )

    # this is here so that the stupid thing wont start roleplaying as the user
    # and therefore making up the conversation
    prompt_template.append(
        ("assistant", "<RESPONSE>:")
    )

    return prompt_template


async def generate_answer(message_history: list[ChatMessage]):
    # generate a vector store
    db = await get_document_database("learning_material/*/*/*")

    # initialize the similarity search
    n_of_best_results = 4
    retriever = db.as_retriever(
        search_type="similarity", search_kwargs={"k": n_of_best_results})

    prompt = generate_prompt(message_history, max_history=5)

    # create the pipeline for generating a response
    # RunnablePassthrough handles the invoke parameters
    retrieval_chain = (
        {"context": retriever, "user_input": RunnablePassthrough()}
        | prompt
        | llm
        | StrOutputParser()
    )

    # fetch the context using the latest message as the fetch string
    user_input = message_history[-1]["content"]
    response = retrieval_chain.invoke(user_input)

    # # debugging
    # print("=====raw response=====")
    # print(response)

    # get the next response from the AI
    # first parse until the last user input and then get the first response
    parsed_answer = response.split(
        str(user_input)
    ).pop().split("<RESPONSE>:", 1).pop().strip()

    print(repr(parsed_answer))

    # replace newlines with br tags, since the gr.chatbot does not work
    # well with newlines
    return parsed_answer.replace("\n\n", "<br>")


def update_chat(user_message: str, history: list):
    return "", history + [{"role": "user", "content": user_message}]


async def handle_conversation(
    history: list[ChatMessage],
    characters_per_second=80
):
    bot_message = await generate_answer(history)
    new_message: ChatMessage = {"role": "assistant",
                                "metadata": {"title": None},
                                "content": ""}
    history.append(new_message)
    for character in bot_message:
        history[-1]['content'] += character
        yield history
        await sleep(1 / characters_per_second)


def create_interface():
    with gr.Blocks() as interface:
        gr.Markdown("# 📄 Thesizer: HAMK Thesis Assistant")
        gr.Markdown("Ask for help with authoring the HAMK thesis!")

        gr.Markdown("## Chat interface")

        with gr.Column():
            chatbot = gr.Chatbot(type="messages")

            with gr.Row():
                user_input = gr.Textbox(
                    label="You:",
                    placeholder="Type your message here...",
                    show_label=False
                )
                send_button = gr.Button("Send")

        # handle the messages being sent
        send_button.click(
            fn=update_chat,
            inputs=[user_input, chatbot],
            outputs=[user_input, chatbot],
            queue=False
        ).then(
            fn=handle_conversation,
            inputs=chatbot,
            outputs=chatbot
        )

        # pressing enter instead of the button
        user_input.submit(
            fn=update_chat,
            inputs=[user_input, chatbot],
            outputs=[user_input, chatbot],
            queue=False
        ).then(
            fn=handle_conversation,
            inputs=chatbot,
            outputs=chatbot
        )

    return interface


if __name__ == "__main__":
    create_interface().launch()