ajitrajasekharan commited on
Commit
424d29e
·
1 Parent(s): a03c359

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -33,15 +33,13 @@ def decode(tokenizer, pred_idx, top_clean):
33
  return '\n'.join(tokens[:top_clean])
34
 
35
  def encode(tokenizer, text_sentence, add_special_tokens=True):
36
- bert_tokenizer = st.session_state['bert_tokenizer']
37
- bert_model = st.session_state['bert_model']
38
 
39
  text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
40
  # if <mask> is the last token, append a "." so that models dont predict punctuation.
41
  #if tokenizer.mask_token == text_sentence.split()[-1]:
42
  # text_sentence += ' .'
43
 
44
- tokenized_text = bert_tokenizer.tokenize(text_sentence)
45
  input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
46
  if (tokenizer.mask_token in text_sentence.split()):
47
  mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
@@ -52,6 +50,7 @@ def encode(tokenizer, text_sentence, add_special_tokens=True):
52
  def get_all_predictions(text_sentence, model_name,top_clean=5):
53
  bert_tokenizer = st.session_state['bert_tokenizer']
54
  bert_model = st.session_state['bert_model']
 
55
 
56
  # ========================= BERT =================================
57
  input_ids, mask_idx,tokenized_text = encode(bert_tokenizer, text_sentence)
 
33
  return '\n'.join(tokens[:top_clean])
34
 
35
  def encode(tokenizer, text_sentence, add_special_tokens=True):
 
 
36
 
37
  text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
38
  # if <mask> is the last token, append a "." so that models dont predict punctuation.
39
  #if tokenizer.mask_token == text_sentence.split()[-1]:
40
  # text_sentence += ' .'
41
 
42
+ tokenized_text = tokenizer.tokenize(text_sentence)
43
  input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
44
  if (tokenizer.mask_token in text_sentence.split()):
45
  mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
 
50
  def get_all_predictions(text_sentence, model_name,top_clean=5):
51
  bert_tokenizer = st.session_state['bert_tokenizer']
52
  bert_model = st.session_state['bert_model']
53
+ top_k = st.session_state['top_k']
54
 
55
  # ========================= BERT =================================
56
  input_ids, mask_idx,tokenized_text = encode(bert_tokenizer, text_sentence)