Asankhaya Sharma commited on
Commit
dfd217b
·
1 Parent(s): 6128070

new format for chat

Browse files
Files changed (3) hide show
  1. main.py +92 -37
  2. question.py +0 -85
  3. requirements.txt +3 -2
main.py CHANGED
@@ -1,42 +1,76 @@
1
  # main.py
2
  import os
3
- import tempfile
4
-
5
  import streamlit as st
6
- from question import chat_with_doc
7
- from langchain.embeddings import HuggingFaceInferenceAPIEmbeddings
8
- from langchain.vectorstores import SupabaseVectorStore
 
 
 
 
 
 
 
9
  from supabase import Client, create_client
10
- from stats import get_usage
 
11
 
12
  supabase_url = st.secrets.SUPABASE_URL
13
  supabase_key = st.secrets.SUPABASE_KEY
14
  openai_api_key = st.secrets.openai_api_key
15
  anthropic_api_key = st.secrets.anthropic_api_key
16
  hf_api_key = st.secrets.hf_api_key
17
- supabase: Client = create_client(supabase_url, supabase_key)
18
- self_hosted = st.secrets.self_hosted
19
  username = st.secrets.username
20
 
21
- # embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
 
22
 
23
  embeddings = HuggingFaceInferenceAPIEmbeddings(
24
  api_key=hf_api_key,
25
  model_name="BAAI/bge-large-en-v1.5"
26
  )
27
 
28
- vector_store = SupabaseVectorStore(supabase, embeddings, query_name='match_documents', table_name="documents")
 
29
 
30
- models = ["meta-llama/Llama-2-70b-chat-hf", "mistralai/Mixtral-8x7B-Instruct-v0.1"]
 
31
 
32
- if openai_api_key:
33
- models += ["gpt-3.5-turbo", "gpt-4"]
 
 
34
 
35
- if anthropic_api_key:
36
- models += ["claude-v1", "claude-v1.3",
37
- "claude-instant-v1-100k", "claude-instant-v1.1-100k"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
-
40
  # Set the theme
41
  st.set_page_config(
42
  page_title="Securade.ai - Safety Copilot",
@@ -54,25 +88,46 @@ st.title("👷‍♂️ Safety Copilot 🦺")
54
 
55
  st.markdown("Chat with your personal safety assistant about any health & safety related queries.")
56
  st.markdown("Up-to-date with latest OSH regulations for Singapore, Indonesia, Malaysia & other parts of Asia.")
 
57
 
58
- st.markdown("---\n\n")
59
-
60
- # Initialize session state variables
61
- if 'model' not in st.session_state:
62
- st.session_state['model'] = "meta-llama/Llama-2-70b-chat-hf"
63
- if 'temperature' not in st.session_state:
64
- st.session_state['temperature'] = 0.1
65
- if 'chunk_size' not in st.session_state:
66
- st.session_state['chunk_size'] = 500
67
- if 'chunk_overlap' not in st.session_state:
68
- st.session_state['chunk_overlap'] = 0
69
- if 'max_tokens' not in st.session_state:
70
- st.session_state['max_tokens'] = 500
71
- if 'username' not in st.session_state:
72
- st.session_state['username'] = username
73
-
74
- stats = str(get_usage(supabase))
75
-
76
- chat_with_doc(st.session_state['model'], vector_store, stats_db=supabase, stats=stats)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- st.markdown("---\n\n")
 
 
 
 
 
1
  # main.py
2
  import os
 
 
3
  import streamlit as st
4
+ import anthropic
5
+
6
+ from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
7
+ from langchain_community.vectorstores import SupabaseVectorStore
8
+ from langchain_community.llms import HuggingFaceEndpoint
9
+ from langchain_community.vectorstores import SupabaseVectorStore
10
+
11
+ from langchain.chains import ConversationalRetrievalChain
12
+ from langchain.memory import ConversationBufferMemory
13
+
14
  from supabase import Client, create_client
15
+ from streamlit.logger import get_logger
16
+ from stats import get_usage, add_usage
17
 
18
  supabase_url = st.secrets.SUPABASE_URL
19
  supabase_key = st.secrets.SUPABASE_KEY
20
  openai_api_key = st.secrets.openai_api_key
21
  anthropic_api_key = st.secrets.anthropic_api_key
22
  hf_api_key = st.secrets.hf_api_key
 
 
23
  username = st.secrets.username
24
 
25
+ supabase: Client = create_client(supabase_url, supabase_key)
26
+ logger = get_logger(__name__)
27
 
28
  embeddings = HuggingFaceInferenceAPIEmbeddings(
29
  api_key=hf_api_key,
30
  model_name="BAAI/bge-large-en-v1.5"
31
  )
32
 
33
+ if 'chat_history' not in st.session_state:
34
+ st.session_state['chat_history'] = []
35
 
36
+ vector_store = SupabaseVectorStore(supabase, embeddings, query_name='match_documents', table_name="documents")
37
+ memory = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
38
 
39
+ model = "meta-llama/Llama-2-70b-chat-hf" #mistralai/Mixtral-8x7B-Instruct-v0.1
40
+ temperature = 0.1
41
+ max_tokens = 500
42
+ stats = str(get_usage(supabase))
43
 
44
+ def response_generator(query):
45
+ qa = None
46
+ add_usage(supabase, "chat", "prompt" + query, {"model": model, "temperature": temperature})
47
+ logger.info('Using HF model %s', model)
48
+ # print(st.session_state['max_tokens'])
49
+ endpoint_url = ("https://api-inference.huggingface.co/models/"+ model)
50
+ model_kwargs = {"temperature" : temperature,
51
+ "max_new_tokens" : max_tokens,
52
+ "return_full_text" : False}
53
+ hf = HuggingFaceEndpoint(
54
+ endpoint_url=endpoint_url,
55
+ task="text-generation",
56
+ huggingfacehub_api_token=hf_api_key,
57
+ model_kwargs=model_kwargs
58
+ )
59
+ qa = ConversationalRetrievalChain.from_llm(hf, retriever=vector_store.as_retriever(search_kwargs={"score_threshold": 0.6, "k": 4,"filter": {"user": username}}), memory=memory, verbose=True, return_source_documents=True)
60
+
61
+ # Generate model's response
62
+ model_response = qa({"question": query})
63
+ logger.info('Result: %s', model_response["answer"])
64
+ sources = model_response["source_documents"]
65
+ logger.info('Sources: %s', model_response["source_documents"])
66
+
67
+ if len(sources) > 0:
68
+ response = model_response["answer"]
69
+ else:
70
+ response = "I am sorry, I do not have enough information to provide an answer. If there is a public source of data that you would like to add, please email [email protected]."
71
+
72
+ return response
73
 
 
74
  # Set the theme
75
  st.set_page_config(
76
  page_title="Securade.ai - Safety Copilot",
 
88
 
89
  st.markdown("Chat with your personal safety assistant about any health & safety related queries.")
90
  st.markdown("Up-to-date with latest OSH regulations for Singapore, Indonesia, Malaysia & other parts of Asia.")
91
+ st.markdown("_"+ stats + " queries answered!_")
92
 
93
+ if 'chat_history' not in st.session_state:
94
+ st.session_state['chat_history'] = []
95
+
96
+ # Display chat messages from history on app rerun
97
+ for message in st.session_state.chat_history:
98
+ with st.chat_message(message["role"]):
99
+ st.markdown(message["content"])
100
+
101
+ # Accept user input
102
+ if prompt := st.chat_input("Ask a question"):
103
+ # print(prompt)
104
+ # Add user message to chat history
105
+ st.session_state.chat_history.append({"role": "user", "content": prompt})
106
+ # Display user message in chat message container
107
+ with st.chat_message("user"):
108
+ st.markdown(prompt)
109
+
110
+ with st.spinner('Safety briefing in progress... Your customized guidance is en route.'):
111
+ response = response_generator(prompt)
112
+
113
+ # Display assistant response in chat message container
114
+ with st.chat_message("assistant"):
115
+ st.markdown(response)
116
+ # Add assistant response to chat history
117
+ # print(response)
118
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
119
+
120
+ # query = st.text_area("## Ask a question (" + stats + " queries answered so far)", max_chars=500)
121
+ # columns = st.columns(2)
122
+ # with columns[0]:
123
+ # button = st.button("Ask")
124
+ # with columns[1]:
125
+ # clear_history = st.button("Clear History", type='secondary')
126
+
127
+ # st.markdown("---\n\n")
128
 
129
+ # if clear_history:
130
+ # # Clear memory in Langchain
131
+ # memory.clear()
132
+ # st.session_state['chat_history'] = []
133
+ # st.experimental_rerun()
question.py DELETED
@@ -1,85 +0,0 @@
1
- import anthropic
2
- import streamlit as st
3
- from streamlit.logger import get_logger
4
- from langchain.chains import ConversationalRetrievalChain
5
- from langchain.memory import ConversationBufferMemory
6
- from langchain.llms import OpenAI
7
- from langchain.llms import HuggingFaceEndpoint
8
- from langchain.chat_models import ChatAnthropic
9
- from langchain.vectorstores import SupabaseVectorStore
10
- from stats import add_usage
11
-
12
- memory = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
13
- openai_api_key = st.secrets.openai_api_key
14
- anthropic_api_key = st.secrets.anthropic_api_key
15
- hf_api_key = st.secrets.hf_api_key
16
- logger = get_logger(__name__)
17
-
18
- def chat_with_doc(model, vector_store: SupabaseVectorStore, stats_db, stats):
19
-
20
- if 'chat_history' not in st.session_state:
21
- st.session_state['chat_history'] = []
22
-
23
- query = st.text_area("## Ask a question (" + stats + " queries answered so far)", max_chars=500)
24
- columns = st.columns(2)
25
- with columns[0]:
26
- button = st.button("Ask")
27
- with columns[1]:
28
- clear_history = st.button("Clear History", type='secondary')
29
-
30
- st.markdown("---\n\n")
31
-
32
- if clear_history:
33
- # Clear memory in Langchain
34
- memory.clear()
35
- st.session_state['chat_history'] = []
36
- st.experimental_rerun()
37
-
38
- if button:
39
- qa = None
40
- add_usage(stats_db, "chat", "prompt" + query, {"model": model, "temperature": st.session_state['temperature']})
41
- if model.startswith("gpt"):
42
- logger.info('Using OpenAI model %s', model)
43
- qa = ConversationalRetrievalChain.from_llm(
44
- OpenAI(
45
- model_name=st.session_state['model'], openai_api_key=openai_api_key, temperature=st.session_state['temperature'], max_tokens=st.session_state['max_tokens']), vector_store.as_retriever(), memory=memory, verbose=True)
46
- elif anthropic_api_key and model.startswith("claude"):
47
- logger.info('Using Anthropics model %s', model)
48
- qa = ConversationalRetrievalChain.from_llm(
49
- ChatAnthropic(
50
- model=st.session_state['model'], anthropic_api_key=anthropic_api_key, temperature=st.session_state['temperature'], max_tokens_to_sample=st.session_state['max_tokens']), vector_store.as_retriever(), memory=memory, verbose=True, max_tokens_limit=102400)
51
- elif hf_api_key:
52
- logger.info('Using HF model %s', model)
53
- # print(st.session_state['max_tokens'])
54
- endpoint_url = ("https://api-inference.huggingface.co/models/"+ model)
55
- model_kwargs = {"temperature" : st.session_state['temperature'],
56
- "max_new_tokens" : st.session_state['max_tokens'],
57
- "return_full_text" : False}
58
- hf = HuggingFaceEndpoint(
59
- endpoint_url=endpoint_url,
60
- task="text-generation",
61
- huggingfacehub_api_token=hf_api_key,
62
- model_kwargs=model_kwargs
63
- )
64
- qa = ConversationalRetrievalChain.from_llm(hf, retriever=vector_store.as_retriever(search_kwargs={"score_threshold": 0.6, "k": 4,"filter": {"user": st.session_state["username"]}}), memory=memory, verbose=True, return_source_documents=True)
65
-
66
- print("Question>")
67
- print(query)
68
- st.session_state['chat_history'].append(("You", query))
69
-
70
- # Generate model's response and add it to chat history
71
- model_response = qa({"question": query})
72
- logger.info('Result: %s', model_response["answer"])
73
- sources = model_response["source_documents"]
74
- logger.info('Sources: %s', model_response["source_documents"])
75
-
76
- if len(sources) > 0:
77
- st.session_state['chat_history'].append(("Safety Copilot", model_response["answer"]))
78
- else:
79
- st.session_state['chat_history'].append(("Safety Copilot", "I am sorry, I do not have enough information to provide an answer. If there is a public source of data that you would like to add, please email [email protected]."))
80
-
81
- # Display chat history
82
- st.empty()
83
- chat_history = st.session_state['chat_history']
84
- for speaker, text in chat_history:
85
- st.markdown(f"**{speaker}:** {text}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,9 +1,10 @@
1
- langchain==0.1.0
 
2
  Markdown==3.4.3
3
  openai==0.27.6
4
  pdf2image==1.16.3
5
  pypdf==3.8.1
6
- streamlit==1.22.0
7
  StrEnum==0.4.10
8
  supabase==1.0.3
9
  tiktoken==0.4.0
 
1
+ langchain-community==0.20.0
2
+ langchain==0.1.7
3
  Markdown==3.4.3
4
  openai==0.27.6
5
  pdf2image==1.16.3
6
  pypdf==3.8.1
7
+ streamlit==1.31.0
8
  StrEnum==0.4.10
9
  supabase==1.0.3
10
  tiktoken==0.4.0