Spaces:
Runtime error
Runtime error
from japanese.embedding import encode_sentences, get_cadidate_embeddings | |
from japanese.tokenizer import extract_keyphrase_candidates | |
from japanese.ranker import DirectedCentralityRnak | |
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer | |
from transformers import AutoModel, AutoModelForMaskedLM | |
def extract_keyphrase(text): | |
# load model | |
model = AutoModel.from_pretrained('cl-tohoku/bert-base-japanese') | |
tokenizer = AutoTokenizer.from_pretrained('cl-tohoku/bert-base-japanese') | |
tokens, keyphrases = extract_keyphrase_candidates(text, tokenizer) | |
document_embs = encode_sentences([tokens], tokenizer, model) | |
document_feats = get_cadidate_embeddings([keyphrases], document_embs, [tokens]) | |
ranker = DirectedCentralityRnak(document_feats, beta=0.1, lambda1=1, lambda2=0.9, alpha=1.2, processors=8) | |
return ranker.extract_summary()[0] | |
def preparation(tokenized_text, mask): | |
# [CLS],[SEP]の挿入 | |
tokenized_text.insert(0, '[CLS]') # 単語リストの先頭に[CLS]を付ける | |
tokenized_text.append('[SEP]') # 単語リストの最後に[SEP]を付ける | |
maru = [] | |
for i, word in enumerate(tokenized_text): | |
if word == '。' and i != len(tokenized_text) - 2: # 「。」の位置検出 | |
maru.append(i) | |
for i, loc in enumerate(maru): | |
tokenized_text.insert(loc + 1 + i, '[SEP]') # 単語リストの「。」の次に[SEP]を挿入する | |
# 「□」を[MASK]に置き換え | |
mask_index = [] | |
for index, word in enumerate(tokenized_text): | |
if word == mask: # 「□」の位置検出 | |
tokenized_text[index] = '[MASK]' | |
mask_index.append(index) | |
return tokenized_text, mask_index | |
def mask_prediction(text, mask_word): | |
model = AutoModelForMaskedLM.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking') | |
tokenizer = AutoTokenizer.from_pretrained('cl-tohoku/bert-base-japanese') | |
tokens, _ = extract_keyphrase_candidates(text, tokenizer) | |
tokenized_text = tokenizer.tokenize(text) | |
tokenized_text, mask_index = preparation(tokenized_text, mask_word) # [CLS],[SEP],[MASK]の追加 | |
tokens = tokenizer.convert_tokens_to_ids(tokenized_text) # IDリストに変換 | |
tokens_tensor = torch.tensor([tokens]) # IDテンソルに変換 | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(tokens_tensor) | |
predictions = outputs[0] | |
for i in range(len(mask_index)): | |
_, predicted_indexes = torch.topk(predictions[0, mask_index[i]], k=5) | |
predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_indexes.tolist()) | |
return predicted_tokens | |
if __name__ == '__main__': | |
text = st.text_input("origin", "ギリシア人ポリュビオスは,著書『歴史』の中で,ローマ共和政の国制(政治体制)を優れたものと評価している。彼によれば,その国制には,コンスルという王制的要素,元老院という共和制的要素,民衆という民主制的要素が存在しており,これら三者が互いに協調や牽制をしあって均衡しているというのである。ローマ人はこの政治体制を誇りとしており,それは,彼らが自らの国家を指して呼んだ「ローマの元老院と民衆」という名称からも読み取ることができる。") | |
phrases = extract_keyphrase(text) | |
for phrase in phrases: | |
for word in phrase.split("_"): | |
distracters = mask_prediction(text, word) | |
if distracters is None: | |
continue | |
for distracter in distracters: | |
st.write(text.replace(word, distracter)) | |