ajitrajasekharan
commited on
Commit
·
424d29e
1
Parent(s):
a03c359
Update app.py
Browse files
app.py
CHANGED
@@ -33,15 +33,13 @@ def decode(tokenizer, pred_idx, top_clean):
|
|
33 |
return '\n'.join(tokens[:top_clean])
|
34 |
|
35 |
def encode(tokenizer, text_sentence, add_special_tokens=True):
|
36 |
-
bert_tokenizer = st.session_state['bert_tokenizer']
|
37 |
-
bert_model = st.session_state['bert_model']
|
38 |
|
39 |
text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
|
40 |
# if <mask> is the last token, append a "." so that models dont predict punctuation.
|
41 |
#if tokenizer.mask_token == text_sentence.split()[-1]:
|
42 |
# text_sentence += ' .'
|
43 |
|
44 |
-
tokenized_text =
|
45 |
input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
|
46 |
if (tokenizer.mask_token in text_sentence.split()):
|
47 |
mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
|
@@ -52,6 +50,7 @@ def encode(tokenizer, text_sentence, add_special_tokens=True):
|
|
52 |
def get_all_predictions(text_sentence, model_name,top_clean=5):
|
53 |
bert_tokenizer = st.session_state['bert_tokenizer']
|
54 |
bert_model = st.session_state['bert_model']
|
|
|
55 |
|
56 |
# ========================= BERT =================================
|
57 |
input_ids, mask_idx,tokenized_text = encode(bert_tokenizer, text_sentence)
|
|
|
33 |
return '\n'.join(tokens[:top_clean])
|
34 |
|
35 |
def encode(tokenizer, text_sentence, add_special_tokens=True):
|
|
|
|
|
36 |
|
37 |
text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
|
38 |
# if <mask> is the last token, append a "." so that models dont predict punctuation.
|
39 |
#if tokenizer.mask_token == text_sentence.split()[-1]:
|
40 |
# text_sentence += ' .'
|
41 |
|
42 |
+
tokenized_text = tokenizer.tokenize(text_sentence)
|
43 |
input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
|
44 |
if (tokenizer.mask_token in text_sentence.split()):
|
45 |
mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
|
|
|
50 |
def get_all_predictions(text_sentence, model_name,top_clean=5):
|
51 |
bert_tokenizer = st.session_state['bert_tokenizer']
|
52 |
bert_model = st.session_state['bert_model']
|
53 |
+
top_k = st.session_state['top_k']
|
54 |
|
55 |
# ========================= BERT =================================
|
56 |
input_ids, mask_idx,tokenized_text = encode(bert_tokenizer, text_sentence)
|