import logging import torch import time import json from typing import List import torch.nn.functional as F from transformers import AutoTokenizer, BertForTokenClassification from nltk import sent_tokenize from regex_tokenlizer import tokenize as word_tokenize def batch_run(lst, batch_size): """ Function to run through a list in batches. Parameters: lst (list): The input list to be processed in batches. batch_size (int): The size of each batch. Yields: list: A batch of the original list. """ for i in range(0, len(lst), batch_size): yield lst[i : i + batch_size] class DetectFeatures(object): def __init__(self, tokens, input_ids, token_type_ids, attention_mask, label_masks): self.input_ids = torch.as_tensor(input_ids, dtype=torch.long) self.token_type_ids = torch.as_tensor(token_type_ids, dtype=torch.long) self.attention_mask = torch.as_tensor(attention_mask, dtype=torch.long) self.label_masks = torch.as_tensor(label_masks, dtype=torch.long) self.tokens = tokens # keys self.keys = ["input_ids", "token_type_ids", "attention_mask", "label_masks"] def __repr__(self) -> str: return str(dict(tokens=self.tokens, features=[K for K in self.keys])) def load_model(device, checkpoint_path) -> BertForTokenClassification: start_time = time.time() model = BertForTokenClassification.from_pretrained(checkpoint_path) logging.info(f"BertForTokenClassification load [{checkpoint_path}]") model.to(device) model.eval() process_time = time.time() - start_time logging.info(f"Model Loaded in {process_time} seconds (device={device})") return model class BertDetecton: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_name = "microvnn/bert-base-vn-mistakes-detector" tokenizer = AutoTokenizer.from_pretrained(model_name) model: BertForTokenClassification = load_model(device, model_name) max_seq_len = 200 batch_size = 2 threshold = 0.75 def sentence_processing(self, tokens): encoding = self.tokenizer( tokens, padding="max_length", truncation=True, max_length=self.max_seq_len, return_tensors="pt", is_split_into_words=True, ) word_ids = encoding.word_ids() previous_word_idx = None offset_mappings = [] for word_idx in word_ids: if word_idx is None: offset_mappings.append(-100) elif word_idx != previous_word_idx: offset_mappings.append(1) else: offset_mappings.append(-100) previous_word_idx = word_idx item = {key: val.squeeze(0) for key, val in encoding.items()} item["label_masks"] = torch.tensor(offset_mappings) item["tokens"] = tokens return DetectFeatures(**item) def pre_processing(self, doc) -> List[DetectFeatures]: lst_ret = [] for S in doc: item: DetectFeatures = self.sentence_processing(S) lst_ret.append(item) return lst_ret def post_processing(self, logits, batch: List[DetectFeatures]): label_masks = torch.stack([F.label_masks for F in batch]) sents = [S.tokens for S in batch] label_masks = label_masks == 1 probs = F.softmax(logits, dim=-1) predictions = torch.argmax(logits, dim=-1) predictions = predictions.detach().cpu().numpy() label_masks = label_masks.detach().cpu().numpy() probs = probs.detach().cpu().numpy() lst_pred = [] for prod, label, mark in zip(probs, predictions, label_masks): lst_pred.append( [ (round(float(P[L]), 4), L) for P, L, M in zip(prod, label, mark) if M == True ] ) # assert len(lst_pred) == len(lst_pred) result = [] for tokens, pred in zip(sents, lst_pred): assert len(tokens) == len(pred) result.append( [ [TOKEN, int(LABEL), float(SCORE)] for TOKEN, (SCORE, LABEL) in zip(tokens, pred) ] ) return result def forward(self, doc): sentences = sent_tokenize(doc) return self.forwards(sentences) def forwards(self, sentences): sentences = [word_tokenize(S) for S in sentences] features = self.pre_processing(sentences) lst_detections = [] for batch in batch_run(features, self.batch_size): try: input_ids = torch.stack([F.input_ids for F in batch]).to(self.device) attention_mask = torch.stack([F.attention_mask for F in batch]).to( self.device ) outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits lst_detections.extend(self.post_processing(logits, batch)) except: pass return [" ".join(s) for s in sentences], lst_detections def test_sentence(): correction: BertDetecton = BertDetecton() doc = """HLV Mikel Arteta (trái) và Pep Guardiola trong trận Arsenal hòa Man City 2-2 trên sân Etihad, Manchester, Anh ngày 22/9. Ảnh: AP. HLV Mikel Arteta cũng không phòng ngự tiêu cực như đồn đoán, khi cho các cầu thủ Arsenal pressing mạnh mẽ, theo một kèm một và sẵng sàng rời vị trí để gây áp lực. Lối chơi này có ưu và nhượt điểm. Nó để lộ ra khoảng trosng cho Man City ghi bàn mở tỷ số, nhưng cũng dán tiếp giúp Arsenal ghi hai bàn để dẫn ngược ngay trong hiệp lột. Phút 9, Savinho loại bỏ Riccardo Calafiori, xộc vào trung lộ rồi chọt khe cho Erling Haaland đâm vào vòng cấm dứt điểm về góc gần, hạ thủ thàng David Raya. Đây là hệ quả từ việc Arsenal áp dụng chiến thuật pressing một kèm một. Khi Savinho thoát khỏi sự kèm cặp của người tương ứng, khoảng trống mở ra ở trun lộ. """ # doc = """ ám chỉ đồng độiSon Heung-min, rồi nói thêm""" # doc = """Nếu không đáp ứng được thì đứng trách họ rời xa mình. # """ # doc = """Lực sỹ Trần Lê Quốc Toàn: Lỡ HCV vì hơn 0,08kg.""" sentences, results = correction.forward(doc) print(results) import itertools error_words = list( itertools.chain.from_iterable( [(W, L, S) for W, L, S in sent if L > 0] for sent in results ) ) print("error_words=", list(error_words), len(error_words)) if __name__ == "__main__": test_sentence()