DrishtiSharma commited on
Commit
801048f
β€’
1 Parent(s): d203126

Update interim.py

Browse files
Files changed (1) hide show
  1. interim.py +100 -133
interim.py CHANGED
@@ -1,8 +1,6 @@
1
- # ref: https://github.com/kram254/Mixture-of-Agents-running-on-Groq/tree/main
2
  import streamlit as st
3
  import json
4
- import asyncio
5
- from typing import Union, Iterable, AsyncIterable
6
  from moa.agent import MOAgent
7
  from moa.agent.moa import ResponseChunk
8
  from streamlit_ace import st_ace
@@ -31,7 +29,7 @@ layer_agent_config_def = {
31
  },
32
  }
33
 
34
- # Recommended configuration
35
  rec_config = {
36
  "main_model": "llama3-70b-8192",
37
  "cycles": 2,
@@ -61,37 +59,41 @@ layer_agent_config_rec = {
61
  },
62
  }
63
 
64
- # Unified streaming function to handle async and sync responses
65
- async def stream_or_async_response(messages: Union[Iterable[ResponseChunk], AsyncIterable[ResponseChunk]]):
66
  layer_outputs = {}
67
-
68
- async def process_message(message):
 
 
 
 
 
 
69
  if message['response_type'] == 'intermediate':
70
  layer = message['metadata']['layer']
71
  if layer not in layer_outputs:
72
  layer_outputs[layer] = []
73
  layer_outputs[layer].append(message['delta'])
 
 
 
 
 
 
 
74
  else:
 
75
  for layer, outputs in layer_outputs.items():
76
- st.write(f"Layer {layer}")
77
- cols = st.columns(len(outputs))
78
- for i, output in enumerate(outputs):
79
- with cols[i]:
80
- st.expander(label=f"Agent {i+1}", expanded=False).write(output)
81
-
82
- layer_outputs.clear()
83
  yield message['delta']
84
 
85
- if isinstance(messages, AsyncIterable):
86
- # Process asynchronous messages
87
- async for message in messages:
88
- await process_message(message)
89
- else:
90
- # Process synchronous messages
91
- for message in messages:
92
- await process_message(message)
93
 
94
- # Set up the MOAgent
95
  def set_moa_agent(
96
  main_model: str = default_config['main_model'],
97
  cycles: int = default_config['cycles'],
@@ -112,7 +114,6 @@ def set_moa_agent(
112
  st.session_state.main_temp = main_model_temperature
113
 
114
  cls_ly_conf = copy.deepcopy(st.session_state.layer_agent_config)
115
-
116
  if override or ("moa_agent" not in st.session_state):
117
  st.session_state.moa_agent = MOAgent.from_config(
118
  main_model=st.session_state.main_model,
@@ -122,16 +123,13 @@ def set_moa_agent(
122
  )
123
 
124
  del cls_ly_conf
125
- del layer_agent_config
126
 
127
- # Streamlit app layout
128
  st.set_page_config(
129
  page_title="Mixture of Agents",
130
- menu_items={
131
- 'About': "## Groq Mixture-Of-Agents \n Powered by [Groq](https://groq.com)"
132
- },
133
- layout="wide"
134
  )
 
135
  valid_model_names = [
136
  'llama3-70b-8192',
137
  'llama3-8b-8192',
@@ -140,100 +138,79 @@ valid_model_names = [
140
  'mixtral-8x7b-32768'
141
  ]
142
 
143
- st.markdown("<a href='https://groq.com'><img src='app/static/banner.png' width='500'></a>", unsafe_allow_html=True)
144
- st.write("---")
145
-
146
- # Initialize session state
147
  if "messages" not in st.session_state:
148
  st.session_state.messages = []
149
 
150
  set_moa_agent()
151
 
152
- # Sidebar for configuration
153
  with st.sidebar:
154
  st.title("MOA Configuration")
155
- with st.form("Agent Configuration", border=False):
156
  if st.form_submit_button("Use Recommended Config"):
157
- try:
158
- set_moa_agent(
159
- main_model=rec_config['main_model'],
160
- cycles=rec_config['cycles'],
161
- layer_agent_config=layer_agent_config_rec,
162
- override=True
163
- )
164
- st.session_state.messages = []
165
- st.success("Configuration updated successfully!")
166
- except Exception as e:
167
- st.error(f"Error updating configuration: {str(e)}")
168
-
169
- # Main model selection
170
- new_main_model = st.selectbox(
171
- "Select Main Model",
172
- options=valid_model_names,
173
- index=valid_model_names.index(st.session_state.main_model)
174
- )
175
-
176
- # Cycles input
177
- new_cycles = st.number_input(
178
- "Number of Layers",
179
- min_value=1,
180
- max_value=10,
181
- value=st.session_state.cycles
182
- )
183
-
184
- # Main Model Temperature
185
- main_temperature = st.number_input(
186
- label="Main Model Temperature",
187
- value=0.1,
188
- min_value=0.0,
189
- max_value=1.0,
190
- step=0.1
191
- )
192
-
193
- # Layer agent configuration
194
- new_layer_agent_config = st_ace(
195
- value=json.dumps(st.session_state.layer_agent_config, indent=2),
196
- language='json',
197
- placeholder="Layer Agent Configuration (JSON)",
198
- show_gutter=False,
199
- wrap=True,
200
- auto_update=True
201
- )
202
-
203
- if st.form_submit_button("Update Configuration"):
204
- try:
205
- new_layer_config = json.loads(new_layer_agent_config)
206
- set_moa_agent(
207
- main_model=new_main_model,
208
- cycles=new_cycles,
209
- layer_agent_config=new_layer_config,
210
- main_model_temperature=main_temperature,
211
- override=True
212
- )
213
- st.session_state.messages = []
214
- st.success("Configuration updated successfully!")
215
- except Exception as e:
216
- st.error(f"Error updating configuration: {str(e)}")
217
 
218
  # Main app layout
219
  st.header("Mixture of Agents")
220
- st.write("This project oversees implementation of Mixture of Agents architecture powered by Groq LLMs.")
221
-
222
- # Display current configuration
223
  with st.expander("Current MOA Configuration", expanded=False):
224
- st.markdown(f"**Main Model**: `{st.session_state.main_model}`")
225
- st.markdown(f"**Main Model Temperature**: `{st.session_state.main_temp:.1f}`")
226
- st.markdown(f"**Layers**: `{st.session_state.cycles}`")
227
- st.markdown("**Layer Agents Config:**")
228
- st_ace(
229
- value=json.dumps(st.session_state.layer_agent_config, indent=2),
230
- language='json',
231
- placeholder="Layer Agent Configuration (JSON)",
232
- show_gutter=False,
233
- wrap=True,
234
- readonly=True,
235
- auto_update=True
236
- )
237
 
238
  # Chat interface
239
  for message in st.session_state.messages:
@@ -241,27 +218,17 @@ for message in st.session_state.messages:
241
  st.markdown(message["content"])
242
 
243
  if query := st.chat_input("Ask a question"):
244
- async def handle_query():
245
- st.session_state.messages.append({"role": "user", "content": query})
246
- with st.chat_message("user"):
247
- st.write(query)
248
-
249
- moa_agent: MOAgent = st.session_state.moa_agent
250
-
251
- with st.chat_message("assistant"):
252
- message_placeholder = st.empty()
253
- messages = moa_agent.chat(query, output_format='json')
254
- async for response in stream_or_async_response(messages):
255
- message_placeholder.markdown(response)
256
-
257
- st.session_state.messages.append({"role": "assistant", "content": response})
258
 
259
- asyncio.run(handle_query())
 
 
 
 
260
 
 
261
 
262
- # Add acknowledgment at the bottom
263
  st.markdown("---")
264
- st.markdown("""
265
- ###
266
- This app is based on [Emmanuel M. Ndaliro's work](https://github.com/kram254/Mixture-of-Agents-running-on-Groq/tree/main).
267
- """)
 
 
1
  import streamlit as st
2
  import json
3
+ from typing import Iterable
 
4
  from moa.agent import MOAgent
5
  from moa.agent.moa import ResponseChunk
6
  from streamlit_ace import st_ace
 
29
  },
30
  }
31
 
32
+ # Recommended Configuration
33
  rec_config = {
34
  "main_model": "llama3-70b-8192",
35
  "cycles": 2,
 
59
  },
60
  }
61
 
62
+ def stream_response(messages: Iterable[ResponseChunk]):
 
63
  layer_outputs = {}
64
+ progress_bar = st.progress(0)
65
+ total_steps = len(messages) # Estimate total messages for progress tracking
66
+ current_step = 0
67
+
68
+ for message in messages:
69
+ current_step += 1
70
+ progress_bar.progress(current_step / total_steps)
71
+
72
  if message['response_type'] == 'intermediate':
73
  layer = message['metadata']['layer']
74
  if layer not in layer_outputs:
75
  layer_outputs[layer] = []
76
  layer_outputs[layer].append(message['delta'])
77
+
78
+ # Real-time rendering for intermediate outputs
79
+ with st.container():
80
+ st.markdown(f"**Layer {layer} (In Progress)**")
81
+ for output in layer_outputs[layer]:
82
+ st.markdown(f"- {output}")
83
+
84
  else:
85
+ # Finalize and display accumulated layer outputs
86
  for layer, outputs in layer_outputs.items():
87
+ st.markdown(f"### Layer {layer} Final Output")
88
+ for output in outputs:
89
+ st.write(output)
90
+ layer_outputs = {} # Reset for next layers
91
+
92
+ # Yield the main agent's output
 
93
  yield message['delta']
94
 
95
+ progress_bar.empty() # Clear progress bar once done
 
 
 
 
 
 
 
96
 
 
97
  def set_moa_agent(
98
  main_model: str = default_config['main_model'],
99
  cycles: int = default_config['cycles'],
 
114
  st.session_state.main_temp = main_model_temperature
115
 
116
  cls_ly_conf = copy.deepcopy(st.session_state.layer_agent_config)
 
117
  if override or ("moa_agent" not in st.session_state):
118
  st.session_state.moa_agent = MOAgent.from_config(
119
  main_model=st.session_state.main_model,
 
123
  )
124
 
125
  del cls_ly_conf
 
126
 
 
127
  st.set_page_config(
128
  page_title="Mixture of Agents",
129
+ layout="wide",
130
+ menu_items={'About': "## Mixture-of-Agents\nPowered by Groq"}
 
 
131
  )
132
+
133
  valid_model_names = [
134
  'llama3-70b-8192',
135
  'llama3-8b-8192',
 
138
  'mixtral-8x7b-32768'
139
  ]
140
 
 
 
 
 
141
  if "messages" not in st.session_state:
142
  st.session_state.messages = []
143
 
144
  set_moa_agent()
145
 
146
+ # Sidebar Configuration
147
  with st.sidebar:
148
  st.title("MOA Configuration")
149
+ with st.form("Agent Configuration", clear_on_submit=False):
150
  if st.form_submit_button("Use Recommended Config"):
151
+ set_moa_agent(
152
+ main_model=rec_config['main_model'],
153
+ cycles=rec_config['cycles'],
154
+ layer_agent_config=layer_agent_config_rec,
155
+ override=True
156
+ )
157
+ st.session_state.messages = []
158
+ st.success("Configuration updated successfully!")
159
+
160
+ # Config toggling
161
+ show_advanced = st.checkbox("Show Advanced Configurations")
162
+ if show_advanced:
163
+ new_main_model = st.selectbox(
164
+ "Main Model",
165
+ valid_model_names,
166
+ index=valid_model_names.index(st.session_state.main_model)
167
+ )
168
+
169
+ new_cycles = st.number_input(
170
+ "Number of Layers",
171
+ min_value=1,
172
+ max_value=10,
173
+ value=st.session_state.cycles
174
+ )
175
+
176
+ main_temperature = st.slider(
177
+ "Main Model Temperature",
178
+ min_value=0.0,
179
+ max_value=1.0,
180
+ value=st.session_state.main_temp,
181
+ step=0.05
182
+ )
183
+
184
+ new_layer_agent_config = st_ace(
185
+ value=json.dumps(st.session_state.layer_agent_config, indent=2),
186
+ language="json",
187
+ show_gutter=False,
188
+ wrap=True,
189
+ auto_update=True
190
+ )
191
+
192
+ if st.form_submit_button("Update Config"):
193
+ try:
194
+ parsed_config = json.loads(new_layer_agent_config)
195
+ set_moa_agent(
196
+ main_model=new_main_model,
197
+ cycles=new_cycles,
198
+ layer_agent_config=parsed_config,
199
+ main_model_temperature=main_temperature,
200
+ override=True
201
+ )
202
+ st.session_state.messages = []
203
+ st.success("Configuration updated successfully!")
204
+ except json.JSONDecodeError:
205
+ st.error("Invalid JSON in Layer Agent Config.")
206
+ except Exception as e:
207
+ st.error(f"Error updating config: {str(e)}")
 
 
 
208
 
209
  # Main app layout
210
  st.header("Mixture of Agents")
211
+ st.markdown("Real-time response tracking with intermediate and final results.")
 
 
212
  with st.expander("Current MOA Configuration", expanded=False):
213
+ st.json(st.session_state.layer_agent_config)
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  # Chat interface
216
  for message in st.session_state.messages:
 
218
  st.markdown(message["content"])
219
 
220
  if query := st.chat_input("Ask a question"):
221
+ st.session_state.messages.append({"role": "user", "content": query})
222
+ with st.chat_message("user"):
223
+ st.markdown(query)
 
 
 
 
 
 
 
 
 
 
 
224
 
225
+ moa_agent: MOAgent = st.session_state.moa_agent
226
+ with st.chat_message("assistant"):
227
+ message_placeholder = st.empty()
228
+ ast_mess = stream_response(moa_agent.chat(query, output_format="json"))
229
+ response = st.write_stream(ast_mess)
230
 
231
+ st.session_state.messages.append({"role": "assistant", "content": response})
232
 
 
233
  st.markdown("---")
234
+ st.markdown("Powered by [Groq](https://groq.com).")