Spaces:
Runtime error
Runtime error
File size: 1,697 Bytes
7c0199b 9a869b4 5a50f61 1659ccb 7c0199b 1659ccb 7c0199b 1659ccb 7c0199b b24eb89 7c0199b b24eb89 7c0199b 1659ccb 7c0199b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import transformers
import string
model_names = ['microsoft/GODEL-v1_1-large-seq2seq',
'facebook/blenderbot-1B-distill',
'facebook/blenderbot_small-90M']
tokenizers = [transformers.AutoTokenizer.from_pretrained(model_names[0]),
transformers.BlenderbotTokenizer.from_pretrained(model_names[1]),
transformers.BlenderbotSmallTokenizer.from_pretrained(model_names[2])]
model = [transformers.AutoModelForSeq2SeqLM.from_pretrained(model_names[0]),
transformers.BlenderbotForConditionalGeneration.from_pretrained(model_names[1]),
transformers.BlenderbotSmallForConditionalGeneration.from_pretrained(model_names[2])]
def generate_text(text, context, model_name, model, tokenizer, minimum=15, maximum=300):
if 'GODEL' in model_name:
text = f'Instruction: you need to response discreetly. [CONTEXT] {context} {text}'
text.replace('\t', ' EOS ')
else:
text = f'{context} {text}'
text = text.replace('\t', '\n')
input_ids = tokenizer(text, return_tensors="pt").input_ids
outputs = model.generate(input_ids, max_new_tokens=maximum, min_new_tokens=minimum, top_p=0.9, do_sample=True)
output = tokenizer.decode(outputs[0], skip_special_tokens=True)
return capitalization(output)
def capitalization(line):
line, end = line[:-1], line[-1]
for mark in '.?!':
line = f'{mark} '.join([part.strip()[0].upper() + part.strip()[1:] for part in line.split(mark) if len(part) > 1])
line = ' '.join([word.capitalize() if word.translate(str.maketrans('', '', string.punctuation)) == 'i'
else word for word in line.split()])
return line + end |