import json | |
import copy | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
import faiss | |
from tqdm import tqdm | |
from import Dataset, DataLoader | |
from torch.cuda.amp import autocast | |
court_text_splitter = "Весь текст судебного документа: " | |
class FaissDocsDataset(Dataset): | |
def __init__(self, data): | | = data | |
def __len__(self): | |
return len( | |
def __getitem__(self, idx): | |
return[idx] | |
def preprocess_inputs(inputs, device): | |
return {k: v[:, 0, :].to(device) for k, v in inputs.items()} | |
def get_subsets_for_db(subsets, data_ids, all_docs): | |
subsets = [data_ids[ss_name] for ss_name in subsets] | |
subsets = [x for ss in subsets for x in ss] | |
all_docs_db = {k: v for k, v in all_docs.items() | |
if v['id'] in subsets} | |
unique_refs = set([ref for doc in all_docs_db.values() | |
for ref, text in doc['added_refs'].items()]) | |
db_data = {ref: text for doc in all_docs_db.values() | |
for ref, text in doc['added_refs'].items() if ref in unique_refs} | |
return db_data | |
def get_subsets_for_qa(subsets, data_ids, all_docs): | |
subsets = [data_ids[ss_name] for ss_name in subsets] | |
subsets = [x for ss in subsets for x in ss] | |
all_docs_qa = {k: v for k, v in all_docs.items() | |
if v['id'] in subsets} | |
return all_docs_qa | |
def filter_db_data_types(text_parts, db_data_in): | |
filtered_db_data = {} | |
db_data = copy.deepcopy(db_data_in) | |
for ref, text in db_data.items(): | |
if any([True for x in text_parts if x in ref]): | |
filtered_db_data[ref] = text | |
return filtered_db_data | |
def filter_qa_data_types(text_parts, all_docs_in): | |
filtered_all_docs = {} | |
all_docs = copy.deepcopy(all_docs_in) | |
for doc_key, doc in all_docs.items(): | |
if not len(doc['added_refs']): | |
filtered_all_docs[doc_key] = doc | |
continue | |
filtered_refs = {} | |
for ref, text in doc['added_refs'].items(): | |
if any([True for x in text_parts if x in ref]): | |
filtered_refs[ref] = text | |
filtered_all_docs[doc_key] = doc | |
filtered_all_docs[doc_key]['added_refs'] = filtered_refs | |
return filtered_all_docs | |
def db_tokenization(filtered_db_data, tokenizer, max_len=510): | |
index_keys = {} | |
index_toks = {} | |
for key_idx, (ref, text) in enumerate(tqdm(filtered_db_data.items(), | |
desc="Tokenizing DB refs")): | |
index_keys[key_idx] = ref | |
text = "passage: " + text | |
index_toks[key_idx] = tokenizer(text, return_tensors="pt", | |
padding='max_length', truncation=True, | |
max_length=max_len) | |
return index_keys, index_toks | |
def qa_tokenization(all_docs_qa, tokenizer, max_len=510): | |
ss_docs = [] | |
for doc in tqdm(all_docs_qa.values(), desc="Tokenizing QA docs"): | |
text = doc['title'] + '\n' + doc['question'] | |
text = "query: " + text | |
text = tokenizer(text, return_tensors="pt", | |
padding='max_length', truncation=True, | |
max_length=max_len) | |
ss_docs.append([text, list(doc['added_refs'].keys())]) | |
val_questions = [x[0] for x in ss_docs] | |
val_refs = {idx: x[1] for idx, x in enumerate(ss_docs)} | |
return val_questions, val_refs | |
def query_tokenization(text, tokenizer, max_len=510): | |
text = "query: " + text | |
text = tokenizer(text, return_tensors="pt", | |
padding='max_length', truncation=True, | |
max_length=max_len) | |
return text | |
def query_embed_extraction(tokens, model, do_normalization=True): | |
model.eval() | |
device = model.device | |
with torch.no_grad(): | |
with autocast(): | |
inputs = {k: v[:, :].to(device) for k, v in tokens.items()} | |
outputs = model(**inputs) | |
embedding = outputs.last_hidden_state[:, 0].cpu() | |
if do_normalization: | |
embedding = F.normalize(embedding, dim=-1) | |
return embedding.numpy() | |
def extract_text_embeddings(index_toks, val_questions, model, | |
do_normalization=True, faiss_batch_size=16): | |
faiss_dataset = FaissDocsDataset(list(index_toks.values())) | |
db_data_loader = DataLoader(faiss_dataset, batch_size=faiss_batch_size) | |
ss_val_dataset = FaissDocsDataset(val_questions) | |
qu_data_loader = DataLoader(ss_val_dataset, batch_size=faiss_batch_size) | |
model.eval() | |
device = model.device | |
docs_embeds = [] | |
questions_embeds = [] | |
with torch.no_grad(): | |
for batch in tqdm(db_data_loader, desc="db_embeds_extraction"): | |
with autocast(): | |
outputs = model(**preprocess_inputs(batch, device)) | |
docs_embeds.extend(outputs.last_hidden_state[:, 0].cpu()) | |
for batch in tqdm(qu_data_loader, desc="qu_embeds_extraction"): | |
with autocast(): | |
outputs = model(**preprocess_inputs(batch, device)) | |
questions_embeds.extend(outputs.last_hidden_state[:, 0].cpu()) | |
docs_embeds_faiss = [torch.unsqueeze(x, 0) for x in docs_embeds] | |
docs_embeds_faiss = | |
questions_embeds_faiss = [torch.unsqueeze(x, 0) for x in questions_embeds] | |
questions_embeds_faiss = | |
if do_normalization: | |
docs_embeds_faiss = F.normalize(docs_embeds_faiss, dim=-1) | |
questions_embeds_faiss = F.normalize(questions_embeds_faiss, dim=-1) | |
return docs_embeds_faiss.numpy(), questions_embeds_faiss.numpy() | |
def filter_ref_parts(ref_dict, filter_parts): | |
filtered_dict = {} | |
for k, refs in ref_dict.items(): | |
filtered_refs = [" ".join([x for x in ref.split() if not any([True for part in filter_parts if part in x])]) | |
for ref in refs] | |
filtered_dict[k] = filtered_refs | |
return filtered_dict | |
def get_final_metrics(pred, true, categories, top_k_values, | |
metrics_func, metrics_func_params): | |
metrics = {} | |
for top_k in top_k_values: | |
ctg_metrics = {} | |
for ctg in categories: | |
ctg_pred, ctg_true = get_exact_ctg_data(pred, true, ctg) | |
metrics_at_k = metrics_func(ctg_pred, ctg_true, top_k, **metrics_func_params) | |
for mk in metrics_at_k.keys(): | |
metrics_at_k[mk] = round(metrics_at_k[mk] * 100, 6) | |
ctg_metrics[ctg] = metrics_at_k | |
metrics[top_k] = ctg_metrics | |
return metrics | |
def get_exact_ctg_data(pred_in, true_in, ctg): | |
if ctg == "all": | |
return pred_in, true_in | |
out_pred = {} | |
out_true = {} | |
for idx, (pred, true) in zip(true_in.keys(), zip(pred_in.values(), true_in.values())): | |
ctg_refs_true = [ref for ref in true if ctg in ref] | |
ctg_refs_pred = [ref for ref in pred if ctg in ref] | |
out_true[idx] = ctg_refs_true | |
out_pred[idx] = ctg_refs_pred | |
return out_pred, out_true | |
def print_metrics(metrics, ref_categories): | |
first_ctg = metrics[list(metrics.keys())[0]] | |
metric_tags = list(first_ctg[list(first_ctg.keys())[0]].keys()) | |
metric_tags = [x.split('@')[0] for x in metric_tags] | |
print('\t', *metric_tags, sep='\t') | |
for ctg, ctg_short in ref_categories.items(): | |
for top_k, vals in metrics.items(): | |
for ctg_tag, ctg_val in vals.items(): | |
if ctg_tag == ctg: | |
ctg_vals_str = ["{:.3f}".format(x).zfill(6) for x in ctg_val.values()] | |
print(f"{ctg_short}@{top_k}", *ctg_vals_str, sep='\t\t') | |