File size: 3,710 Bytes
fce29f6
 
 
 
a95f303
6d0fa58
fce29f6
6d0fa58
fce29f6
6d0fa58
fce29f6
 
 
 
 
 
 
 
 
6d0fa58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c183948
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from japanese.embedding import encode_sentences, get_cadidate_embeddings
from japanese.tokenizer import extract_keyphrase_candidates
from japanese.ranker import DirectedCentralityRnak

import streamlit as st
import torch
from transformers import AutoTokenizer
from transformers import AutoModel, AutoModelForMaskedLM

def extract_keyphrase(text):
    # load model
    model = AutoModel.from_pretrained('cl-tohoku/bert-base-japanese')
    tokenizer = AutoTokenizer.from_pretrained('cl-tohoku/bert-base-japanese')

    tokens, keyphrases = extract_keyphrase_candidates(text, tokenizer)

    document_embs = encode_sentences([tokens], tokenizer, model)
    document_feats = get_cadidate_embeddings([keyphrases], document_embs, [tokens])
    ranker = DirectedCentralityRnak(document_feats, beta=0.1, lambda1=1, lambda2=0.9, alpha=1.2, processors=8)
    return ranker.extract_summary()[0]


def preparation(tokenized_text, mask):
    # [CLS],[SEP]の挿入
    tokenized_text.insert(0, '[CLS]')  # 単語リストの先頭に[CLS]を付ける
    tokenized_text.append('[SEP]')  # 単語リストの最後に[SEP]を付ける

    maru = []
    for i, word in enumerate(tokenized_text):
        if word == '。' and i != len(tokenized_text) - 2:  # 「。」の位置検出
            maru.append(i)

    for i, loc in enumerate(maru):
        tokenized_text.insert(loc + 1 + i, '[SEP]')  # 単語リストの「。」の次に[SEP]を挿入する

    # 「□」を[MASK]に置き換え 
    mask_index = []
    for index, word in enumerate(tokenized_text):
        if word == mask:  # 「□」の位置検出
            tokenized_text[index] = '[MASK]'
            mask_index.append(index)

    return tokenized_text, mask_index


def mask_prediction(text, mask_word):
    model = AutoModelForMaskedLM.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
    tokenizer = AutoTokenizer.from_pretrained('cl-tohoku/bert-base-japanese')

    tokens, _ = extract_keyphrase_candidates(text, tokenizer)
    tokenized_text = tokenizer.tokenize(text)
    tokenized_text, mask_index = preparation(tokenized_text, mask_word)  # [CLS],[SEP],[MASK]の追加
    tokens = tokenizer.convert_tokens_to_ids(tokenized_text)  # IDリストに変換
    tokens_tensor = torch.tensor([tokens])  # IDテンソルに変換

    model.eval()
    with torch.no_grad():
        outputs = model(tokens_tensor)
        predictions = outputs[0]
        for i in range(len(mask_index)):
            _, predicted_indexes = torch.topk(predictions[0, mask_index[i]], k=5)
            predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_indexes.tolist())
            return predicted_tokens


if __name__ == '__main__':
    text = st.text_input("origin", "ギリシア人ポリュビオスは,著書『歴史』の中で,ローマ共和政の国制(政治体制)を優れたものと評価している。彼によれば,その国制には,コンスルという王制的要素,元老院という共和制的要素,民衆という民主制的要素が存在しており,これら三者が互いに協調や牽制をしあって均衡しているというのである。ローマ人はこの政治体制を誇りとしており,それは,彼らが自らの国家を指して呼んだ「ローマの元老院と民衆」という名称からも読み取ることができる。")
    phrases = extract_keyphrase(text)
    for phrase in phrases:
        for word in phrase.split("_"):
            distracters = mask_prediction(text, word)
            if distracters is None:
                continue
            for distracter in distracters:
                st.write(text.replace(word, distracter))