Spaces:
Running
Running
import os | |
import warnings | |
import requests | |
import torch | |
import streamlit as st | |
from streamlit_lottie import st_lottie | |
from transformers import AutoTokenizer, AutoModelWithLMHead | |
warnings.filterwarnings("ignore") | |
st.set_page_config(layout='centered', page_title='GPT2-Horoscopes') | |
def load_lottieurl(url: str): | |
# https://github.com/tylerjrichards/streamlit_goodreads_app/blob/master/books.py | |
r = requests.get(url) | |
if r.status_code != 200: | |
return None | |
return r.json() | |
lottie_book = load_lottieurl('https://assets2.lottiefiles.com/packages/lf20_WL3aE7.json') | |
st_lottie(lottie_book, speed=1, height=200, key="initial") | |
st.markdown('# GPT2-Horoscopes!') | |
st.markdown(""" | |
Hello! This lovely app lets GPT-2 write awesome horoscopes for you. All you need to do | |
is select your sign and choose the horoscope category :) | |
""") | |
st.markdown(""" | |
*If you are interested in the fine-tuned model, you can visit the [Model Hub](https://huggingface.co/shahp7575/gpt2-horoscopes) or | |
my [GitHub Repo](https://github.com/shahp7575/gpt2-horoscopes).* | |
""") | |
def download_model(): | |
tokenizer = AutoTokenizer.from_pretrained('shahp7575/gpt2-horoscopes') | |
model = AutoModelWithLMHead.from_pretrained('shahp7575/gpt2-horoscopes') | |
return model, tokenizer | |
model, tokenizer = download_model() | |
def make_prompt(category): | |
return f"<|category|> {category} <|horoscope|>" | |
def generate(prompt, model, tokenizer, temperature, num_outputs, top_k): | |
sample_outputs = model.generate(prompt, | |
#bos_token_id=random.randint(1,30000), | |
do_sample=True, | |
top_k=top_k, | |
max_length = 300, | |
top_p=0.95, | |
temperature=temperature, | |
num_return_sequences=num_outputs) | |
return sample_outputs | |
with st.beta_container(): | |
horoscope = st.selectbox("Choose Your Sign: ", ('Aquarius', 'Pisces', 'Aries', | |
'Taurus', 'Gemini', 'Cancer', | |
'Leo', 'Virgo', 'Libra', | |
'Scorpio', 'Sagittarius', 'Capricorn'), index=0) | |
choice = st.selectbox("Choose Category:", ('general', 'career', 'love', 'wellness', 'birthday'), | |
index=0, ) | |
temp_slider = st.slider("Temperature (Higher Value = More randomness)", min_value=0.01, max_value=1.0, value=0.95) | |
if st.button('Generate Horoscopes!'): | |
prompt = make_prompt(choice) | |
prompt_encoded = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0) | |
with st.spinner('Generating...'): | |
sample_output = generate(prompt_encoded, model, tokenizer, temperature=temp_slider, num_outputs=1, top_k=40) | |
final_out = tokenizer.decode(sample_output[0], skip_special_tokens=True) | |
st.write(final_out[len(choice)+2:]) | |
else: pass | |