DrishtiSharma commited on
Commit
1faf3be
·
verified ·
1 Parent(s): 2da304e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +272 -0
app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import json
3
+ from typing import Iterable
4
+ from moa.agent import MOAgent
5
+ from moa.agent.moa import ResponseChunk
6
+ from streamlit_ace import st_ace
7
+ import copy
8
+
9
+ # Default configuration
10
+ default_config = {
11
+ "main_model": "llama3-70b-8192",
12
+ "cycles": 3,
13
+ "layer_agent_config": {}
14
+ }
15
+
16
+ layer_agent_config_def = {
17
+ "layer_agent_1": {
18
+ "system_prompt": "Think through your response step by step. {helper_response}",
19
+ "model_name": "llama3-8b-8192"
20
+ },
21
+ "layer_agent_2": {
22
+ "system_prompt": "Respond with a thought and then your response to the question. {helper_response}",
23
+ "model_name": "gemma-7b-it",
24
+ "temperature": 0.7
25
+ },
26
+ "layer_agent_3": {
27
+ "system_prompt": "You are an expert at logic and reasoning. Always take a logical approach to the answer. {helper_response}",
28
+ "model_name": "llama3-8b-8192"
29
+ },
30
+
31
+ }
32
+
33
+ # Recommended Configuration
34
+
35
+ rec_config = {
36
+ "main_model": "llama3-70b-8192",
37
+ "cycles": 2,
38
+ "layer_agent_config": {}
39
+ }
40
+
41
+ layer_agent_config_rec = {
42
+ "layer_agent_1": {
43
+ "system_prompt": "Think through your response step by step. {helper_response}",
44
+ "model_name": "llama3-8b-8192",
45
+ "temperature": 0.1
46
+ },
47
+ "layer_agent_2": {
48
+ "system_prompt": "Respond with a thought and then your response to the question. {helper_response}",
49
+ "model_name": "llama3-8b-8192",
50
+ "temperature": 0.2
51
+ },
52
+ "layer_agent_3": {
53
+ "system_prompt": "You are an expert at logic and reasoning. Always take a logical approach to the answer. {helper_response}",
54
+ "model_name": "llama3-8b-8192",
55
+ "temperature": 0.4
56
+ },
57
+ "layer_agent_4": {
58
+ "system_prompt": "You are an expert planner agent. Create a plan for how to answer the human's query. {helper_response}",
59
+ "model_name": "mixtral-8x7b-32768",
60
+ "temperature": 0.5
61
+ },
62
+ }
63
+
64
+
65
+ def stream_response(messages: Iterable[ResponseChunk]):
66
+ layer_outputs = {}
67
+ for message in messages:
68
+ if message['response_type'] == 'intermediate':
69
+ layer = message['metadata']['layer']
70
+ if layer not in layer_outputs:
71
+ layer_outputs[layer] = []
72
+ layer_outputs[layer].append(message['delta'])
73
+ else:
74
+ # Display accumulated layer outputs
75
+ for layer, outputs in layer_outputs.items():
76
+ st.write(f"Layer {layer}")
77
+ cols = st.columns(len(outputs))
78
+ for i, output in enumerate(outputs):
79
+ with cols[i]:
80
+ st.expander(label=f"Agent {i+1}", expanded=False).write(output)
81
+
82
+ # Clear layer outputs for the next iteration
83
+ layer_outputs = {}
84
+
85
+ # Yield the main agent's output
86
+ yield message['delta']
87
+
88
+ def set_moa_agent(
89
+ main_model: str = default_config['main_model'],
90
+ cycles: int = default_config['cycles'],
91
+ layer_agent_config: dict[dict[str, any]] = copy.deepcopy(layer_agent_config_def),
92
+ main_model_temperature: float = 0.1,
93
+ override: bool = False
94
+ ):
95
+ if override or ("main_model" not in st.session_state):
96
+ st.session_state.main_model = main_model
97
+ else:
98
+ if "main_model" not in st.session_state: st.session_state.main_model = main_model
99
+
100
+ if override or ("cycles" not in st.session_state):
101
+ st.session_state.cycles = cycles
102
+ else:
103
+ if "cycles" not in st.session_state: st.session_state.cycles = cycles
104
+
105
+ if override or ("layer_agent_config" not in st.session_state):
106
+ st.session_state.layer_agent_config = layer_agent_config
107
+ else:
108
+ if "layer_agent_config" not in st.session_state: st.session_state.layer_agent_config = layer_agent_config
109
+
110
+ if override or ("main_temp" not in st.session_state):
111
+ st.session_state.main_temp = main_model_temperature
112
+ else:
113
+ if "main_temp" not in st.session_state: st.session_state.main_temp = main_model_temperature
114
+
115
+ cls_ly_conf = copy.deepcopy(st.session_state.layer_agent_config)
116
+
117
+ if override or ("moa_agent" not in st.session_state):
118
+ st.session_state.moa_agent = MOAgent.from_config(
119
+ main_model=st.session_state.main_model,
120
+ cycles=st.session_state.cycles,
121
+ layer_agent_config=cls_ly_conf,
122
+ temperature=st.session_state.main_temp
123
+ )
124
+
125
+ del cls_ly_conf
126
+ del layer_agent_config
127
+
128
+ st.set_page_config(
129
+ page_title="Karios Agents Powered by Groq",
130
+ page_icon='static/favicon.ico',
131
+ menu_items={
132
+ 'About': "## Groq Mixture-Of-Agents \n Powered by [Groq](https://groq.com)"
133
+ },
134
+ layout="wide"
135
+ )
136
+ valid_model_names = [
137
+ 'llama3-70b-8192',
138
+ 'llama3-8b-8192',
139
+ 'gemma-7b-it',
140
+ 'gemma2-9b-it',
141
+ 'mixtral-8x7b-32768'
142
+ ]
143
+
144
+ st.markdown("<a href='https://groq.com'><img src='app/static/banner.png' width='500'></a>", unsafe_allow_html=True)
145
+ st.write("---")
146
+
147
+
148
+
149
+ # Initialize session state
150
+ if "messages" not in st.session_state:
151
+ st.session_state.messages = []
152
+
153
+ set_moa_agent()
154
+
155
+ # Sidebar for configuration
156
+ with st.sidebar:
157
+ # config_form = st.form("Agent Configuration", border=False)
158
+ st.title("MOA Configuration")
159
+ with st.form("Agent Configuration", border=False):
160
+ if st.form_submit_button("Use Recommended Config"):
161
+ try:
162
+ set_moa_agent(
163
+ main_model=rec_config['main_model'],
164
+ cycles=rec_config['cycles'],
165
+ layer_agent_config=layer_agent_config_rec,
166
+ override=True
167
+ )
168
+ st.session_state.messages = []
169
+ st.success("Configuration updated successfully!")
170
+ except json.JSONDecodeError:
171
+ st.error("Invalid JSON in Layer Agent Configuration. Please check your input.")
172
+ except Exception as e:
173
+ st.error(f"Error updating configuration: {str(e)}")
174
+ # Main model selection
175
+ new_main_model = st.selectbox(
176
+ "Select Main Model",
177
+ options=valid_model_names,
178
+ index=valid_model_names.index(st.session_state.main_model)
179
+ )
180
+
181
+ # Cycles input
182
+ new_cycles = st.number_input(
183
+ "Number of Layers",
184
+ min_value=1,
185
+ max_value=10,
186
+ value=st.session_state.cycles
187
+ )
188
+
189
+ # Main Model Temperature
190
+ main_temperature = st.number_input(
191
+ label="Main Model Temperature",
192
+ value=0.1,
193
+ min_value=0.0,
194
+ max_value=1.0,
195
+ step=0.1
196
+ )
197
+
198
+ # Layer agent configuration
199
+ tooltip = "Agents in the layer agent configuration run in parallel _per cycle_. Each layer agent supports all initialization parameters of [Langchain's ChatGroq](https://api.python.langchain.com/en/latest/chat_models/langchain_groq.chat_models.ChatGroq.html) class as valid dictionary fields."
200
+ st.markdown("Layer Agent Config", help=tooltip)
201
+ new_layer_agent_config = st_ace(
202
+ value=json.dumps(st.session_state.layer_agent_config, indent=2),
203
+ language='json',
204
+ placeholder="Layer Agent Configuration (JSON)",
205
+ show_gutter=False,
206
+ wrap=True,
207
+ auto_update=True
208
+ )
209
+
210
+ if st.form_submit_button("Update Configuration"):
211
+ try:
212
+ new_layer_config = json.loads(new_layer_agent_config)
213
+ set_moa_agent(
214
+ main_model=new_main_model,
215
+ cycles=new_cycles,
216
+ layer_agent_config=new_layer_config,
217
+ main_model_temperature=main_temperature,
218
+ override=True
219
+ )
220
+ st.session_state.messages = []
221
+ st.success("Configuration updated successfully!")
222
+ except json.JSONDecodeError:
223
+ st.error("Invalid JSON in Layer Agent Configuration. Please check your input.")
224
+ except Exception as e:
225
+ st.error(f"Error updating configuration: {str(e)}")
226
+
227
+ st.markdown("---")
228
+ st.markdown("""
229
+ ### Credits
230
+ - MOA: [Together AI](https://www.together.ai/blog/together-moa)
231
+ - LLMs: [Groq](https://groq.com/)
232
+ - Paper: [arXiv:2406.04692](https://arxiv.org/abs/2406.04692)
233
+ """)
234
+
235
+ # Main app layout
236
+ st.header("Karios Agents", anchor=False)
237
+ st.write("A this project oversees implementation of Mixture of Agents architecture Powered by Groq LLMs.")
238
+ # st.image("./static/moa_groq.svg", caption="Mixture of Agents Workflow", width=1000)
239
+
240
+ # Display current configuration
241
+ with st.expander("Current MOA Configuration", expanded=False):
242
+ st.markdown(f"**Main Model**: ``{st.session_state.main_model}``")
243
+ st.markdown(f"**Main Model Temperature**: ``{st.session_state.main_temp:.1f}``")
244
+ st.markdown(f"**Layers**: ``{st.session_state.cycles}``")
245
+ st.markdown(f"**Layer Agents Config**:")
246
+ new_layer_agent_config = st_ace(
247
+ value=json.dumps(st.session_state.layer_agent_config, indent=2),
248
+ language='json',
249
+ placeholder="Layer Agent Configuration (JSON)",
250
+ show_gutter=False,
251
+ wrap=True,
252
+ readonly=True,
253
+ auto_update=True
254
+ )
255
+
256
+ # Chat interface
257
+ for message in st.session_state.messages:
258
+ with st.chat_message(message["role"]):
259
+ st.markdown(message["content"])
260
+
261
+ if query := st.chat_input("Ask a question"):
262
+ st.session_state.messages.append({"role": "user", "content": query})
263
+ with st.chat_message("user"):
264
+ st.write(query)
265
+
266
+ moa_agent: MOAgent = st.session_state.moa_agent
267
+ with st.chat_message("assistant"):
268
+ message_placeholder = st.empty()
269
+ ast_mess = stream_response(moa_agent.chat(query, output_format='json'))
270
+ response = st.write_stream(ast_mess)
271
+
272
+ st.session_state.messages.append({"role": "assistant", "content": response})