Winnie / bot /chatbot.py
lewiswu1209's picture
Refactoring
f3c6b77
raw
history blame
4.57 kB
from tokenize import tokenize
import torch
import torch.nn.functional as F
from transformers import BertTokenizer, GPT2LMHeadModel
def replace_special_tokens(tokens:list, special_token_dict:dict)->list:
replaced_tokens:list = []
for token in tokens:
if token in special_token_dict.keys():
replaced_tokens.append( special_token_dict[token] )
else:
replaced_tokens.append( token )
return replaced_tokens
def top_k_top_p_filtering(logits, top_k:int=0, top_p:float=0.0, filter_value:float=-float('Inf')):
top_k = min( top_k, logits.size(-1) )
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
return logits
class ChatBot():
def get_chat_bot(checkpoint:str, vocab_path:str = None, special_token_list:list = [])->object:
tokenizer = ChatBot.get_tokenizer(checkpoint, vocab_path, special_token_list)
model = GPT2LMHeadModel.from_pretrained(checkpoint)
return ChatBot(tokenizer, model)
def get_tokenizer(checkpoint:str, vocab_path:str = None, special_token_list:list = [])->object:
if vocab_path is None:
tokenizer = BertTokenizer.from_pretrained(checkpoint, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
else:
tokenizer = BertTokenizer(vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
tokenizer.add_special_tokens( {'additional_special_tokens':special_token_list} )
return tokenizer
def __init__(self, tokenizer:object, model:object)->None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = tokenizer
self.model = model
self.model.to(self.device)
def convert_ids_to_tokens(self, ids:list):
return self.tokenizer.convert_ids_to_tokens(ids)
def convert_ids_to_text(self, ids):
return "".join( self.convert_ids_to_tokens(ids) )
def convert_text_to_ids(self, text, add_special_tokens=False):
return self.tokenizer.encode(text, add_special_tokens=add_special_tokens)
def get_prediction(self, input_tensor, input_ids, repetition_penalty, temperature, top_k, top_p):
self.model.eval()
generated_ids = []
for _ in range(64):
output_pt = self.model(input_tensor)
next_token_logits = output_pt.logits[0, -1, :]
for id in set(input_ids):
if id != self.tokenizer.sep_token_id:
next_token_logits[id] /= repetition_penalty
for id in set(generated_ids):
next_token_logits[id] /= repetition_penalty
next_token_logits = next_token_logits / temperature
next_token_logits[self.tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
next_token = torch.multinomial( F.softmax(filtered_logits, dim=-1), num_samples=1 )
if next_token == self.tokenizer.sep_token_id:
break
input_tensor = torch.cat( (input_tensor, next_token.unsqueeze(0)), dim=1 )
generated_ids.append( next_token.item() )
return generated_ids
def chat(self:object, text:str, history:list, role_card:dict={}) -> str:
text_ids = self.tokenizer.encode(text, add_special_tokens=False)
history.append(text_ids)
input_ids = [self.tokenizer.cls_token_id]
for history_utr in history[-50:]:
input_ids.extend(history_utr)
input_ids.append(self.tokenizer.sep_token_id)
input_tensor = torch.tensor(input_ids).to(self.device).unsqueeze(0)
generated_ids = self.get_prediction(input_tensor, input_ids, repetition_penalty=1.2, temperature=0.73, top_k=10, top_p=0.7)
history.append(generated_ids)
generated_tokens = replace_special_tokens( self.convert_ids_to_tokens(generated_ids), role_card )
return "".join(generated_tokens), history