Spaces:
Runtime error
Runtime error
import torch | |
import time | |
from torch.utils.data import DataLoader | |
from datetime import datetime as dt | |
from params import * | |
from models.model import ModelWrapper | |
from utils.metrics import get_mned_metric_from_TruePredict, get_metric_from_TrueWrongPredictV3 | |
from utils.logger import get_logger | |
from models.sampler import RandomBatchSampler, BucketBatchSampler | |
from termcolor import colored | |
import re | |
class Corrector: | |
def __init__(self, model_wrapper: ModelWrapper): | |
self.model_name = model_wrapper.model_name | |
self.model = model_wrapper.model | |
self.model_wrapper = model_wrapper | |
self.logger = get_logger("./log/test.log") | |
self.device = DEVICE | |
self.model.to(self.device) | |
self.logger.log(f"Device: {self.device}") | |
self.logger.log("Loaded model") | |
def correct_transfomer_with_tr(self, batch, num_beams = 2): | |
correction = dict() | |
original = batch | |
splits = re.findall("\w\S+\w|\w+|[^\w\s]{1}", batch) | |
batch = " ".join(splits) | |
with torch.no_grad(): | |
self.model.eval() | |
batch_infer_start = time.time() | |
batch = self.model_wrapper.collator.collate([[None, batch,None, None]], type = "correct") | |
result = self.model.inference(batch['batch_src'], num_beams = num_beams, | |
tokenAligner=self.model_wrapper.collator.tokenAligner) | |
correction['predict_text'] = result | |
correction['noised_text'] = batch['noised_texts'] | |
correction['original_text'] = original | |
total_infer_time = time.time() - batch_infer_start | |
correction['time'] = total_infer_time | |
t = re.sub(r"(\s*)([.,:?!;]{1})(\s*)", r"\2\3", correction['predict_text'][0]) | |
t = re.sub(r"((?P<parenthesis>\()\s)", r"\g<parenthesis>", t) | |
t = re.sub(r"(\s(?P<parenthesis>\)))", r"\g<parenthesis>", t) | |
t = re.sub(r"((?P<bracket>\[)\s)", r"\g<bracket>", t) | |
t = re.sub(r"(\s(?P<bracket>\]))", r"\g<bracket>", t) | |
t = re.sub(r"([\'\"])\s(.*)\s([\'\"])", r"\1\2\3", t) | |
correction['predict_text']= re.sub(r"\s(%)", "%", t) | |
return correction | |
def _get_transfomer_with_tr_generations(self, batch, num_beams = 2): | |
correction = dict() | |
with torch.no_grad(): | |
self.model.eval() | |
batch_infer_start = time.time() | |
result = self.model.inference(batch['batch_src'], num_beams = num_beams, | |
tokenAligner=self.model_wrapper.collator.tokenAligner) | |
correction['predict_text'] = result | |
correction['noised_text'] = batch['noised_texts'] | |
total_infer_time = time.time() - batch_infer_start | |
correction['time'] = total_infer_time | |
return correction | |
def step(self, batch, num_beams = 2): | |
outputs= self._get_transfomer_with_tr_generations(batch, num_beams) | |
batch_predictions = outputs['predict_text'] | |
batch_label_texts = batch['label_texts'] | |
batch_noised_texts = batch['noised_texts'] | |
return batch_predictions, batch_noised_texts, batch_label_texts | |
def _evaluation_loop_autoregressive(self, data_loader, num_beams = 2): | |
TP, FP, FN = 0, 0, 0 | |
MNED = 0.0 | |
O_MNED = 0.0 | |
total_infer_time = 0.0 | |
twp_logger = get_logger(f"./log/true_wrong_predict{time.time()}.log") | |
with torch.no_grad(): | |
self.model.eval() | |
for step, batch in enumerate(data_loader): | |
batch_infer_start = time.time() | |
batch_predictions, batch_noised_texts, batch_label_texts = \ | |
self.step(batch, num_beams = num_beams) | |
batch_infer_time = time.time() - batch_infer_start | |
_TP, _FP, _FN = get_metric_from_TrueWrongPredictV3(batch_label_texts, batch_noised_texts, batch_predictions, self.model_wrapper.tokenAligner.vocab, twp_logger) | |
TP += _TP | |
FP += _FP | |
FN += _FN | |
_MNED = get_mned_metric_from_TruePredict(batch_label_texts, batch_predictions) | |
MNED += _MNED | |
_O_MNED = get_mned_metric_from_TruePredict(batch_label_texts, batch_noised_texts) | |
O_MNED += _O_MNED | |
info = '{} - Evaluate - iter: {:08d}/{:08d} - TP: {} - FP: {} - FN: {} - _MNED: {:.5f} - _O_MNED: {:.5f} - {} time: {:.2f}s'.format( | |
dt.now(), | |
step, | |
self.test_iters, | |
_TP, | |
_FP, | |
_FN, | |
_MNED, | |
_O_MNED, | |
self.device, | |
batch_infer_time) | |
self.logger.log(info) | |
torch.cuda.empty_cache() | |
total_infer_time += time.time() - batch_infer_start | |
return total_infer_time, TP, FP, FN, MNED / len(data_loader), O_MNED / len(data_loader) | |
def evaluate(self, dataset, beams: int = None): | |
def test_collate_wrapper(batch): | |
return self.model_wrapper.collator.collate(batch, type = "test") | |
if not BUCKET_SAMPLING: | |
self.test_sampler = RandomBatchSampler(dataset, VALID_BATCH_SIZE, shuffle = False) | |
else: | |
self.test_sampler = BucketBatchSampler(dataset, shuffle = True) | |
data_loader = DataLoader(dataset=dataset,batch_sampler= self.test_sampler,\ | |
collate_fn=test_collate_wrapper) | |
self.test_iters = len(data_loader) | |
assert beams != None | |
total_infer_time, TP, FP, FN, MNED, O_MNED = self._evaluation_loop_autoregressive(data_loader, num_beams = beams) | |
self.logger.log("Total inference time for this data is: {:4f} secs".format(total_infer_time)) | |
self.logger.log("###############################################") | |
info = f"Metrics for Auto-Regressive with Beam Search number {beams}" | |
self.logger.log(colored(info, "green")) | |
dc_TP = TP | |
dc_FP = FP | |
dc_FN = FN | |
dc_precision = dc_TP / (dc_TP + dc_FP) | |
dc_recall = dc_TP / (dc_TP + dc_FN) | |
dc_F1 = 2. * dc_precision * dc_recall/ ((dc_precision + dc_recall) + 1e-8) | |
self.logger.log(f"TP: {TP}. FP: {FP}. FN: {FN}") | |
self.logger.log(f"Precision: {dc_precision}") | |
self.logger.log(f"Recall: {dc_recall}") | |
self.logger.log(f"F1: {dc_F1}") | |
self.logger.log(f"MNED: {MNED}") | |
self.logger.log(f"O_MNED: {O_MNED}") | |
return | |