borrs / RAG.py
JosueElias's picture
Working on frontend.
b50e558
# os.environ['CUDA_VISIBLE_DEVICES'] ='0'
from dataset_with_embeddings import datasetx
from transformers import AutoModelForMultipleChoice
from transformers import AutoTokenizer
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from transformers import AutoTokenizer, AutoModel
import torch
import pandas as pd
GENERATOR_MODEL = "JosueElias/pipeline_generator_model"
GENERATOR_TOKENIZER = "JosueElias/pipeline_generator_tokenizer"
QUERY_MODEL = "JosueElias/pipeline_query_model"
QUERY_TOKENIZER = "JosueElias/pipeline_query_tokenizer"
DEVICE = "cpu" # cpu or cuda
class Pipeline:
#---- init class
def __init__(self):
self.model = AutoModelForMultipleChoice.from_pretrained(GENERATOR_MODEL)
self.tokenizer = AutoTokenizer.from_pretrained(GENERATOR_TOKENIZER)
self.semModel = AutoModel.from_pretrained(QUERY_MODEL)
self.semTokenizer = AutoTokenizer.from_pretrained(QUERY_TOKENIZER)
self.device = torch.device(DEVICE)
self.semModel.to(self.device)
self.model.to(self.device)
#---- utils functions
def convert_to_letter(self,a):
if a == 0:
return "A"
if a==1:
return "B"
if a==2:
return "C"
if a==3:
return "D"
if a==4:
return "E"
def filter_stopwords(self,example_sent):
stop_words = set(stopwords.words('english'))
word_tokens = word_tokenize(example_sent)
filtered_sentence = [w for w in word_tokens if not w.lower() in stop_words]
return " ".join(filtered_sentence)
def cls_pooling(self,model_output):
return model_output.pooler_output#last_hidden_state[:, 0]
def get_embeddings(self,text_list):
encoded_input = self.semTokenizer(
text_list, padding=True, truncation=True, return_tensors="pt",max_length =512
)
encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
model_output = self.semModel(**encoded_input)
return self.cls_pooling(model_output)
#---- retriever
def get_context_from_text(self,question):
question_embedding = self.get_embeddings([question]).cpu().detach().numpy()
scores, samples = datasetx.get_nearest_examples(
"embeddings", question_embedding, k=5
)
samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True)
contexts = ""
# aux_row = ""
for _, row in samples_df.iterrows():
contexts = contexts + f"=={row.section}== {row.text} "
# if aux_row =={row.title}:
# contexts = contexts + f"=={row.section}== {row.text}"
# else:
# contexts = contexts + f"==={row.title}=== =={row.section}== {row.text}"
# aux_row = {row.title}
return contexts
#---- generator
# [CLS] context #### question? [SEP] answer [SEP]
def create_tokens(self,quetion_and_options,context):
question = quetion_and_options["prompt"]
candidate1 = "#### "+question + " [SEP] "+quetion_and_options["A"]+ " [SEP]"
candidate2 = "#### "+question + " [SEP] "+quetion_and_options["B"]+ " [SEP]"
candidate3 = "#### "+question + " [SEP] "+quetion_and_options["C"]+ " [SEP]"
candidate4 = "#### "+question + " [SEP] "+quetion_and_options["D"]+ " [SEP]"
candidate5 = "#### "+question + " [SEP] "+quetion_and_options["E"]+ " [SEP]"
prompt = "[CLS]"+ context
inputs = self.tokenizer([
[prompt, candidate1],
[prompt, candidate2],
[prompt, candidate3],
[prompt, candidate4],
[prompt, candidate5]
], return_tensors="pt", padding=True,truncation="only_first",max_length =512,add_special_tokens=False)
labels = torch.tensor(0).unsqueeze(0)
return (inputs,labels)
def infer_answer(self,mi_tupla):
(inputs,labels) = mi_tupla
inputs = {k: v.to(self.device) for k, v in inputs.items()}
labels = labels.to(self.device)
outputs = self.model(**{k: v.unsqueeze(0) for k, v in inputs.items()}, labels=labels)
logits = outputs.logits
_, topk_indices = torch.topk(logits, k=3, dim=1)
#predicted_class = logits.argmax().item()
return topk_indices
#---- retriever + generator
def give_the_best_answer(self,dict_with_all_the_info):
a = self.get_context_from_text(dict_with_all_the_info["prompt"])
b = self.create_tokens(dict_with_all_the_info,a)
c = self.infer_answer(b)
d = self.convert_to_letter(int(c[0][0]))
#print("\nThe answer is ",)
return d
pipeline = Pipeline()