Spaces:
Sleeping
Sleeping
from tokenize import tokenize | |
import torch | |
import torch.nn.functional as F | |
from transformers import BertTokenizer, GPT2LMHeadModel | |
def replace_special_tokens(tokens:list, special_token_dict:dict)->list: | |
replaced_tokens:list = [] | |
for token in tokens: | |
if token in special_token_dict.keys(): | |
replaced_tokens.append( special_token_dict[token] ) | |
else: | |
replaced_tokens.append( token ) | |
return replaced_tokens | |
def top_k_top_p_filtering(logits, top_k:int=0, top_p:float=0.0, filter_value:float=-float('Inf')): | |
top_k = min( top_k, logits.size(-1) ) | |
if top_k > 0: | |
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |
logits[indices_to_remove] = filter_value | |
if top_p > 0.0: | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
sorted_indices_to_remove = cumulative_probs > top_p | |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
sorted_indices_to_remove[..., 0] = 0 | |
indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
logits[indices_to_remove] = filter_value | |
return logits | |
class ChatBot(): | |
def get_chat_bot(checkpoint:str, vocab_path:str = None, special_token_list:list = [])->object: | |
tokenizer = ChatBot.get_tokenizer(checkpoint, vocab_path, special_token_list) | |
model = GPT2LMHeadModel.from_pretrained(checkpoint) | |
return ChatBot(tokenizer, model) | |
def get_tokenizer(checkpoint:str, vocab_path:str = None, special_token_list:list = [])->object: | |
if vocab_path is None: | |
tokenizer = BertTokenizer.from_pretrained(checkpoint, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]") | |
else: | |
tokenizer = BertTokenizer(vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]") | |
tokenizer.add_special_tokens( {'additional_special_tokens':special_token_list} ) | |
return tokenizer | |
def __init__(self, tokenizer:object, model:object)->None: | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.tokenizer = tokenizer | |
self.model = model | |
self.model.to(self.device) | |
def convert_ids_to_tokens(self, ids:list): | |
return self.tokenizer.convert_ids_to_tokens(ids) | |
def convert_ids_to_text(self, ids): | |
return "".join( self.convert_ids_to_tokens(ids) ) | |
def convert_text_to_ids(self, text, add_special_tokens=False): | |
return self.tokenizer.encode(text, add_special_tokens=add_special_tokens) | |
def get_prediction(self, input_tensor, input_ids, repetition_penalty, temperature, top_k, top_p): | |
self.model.eval() | |
generated_ids = [] | |
for _ in range(64): | |
output_pt = self.model(input_tensor) | |
next_token_logits = output_pt.logits[0, -1, :] | |
for id in set(input_ids): | |
if id != self.tokenizer.sep_token_id: | |
next_token_logits[id] /= repetition_penalty | |
for id in set(generated_ids): | |
next_token_logits[id] /= repetition_penalty | |
next_token_logits = next_token_logits / temperature | |
next_token_logits[self.tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf') | |
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) | |
next_token = torch.multinomial( F.softmax(filtered_logits, dim=-1), num_samples=1 ) | |
if next_token == self.tokenizer.sep_token_id: | |
break | |
input_tensor = torch.cat( (input_tensor, next_token.unsqueeze(0)), dim=1 ) | |
generated_ids.append( next_token.item() ) | |
return generated_ids | |
def chat(self:object, text:str, history:list, role_card:dict={}) -> str: | |
text_ids = self.tokenizer.encode(text, add_special_tokens=False) | |
history.append(text_ids) | |
input_ids = [self.tokenizer.cls_token_id] | |
for history_utr in history[-50:]: | |
input_ids.extend(history_utr) | |
input_ids.append(self.tokenizer.sep_token_id) | |
input_tensor = torch.tensor(input_ids).to(self.device).unsqueeze(0) | |
generated_ids = self.get_prediction(input_tensor, input_ids, repetition_penalty=1.2, temperature=0.73, top_k=10, top_p=0.7) | |
history.append(generated_ids) | |
generated_tokens = replace_special_tokens( self.convert_ids_to_tokens(generated_ids), role_card ) | |
return "".join(generated_tokens), history | |