microvnn's picture
Upload inference.py
d0ffff2 verified
import logging
import torch
import time
import json
from typing import List
import torch.nn.functional as F
from transformers import AutoTokenizer, BertForTokenClassification
from nltk import sent_tokenize
from regex_tokenlizer import tokenize as word_tokenize
def batch_run(lst, batch_size):
"""
Function to run through a list in batches.
Parameters:
lst (list): The input list to be processed in batches.
batch_size (int): The size of each batch.
Yields:
list: A batch of the original list.
"""
for i in range(0, len(lst), batch_size):
yield lst[i : i + batch_size]
class DetectFeatures(object):
def __init__(self, tokens, input_ids, token_type_ids, attention_mask, label_masks):
self.input_ids = torch.as_tensor(input_ids, dtype=torch.long)
self.token_type_ids = torch.as_tensor(token_type_ids, dtype=torch.long)
self.attention_mask = torch.as_tensor(attention_mask, dtype=torch.long)
self.label_masks = torch.as_tensor(label_masks, dtype=torch.long)
self.tokens = tokens
# keys
self.keys = ["input_ids", "token_type_ids", "attention_mask", "label_masks"]
def __repr__(self) -> str:
return str(dict(tokens=self.tokens, features=[K for K in self.keys]))
def load_model(device, checkpoint_path) -> BertForTokenClassification:
start_time = time.time()
model = BertForTokenClassification.from_pretrained(checkpoint_path)
logging.info(f"BertForTokenClassification load [{checkpoint_path}]")
model.to(device)
model.eval()
process_time = time.time() - start_time
logging.info(f"Model Loaded in {process_time} seconds (device={device})")
return model
class BertDetecton:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "microvnn/bert-base-vn-mistakes-detector"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model: BertForTokenClassification = load_model(device, model_name)
max_seq_len = 200
batch_size = 2
threshold = 0.75
def sentence_processing(self, tokens):
encoding = self.tokenizer(
tokens,
padding="max_length",
truncation=True,
max_length=self.max_seq_len,
return_tensors="pt",
is_split_into_words=True,
)
word_ids = encoding.word_ids()
previous_word_idx = None
offset_mappings = []
for word_idx in word_ids:
if word_idx is None:
offset_mappings.append(-100)
elif word_idx != previous_word_idx:
offset_mappings.append(1)
else:
offset_mappings.append(-100)
previous_word_idx = word_idx
item = {key: val.squeeze(0) for key, val in encoding.items()}
item["label_masks"] = torch.tensor(offset_mappings)
item["tokens"] = tokens
return DetectFeatures(**item)
def pre_processing(self, doc) -> List[DetectFeatures]:
lst_ret = []
for S in doc:
item: DetectFeatures = self.sentence_processing(S)
lst_ret.append(item)
return lst_ret
def post_processing(self, logits, batch: List[DetectFeatures]):
label_masks = torch.stack([F.label_masks for F in batch])
sents = [S.tokens for S in batch]
label_masks = label_masks == 1
probs = F.softmax(logits, dim=-1)
predictions = torch.argmax(logits, dim=-1)
predictions = predictions.detach().cpu().numpy()
label_masks = label_masks.detach().cpu().numpy()
probs = probs.detach().cpu().numpy()
lst_pred = []
for prod, label, mark in zip(probs, predictions, label_masks):
lst_pred.append(
[
(round(float(P[L]), 4), L)
for P, L, M in zip(prod, label, mark)
if M == True
]
)
# assert len(lst_pred) == len(lst_pred)
result = []
for tokens, pred in zip(sents, lst_pred):
assert len(tokens) == len(pred)
result.append(
[
[TOKEN, int(LABEL), float(SCORE)]
for TOKEN, (SCORE, LABEL) in zip(tokens, pred)
]
)
return result
def forward(self, doc):
sentences = sent_tokenize(doc)
return self.forwards(sentences)
def forwards(self, sentences):
sentences = [word_tokenize(S) for S in sentences]
features = self.pre_processing(sentences)
lst_detections = []
for batch in batch_run(features, self.batch_size):
try:
input_ids = torch.stack([F.input_ids for F in batch]).to(self.device)
attention_mask = torch.stack([F.attention_mask for F in batch]).to(
self.device
)
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
lst_detections.extend(self.post_processing(logits, batch))
except:
pass
return [" ".join(s) for s in sentences], lst_detections
def test_sentence():
correction: BertDetecton = BertDetecton()
doc = """HLV Mikel Arteta (trái) và Pep Guardiola trong trận Arsenal hòa Man City 2-2 trên sân Etihad, Manchester, Anh ngày 22/9. Ảnh: AP.
HLV Mikel Arteta cũng không phòng ngự tiêu cực như đồn đoán, khi cho các cầu thủ Arsenal pressing mạnh mẽ, theo một kèm một và sẵng sàng rời vị trí để gây áp lực. Lối chơi này có ưu và nhượt điểm. Nó để lộ ra khoảng trosng cho Man City ghi bàn mở tỷ số, nhưng cũng dán tiếp giúp Arsenal ghi hai bàn để dẫn ngược ngay trong hiệp lột.
Phút 9, Savinho loại bỏ Riccardo Calafiori, xộc vào trung lộ rồi chọt khe cho Erling Haaland đâm vào vòng cấm dứt điểm về góc gần, hạ thủ thàng David Raya. Đây là hệ quả từ việc Arsenal áp dụng chiến thuật pressing một kèm một. Khi Savinho thoát khỏi sự kèm cặp của người tương ứng, khoảng trống mở ra ở trun lộ.
"""
# doc = """ ám chỉ đồng độiSon Heung-min, rồi nói thêm"""
# doc = """Nếu không đáp ứng được thì đứng trách họ rời xa mình.
# """
# doc = """Lực sỹ Trần Lê Quốc Toàn: Lỡ HCV vì hơn 0,08kg."""
sentences, results = correction.forward(doc)
print(results)
import itertools
error_words = list(
itertools.chain.from_iterable(
[(W, L, S) for W, L, S in sent if L > 0] for sent in results
)
)
print("error_words=", list(error_words), len(error_words))
if __name__ == "__main__":
test_sentence()