chansung commited on
Commit
9f94724
·
1 Parent(s): 72b89d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -31
app.py CHANGED
@@ -63,10 +63,43 @@ def fill_up_placeholders(txt):
63
  "" if len(placeholders) >= 1 else txt
64
  )
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  async def rollback_last(
67
  idx, local_data, chat_state,
68
- global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv
 
69
  ):
 
 
70
  res = [
71
  chat_state["ppmanager_type"].from_json(json.dumps(ppm))
72
  for ppm in local_data
@@ -80,8 +113,17 @@ async def rollback_last(
80
  PingPong(last_user_message, "")
81
  )
82
  prompt = build_prompts(ppm, global_context, ctx_num_lconv)
 
 
 
 
 
 
 
 
83
  async for result in gen_text(
84
- prompt, hf_model=MODEL_ID, hf_token=TOKEN,
 
85
  parameters={
86
  'max_new_tokens': res_mnts,
87
  'do_sample': res_sample,
@@ -108,35 +150,6 @@ def reset_chat(idx, ld, state):
108
  gr.update(interactive=False),
109
  )
110
 
111
- def internet_search(ppmanager, serper_api_key, global_context, ctx_num_lconv, device="gpu"):
112
- internet_search_ppm = copy.deepcopy(ppmanager)
113
- user_msg = internet_search_ppm.pingpongs[-1].ping
114
- internet_search_prompt = f"My question is '{user_msg}'. Based on the conversation history, give me an appropriate query to answer my question for google search. You should not say more than query. You should not say any words except the query."
115
-
116
- internet_search_ppm.pingpongs[-1].ping = internet_search_prompt
117
- internet_search_prompt = build_prompts(internet_search_ppm, "", win_size=ctx_num_lconv)
118
-
119
- instruction = gen_text_none_stream(internet_search_prompt, hf_model=MODEL_ID, hf_token=TOKEN)
120
- ###
121
-
122
- searcher = SimilaritySearcher.from_pretrained(device=device)
123
- iss = InternetSearchStrategy(
124
- searcher,
125
- instruction=instruction,
126
- serper_api_key=serper_api_key
127
- )(ppmanager)
128
-
129
- step_ppm = None
130
- while True:
131
- try:
132
- step_ppm, _ = next(iss)
133
- yield "", step_ppm.build_uis()
134
- except StopIteration:
135
- break
136
-
137
- search_prompt = build_prompts(step_ppm, global_context, ctx_num_lconv)
138
- yield search_prompt, ppmanager.build_uis()
139
-
140
  async def chat_stream(
141
  idx, local_data, instruction_txtbox, chat_state,
142
  global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
 
63
  "" if len(placeholders) >= 1 else txt
64
  )
65
 
66
+
67
+ def internet_search(ppmanager, serper_api_key, global_context, ctx_num_lconv, device="cuda"):
68
+ internet_search_ppm = copy.deepcopy(ppmanager)
69
+ user_msg = internet_search_ppm.pingpongs[-1].ping
70
+ internet_search_prompt = f"My question is '{user_msg}'. Based on the conversation history, give me an appropriate query to answer my question for google search. You should not say more than query. You should not say any words except the query."
71
+
72
+ internet_search_ppm.pingpongs[-1].ping = internet_search_prompt
73
+ internet_search_prompt = build_prompts(internet_search_ppm, "", win_size=ctx_num_lconv)
74
+
75
+ instruction = gen_text_none_stream(internet_search_prompt, hf_model=MODEL_ID, hf_token=TOKEN)
76
+ ###
77
+
78
+ searcher = SimilaritySearcher.from_pretrained(device=device)
79
+ iss = InternetSearchStrategy(
80
+ searcher,
81
+ instruction=instruction,
82
+ serper_api_key=serper_api_key
83
+ )(ppmanager)
84
+
85
+ step_ppm = None
86
+ while True:
87
+ try:
88
+ step_ppm, _ = next(iss)
89
+ yield "", step_ppm.build_uis()
90
+ except StopIteration:
91
+ break
92
+
93
+ search_prompt = build_prompts(step_ppm, global_context, ctx_num_lconv)
94
+ yield search_prompt, ppmanager.build_uis()
95
+
96
  async def rollback_last(
97
  idx, local_data, chat_state,
98
+ global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
99
+ internet_option, serper_api_key
100
  ):
101
+ internet_option = True if internet_option == "on" else False
102
+
103
  res = [
104
  chat_state["ppmanager_type"].from_json(json.dumps(ppm))
105
  for ppm in local_data
 
113
  PingPong(last_user_message, "")
114
  )
115
  prompt = build_prompts(ppm, global_context, ctx_num_lconv)
116
+
117
+ #######
118
+ if internet_option:
119
+ search_prompt = None
120
+ for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
121
+ search_prompt = tmp_prompt
122
+ yield "", prompt, uis, str(res), gr.update(interactive=False)
123
+
124
  async for result in gen_text(
125
+ search_prompt if internet_option else prompt,
126
+ hf_model=MODEL_ID, hf_token=TOKEN,
127
  parameters={
128
  'max_new_tokens': res_mnts,
129
  'do_sample': res_sample,
 
150
  gr.update(interactive=False),
151
  )
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  async def chat_stream(
154
  idx, local_data, instruction_txtbox, chat_state,
155
  global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,