import transformers import streamlit as st from transformers import GPT2LMHeadModel, GPT2Tokenizer import numpy as np from PIL import Image import torch st.title(""" History Mystery """) # Добавление слайдера temperature = st.slider("Градус дичи", 1, 20, 1) max_len = st.slider(" Длина сгенерированного отрывка", 40, 120, 2) # Загрузка модели и токенизатора # model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2') # tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2') # #Задаем класс модели (уже в streamlit/tg_bot) @st.cache def load_gpt(): model_GPT = GPT2LMHeadModel.from_pretrained( 'sberbank-ai/rugpt3small_based_on_gpt2', output_attentions = False, output_hidden_states = False, ) tokenizer_GPT = GPT2Tokenizer.from_pretrained( 'sberbank-ai/rugpt3small_based_on_gpt2', output_attentions = False, output_hidden_states = False, ) model_GPT.load_state_dict(torch.load('model_history.pt', map_location=torch.device('cpu'))) return model_GPT, tokenizer_GPT # # Вешаем сохраненные веса на нашу модель # Функция для генерации текста def generate_text(model_GPT, tokenizer_GPT, prompt): # Преобразование входной строки в токены input_ids = tokenizer_GPT.encode(prompt, return_tensors='pt') # Генерация текста output = model_GPT.generate(input_ids=input_ids, max_length=70, num_beams=5, do_sample=True, temperature=1., top_k=50, top_p=0.6, no_repeat_ngram_size=3, num_return_sequences=3) # Декодирование сгенерированного текста generated_text = tokenizer_GPT.decode(output[0], skip_special_tokens=True) return generated_text # Streamlit приложение def main(): model_GPT, tokenizer_GPT = load_gpt() st.write(""" # GPT-3 генерация текста """) # Ввод строки пользователем prompt = st.text_area("Какую фразу нужно продолжить:", value="В средние века на руси") # # Генерация текста по введенной строке # generated_text = generate_text(prompt) # Создание кнопки "Сгенерировать" generate_button = st.button("За работу!") # Обработка события нажатия кнопки if generate_button: # Вывод сгенерированного текста generated_text = generate_text(model_GPT, tokenizer_GPT, prompt) st.subheader("Продолжение:") st.write(generated_text) if __name__ == "__main__": main()