from utils.logger import get_logger import numpy as np from rapidfuzz.distance.Levenshtein import normalized_distance import multiprocessing import time import utils.alignment as alignment def _get_mned_metric_from_TruePredict(true_text, predict_text): return normalized_distance(predict_text, true_text) def get_mned_metric_from_TruePredict(batch_true_text, batch_predict_text): total_NMED = 0.0 count = 0 for true_text, predict_text in zip(batch_true_text, batch_predict_text): total_NMED += _get_mned_metric_from_TruePredict(true_text, predict_text) count += 1 return total_NMED / count def get_metric_for_tfm(batch_predicts, batch_targets, batch_length): num_correct, num_wrong = 0, 0 for predict, target, length in zip(batch_predicts, batch_targets, batch_length): predict = predict[1:-1] target = target[1:-1] predict = np.array(predict[0:length]) target = np.array(target[0:length]) num_correct += np.sum(predict == target) num_wrong += np.sum(predict != target) return num_correct, num_wrong def allign_seq2trueseq(seq, true_seq, gap_symbol = "-"): prev_sep = None next_sep = None seq_list = [] true_list = [] accumulate_true_word = "" accumulate_pred_word = "" assert len(true_seq) == len(seq) for i in range(len(true_seq)): if true_seq[i] != " ": accumulate_true_word += true_seq[i] accumulate_pred_word += seq[i] else: if seq[i] == gap_symbol: next_sep = gap_symbol if prev_sep != None and prev_sep == gap_symbol: accumulate_pred_word = "@@" + accumulate_pred_word if next_sep != None and next_sep == gap_symbol: accumulate_pred_word = accumulate_pred_word + "@@" else: next_sep = " " if prev_sep != None and prev_sep == gap_symbol: accumulate_pred_word = "@@" + accumulate_pred_word if next_sep != None and next_sep == gap_symbol: accumulate_pred_word = accumulate_pred_word + "@@" true_list.append(accumulate_true_word.replace(gap_symbol, "")) seq_list.append(accumulate_pred_word) accumulate_pred_word = "" accumulate_true_word = "" prev_sep = next_sep next_sep = None return seq_list, true_list def align_2seq2trueseq(wrong_text, pred_text, true_text, gap_symbol = "-"): assert gap_symbol != None and len(gap_symbol) == 1 seq1, true_seq = alignment.needle(wrong_text, true_text, gap_symbol) seq1_list, true_list = allign_seq2trueseq(seq1, true_seq, gap_symbol) seq2, true_seq = alignment.needle(pred_text, true_text, gap_symbol) seq2_list, _ = allign_seq2trueseq(seq2, true_seq, gap_symbol) return list(zip(seq1_list, seq2_list, true_list)) def _get_metric_from_TrueWrongPredictV3(true_text, wrong_text, predict_text, vocab = None): gap_symbol = None if vocab != None: all_symbols = set(list(vocab.chartoken2idx.keys())[4:]) symbols = set(list(wrong_text + predict_text + true_text)) usable_symbols = all_symbols.difference(symbols) assert len(usable_symbols) > 0 if "-" not in usable_symbols: gap_symbol = usable_symbols.pop() else: gap_symbol = "-" gap_symbol = gap_symbol if gap_symbol != None else "-" alignment = align_2seq2trueseq(wrong_text, predict_text, true_text, gap_symbol) TP, FP, FN = 0, 0, 0 for wrong, predict, true in alignment: if wrong == true: if predict[:-2] == true: pass elif predict != true: if len(predict.split(" ")) == len(true.split(" ")): FP += 1 else: penalty = len(predict.split(" ")) - len(true.split(" ")) assert penalty > 0 FP += penalty else: if predict == true: TP += 1 else: if len(predict.split(" ")) == len(true.split(" ")): FN += 1 else: penalty = len(predict.split(" ")) - len(true.split(" ")) assert penalty > 0 FN += penalty return TP, FP, FN def worker_task(true_text, wrong_text, predict_text, vocab): _TP, _FP, _FN = _get_metric_from_TrueWrongPredictV3(true_text, wrong_text, predict_text, vocab) return (_TP, _FP, _FN) from multiprocessing import Pool def get_metric_from_TrueWrongPredictV3(batch_true_text, batch_wrong_text, batch_predict_text, vocab, twp_logger): assert vocab != None TPs, FPs, FNs = 0, 0, 0 with Pool(int(multiprocessing.cpu_count() / 3)) as pool: data = [(true_text, wrong_text, pred_text, vocab) for true_text, wrong_text, pred_text in zip(batch_true_text, batch_wrong_text, batch_predict_text)] result = pool.starmap_async(worker_task, data) for i, result in enumerate(result.get()): TPs += result[0] FPs += result[1] FNs += result[2] if twp_logger: twp_logger.log(batch_true_text[i], file_only=True) twp_logger.log(batch_wrong_text[i], file_only=True) twp_logger.log(batch_predict_text[i], file_only=True) twp_logger.log(f"{result[0]} - {result[1]} - {result[2]}", file_only=True) return TPs, FPs, FNs