microvnn commited on
Commit
d0ffff2
·
verified ·
1 Parent(s): ac247d7

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +170 -0
inference.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ import time
4
+ import json
5
+ from typing import List
6
+ import torch.nn.functional as F
7
+ from transformers import AutoTokenizer, BertForTokenClassification
8
+ from nltk import sent_tokenize
9
+ from regex_tokenlizer import tokenize as word_tokenize
10
+
11
+
12
+ def batch_run(lst, batch_size):
13
+ """
14
+ Function to run through a list in batches.
15
+
16
+ Parameters:
17
+ lst (list): The input list to be processed in batches.
18
+ batch_size (int): The size of each batch.
19
+
20
+ Yields:
21
+ list: A batch of the original list.
22
+ """
23
+ for i in range(0, len(lst), batch_size):
24
+ yield lst[i : i + batch_size]
25
+
26
+
27
+ class DetectFeatures(object):
28
+ def __init__(self, tokens, input_ids, token_type_ids, attention_mask, label_masks):
29
+ self.input_ids = torch.as_tensor(input_ids, dtype=torch.long)
30
+ self.token_type_ids = torch.as_tensor(token_type_ids, dtype=torch.long)
31
+ self.attention_mask = torch.as_tensor(attention_mask, dtype=torch.long)
32
+ self.label_masks = torch.as_tensor(label_masks, dtype=torch.long)
33
+ self.tokens = tokens
34
+ # keys
35
+ self.keys = ["input_ids", "token_type_ids", "attention_mask", "label_masks"]
36
+
37
+ def __repr__(self) -> str:
38
+ return str(dict(tokens=self.tokens, features=[K for K in self.keys]))
39
+
40
+
41
+ def load_model(device, checkpoint_path) -> BertForTokenClassification:
42
+ start_time = time.time()
43
+ model = BertForTokenClassification.from_pretrained(checkpoint_path)
44
+ logging.info(f"BertForTokenClassification load [{checkpoint_path}]")
45
+ model.to(device)
46
+ model.eval()
47
+ process_time = time.time() - start_time
48
+ logging.info(f"Model Loaded in {process_time} seconds (device={device})")
49
+ return model
50
+
51
+
52
+ class BertDetecton:
53
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+ model_name = "microvnn/bert-base-vn-mistakes-detector"
55
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
56
+ model: BertForTokenClassification = load_model(device, model_name)
57
+ max_seq_len = 200
58
+ batch_size = 2
59
+ threshold = 0.75
60
+
61
+ def sentence_processing(self, tokens):
62
+ encoding = self.tokenizer(
63
+ tokens,
64
+ padding="max_length",
65
+ truncation=True,
66
+ max_length=self.max_seq_len,
67
+ return_tensors="pt",
68
+ is_split_into_words=True,
69
+ )
70
+ word_ids = encoding.word_ids()
71
+ previous_word_idx = None
72
+ offset_mappings = []
73
+ for word_idx in word_ids:
74
+ if word_idx is None:
75
+ offset_mappings.append(-100)
76
+ elif word_idx != previous_word_idx:
77
+ offset_mappings.append(1)
78
+ else:
79
+ offset_mappings.append(-100)
80
+ previous_word_idx = word_idx
81
+ item = {key: val.squeeze(0) for key, val in encoding.items()}
82
+ item["label_masks"] = torch.tensor(offset_mappings)
83
+ item["tokens"] = tokens
84
+ return DetectFeatures(**item)
85
+
86
+ def pre_processing(self, doc) -> List[DetectFeatures]:
87
+ lst_ret = []
88
+ for S in doc:
89
+ item: DetectFeatures = self.sentence_processing(S)
90
+ lst_ret.append(item)
91
+ return lst_ret
92
+
93
+ def post_processing(self, logits, batch: List[DetectFeatures]):
94
+ label_masks = torch.stack([F.label_masks for F in batch])
95
+ sents = [S.tokens for S in batch]
96
+ label_masks = label_masks == 1
97
+ probs = F.softmax(logits, dim=-1)
98
+ predictions = torch.argmax(logits, dim=-1)
99
+ predictions = predictions.detach().cpu().numpy()
100
+ label_masks = label_masks.detach().cpu().numpy()
101
+ probs = probs.detach().cpu().numpy()
102
+ lst_pred = []
103
+ for prod, label, mark in zip(probs, predictions, label_masks):
104
+ lst_pred.append(
105
+ [
106
+ (round(float(P[L]), 4), L)
107
+ for P, L, M in zip(prod, label, mark)
108
+ if M == True
109
+ ]
110
+ )
111
+ # assert len(lst_pred) == len(lst_pred)
112
+ result = []
113
+ for tokens, pred in zip(sents, lst_pred):
114
+ assert len(tokens) == len(pred)
115
+ result.append(
116
+ [
117
+ [TOKEN, int(LABEL), float(SCORE)]
118
+ for TOKEN, (SCORE, LABEL) in zip(tokens, pred)
119
+ ]
120
+ )
121
+ return result
122
+
123
+ def forward(self, doc):
124
+ sentences = sent_tokenize(doc)
125
+ return self.forwards(sentences)
126
+
127
+ def forwards(self, sentences):
128
+ sentences = [word_tokenize(S) for S in sentences]
129
+ features = self.pre_processing(sentences)
130
+ lst_detections = []
131
+ for batch in batch_run(features, self.batch_size):
132
+ try:
133
+ input_ids = torch.stack([F.input_ids for F in batch]).to(self.device)
134
+ attention_mask = torch.stack([F.attention_mask for F in batch]).to(
135
+ self.device
136
+ )
137
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
138
+ logits = outputs.logits
139
+ lst_detections.extend(self.post_processing(logits, batch))
140
+ except:
141
+ pass
142
+ return [" ".join(s) for s in sentences], lst_detections
143
+
144
+
145
+ def test_sentence():
146
+
147
+ correction: BertDetecton = BertDetecton()
148
+ doc = """HLV Mikel Arteta (trái) và Pep Guardiola trong trận Arsenal hòa Man City 2-2 trên sân Etihad, Manchester, Anh ngày 22/9. Ảnh: AP.
149
+
150
+ HLV Mikel Arteta cũng không phòng ngự tiêu cực như đồn đoán, khi cho các cầu thủ Arsenal pressing mạnh mẽ, theo một kèm một và sẵng sàng rời vị trí để gây áp lực. Lối chơi này có ưu và nhượt điểm. Nó để lộ ra khoảng trosng cho Man City ghi bàn mở tỷ số, nhưng cũng dán tiếp giúp Arsenal ghi hai bàn để dẫn ngược ngay trong hiệp lột.
151
+
152
+ Phút 9, Savinho loại bỏ Riccardo Calafiori, xộc vào trung lộ rồi chọt khe cho Erling Haaland đâm vào vòng cấm dứt điểm về góc gần, hạ thủ thàng David Raya. Đây là hệ quả từ việc Arsenal áp dụng chiến thuật pressing một kèm một. Khi Savinho thoát khỏi sự kèm cặp của người tương ứng, khoảng trống mở ra ở trun lộ.
153
+ """
154
+ # doc = """ ám chỉ đồng độiSon Heung-min, rồi nói thêm"""
155
+ # doc = """Nếu không đáp ứng được thì đứng trách họ rời xa mình.
156
+ # """
157
+ # doc = """Lực sỹ Trần Lê Quốc Toàn: Lỡ HCV vì hơn 0,08kg."""
158
+ sentences, results = correction.forward(doc)
159
+ print(results)
160
+ import itertools
161
+
162
+ error_words = list(
163
+ itertools.chain.from_iterable(
164
+ [(W, L, S) for W, L, S in sent if L > 0] for sent in results
165
+ )
166
+ )
167
+ print("error_words=", list(error_words), len(error_words))
168
+
169
+ if __name__ == "__main__":
170
+ test_sentence()