File size: 1,697 Bytes
7c0199b
 
 
9a869b4
5a50f61
1659ccb
7c0199b
 
 
1659ccb
7c0199b
 
 
1659ccb
7c0199b
 
 
 
b24eb89
 
7c0199b
b24eb89
 
7c0199b
 
 
1659ccb
7c0199b
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import transformers
import string

model_names = ['microsoft/GODEL-v1_1-large-seq2seq',
               'facebook/blenderbot-1B-distill',
               'facebook/blenderbot_small-90M']

tokenizers = [transformers.AutoTokenizer.from_pretrained(model_names[0]),
              transformers.BlenderbotTokenizer.from_pretrained(model_names[1]),
              transformers.BlenderbotSmallTokenizer.from_pretrained(model_names[2])]

model = [transformers.AutoModelForSeq2SeqLM.from_pretrained(model_names[0]),
         transformers.BlenderbotForConditionalGeneration.from_pretrained(model_names[1]),
         transformers.BlenderbotSmallForConditionalGeneration.from_pretrained(model_names[2])]


def generate_text(text, context, model_name, model, tokenizer, minimum=15, maximum=300):
    if 'GODEL' in model_name:
        text = f'Instruction: you need to response discreetly. [CONTEXT] {context} {text}'
        text.replace('\t', ' EOS ')
    else:
        text = f'{context} {text}'
        text = text.replace('\t', '\n')
    input_ids = tokenizer(text, return_tensors="pt").input_ids
    outputs = model.generate(input_ids, max_new_tokens=maximum, min_new_tokens=minimum, top_p=0.9, do_sample=True)
    output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return capitalization(output)


def capitalization(line):
    line, end = line[:-1], line[-1]
    for mark in '.?!':
        line = f'{mark} '.join([part.strip()[0].upper() +  part.strip()[1:] for part in line.split(mark) if len(part) > 1])
    line = ' '.join([word.capitalize() if word.translate(str.maketrans('', '', string.punctuation)) == 'i'
                    else word for word in line.split()])
    return line + end