Spaces:
Running
Running
HannahLin271
commited on
Update utils.py
Browse files
utils.py
CHANGED
@@ -56,7 +56,7 @@ def init_model_from(url, filename):
|
|
56 |
ckpt_path = Path(out_dir) / filename
|
57 |
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
58 |
if not os.path.exists(ckpt_path):
|
59 |
-
gr.Info('Downloading model...')
|
60 |
download_file(url, ckpt_path)
|
61 |
gr.Info('✅Model downloaded successfully.', duration=2)
|
62 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
@@ -70,24 +70,24 @@ def init_model_from(url, filename):
|
|
70 |
model.load_state_dict(state_dict)
|
71 |
return model
|
72 |
|
73 |
-
def respond(input, samples, model, encode, decode, max_new_tokens,temperature, top_k):
|
|
|
74 |
x = (torch.tensor(encode(input), dtype=torch.long, device=device)[None, ...])
|
75 |
with torch.no_grad():
|
76 |
for k in range(samples):
|
|
|
77 |
generated = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
|
78 |
|
79 |
output = decode(generated[0].tolist())
|
|
|
|
|
80 |
|
81 |
-
match_botoutput = re.search(r'<human>(.*?)<', output)
|
82 |
-
match_emotion = re.search(r'<emotion>\s*(.*?)\s*<', output)
|
83 |
-
match_context = re.search(r'<context>\s*(.*?)\s*<', output)
|
84 |
response = ''
|
85 |
-
emotion = ''
|
86 |
-
context = ''
|
87 |
if match_botoutput:
|
88 |
try :
|
89 |
-
response = match_botoutput.group(1).
|
90 |
except:
|
91 |
-
response =
|
92 |
#return response, emotion, context
|
93 |
-
return [input, response]
|
|
|
56 |
ckpt_path = Path(out_dir) / filename
|
57 |
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
58 |
if not os.path.exists(ckpt_path):
|
59 |
+
gr.Info('Downloading model...',duration=10)
|
60 |
download_file(url, ckpt_path)
|
61 |
gr.Info('✅Model downloaded successfully.', duration=2)
|
62 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
|
|
70 |
model.load_state_dict(state_dict)
|
71 |
return model
|
72 |
|
73 |
+
def respond(input, samples, model, encode, decode, max_new_tokens,temperature, top_k):
|
74 |
+
input = "<bot> " + input
|
75 |
x = (torch.tensor(encode(input), dtype=torch.long, device=device)[None, ...])
|
76 |
with torch.no_grad():
|
77 |
for k in range(samples):
|
78 |
+
|
79 |
generated = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
|
80 |
|
81 |
output = decode(generated[0].tolist())
|
82 |
+
# if input in output:
|
83 |
+
# output = output.split(input)[-1].strip() # Take the part after `<input>`
|
84 |
|
85 |
+
match_botoutput = re.search(r'<human>(.*?)<', output, re.DOTALL)
|
|
|
|
|
86 |
response = ''
|
|
|
|
|
87 |
if match_botoutput:
|
88 |
try :
|
89 |
+
response = match_botoutput.group(1).strip()
|
90 |
except:
|
91 |
+
response = ''
|
92 |
#return response, emotion, context
|
93 |
+
return [input, response, output]
|