akshay7 commited on
Commit
5c4e25b
·
1 Parent(s): f7fd10d

ADD: Caching of model on streamlit

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -15,10 +15,10 @@ def main():
15
  st.title('Emoji-motion!')
16
 
17
  example_prompts = [
 
18
  "Today is going to be awesome!",
19
- "Pity those who don't feel anything at all.",
20
- "I envy people that know love.",
21
- "Nature is so beautiful"]
22
 
23
  example = st.selectbox("Choose a pre-defined example", example_prompts)
24
 
@@ -31,9 +31,6 @@ def main():
31
  BASE_MODEL = st.selectbox("Choose a model", models_to_choose)
32
  TOP_N = 5
33
 
34
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
35
- model = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL)
36
-
37
  def preprocess(text):
38
  new_text = []
39
  for t in text.split(" "):
@@ -42,17 +39,20 @@ def main():
42
  new_text.append(t)
43
  return " ".join(new_text)
44
 
 
 
 
 
 
 
45
  def get_top_emojis(text, top_n=TOP_N):
 
46
  preprocessed = preprocess(text)
47
  inputs = tokenizer(preprocessed, return_tensors="pt")
48
  preds = model(**inputs).logits
49
  scores = torch.nn.functional.softmax(preds, dim=-1).detach().numpy()
50
  ranking = np.argsort(scores)
51
- print(ranking)
52
  ranking = ranking.squeeze()[::-1][:top_n]
53
- print(scores)
54
- print(ranking)
55
- print(model.config.id2label)
56
  emojis = [model.config.id2label[i] for i in ranking]
57
  return ', '.join(map(str, emojis))
58
 
 
15
  st.title('Emoji-motion!')
16
 
17
  example_prompts = [
18
+ "This space is lit!!",
19
  "Today is going to be awesome!",
20
+ "I love Machine Learning",
21
+ "Cool cool cool no doubt no doubt no doubt"]
 
22
 
23
  example = st.selectbox("Choose a pre-defined example", example_prompts)
24
 
 
31
  BASE_MODEL = st.selectbox("Choose a model", models_to_choose)
32
  TOP_N = 5
33
 
 
 
 
34
  def preprocess(text):
35
  new_text = []
36
  for t in text.split(" "):
 
39
  new_text.append(t)
40
  return " ".join(new_text)
41
 
42
+ @st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False)
43
+ def load_model():
44
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
45
+ model = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL)
46
+ return model, tokenizer
47
+
48
  def get_top_emojis(text, top_n=TOP_N):
49
+ model, tokenizer = load_model()
50
  preprocessed = preprocess(text)
51
  inputs = tokenizer(preprocessed, return_tensors="pt")
52
  preds = model(**inputs).logits
53
  scores = torch.nn.functional.softmax(preds, dim=-1).detach().numpy()
54
  ranking = np.argsort(scores)
 
55
  ranking = ranking.squeeze()[::-1][:top_n]
 
 
 
56
  emojis = [model.config.id2label[i] for i in ranking]
57
  return ', '.join(map(str, emojis))
58