Spaces:
Sleeping
Sleeping
ivnban27-ctl
commited on
Commit
·
5e4965f
1
Parent(s):
5f8859a
fix on roberta input len
Browse files- README.md +1 -1
- models/ta_models/bp_utils.py +7 -3
- models/ta_models/cpc_utils.py +7 -3
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title: Conversation Simulator
|
3 |
emoji: 💬
|
4 |
colorFrom: red
|
5 |
colorTo: red
|
|
|
1 |
---
|
2 |
+
title: Conversation Simulator DEV
|
3 |
emoji: 💬
|
4 |
colorFrom: red
|
5 |
colorTo: red
|
models/ta_models/bp_utils.py
CHANGED
@@ -22,9 +22,10 @@ def bp_predict_message(context, input):
|
|
22 |
context,
|
23 |
input,
|
24 |
truncation="only_first",
|
|
|
25 |
)['input_ids']
|
26 |
body_request = {
|
27 |
-
"inputs": [tokenizer.decode(encoding
|
28 |
"params": {
|
29 |
"top_k": None
|
30 |
}
|
@@ -38,8 +39,11 @@ def bp_predict_message(context, input):
|
|
38 |
response = response.json()['predictions'][0]
|
39 |
logger.debug(f"Raw BP prediction is {response}")
|
40 |
return [{k:v > BP_THRESHOLD if k=="score" else v for k,v in dict_.items()} for _, dict_ in response.items() ]
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
43 |
|
44 |
def bp_push2db(manual_confirmation=None):
|
45 |
if manual_confirmation is None:
|
|
|
22 |
context,
|
23 |
input,
|
24 |
truncation="only_first",
|
25 |
+
max_length = tokenizer.model_max_length - 2,
|
26 |
)['input_ids']
|
27 |
body_request = {
|
28 |
+
"inputs": [tokenizer.decode(encoding[1:-1])],
|
29 |
"params": {
|
30 |
"top_k": None
|
31 |
}
|
|
|
39 |
response = response.json()['predictions'][0]
|
40 |
logger.debug(f"Raw BP prediction is {response}")
|
41 |
return [{k:v > BP_THRESHOLD if k=="score" else v for k,v in dict_.items()} for _, dict_ in response.items() ]
|
42 |
+
else:
|
43 |
+
raise Exception(f"Error in response: {response.json()}")
|
44 |
+
except Exception as e:
|
45 |
+
logger.debug(f"Error in response: {e}")
|
46 |
+
st.switch_page("pages/model_loader.py")
|
47 |
|
48 |
def bp_push2db(manual_confirmation=None):
|
49 |
if manual_confirmation is None:
|
models/ta_models/cpc_utils.py
CHANGED
@@ -22,9 +22,10 @@ def cpc_predict_message(context, input):
|
|
22 |
context,
|
23 |
input,
|
24 |
truncation="only_first",
|
|
|
25 |
)['input_ids']
|
26 |
body_request = {
|
27 |
-
"inputs": [tokenizer.decode(encoding
|
28 |
}
|
29 |
|
30 |
try:
|
@@ -32,8 +33,11 @@ def cpc_predict_message(context, input):
|
|
32 |
response = requests.post(url=CPC_URL, headers=HEADERS, json=body_request)
|
33 |
if response.status_code == 200:
|
34 |
return response.json()['predictions'][0]["0"]["label"]
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
37 |
|
38 |
def cpc_push2db(is_same):
|
39 |
text_is_same = "SAME" if is_same else "WRONG"
|
|
|
22 |
context,
|
23 |
input,
|
24 |
truncation="only_first",
|
25 |
+
max_length = tokenizer.model_max_length - 2,
|
26 |
)['input_ids']
|
27 |
body_request = {
|
28 |
+
"inputs": [tokenizer.decode(encoding[1:-1])]
|
29 |
}
|
30 |
|
31 |
try:
|
|
|
33 |
response = requests.post(url=CPC_URL, headers=HEADERS, json=body_request)
|
34 |
if response.status_code == 200:
|
35 |
return response.json()['predictions'][0]["0"]["label"]
|
36 |
+
else:
|
37 |
+
raise Exception(f"Error in response: {response.json()}")
|
38 |
+
except Exception as e:
|
39 |
+
logger.debug(f"Error in response: {e}")
|
40 |
+
st.switch_page("pages/model_loader.py")
|
41 |
|
42 |
def cpc_push2db(is_same):
|
43 |
text_is_same = "SAME" if is_same else "WRONG"
|