from tokenize import tokenize import torch import torch.nn.functional as F from transformers import BertTokenizer, GPT2LMHeadModel def replace_special_tokens(tokens:list, special_token_dict:dict)->list: replaced_tokens:list = [] for token in tokens: if token in special_token_dict.keys(): replaced_tokens.append( special_token_dict[token] ) else: replaced_tokens.append( token ) return replaced_tokens def top_k_top_p_filtering(logits, top_k:int=0, top_p:float=0.0, filter_value:float=-float('Inf')): top_k = min( top_k, logits.size(-1) ) if top_k > 0: indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p > 0.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[indices_to_remove] = filter_value return logits class ChatBot(): def get_chat_bot(checkpoint:str, vocab_path:str = None, special_token_list:list = [])->object: tokenizer = ChatBot.get_tokenizer(checkpoint, vocab_path, special_token_list) model = GPT2LMHeadModel.from_pretrained(checkpoint) return ChatBot(tokenizer, model) def get_tokenizer(checkpoint:str, vocab_path:str = None, special_token_list:list = [])->object: if vocab_path is None: tokenizer = BertTokenizer.from_pretrained(checkpoint, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]") else: tokenizer = BertTokenizer(vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]") tokenizer.add_special_tokens( {'additional_special_tokens':special_token_list} ) return tokenizer def __init__(self, tokenizer:object, model:object)->None: self.device = "cuda" if torch.cuda.is_available() else "cpu" self.tokenizer = tokenizer self.model = model self.model.to(self.device) def convert_ids_to_tokens(self, ids:list): return self.tokenizer.convert_ids_to_tokens(ids) def convert_ids_to_text(self, ids): return "".join( self.convert_ids_to_tokens(ids) ) def convert_text_to_ids(self, text, add_special_tokens=False): return self.tokenizer.encode(text, add_special_tokens=add_special_tokens) def get_prediction(self, input_tensor, input_ids, repetition_penalty, temperature, top_k, top_p): self.model.eval() generated_ids = [] for _ in range(64): output_pt = self.model(input_tensor) next_token_logits = output_pt.logits[0, -1, :] for id in set(input_ids): if id != self.tokenizer.sep_token_id: next_token_logits[id] /= repetition_penalty for id in set(generated_ids): next_token_logits[id] /= repetition_penalty next_token_logits = next_token_logits / temperature next_token_logits[self.tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf') filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) next_token = torch.multinomial( F.softmax(filtered_logits, dim=-1), num_samples=1 ) if next_token == self.tokenizer.sep_token_id: break input_tensor = torch.cat( (input_tensor, next_token.unsqueeze(0)), dim=1 ) generated_ids.append( next_token.item() ) return generated_ids def chat(self:object, text:str, history:list, role_card:dict={}) -> str: text_ids = self.tokenizer.encode(text, add_special_tokens=False) history.append(text_ids) input_ids = [self.tokenizer.cls_token_id] for history_utr in history[-50:]: input_ids.extend(history_utr) input_ids.append(self.tokenizer.sep_token_id) input_tensor = torch.tensor(input_ids).to(self.device).unsqueeze(0) generated_ids = self.get_prediction(input_tensor, input_ids, repetition_penalty=1.2, temperature=0.73, top_k=10, top_p=0.7) history.append(generated_ids) generated_tokens = replace_special_tokens( self.convert_ids_to_tokens(generated_ids), role_card ) return "".join(generated_tokens), history