File size: 4,574 Bytes
f3c6b77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

from tokenize import tokenize
import torch
import torch.nn.functional as F

from transformers import BertTokenizer, GPT2LMHeadModel

def replace_special_tokens(tokens:list, special_token_dict:dict)->list:
    replaced_tokens:list = []

    for token in tokens:
        if token in special_token_dict.keys():
            replaced_tokens.append( special_token_dict[token] )
        else:
            replaced_tokens.append( token )

    return replaced_tokens

def top_k_top_p_filtering(logits, top_k:int=0, top_p:float=0.0, filter_value:float=-float('Inf')):
    top_k = min( top_k, logits.size(-1) )

    if top_k > 0:
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value

    return logits

class ChatBot():

    def get_chat_bot(checkpoint:str, vocab_path:str = None, special_token_list:list = [])->object:
        tokenizer = ChatBot.get_tokenizer(checkpoint, vocab_path, special_token_list)
        model = GPT2LMHeadModel.from_pretrained(checkpoint)

        return ChatBot(tokenizer, model)

    def get_tokenizer(checkpoint:str, vocab_path:str = None, special_token_list:list = [])->object:
        if vocab_path is None:
            tokenizer = BertTokenizer.from_pretrained(checkpoint, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
        else:
            tokenizer = BertTokenizer(vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
        tokenizer.add_special_tokens( {'additional_special_tokens':special_token_list} )

        return tokenizer

    def __init__(self, tokenizer:object, model:object)->None:
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = tokenizer
        self.model = model
        self.model.to(self.device)

    def convert_ids_to_tokens(self, ids:list):
        return self.tokenizer.convert_ids_to_tokens(ids)

    def convert_ids_to_text(self, ids):
        return "".join( self.convert_ids_to_tokens(ids) )

    def convert_text_to_ids(self, text, add_special_tokens=False):
        return self.tokenizer.encode(text, add_special_tokens=add_special_tokens)

    def get_prediction(self, input_tensor, input_ids, repetition_penalty, temperature, top_k, top_p):
        self.model.eval()

        generated_ids = []
        for _ in range(64):
            output_pt = self.model(input_tensor)

            next_token_logits = output_pt.logits[0, -1, :]
            for id in set(input_ids):
                if id != self.tokenizer.sep_token_id:
                    next_token_logits[id] /= repetition_penalty
            for id in set(generated_ids):
                next_token_logits[id] /= repetition_penalty
            next_token_logits = next_token_logits / temperature
            next_token_logits[self.tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            next_token = torch.multinomial( F.softmax(filtered_logits, dim=-1), num_samples=1 )

            if next_token == self.tokenizer.sep_token_id:
                break

            input_tensor = torch.cat( (input_tensor, next_token.unsqueeze(0)), dim=1 )
            generated_ids.append( next_token.item() )

        return generated_ids

    def chat(self:object, text:str, history:list, role_card:dict={}) -> str:
        text_ids = self.tokenizer.encode(text, add_special_tokens=False)
        history.append(text_ids)
        input_ids = [self.tokenizer.cls_token_id]
        for history_utr in history[-50:]:
            input_ids.extend(history_utr)
            input_ids.append(self.tokenizer.sep_token_id)
        input_tensor = torch.tensor(input_ids).to(self.device).unsqueeze(0)
        generated_ids = self.get_prediction(input_tensor, input_ids, repetition_penalty=1.2, temperature=0.73, top_k=10, top_p=0.7)

        history.append(generated_ids)

        generated_tokens = replace_special_tokens( self.convert_ids_to_tokens(generated_ids), role_card )

        return "".join(generated_tokens), history