Spaces:
Runtime error
Runtime error
import copy | |
import torch | |
import torch.nn.functional as F | |
from transformers import GPTNeoForCausalLM, AutoTokenizer, pipeline | |
import numpy as np | |
from tqdm import trange | |
import streamlit as st | |
def set_seed(seed): | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
try: | |
torch.cuda.manual_seed_all(seed) | |
except: | |
pass | |
MODEL_CLASSES = { | |
'lcw99/gpt-neo-1.3B-ko-fp16': (GPTNeoForCausalLM, AutoTokenizer), | |
'lcw99/gpt-neo-1.3B-ko': (GPTNeoForCausalLM, AutoTokenizer), | |
} | |
# @st.cache | |
def load_model(model_name): | |
model_class, tokenizer_class = MODEL_CLASSES[model_name] | |
model = model_class.from_pretrained( | |
model_name, | |
torch_dtype=torch.float32, | |
low_cpu_mem_usage=True, | |
use_cache=False, | |
gradient_checkpointing=False, | |
device_map='auto', | |
#revision="float16", | |
#load_in_8bit=True | |
) | |
tokenizer = tokenizer_class.from_pretrained(model_name) | |
model.to(device) | |
model.eval() | |
return model, tokenizer | |
if __name__ == "__main__": | |
# Selectors | |
model_name = st.sidebar.selectbox("Model", list(MODEL_CLASSES.keys())) | |
length = st.sidebar.slider("Length", 50, 2048, 100) | |
temperature = st.sidebar.slider("Temperature", 0.0, 3.0, 0.8) | |
top_k = st.sidebar.slider("Top K", 0, 10, 0) | |
top_p = st.sidebar.slider("Top P", 0.0, 1.0, 0.7) | |
st.title("Text generation with GPT-neo Korean") | |
raw_text = st.text_input("์์ํ๋ ๋ฌธ์ฅ์ ์ ๋ ฅํ๊ณ ์ํฐ๋ฅผ ์น์ธ์.", placeholder="๊ณจํ๋ฅผ ์ ์น๊ณ ์ถ๋ค๋ฉด,", | |
key="text_input1") | |
if raw_text: | |
st.write(raw_text) | |
with st.spinner(f'loading model({model_name}) wait...'): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model, tokenizer = load_model(model_name) | |
# making a copy so streamlit doesn't reload models | |
# model = copy.deepcopy(model) | |
# tokenizer = copy.deepcopy(tokenizer) | |
if False: | |
text_generation = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
) | |
with st.spinner(f'Generating text wait...'): | |
# generated = text_generation( | |
# raw_text, | |
# max_length=length, | |
# do_sample=True, | |
# min_length=100, | |
# num_return_sequences=3, | |
# top_p=top_p, | |
# top_k=top_k | |
# ) | |
# st.write(*generated) | |
encoded_input = tokenizer(raw_text, return_tensors='pt') | |
output_sequences = model.generate( | |
input_ids=encoded_input['input_ids'].to(device), | |
attention_mask=encoded_input['attention_mask'].to(device), | |
max_length=length, | |
do_sample=True, | |
min_length=20, | |
top_p=top_p, | |
top_k=top_k | |
) | |
generated = tokenizer.decode(output_sequences[0], skip_special_tokens=True) | |
#print(generated) | |
st.write(generated) | |