File size: 3,093 Bytes
418d903
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import warnings
import requests
import torch
import streamlit as st
from streamlit_lottie import st_lottie
from transformers import AutoTokenizer, AutoModelWithLMHead

warnings.filterwarnings("ignore")

st.set_page_config(layout='centered', page_title='GPT2-Horoscopes')

def load_lottieurl(url: str):
    # https://github.com/tylerjrichards/streamlit_goodreads_app/blob/master/books.py
    r = requests.get(url)
    if r.status_code != 200:
        return None
    return r.json()

lottie_book = load_lottieurl('https://assets2.lottiefiles.com/packages/lf20_WL3aE7.json')
st_lottie(lottie_book, speed=1, height=200, key="initial")

st.markdown('# GPT2-Horoscopes!')
st.markdown("""
Hello! This lovely app lets GPT-2 write awesome horoscopes for you. All you need to do
is select your sign and choose the horoscope category :)  
""")
st.markdown("""
*If you are interested in the fine-tuned model, you can visit the [Model Hub](https://huggingface.co/shahp7575/gpt2-horoscopes) or 
my [GitHub Repo](https://github.com/shahp7575/gpt2-horoscopes).*
""")


@st.cache(allow_output_mutation=True, max_entries=1)
def download_model():
    tokenizer = AutoTokenizer.from_pretrained('shahp7575/gpt2-horoscopes')
    model = AutoModelWithLMHead.from_pretrained('shahp7575/gpt2-horoscopes')
    return model, tokenizer
model, tokenizer = download_model()

def make_prompt(category):
    return f"<|category|> {category} <|horoscope|>"

def generate(prompt, model, tokenizer, temperature, num_outputs, top_k):

    sample_outputs = model.generate(prompt, 
                                    #bos_token_id=random.randint(1,30000),
                                    do_sample=True,   
                                    top_k=top_k, 
                                    max_length = 300,
                                    top_p=0.95,
                                    temperature=temperature,
                                    num_return_sequences=num_outputs)

    return sample_outputs
    
with st.beta_container():

    horoscope = st.selectbox("Choose Your Sign: ", ('Aquarius', 'Pisces', 'Aries',
                                                         'Taurus', 'Gemini', 'Cancer',
                                                         'Leo', 'Virgo', 'Libra', 
                                                          'Scorpio', 'Sagittarius', 'Capricorn'), index=0)
    choice = st.selectbox("Choose Category:", ('general', 'career', 'love', 'wellness', 'birthday'),
                                    index=0, )

    temp_slider = st.slider("Temperature (Higher Value = More randomness)", min_value=0.01, max_value=1.0, value=0.95)

if st.button('Generate Horoscopes!'):
    prompt = make_prompt(choice)
    prompt_encoded = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
    with st.spinner('Generating...'):
        sample_output = generate(prompt_encoded, model, tokenizer, temperature=temp_slider, num_outputs=1, top_k=40)
        final_out = tokenizer.decode(sample_output[0], skip_special_tokens=True)
        st.write(final_out[len(choice)+2:])
else: pass