Spaces:
Build error
Build error
# Copyright (c) 2022, Lawrence Livermore National Security, LLC. | |
# All rights reserved. | |
# See the top-level LICENSE and NOTICE files for details. | |
# LLNL-CODE-838964 | |
# SPDX-License-Identifier: Apache-2.0-with-LLVM-exception | |
import sys | |
import json | |
from math import ceil | |
import torch | |
import numpy as np | |
from torch import tensor | |
from torch.nn.functional import log_softmax | |
from torch.distributions.categorical import Categorical | |
from transformers import T5Tokenizer, T5ForConditionalGeneration | |
# load UnifiedQA onto device | |
model_name = "allenai/unifiedqa-v2-t5-large-1363200" | |
tokenizer = T5Tokenizer.from_pretrained(model_name) | |
model = T5ForConditionalGeneration.from_pretrained(model_name) | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
model.to(device) | |
def get_inputs(contexts_json, ranked_contexts_json): | |
with open(contexts_json, 'rt') as fp: | |
contexts = json.load(fp) | |
with open(ranked_contexts_json, 'rt') as fp: | |
ranked_contexts = json.load(fp) | |
question_id = list(ranked_contexts.keys())[0] | |
# assert len(questions) == 1, f'JSON should only have 1 question but found {len(questions)}: {questions}' | |
question = ranked_contexts[question_id]['text'] | |
context_ids_sorted = ranked_contexts[question_id]['ranks'] | |
context_scores = ranked_contexts[question_id]['scores'] | |
contexts = [contexts[context_id]['text'] for context_id in context_ids_sorted] | |
# returns the question (str) and its contexts (sequence) | |
return question, contexts, context_scores | |
def get_tokens(text, tokenizer, max_tokens): | |
return tokenizer.encode_plus(text, return_tensors='pt', max_length=max_tokens, padding='max_length', truncation=True)['input_ids'] | |
def prepare_inputs(tokenizer, max_tokens, context, question): | |
input_str = f'{question} \\n {context}' | |
inputs = get_tokens(input_str, tokenizer, max_tokens) | |
return inputs | |
def get_outputs(model, tokenizer, input_tokens, max_tokens): | |
output_dict = model.generate(input_tokens, output_scores=True, return_dict_in_generate=True, **{'max_length': max_tokens}) | |
pred_tokens = output_dict['sequences'].squeeze().tolist() | |
# initialize metrics | |
logit_entropy = [] | |
sentence_probs = [] | |
# accumulate metrics over logit_sequence | |
logit_sequence = output_dict['scores'][:-1] # discard end token | |
for logit in logit_sequence: | |
log_probs = log_softmax(logit, dim=-1) | |
# update metrics | |
logit_entropy.append(Categorical(log_probs.exp()).entropy()) | |
sentence_probs.append(log_probs.max()) | |
# finish metrics calculation | |
logit_entropy = tensor(logit_entropy) | |
sentence_probs = tensor(sentence_probs) | |
entropy = logit_entropy.mean() | |
sentence_std = 0 if len(logit_sequence) == 1 else sentence_probs.std(unbiased=True).exp() | |
# use entropy * sentence_std as uncertainty | |
uncertainty = (entropy * sentence_std).item() | |
# convert answer tokens to str | |
pred_str = tokenizer.decode(pred_tokens, skip_special_tokens=True).lower() | |
return pred_str, uncertainty | |
# k_percent: percentage of contexts to use, cannot be less than min_k or greater than max_k | |
# min_k: minimum number of contexts to use, if possible. Setting this too small reduces recall | |
# max_k: maximum number of contexts to use. Setting this too big reduces precision | |
# recommended uncertainty thresholds are 2,3,4, and 5. The lower the threshold, the more aggressive the filtering | |
def run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=0.1, min_k=10, max_k=25, uncertainty_thresh=3): | |
k = min(max(ceil(k_percent * len(contexts)), min_k), max_k) | |
contexts = contexts[:k] | |
context_scores = context_scores[:k] | |
# iterate through top-k contexts | |
answers = [] | |
uncertainty = [] | |
for context in contexts: | |
input_tokens = prepare_inputs(tokenizer, 512, context, question).to(device) | |
pred_str, uncertainty_1 = get_outputs(model, tokenizer, input_tokens, 512) | |
answers.append(pred_str) | |
uncertainty.append(uncertainty_1) | |
# contexts = np.array(contexts) | |
# answers = np.array(answers) | |
# uncertainty = np.array(uncertainty) | |
# sort by uncertainty, ascending order | |
# order = np.argsort(uncertainty) | |
# contexts = contexts[order] | |
# answers = answers[order] | |
# uncertainty = uncertainty[order] | |
# init lists for threshed answers | |
# weak_contexts = [] | |
# weak_answers = [] | |
# weak_uncertainty = [] | |
# filter by uncertainty | |
# if len(answers) > min_k: | |
# weak = np.argwhere(uncertainty > uncertainty_thresh) # exceeds threshold | |
# weak_contexts = contexts[weak].tolist() | |
# weak_answers = answers[weak].tolist() | |
# weak_uncertainty = uncertainty[weak].tolist() | |
# strong = np.argwhere(uncertainty <= uncertainty_thresh) # within threshold | |
# contexts = contexts[strong] | |
# answers = answers[strong] | |
# uncertainty = uncertainty[strong] | |
# contexts = contexts.tolist() | |
# answers = answers.tolist() | |
# uncertainty = uncertainty.tolist() | |
# return {'contexts': contexts, 'answers': answers, 'uncertainty': uncertainty}, \ | |
# {'contexts': weak_contexts, 'answers': weak_answers, 'uncertainty': weak_uncertainty} | |
return {'contexts': contexts, 'answers': answers, 'context_scores':context_scores, 'uncertainty': uncertainty} | |
def get_qa_results(contexts_json, ranked_contexts_json, topk): | |
# extract question and contexts from json | |
question, contexts, context_scores = get_inputs(contexts_json, ranked_contexts_json) | |
# infer answers | |
with torch.inference_mode(True): | |
# strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent) | |
qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk) | |
return qa_results | |
def get_qa_results_in_memory(contexts, ranked_contexts, topk): | |
question_id = list(ranked_contexts.keys())[0] | |
# assert len(questions) == 1, f'JSON should only have 1 question but found {len(questions)}: {questions}' | |
question = ranked_contexts[question_id]['text'] | |
context_ids_sorted = ranked_contexts[question_id]['ranks'] | |
context_scores = ranked_contexts[question_id]['scores'] | |
contexts = [contexts[context_id]['text'] for context_id in context_ids_sorted] | |
# infer answers | |
with torch.inference_mode(True): | |
# strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent) | |
qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk) | |
return qa_results | |
def load_custom_model(finetuned_model_path): | |
global tokenizer | |
global model | |
# load UnifiedQA onto device | |
tokenizer = T5Tokenizer.from_pretrained(finetuned_model_path) | |
model = T5ForConditionalGeneration.from_pretrained(finetuned_model_path) | |
model.to(device) | |
def get_qa_results_in_memory_finetuned_unifiedqa(question, context_scores, contexts, topk): | |
# infer answers | |
with torch.inference_mode(True): | |
# strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent) | |
qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk) | |
return qa_results | |