Spaces:
Runtime error
Runtime error
ADD: Caching of model on streamlit
Browse files
app.py
CHANGED
@@ -15,10 +15,10 @@ def main():
|
|
15 |
st.title('Emoji-motion!')
|
16 |
|
17 |
example_prompts = [
|
|
|
18 |
"Today is going to be awesome!",
|
19 |
-
"
|
20 |
-
"
|
21 |
-
"Nature is so beautiful"]
|
22 |
|
23 |
example = st.selectbox("Choose a pre-defined example", example_prompts)
|
24 |
|
@@ -31,9 +31,6 @@ def main():
|
|
31 |
BASE_MODEL = st.selectbox("Choose a model", models_to_choose)
|
32 |
TOP_N = 5
|
33 |
|
34 |
-
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
35 |
-
model = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL)
|
36 |
-
|
37 |
def preprocess(text):
|
38 |
new_text = []
|
39 |
for t in text.split(" "):
|
@@ -42,17 +39,20 @@ def main():
|
|
42 |
new_text.append(t)
|
43 |
return " ".join(new_text)
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
def get_top_emojis(text, top_n=TOP_N):
|
|
|
46 |
preprocessed = preprocess(text)
|
47 |
inputs = tokenizer(preprocessed, return_tensors="pt")
|
48 |
preds = model(**inputs).logits
|
49 |
scores = torch.nn.functional.softmax(preds, dim=-1).detach().numpy()
|
50 |
ranking = np.argsort(scores)
|
51 |
-
print(ranking)
|
52 |
ranking = ranking.squeeze()[::-1][:top_n]
|
53 |
-
print(scores)
|
54 |
-
print(ranking)
|
55 |
-
print(model.config.id2label)
|
56 |
emojis = [model.config.id2label[i] for i in ranking]
|
57 |
return ', '.join(map(str, emojis))
|
58 |
|
|
|
15 |
st.title('Emoji-motion!')
|
16 |
|
17 |
example_prompts = [
|
18 |
+
"This space is lit!!",
|
19 |
"Today is going to be awesome!",
|
20 |
+
"I love Machine Learning",
|
21 |
+
"Cool cool cool no doubt no doubt no doubt"]
|
|
|
22 |
|
23 |
example = st.selectbox("Choose a pre-defined example", example_prompts)
|
24 |
|
|
|
31 |
BASE_MODEL = st.selectbox("Choose a model", models_to_choose)
|
32 |
TOP_N = 5
|
33 |
|
|
|
|
|
|
|
34 |
def preprocess(text):
|
35 |
new_text = []
|
36 |
for t in text.split(" "):
|
|
|
39 |
new_text.append(t)
|
40 |
return " ".join(new_text)
|
41 |
|
42 |
+
@st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False)
|
43 |
+
def load_model():
|
44 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
45 |
+
model = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL)
|
46 |
+
return model, tokenizer
|
47 |
+
|
48 |
def get_top_emojis(text, top_n=TOP_N):
|
49 |
+
model, tokenizer = load_model()
|
50 |
preprocessed = preprocess(text)
|
51 |
inputs = tokenizer(preprocessed, return_tensors="pt")
|
52 |
preds = model(**inputs).logits
|
53 |
scores = torch.nn.functional.softmax(preds, dim=-1).detach().numpy()
|
54 |
ranking = np.argsort(scores)
|
|
|
55 |
ranking = ranking.squeeze()[::-1][:top_n]
|
|
|
|
|
|
|
56 |
emojis = [model.config.id2label[i] for i in ranking]
|
57 |
return ', '.join(map(str, emojis))
|
58 |
|