|
import transformers |
|
import streamlit as st |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
|
|
st.title(""" |
|
History Mystery |
|
""") |
|
|
|
temperature = st.slider("Градус дичи", 1, 20, 1) |
|
max_len = st.slider(" Длина сгенерированного отрывка", 40, 120, 2) |
|
|
|
|
|
|
|
|
|
|
|
@st.cache |
|
def load_gpt(): |
|
model_GPT = GPT2LMHeadModel.from_pretrained( |
|
'sberbank-ai/rugpt3small_based_on_gpt2', |
|
output_attentions = False, |
|
output_hidden_states = False, |
|
) |
|
tokenizer_GPT = GPT2Tokenizer.from_pretrained( |
|
'sberbank-ai/rugpt3small_based_on_gpt2', |
|
output_attentions = False, |
|
output_hidden_states = False, |
|
) |
|
model_GPT.load_state_dict(torch.load('model_history.pt', map_location=torch.device('cpu'))) |
|
return model_GPT, tokenizer_GPT |
|
|
|
|
|
|
|
|
|
def generate_text(model_GPT, tokenizer_GPT, prompt): |
|
|
|
input_ids = tokenizer_GPT.encode(prompt, return_tensors='pt') |
|
|
|
|
|
output = model_GPT.generate(input_ids=input_ids, max_length=70, num_beams=5, do_sample=True, |
|
temperature=1., top_k=50, top_p=0.6, no_repeat_ngram_size=3, |
|
num_return_sequences=3) |
|
|
|
|
|
generated_text = tokenizer_GPT.decode(output[0], skip_special_tokens=True) |
|
|
|
return generated_text |
|
|
|
|
|
def main(): |
|
model_GPT, tokenizer_GPT = load_gpt() |
|
st.write(""" |
|
# GPT-3 генерация текста |
|
""") |
|
|
|
|
|
prompt = st.text_area("Какую фразу нужно продолжить:", value="В средние века на руси") |
|
|
|
|
|
|
|
|
|
generate_button = st.button("За работу!") |
|
|
|
if generate_button: |
|
|
|
generated_text = generate_text(model_GPT, tokenizer_GPT, prompt) |
|
st.subheader("Продолжение:") |
|
st.write(generated_text) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|