bstraehle commited on
Commit
ddfaa69
·
1 Parent(s): 1e517cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -14
app.py CHANGED
@@ -119,19 +119,7 @@ def rag_chain(llm, prompt, db):
119
  completion = rag_chain({"query": prompt})
120
  return completion, rag_chain
121
 
122
- def wandb_trace(rag_option, prompt, completion, chain, err_msg, start_time_ms, end_time_ms):
123
- result = ""
124
- generation_info = ""
125
- llm_output = ""
126
-
127
- if (rag_option == RAG_OFF):
128
- if (completion.generations[0] != None and completion.generations[0][0] != None):
129
- result = completion.generations[0][0].text
130
- generation_info = completion.generations[0][0].generation_info
131
- llm_output = completion.llm_output
132
- else:
133
- result = completion["result"]
134
-
135
  wandb.init(project = "openai-llm-rag")
136
 
137
  trace = Trace(
@@ -178,6 +166,9 @@ def invoke(openai_api_key, rag_option, prompt):
178
 
179
  chain = None
180
  completion = ""
 
 
 
181
  err_msg = ""
182
 
183
  try:
@@ -191,19 +182,25 @@ def invoke(openai_api_key, rag_option, prompt):
191
  #document_storage_chroma(splits)
192
  db = document_retrieval_chroma(llm, prompt)
193
  completion, chain = rag_chain(llm, prompt, db)
 
194
  elif (rag_option == RAG_MONGODB):
195
  #splits = document_loading_splitting()
196
  #document_storage_mongodb(splits)
197
  db = document_retrieval_mongodb(llm, prompt)
198
  completion, chain = rag_chain(llm, prompt, db)
 
199
  else:
200
  completion, chain = llm_chain(llm, prompt)
 
 
 
 
201
  except Exception as e:
202
  err_msg = e
203
  raise gr.Error(e)
204
  finally:
205
  end_time_ms = round(time.time() * 1000)
206
- wandb_trace(rag_option, prompt, completion, chain, err_msg, start_time_ms, end_time_ms)
207
  return result
208
 
209
  gr.close_all()
 
119
  completion = rag_chain({"query": prompt})
120
  return completion, rag_chain
121
 
122
+ def wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms):
 
 
 
 
 
 
 
 
 
 
 
 
123
  wandb.init(project = "openai-llm-rag")
124
 
125
  trace = Trace(
 
166
 
167
  chain = None
168
  completion = ""
169
+ result = ""
170
+ generation_info = ""
171
+ llm_output = ""
172
  err_msg = ""
173
 
174
  try:
 
182
  #document_storage_chroma(splits)
183
  db = document_retrieval_chroma(llm, prompt)
184
  completion, chain = rag_chain(llm, prompt, db)
185
+ result = completion["result"]
186
  elif (rag_option == RAG_MONGODB):
187
  #splits = document_loading_splitting()
188
  #document_storage_mongodb(splits)
189
  db = document_retrieval_mongodb(llm, prompt)
190
  completion, chain = rag_chain(llm, prompt, db)
191
+ result = completion["result"]
192
  else:
193
  completion, chain = llm_chain(llm, prompt)
194
+ if (completion.generations[0] != None and completion.generations[0][0] != None):
195
+ result = completion.generations[0][0].text
196
+ generation_info = completion.generations[0][0].generation_info
197
+ llm_output = completion.llm_output
198
  except Exception as e:
199
  err_msg = e
200
  raise gr.Error(e)
201
  finally:
202
  end_time_ms = round(time.time() * 1000)
203
+ wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms)
204
  return result
205
 
206
  gr.close_all()