ajitrajasekharan
commited on
Commit
·
1c53eb1
1
Parent(s):
e32131f
Update app.py
Browse files
app.py
CHANGED
@@ -34,8 +34,8 @@ def encode(tokenizer, text_sentence, add_special_tokens=True):
|
|
34 |
if tokenizer.mask_token == text_sentence.split()[-1]:
|
35 |
text_sentence += ' .'
|
36 |
|
37 |
-
|
38 |
-
|
39 |
return input_ids, mask_idx
|
40 |
|
41 |
def get_all_predictions(text_sentence, top_clean=5):
|
@@ -48,7 +48,7 @@ def get_all_predictions(text_sentence, top_clean=5):
|
|
48 |
|
49 |
def get_bert_prediction(input_text,top_k):
|
50 |
try:
|
51 |
-
input_text += ' <mask>'
|
52 |
res = get_all_predictions(input_text, top_clean=int(top_k))
|
53 |
return res
|
54 |
except Exception as error:
|
|
|
34 |
if tokenizer.mask_token == text_sentence.split()[-1]:
|
35 |
text_sentence += ' .'
|
36 |
|
37 |
+
input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
|
38 |
+
mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
|
39 |
return input_ids, mask_idx
|
40 |
|
41 |
def get_all_predictions(text_sentence, top_clean=5):
|
|
|
48 |
|
49 |
def get_bert_prediction(input_text,top_k):
|
50 |
try:
|
51 |
+
#input_text += ' <mask>'
|
52 |
res = get_all_predictions(input_text, top_clean=int(top_k))
|
53 |
return res
|
54 |
except Exception as error:
|