Spaces:
Paused
Paused
dev/streamlit-ui (#12)
Browse files- modify to sdk streamlit (06d784cd04f9d79f73d733703dc18fcff332a3c0)
README.md
CHANGED
@@ -3,8 +3,8 @@ title: NVDA 日本語版ガイドブックQA
|
|
3 |
emoji: 👀
|
4 |
colorFrom: green
|
5 |
colorTo: yellow
|
6 |
-
sdk:
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: cc0-1.0
|
|
|
3 |
emoji: 👀
|
4 |
colorFrom: green
|
5 |
colorTo: yellow
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.25.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: cc0-1.0
|
app.py
CHANGED
@@ -1,9 +1,13 @@
|
|
1 |
from time import time
|
2 |
-
|
|
|
|
|
|
|
3 |
from langchain.chains import RetrievalQA
|
4 |
from langchain.embeddings import OpenAIEmbeddings
|
5 |
from langchain.embeddings import HuggingFaceEmbeddings
|
6 |
-
|
|
|
7 |
import torch
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
9 |
from langchain.llms import HuggingFacePipeline
|
@@ -25,7 +29,7 @@ E5_EMBEDDINGS = HuggingFaceEmbeddings(
|
|
25 |
encode_kwargs=E5_ENCODE_KWARGS,
|
26 |
)
|
27 |
|
28 |
-
if torch.cuda.is_available():
|
29 |
RINNA_MODEL_NAME = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
|
30 |
RINNA_TOKENIZER = AutoTokenizer.from_pretrained(RINNA_MODEL_NAME, use_fast=False)
|
31 |
RINNA_MODEL = AutoModelForCausalLM.from_pretrained(
|
@@ -86,17 +90,6 @@ def _get_llm_model(
|
|
86 |
return llm
|
87 |
|
88 |
|
89 |
-
# prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
90 |
-
|
91 |
-
# {context}
|
92 |
-
|
93 |
-
# Question: {question}
|
94 |
-
# Answer in Japanese:"""
|
95 |
-
# PROMPT = PromptTemplate(
|
96 |
-
# template=prompt_template, input_variables=["context", "question"]
|
97 |
-
# )
|
98 |
-
|
99 |
-
|
100 |
def get_retrieval_qa(
|
101 |
collection_name: str | None,
|
102 |
model_name: str | None,
|
@@ -122,7 +115,6 @@ def get_retrieval_qa(
|
|
122 |
llm = _get_llm_model(model_name, temperature)
|
123 |
|
124 |
# chain_type_kwargs = {"prompt": PROMPT}
|
125 |
-
|
126 |
result = RetrievalQA.from_chain_type(
|
127 |
llm=llm,
|
128 |
chain_type="stuff",
|
@@ -146,11 +138,8 @@ def get_related_url(metadata):
|
|
146 |
yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
|
147 |
|
148 |
|
149 |
-
def
|
150 |
-
query: str, collection_name: str, model_name: str, option: str, temperature: float
|
151 |
-
):
|
152 |
now = time()
|
153 |
-
qa = get_retrieval_qa(collection_name, model_name, temperature, option)
|
154 |
try:
|
155 |
result = qa(query)
|
156 |
except InvalidRequestError as e:
|
@@ -163,29 +152,62 @@ def main(
|
|
163 |
return result["result"], html
|
164 |
|
165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
AVAILABLE_LLMS = ["GPT-3.5", "GPT-4"]
|
167 |
|
168 |
if RINNA_MODEL is not None:
|
169 |
AVAILABLE_LLMS.append("rinna")
|
170 |
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
gr.Textbox(label="query"),
|
175 |
-
gr.Radio(["E5", "OpenAI"], value="E5", label="Embedding"),
|
176 |
-
gr.Radio(
|
177 |
-
AVAILABLE_LLMS, value="GPT-3.5", label="Model", info="GPU環境だとrinnaが選択可能"
|
178 |
-
),
|
179 |
-
gr.Radio(
|
180 |
-
["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"],
|
181 |
-
value="All",
|
182 |
-
label="絞り込み",
|
183 |
-
info="ドキュメント制限する?",
|
184 |
-
),
|
185 |
-
gr.Slider(0, 2),
|
186 |
-
],
|
187 |
-
outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()],
|
188 |
-
)
|
189 |
|
|
|
|
|
190 |
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from time import time
|
2 |
+
from typing import Iterable
|
3 |
+
|
4 |
+
# import gradio as gr
|
5 |
+
import streamlit as st
|
6 |
from langchain.chains import RetrievalQA
|
7 |
from langchain.embeddings import OpenAIEmbeddings
|
8 |
from langchain.embeddings import HuggingFaceEmbeddings
|
9 |
+
|
10 |
+
# from langchain.prompts import PromptTemplate
|
11 |
import torch
|
12 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
13 |
from langchain.llms import HuggingFacePipeline
|
|
|
29 |
encode_kwargs=E5_ENCODE_KWARGS,
|
30 |
)
|
31 |
|
32 |
+
if False and torch.cuda.is_available(): # TODO: for local debug
|
33 |
RINNA_MODEL_NAME = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
|
34 |
RINNA_TOKENIZER = AutoTokenizer.from_pretrained(RINNA_MODEL_NAME, use_fast=False)
|
35 |
RINNA_MODEL = AutoModelForCausalLM.from_pretrained(
|
|
|
90 |
return llm
|
91 |
|
92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
def get_retrieval_qa(
|
94 |
collection_name: str | None,
|
95 |
model_name: str | None,
|
|
|
115 |
llm = _get_llm_model(model_name, temperature)
|
116 |
|
117 |
# chain_type_kwargs = {"prompt": PROMPT}
|
|
|
118 |
result = RetrievalQA.from_chain_type(
|
119 |
llm=llm,
|
120 |
chain_type="stuff",
|
|
|
138 |
yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
|
139 |
|
140 |
|
141 |
+
def run_qa(query: str, qa: RetrievalQA) -> tuple[str, str]:
|
|
|
|
|
142 |
now = time()
|
|
|
143 |
try:
|
144 |
result = qa(query)
|
145 |
except InvalidRequestError as e:
|
|
|
152 |
return result["result"], html
|
153 |
|
154 |
|
155 |
+
def main(
|
156 |
+
query: str,
|
157 |
+
collection_name: str | None,
|
158 |
+
model_name: str | None,
|
159 |
+
option: str | None,
|
160 |
+
temperature: float,
|
161 |
+
e5_option: list[str],
|
162 |
+
) -> Iterable[tuple[str, tuple[str, str]]]:
|
163 |
+
qa = get_retrieval_qa(collection_name, model_name, temperature, option)
|
164 |
+
if collection_name == "E5":
|
165 |
+
for option in e5_option:
|
166 |
+
if option == "No":
|
167 |
+
yield "E5 No", run_qa(query, qa)
|
168 |
+
elif option == "Query":
|
169 |
+
yield "E5 Query", run_qa("query: " + query, qa)
|
170 |
+
elif option == "Passage":
|
171 |
+
yield "E5 Passage", run_qa("passage: " + query, qa)
|
172 |
+
else:
|
173 |
+
raise ValueError("Unknow option")
|
174 |
+
else:
|
175 |
+
yield "OpenAI", run_qa(query, qa)
|
176 |
+
|
177 |
+
|
178 |
AVAILABLE_LLMS = ["GPT-3.5", "GPT-4"]
|
179 |
|
180 |
if RINNA_MODEL is not None:
|
181 |
AVAILABLE_LLMS.append("rinna")
|
182 |
|
183 |
+
with st.form("my_form"):
|
184 |
+
query = st.text_input(label="query")
|
185 |
+
collection_name = st.radio(options=["E5", "OpenAI"], label="Embedding")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
+
# if collection_name == "E5": # TODO : 選択肢で選べるようにする
|
188 |
+
e5_option = st.multiselect("E5 option", ["No", "Query", "Passage"], default="No")
|
189 |
|
190 |
+
model_name = st.radio(
|
191 |
+
options=AVAILABLE_LLMS,
|
192 |
+
label="Model",
|
193 |
+
help="GPU環境だとrinnaが選択可能",
|
194 |
+
)
|
195 |
+
option = st.radio(
|
196 |
+
options=["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"],
|
197 |
+
label="絞り込み",
|
198 |
+
help="ドキュメント制限する?",
|
199 |
+
)
|
200 |
+
temperature = st.slider(label="temperature", min_value=0, max_value=2)
|
201 |
+
|
202 |
+
submitted = st.form_submit_button("Submit")
|
203 |
+
if submitted:
|
204 |
+
with st.spinner("Searching..."):
|
205 |
+
results = main(
|
206 |
+
query, collection_name, model_name, option, temperature, e5_option
|
207 |
+
)
|
208 |
+
for type_, (answer, html) in results:
|
209 |
+
with st.container():
|
210 |
+
st.header(type_)
|
211 |
+
st.write(answer)
|
212 |
+
st.markdown(html, unsafe_allow_html=True)
|
213 |
+
st.divider()
|