adorkin commited on
Commit
bc7f7c5
·
1 Parent(s): f7fd10d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -24
app.py CHANGED
@@ -3,6 +3,30 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import numpy as np
4
  import torch
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def main():
7
 
8
  st.set_page_config( # Alternate names: setup_page, page, layout
@@ -34,31 +58,10 @@ def main():
34
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
35
  model = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL)
36
 
37
- def preprocess(text):
38
- new_text = []
39
- for t in text.split(" "):
40
- t = '@user' if t.startswith('@') and len(t) > 1 else t
41
- t = 'http' if t.startswith('http') else t
42
- new_text.append(t)
43
- return " ".join(new_text)
44
-
45
- def get_top_emojis(text, top_n=TOP_N):
46
- preprocessed = preprocess(text)
47
- inputs = tokenizer(preprocessed, return_tensors="pt")
48
- preds = model(**inputs).logits
49
- scores = torch.nn.functional.softmax(preds, dim=-1).detach().numpy()
50
- ranking = np.argsort(scores)
51
- print(ranking)
52
- ranking = ranking.squeeze()[::-1][:top_n]
53
- print(scores)
54
- print(ranking)
55
- print(model.config.id2label)
56
- emojis = [model.config.id2label[i] for i in ranking]
57
- return ', '.join(map(str, emojis))
58
 
59
  # Define function to run when submit is clicked
60
  def submit(message):
61
- if len(message)>0:
62
  st.header(get_top_emojis(message))
63
  else:
64
  st.error("The text can't be empty")
@@ -68,8 +71,9 @@ def main():
68
  submit(message)
69
 
70
  st.text('')
71
- st.markdown('<span style="color:blue; font-size:10px">App created by [@AlekseyDorkin](https://huggingface.co/AlekseyDorkin) \
72
- and [@akshay7](https://huggingface.co/akshay7)</span>',unsafe_allow_html=True)
 
73
 
74
 
75
  if __name__ == "__main__":
 
3
  import numpy as np
4
  import torch
5
 
6
+
7
+ def preprocess(text):
8
+ new_text = []
9
+ for t in text.split(" "):
10
+ t = '@user' if t.startswith('@') and len(t) > 1 else t
11
+ t = 'http' if t.startswith('http') else t
12
+ new_text.append(t)
13
+ return " ".join(new_text)
14
+
15
+
16
+ def get_top_emojis(text, top_n=TOP_N):
17
+ preprocessed = preprocess(text)
18
+ inputs = tokenizer(preprocessed, return_tensors="pt")
19
+ preds = model(**inputs).logits
20
+ scores = torch.nn.functional.softmax(preds, dim=-1).detach().numpy()
21
+ ranking = np.argsort(scores)
22
+ print(ranking)
23
+ ranking = ranking.squeeze()[::-1][:top_n]
24
+ print(scores)
25
+ print(ranking)
26
+ print(model.config.id2label)
27
+ emojis = [model.config.id2label[i] for i in ranking]
28
+ return '/t'.join(map(str, emojis))
29
+
30
  def main():
31
 
32
  st.set_page_config( # Alternate names: setup_page, page, layout
 
58
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
59
  model = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL)
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  # Define function to run when submit is clicked
63
  def submit(message):
64
+ if len(message) > 0:
65
  st.header(get_top_emojis(message))
66
  else:
67
  st.error("The text can't be empty")
 
71
  submit(message)
72
 
73
  st.text('')
74
+ st.markdown('''<span style="color:blue; font-size:10px">App created by [@AlekseyDorkin](https://huggingface.co/AlekseyDorkin
75
+ and [@akshay7](https://huggingface.co/akshay7)</span>''',
76
+ unsafe_allow_html=True)
77
 
78
 
79
  if __name__ == "__main__":