Spaces:
Sleeping
Sleeping
import os | |
import pandas as pd | |
import gradio as gr | |
from langchain.vectorstores import Chroma | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.llms import OpenAI | |
from langchain.chat_models import ChatOpenAI | |
from langchain.chains import RetrievalQA | |
def Loading(): | |
return "๋ฐ์ดํฐ ๋ก๋ฉ ์ค..." | |
def LoadData(openai_key): | |
if openai_key is not None: | |
os.environ["OPENAI_API_KEY"] = openai_key | |
persist_directory = 'realdb_LLM' | |
embedding = OpenAIEmbeddings() | |
vectordb = Chroma( | |
persist_directory=persist_directory, | |
embedding_function=embedding | |
) | |
global retriever | |
retriever = vectordb.as_retriever(search_kwargs={"k": 1}) | |
return "์ค๋น ์๋ฃ" | |
else: | |
return "์ฌ์ฉํ์๋ API Key๋ฅผ ์ ๋ ฅํ์ฌ ์ฃผ์๊ธฐ ๋ฐ๋๋๋ค." | |
# ์ฑ๋ด์ ๋ต๋ณ์ ์ฒ๋ฆฌํ๋ ํจ์ | |
def respond(message, chat_history, temperature, top_p): | |
try: | |
print(temperature) | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=OpenAI(temperature=temperature, top_p=top_p), | |
# llm=OpenAI(temperature=0.4), | |
# llm=ChatOpenAI(temperature=0), | |
chain_type="stuff", | |
retriever=retriever | |
) | |
result = qa_chain(message) | |
bot_message = result['result'] | |
# ์ฑํ ๊ธฐ๋ก์ ์ฌ์ฉ์์ ๋ฉ์์ง์ ๋ด์ ์๋ต์ ์ถ๊ฐ. | |
chat_history.append((message, bot_message)) | |
return "", chat_history | |
except: | |
chat_history.append(("", "API Key ์ ๋ ฅ ์๋ง")) | |
return " ", chat_history | |
# ์ฑ๋ด ์ค๋ช | |
title = """ | |
<div style="text-align: center; max-width: 500px; margin: 0 auto;"> | |
<div> | |
<h1>Pretraining Chatbot V2 Real</h1> | |
</div> | |
<p style="margin-bottom: 10px; font-size: 94%"> | |
OpenAI LLM๋ฅผ ์ด์ฉํ Chatbot (Similarity) | |
</p> | |
</div> | |
""" | |
# ๊พธ๋ฏธ๊ธฐ | |
css=""" | |
#col-container {max-width: 700px; margin-left: auto; margin-right: auto;} | |
""" | |
with gr.Blocks(css=css) as UnivChatbot: | |
with gr.Column(elem_id="col-container"): | |
gr.HTML(title) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
openai_key = gr.Textbox(label="You OpenAI API key", type="password", placeholder="OpenAI Key Type", elem_id="InputKey", show_label=False, container=False) | |
with gr.Column(scale=1): | |
langchain_status = gr.Textbox(placeholder="Status", interactive=False, show_label=False, container=False) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
temperature = gr.Slider( | |
label="Temperature", | |
minimum=0, | |
maximum=2.0, | |
step=0.01, | |
value=0.7, | |
) | |
with gr.Column(scale=4): | |
top_p = gr.Slider( | |
label="Top_p", | |
minimum=0, | |
maximum=1, | |
step=0.01, | |
value=0.5, | |
) | |
with gr.Column(scale=1): | |
chk_key = gr.Button("ํ์ธ", variant="primary") | |
chatbot = gr.Chatbot(label="๋ํ ์ฑ๋ด์์คํ (OpenAI LLM)", elem_id="chatbot") # ์๋จ ์ข์ธก | |
with gr.Row(): | |
with gr.Column(scale=9): | |
msg = gr.Textbox(label="์ ๋ ฅ", placeholder="๊ถ๊ธํ์ ๋ด์ญ์ ์ ๋ ฅํ์ฌ ์ฃผ์ธ์.", elem_id="InputQuery", show_label=False, container=False) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
submit = gr.Button("์ ์ก", variant="primary") | |
with gr.Column(scale=1): | |
clear = gr.Button("์ด๊ธฐํ", variant="stop") | |
#chk_key.click(Loading, None, langchain_status, queue=False) | |
chk_key.click( | |
fn=LoadData, | |
inputs=[openai_key], | |
outputs=[langchain_status], | |
queue=False | |
) | |
# ์ฌ์ฉ์์ ์ ๋ ฅ์ ์ ์ถ(submit)ํ๋ฉด respond ํจ์๊ฐ ํธ์ถ. | |
msg.submit( | |
fn=respond, | |
inputs=[msg, chatbot, temperature, top_p], | |
outputs=[msg, chatbot] | |
) | |
submit.click(respond, [msg, chatbot, temperature, top_p], [msg, chatbot]) | |
# '์ด๊ธฐํ' ๋ฒํผ์ ํด๋ฆญํ๋ฉด ์ฑํ ๊ธฐ๋ก์ ์ด๊ธฐํ. | |
clear.click(lambda: None, None, chatbot, queue=False) | |
UnivChatbot.launch() |