ajitrajasekharan commited on
Commit
1c53eb1
·
1 Parent(s): e32131f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -34,8 +34,8 @@ def encode(tokenizer, text_sentence, add_special_tokens=True):
34
  if tokenizer.mask_token == text_sentence.split()[-1]:
35
  text_sentence += ' .'
36
 
37
- input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
38
- mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
39
  return input_ids, mask_idx
40
 
41
  def get_all_predictions(text_sentence, top_clean=5):
@@ -48,7 +48,7 @@ def get_all_predictions(text_sentence, top_clean=5):
48
 
49
  def get_bert_prediction(input_text,top_k):
50
  try:
51
- input_text += ' <mask>'
52
  res = get_all_predictions(input_text, top_clean=int(top_k))
53
  return res
54
  except Exception as error:
 
34
  if tokenizer.mask_token == text_sentence.split()[-1]:
35
  text_sentence += ' .'
36
 
37
+ input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
38
+ mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
39
  return input_ids, mask_idx
40
 
41
  def get_all_predictions(text_sentence, top_clean=5):
 
48
 
49
  def get_bert_prediction(input_text,top_k):
50
  try:
51
+ #input_text += ' <mask>'
52
  res = get_all_predictions(input_text, top_clean=int(top_k))
53
  return res
54
  except Exception as error: