File size: 2,928 Bytes
507c1f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import transformers
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
from PIL import Image
import torch

st.title("""
 History Mystery
 """)
# Добавление слайдера
temperature = st.slider("Градус дичи", 1, 20, 1)
max_len = st.slider(" Длина сгенерированного отрывка", 40, 120, 2)
# Загрузка модели и токенизатора
# model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
# tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
# #Задаем класс модели (уже в streamlit/tg_bot)

@st.cache
def load_gpt():
    model_GPT = GPT2LMHeadModel.from_pretrained(
     'sberbank-ai/rugpt3small_based_on_gpt2',
     output_attentions = False,
     output_hidden_states = False,
    )
    tokenizer_GPT = GPT2Tokenizer.from_pretrained(
        'sberbank-ai/rugpt3small_based_on_gpt2',
        output_attentions = False,
        output_hidden_states = False,
        )
    model_GPT.load_state_dict(torch.load('model_history.pt', map_location=torch.device('cpu')))
    return model_GPT, tokenizer_GPT

# # Вешаем сохраненные веса на нашу модель

# Функция для генерации текста
def generate_text(model_GPT, tokenizer_GPT, prompt):
    # Преобразование входной строки в токены
    input_ids = tokenizer_GPT.encode(prompt, return_tensors='pt')

    # Генерация текста
    output = model_GPT.generate(input_ids=input_ids, max_length=70, num_beams=5, do_sample=True,
                            temperature=1., top_k=50, top_p=0.6, no_repeat_ngram_size=3,
                            num_return_sequences=3)

    # Декодирование сгенерированного текста
    generated_text = tokenizer_GPT.decode(output[0], skip_special_tokens=True)

    return generated_text

# Streamlit приложение
def main():
    model_GPT, tokenizer_GPT = load_gpt()
    st.write("""
    # GPT-3 генерация текста
    """)

    # Ввод строки пользователем
    prompt = st.text_area("Какую фразу нужно продолжить:", value="В средние века на руси")

    # # Генерация текста по введенной строке
    # generated_text = generate_text(prompt)
    # Создание кнопки "Сгенерировать"
    generate_button = st.button("За работу!")
    # Обработка события нажатия кнопки
    if generate_button:
    # Вывод сгенерированного текста
        generated_text = generate_text(model_GPT, tokenizer_GPT, prompt)
        st.subheader("Продолжение:")
        st.write(generated_text)



if __name__ == "__main__":
    main()