HannahLin271 commited on
Commit
5371253
·
verified ·
1 Parent(s): d0b1ec0

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +10 -10
utils.py CHANGED
@@ -56,7 +56,7 @@ def init_model_from(url, filename):
56
  ckpt_path = Path(out_dir) / filename
57
  ckpt_path.parent.mkdir(parents=True, exist_ok=True)
58
  if not os.path.exists(ckpt_path):
59
- gr.Info('Downloading model...')
60
  download_file(url, ckpt_path)
61
  gr.Info('✅Model downloaded successfully.', duration=2)
62
  checkpoint = torch.load(ckpt_path, map_location=device)
@@ -70,24 +70,24 @@ def init_model_from(url, filename):
70
  model.load_state_dict(state_dict)
71
  return model
72
 
73
- def respond(input, samples, model, encode, decode, max_new_tokens,temperature, top_k): # generation function
 
74
  x = (torch.tensor(encode(input), dtype=torch.long, device=device)[None, ...])
75
  with torch.no_grad():
76
  for k in range(samples):
 
77
  generated = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
78
 
79
  output = decode(generated[0].tolist())
 
 
80
 
81
- match_botoutput = re.search(r'<human>(.*?)<', output)
82
- match_emotion = re.search(r'<emotion>\s*(.*?)\s*<', output)
83
- match_context = re.search(r'<context>\s*(.*?)\s*<', output)
84
  response = ''
85
- emotion = ''
86
- context = ''
87
  if match_botoutput:
88
  try :
89
- response = match_botoutput.group(1).replace('<endOfText>','')
90
  except:
91
- response = match_botoutput.group(1)
92
  #return response, emotion, context
93
- return [input, response]
 
56
  ckpt_path = Path(out_dir) / filename
57
  ckpt_path.parent.mkdir(parents=True, exist_ok=True)
58
  if not os.path.exists(ckpt_path):
59
+ gr.Info('Downloading model...',duration=10)
60
  download_file(url, ckpt_path)
61
  gr.Info('✅Model downloaded successfully.', duration=2)
62
  checkpoint = torch.load(ckpt_path, map_location=device)
 
70
  model.load_state_dict(state_dict)
71
  return model
72
 
73
+ def respond(input, samples, model, encode, decode, max_new_tokens,temperature, top_k):
74
+ input = "<bot> " + input
75
  x = (torch.tensor(encode(input), dtype=torch.long, device=device)[None, ...])
76
  with torch.no_grad():
77
  for k in range(samples):
78
+
79
  generated = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
80
 
81
  output = decode(generated[0].tolist())
82
+ # if input in output:
83
+ # output = output.split(input)[-1].strip() # Take the part after `<input>`
84
 
85
+ match_botoutput = re.search(r'<human>(.*?)<', output, re.DOTALL)
 
 
86
  response = ''
 
 
87
  if match_botoutput:
88
  try :
89
+ response = match_botoutput.group(1).strip()
90
  except:
91
+ response = ''
92
  #return response, emotion, context
93
+ return [input, response, output]