|
import logging
|
|
import torch
|
|
import time
|
|
import json
|
|
from typing import List
|
|
import torch.nn.functional as F
|
|
from transformers import AutoTokenizer, BertForTokenClassification
|
|
from nltk import sent_tokenize
|
|
from regex_tokenlizer import tokenize as word_tokenize
|
|
|
|
|
|
def batch_run(lst, batch_size):
|
|
"""
|
|
Function to run through a list in batches.
|
|
|
|
Parameters:
|
|
lst (list): The input list to be processed in batches.
|
|
batch_size (int): The size of each batch.
|
|
|
|
Yields:
|
|
list: A batch of the original list.
|
|
"""
|
|
for i in range(0, len(lst), batch_size):
|
|
yield lst[i : i + batch_size]
|
|
|
|
|
|
class DetectFeatures(object):
|
|
def __init__(self, tokens, input_ids, token_type_ids, attention_mask, label_masks):
|
|
self.input_ids = torch.as_tensor(input_ids, dtype=torch.long)
|
|
self.token_type_ids = torch.as_tensor(token_type_ids, dtype=torch.long)
|
|
self.attention_mask = torch.as_tensor(attention_mask, dtype=torch.long)
|
|
self.label_masks = torch.as_tensor(label_masks, dtype=torch.long)
|
|
self.tokens = tokens
|
|
|
|
self.keys = ["input_ids", "token_type_ids", "attention_mask", "label_masks"]
|
|
|
|
def __repr__(self) -> str:
|
|
return str(dict(tokens=self.tokens, features=[K for K in self.keys]))
|
|
|
|
|
|
def load_model(device, checkpoint_path) -> BertForTokenClassification:
|
|
start_time = time.time()
|
|
model = BertForTokenClassification.from_pretrained(checkpoint_path)
|
|
logging.info(f"BertForTokenClassification load [{checkpoint_path}]")
|
|
model.to(device)
|
|
model.eval()
|
|
process_time = time.time() - start_time
|
|
logging.info(f"Model Loaded in {process_time} seconds (device={device})")
|
|
return model
|
|
|
|
|
|
class BertDetecton:
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model_name = "microvnn/bert-base-vn-mistakes-detector"
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
model: BertForTokenClassification = load_model(device, model_name)
|
|
max_seq_len = 200
|
|
batch_size = 2
|
|
threshold = 0.75
|
|
|
|
def sentence_processing(self, tokens):
|
|
encoding = self.tokenizer(
|
|
tokens,
|
|
padding="max_length",
|
|
truncation=True,
|
|
max_length=self.max_seq_len,
|
|
return_tensors="pt",
|
|
is_split_into_words=True,
|
|
)
|
|
word_ids = encoding.word_ids()
|
|
previous_word_idx = None
|
|
offset_mappings = []
|
|
for word_idx in word_ids:
|
|
if word_idx is None:
|
|
offset_mappings.append(-100)
|
|
elif word_idx != previous_word_idx:
|
|
offset_mappings.append(1)
|
|
else:
|
|
offset_mappings.append(-100)
|
|
previous_word_idx = word_idx
|
|
item = {key: val.squeeze(0) for key, val in encoding.items()}
|
|
item["label_masks"] = torch.tensor(offset_mappings)
|
|
item["tokens"] = tokens
|
|
return DetectFeatures(**item)
|
|
|
|
def pre_processing(self, doc) -> List[DetectFeatures]:
|
|
lst_ret = []
|
|
for S in doc:
|
|
item: DetectFeatures = self.sentence_processing(S)
|
|
lst_ret.append(item)
|
|
return lst_ret
|
|
|
|
def post_processing(self, logits, batch: List[DetectFeatures]):
|
|
label_masks = torch.stack([F.label_masks for F in batch])
|
|
sents = [S.tokens for S in batch]
|
|
label_masks = label_masks == 1
|
|
probs = F.softmax(logits, dim=-1)
|
|
predictions = torch.argmax(logits, dim=-1)
|
|
predictions = predictions.detach().cpu().numpy()
|
|
label_masks = label_masks.detach().cpu().numpy()
|
|
probs = probs.detach().cpu().numpy()
|
|
lst_pred = []
|
|
for prod, label, mark in zip(probs, predictions, label_masks):
|
|
lst_pred.append(
|
|
[
|
|
(round(float(P[L]), 4), L)
|
|
for P, L, M in zip(prod, label, mark)
|
|
if M == True
|
|
]
|
|
)
|
|
|
|
result = []
|
|
for tokens, pred in zip(sents, lst_pred):
|
|
assert len(tokens) == len(pred)
|
|
result.append(
|
|
[
|
|
[TOKEN, int(LABEL), float(SCORE)]
|
|
for TOKEN, (SCORE, LABEL) in zip(tokens, pred)
|
|
]
|
|
)
|
|
return result
|
|
|
|
def forward(self, doc):
|
|
sentences = sent_tokenize(doc)
|
|
return self.forwards(sentences)
|
|
|
|
def forwards(self, sentences):
|
|
sentences = [word_tokenize(S) for S in sentences]
|
|
features = self.pre_processing(sentences)
|
|
lst_detections = []
|
|
for batch in batch_run(features, self.batch_size):
|
|
try:
|
|
input_ids = torch.stack([F.input_ids for F in batch]).to(self.device)
|
|
attention_mask = torch.stack([F.attention_mask for F in batch]).to(
|
|
self.device
|
|
)
|
|
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
|
logits = outputs.logits
|
|
lst_detections.extend(self.post_processing(logits, batch))
|
|
except:
|
|
pass
|
|
return [" ".join(s) for s in sentences], lst_detections
|
|
|
|
|
|
def test_sentence():
|
|
|
|
correction: BertDetecton = BertDetecton()
|
|
doc = """HLV Mikel Arteta (trái) và Pep Guardiola trong trận Arsenal hòa Man City 2-2 trên sân Etihad, Manchester, Anh ngày 22/9. Ảnh: AP.
|
|
|
|
HLV Mikel Arteta cũng không phòng ngự tiêu cực như đồn đoán, khi cho các cầu thủ Arsenal pressing mạnh mẽ, theo một kèm một và sẵng sàng rời vị trí để gây áp lực. Lối chơi này có ưu và nhượt điểm. Nó để lộ ra khoảng trosng cho Man City ghi bàn mở tỷ số, nhưng cũng dán tiếp giúp Arsenal ghi hai bàn để dẫn ngược ngay trong hiệp lột.
|
|
|
|
Phút 9, Savinho loại bỏ Riccardo Calafiori, xộc vào trung lộ rồi chọt khe cho Erling Haaland đâm vào vòng cấm dứt điểm về góc gần, hạ thủ thàng David Raya. Đây là hệ quả từ việc Arsenal áp dụng chiến thuật pressing một kèm một. Khi Savinho thoát khỏi sự kèm cặp của người tương ứng, khoảng trống mở ra ở trun lộ.
|
|
"""
|
|
|
|
|
|
|
|
|
|
sentences, results = correction.forward(doc)
|
|
print(results)
|
|
import itertools
|
|
|
|
error_words = list(
|
|
itertools.chain.from_iterable(
|
|
[(W, L, S) for W, L, S in sent if L > 0] for sent in results
|
|
)
|
|
)
|
|
print("error_words=", list(error_words), len(error_words))
|
|
|
|
if __name__ == "__main__":
|
|
test_sentence()
|
|
|