annafil commited on
Commit
2846f20
·
verified ·
1 Parent(s): f877cfa

Initial commit

Browse files
Files changed (1) hide show
  1. streamlit_app.py +79 -0
streamlit_app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import replicate
3
+ import os
4
+
5
+ # App title
6
+ st.set_page_config(page_title="Chat with Snowflake Arctic")
7
+
8
+ # Replicate Credentials
9
+ with st.sidebar:
10
+ st.title('Chat with Snowflake Arctic')
11
+ if 'REPLICATE_API_TOKEN' in st.secrets:
12
+ st.success('API token loaded!', icon='✅')
13
+ replicate_api = st.secrets['REPLICATE_API_TOKEN']
14
+ else:
15
+ replicate_api = st.text_input('Enter Replicate API token:', type='password')
16
+ if not (replicate_api.startswith('r8_') and len(replicate_api)==40):
17
+ st.warning('Please enter your Replicate API token.', icon='⚠️')
18
+ st.markdown("**Don't have an API token?** Head over to [Replicate](https://replicate.com) to sign up for one.")
19
+ else:
20
+ st.success('API token loaded!', icon='✅')
21
+
22
+ os.environ['REPLICATE_API_TOKEN'] = replicate_api
23
+ st.subheader("Adjust model parameters")
24
+ temperature = st.sidebar.slider('temperature', min_value=0.01, max_value=5.0, value=0.6, step=0.01)
25
+ top_p = st.sidebar.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01)
26
+ max_length = st.sidebar.slider('max_length', min_value=32, max_value=128, value=120, step=8)
27
+
28
+ # Store LLM-generated responses
29
+ if "messages" not in st.session_state.keys():
30
+ st.session_state.messages = [{"role": "assistant", "content": "Ask me anything"}]
31
+
32
+ # Display or clear chat messages
33
+ for message in st.session_state.messages:
34
+ with st.chat_message(message["role"]):
35
+ st.write(message["content"])
36
+
37
+ def clear_chat_history():
38
+ st.session_state.messages = [{"role": "assistant", "content": "Ask me anything"}]
39
+ st.sidebar.button('Clear chat history', on_click=clear_chat_history)
40
+
41
+ # Function for generating Snowflake Arctic response
42
+ def generate_arctic_response(prompt_input):
43
+ string_dialogue = "You are a helpful assistant. You do not respond as 'User' or pretend to be 'User.' You only respond once as 'Assistant.'"
44
+ for dict_message in st.session_state.messages:
45
+ if dict_message["role"] == "user":
46
+ string_dialogue += "<|im_start|>user\n" + dict_message["content"] + "\n<|eot_id|>\n"
47
+ else:
48
+ string_dialogue += "<|im_start|>assistant\n" + dict_message["content"] + "\n<|eot_id|>\n"
49
+
50
+ for event in replicate.stream("snowflake/snowflake-arctic-instruct",
51
+ input={"prompt": f"""
52
+ <|im_start|>system
53
+ You're a helpful assistant<|im_end|>
54
+ <|im_start|>user
55
+ {string_dialogue}
56
+ <|im_end|>
57
+ <|im_start|>assistant
58
+ """,
59
+ "prompt_template": r"{prompt}",
60
+ "temperature": temperature,
61
+ "top_p": top_p,
62
+ "max_length": max_length,
63
+ "repetition_penalty":1}):
64
+ yield str(event)
65
+
66
+ # User-provided prompt
67
+ if prompt := st.chat_input(disabled=not replicate_api):
68
+ st.session_state.messages.append({"role": "user", "content": prompt})
69
+ with st.chat_message("user"):
70
+ st.write(prompt)
71
+
72
+ # Generate a new response if last message is not from assistant
73
+ if st.session_state.messages[-1]["role"] != "assistant":
74
+ with st.chat_message("assistant"):
75
+ with st.spinner("Thinking..."):
76
+ response = generate_arctic_response(prompt)
77
+ full_response = st.write_stream(response)
78
+ message = {"role": "assistant", "content": full_response}
79
+ st.session_state.messages.append(message)