ivnban27-ctl commited on
Commit
5e4965f
·
1 Parent(s): 5f8859a

fix on roberta input len

Browse files
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Conversation Simulator
3
  emoji: 💬
4
  colorFrom: red
5
  colorTo: red
 
1
  ---
2
+ title: Conversation Simulator DEV
3
  emoji: 💬
4
  colorFrom: red
5
  colorTo: red
models/ta_models/bp_utils.py CHANGED
@@ -22,9 +22,10 @@ def bp_predict_message(context, input):
22
  context,
23
  input,
24
  truncation="only_first",
 
25
  )['input_ids']
26
  body_request = {
27
- "inputs": [tokenizer.decode(encoding)[1:-1]],
28
  "params": {
29
  "top_k": None
30
  }
@@ -38,8 +39,11 @@ def bp_predict_message(context, input):
38
  response = response.json()['predictions'][0]
39
  logger.debug(f"Raw BP prediction is {response}")
40
  return [{k:v > BP_THRESHOLD if k=="score" else v for k,v in dict_.items()} for _, dict_ in response.items() ]
41
- except:
42
- pass
 
 
 
43
 
44
  def bp_push2db(manual_confirmation=None):
45
  if manual_confirmation is None:
 
22
  context,
23
  input,
24
  truncation="only_first",
25
+ max_length = tokenizer.model_max_length - 2,
26
  )['input_ids']
27
  body_request = {
28
+ "inputs": [tokenizer.decode(encoding[1:-1])],
29
  "params": {
30
  "top_k": None
31
  }
 
39
  response = response.json()['predictions'][0]
40
  logger.debug(f"Raw BP prediction is {response}")
41
  return [{k:v > BP_THRESHOLD if k=="score" else v for k,v in dict_.items()} for _, dict_ in response.items() ]
42
+ else:
43
+ raise Exception(f"Error in response: {response.json()}")
44
+ except Exception as e:
45
+ logger.debug(f"Error in response: {e}")
46
+ st.switch_page("pages/model_loader.py")
47
 
48
  def bp_push2db(manual_confirmation=None):
49
  if manual_confirmation is None:
models/ta_models/cpc_utils.py CHANGED
@@ -22,9 +22,10 @@ def cpc_predict_message(context, input):
22
  context,
23
  input,
24
  truncation="only_first",
 
25
  )['input_ids']
26
  body_request = {
27
- "inputs": [tokenizer.decode(encoding)[1:-1]]
28
  }
29
 
30
  try:
@@ -32,8 +33,11 @@ def cpc_predict_message(context, input):
32
  response = requests.post(url=CPC_URL, headers=HEADERS, json=body_request)
33
  if response.status_code == 200:
34
  return response.json()['predictions'][0]["0"]["label"]
35
- except:
36
- pass
 
 
 
37
 
38
  def cpc_push2db(is_same):
39
  text_is_same = "SAME" if is_same else "WRONG"
 
22
  context,
23
  input,
24
  truncation="only_first",
25
+ max_length = tokenizer.model_max_length - 2,
26
  )['input_ids']
27
  body_request = {
28
+ "inputs": [tokenizer.decode(encoding[1:-1])]
29
  }
30
 
31
  try:
 
33
  response = requests.post(url=CPC_URL, headers=HEADERS, json=body_request)
34
  if response.status_code == 200:
35
  return response.json()['predictions'][0]["0"]["label"]
36
+ else:
37
+ raise Exception(f"Error in response: {response.json()}")
38
+ except Exception as e:
39
+ logger.debug(f"Error in response: {e}")
40
+ st.switch_page("pages/model_loader.py")
41
 
42
  def cpc_push2db(is_same):
43
  text_is_same = "SAME" if is_same else "WRONG"