Ptato commited on
Commit
5505986
1 Parent(s): 5cf11d3

table functionality

Browse files
Files changed (3) hide show
  1. app.py +237 -60
  2. requirements.txt +3 -1
  3. test.csv +0 -0
app.py CHANGED
@@ -5,22 +5,126 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
  import os
6
  import torch
7
  import numpy as np
8
- os.environ['KMP_DUPLICATE_LIB_OK'] = "True"
 
9
 
 
10
 
11
 
12
  st.title("Sentiment Analysis App")
13
-
14
- labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
 
 
 
 
 
 
15
 
16
  form = st.form(key='Sentiment Analysis')
17
- box = form.selectbox('Select Pre-trained Model:', ['bertweet-base-sentiment-analysis',
18
- 'distilbert-base-uncased-finetuned-sst-2-english',
19
- 'twitter-roberta-base-sentiment',
20
- 'Modified Bert Toxicity Classification'
21
- ], key=1)
 
22
  tweet = form.text_input(label='Enter text to analyze:', value="\"We've seen in the last few months, unprecedented amounts of Voter Fraud.\" @SenTedCruz True!")
23
  submit = form.form_submit_button(label='Submit')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  if submit and tweet:
26
  with st.spinner('Analyzing...'):
@@ -32,11 +136,11 @@ if submit and tweet:
32
  else:
33
  col1, col2, col3, col4, col5 = st.columns(5)
34
  if box == 'bertweet-base-sentiment-analysis':
35
- pipeline = pipeline(task="sentiment-analysis", model="finiteautomata/bertweet-base-sentiment-analysis")
36
  elif box == 'twitter-roberta-base-sentiment':
37
- pipeline = pipeline(task="sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment")
38
  elif box == 'distilbert-base-uncased-finetuned-sst-2-english':
39
- pipeline = pipeline(task="sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
40
 
41
 
42
  # <--- Unecessary Testing --->
@@ -53,8 +157,8 @@ if submit and tweet:
53
  predictions = np.zeros(probs.shape)
54
  predictions[np.where(probs >= 0.5)] = 1
55
  # turn predicted id's into actual label names
56
- id2label = {idx: label for idx, label in enumerate(labels)}
57
- predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
58
  print(predicted_labels)
59
  print(predictions[0])
60
  else:
@@ -64,60 +168,133 @@ if submit and tweet:
64
  encoding = {k: v.to(model.device) for k,v in encoding.items()}
65
  predictions = model(**encoding)
66
  print(predictions)
67
- col4
68
- if pipeline:
69
- predictions = pipeline(tweet)
70
  col2.header("Judgement")
71
  else:
72
- col2.header("Toxic?")
73
  col4.header("Toxicity Type")
74
  col5.header("Probability")
75
- print(predictions)
76
 
77
  col1.header("Tweet")
78
  col3.header("Probability")
79
 
80
- col1.subheader(tweet)
81
- for p in predictions:
82
- if box == 'bertweet-base-sentiment-analysis':
83
- if p['label'] == "POS":
84
- col2.success("POSITIVE")
85
- col3.success(f"{ round(p['score'] * 100, 1)}%")
86
- elif p['label'] == "NEU":
87
- col2.warning(f"{ p['label'] }")
88
- col3.warning(f"{round(p['score'] * 100, 1)}%")
89
- else:
90
- col2.error("NEGATIVE")
91
- col3.error(f"{round(p['score'] * 100, 1)}%")
92
- elif box == 'distilbert-base-uncased-finetuned-sst-2-english':
93
- if p['label'] == "POSITIVE":
94
- col2.success("POSITIVE")
95
- col3.success(f"{round(p['score'] * 100, 1)}%")
96
- else:
97
- col2.error("NEGATIVE")
98
- col3.error(f"{round(p['score'] * 100, 1)}%")
99
- elif box == 'twitter-roberta-base-sentiment':
100
- if p['label'] == "LABEL_2":
101
- col2.success("POSITIVE")
102
- col3.success(f"{round(p['score'] * 100, 1)}%")
103
- elif p['label'] == "LABEL_0":
104
- col2.error("NEGATIVE")
105
- col3.error(f"{round(p['score'] * 100, 1)}%")
106
- else:
107
- col2.warning("NEUTRAL")
108
- col3.warning(f"{round(p['score'] * 100, 1)}%")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  else:
110
- if predictions[0] == 0:
111
- col2.success("NO TOXICITY")
112
- col3.success(f"{100 - round(probs[0] * 100, 1)}%")
113
- col4.success("N/A")
114
- col5.success("N/A")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  else:
116
- col2.error("TOXIC")
117
- col3.error(f"{round(probs[0] * 100, 1)}%")
118
- _max = 1
119
- for i in range(2, len(predictions)):
120
- if probs[i] > probs[_max]:
121
- _max = i
122
- col4.error(labels[_max])
123
- col5.error(f"{round(probs[_max] * 100, 1)}%")
 
5
  import os
6
  import torch
7
  import numpy as np
8
+ import pandas as pd
9
+
10
 
11
+ os.environ['KMP_DUPLICATE_LIB_OK'] = "True"
12
 
13
 
14
  st.title("Sentiment Analysis App")
15
+ if 'logs' not in st.session_state:
16
+ st.session_state.logs = dict()
17
+ if 'labels' not in st.session_state:
18
+ st.session_state.labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
19
+ if 'id2label' not in st.session_state:
20
+ st.session_state.id2label = {idx: label for idx, label in enumerate(st.session_state.labels)}
21
+ if 'filled' not in st.session_state:
22
+ st.session_state.filled = False
23
 
24
  form = st.form(key='Sentiment Analysis')
25
+ st.session_state.options = ['bertweet-base-sentiment-analysis',
26
+ 'distilbert-base-uncased-finetuned-sst-2-english',
27
+ 'twitter-roberta-base-sentiment',
28
+ # 'Modified Bert Toxicity Classification'
29
+ ]
30
+ box = form.selectbox('Select Pre-trained Model:', st.session_state.options, key=1)
31
  tweet = form.text_input(label='Enter text to analyze:', value="\"We've seen in the last few months, unprecedented amounts of Voter Fraud.\" @SenTedCruz True!")
32
  submit = form.form_submit_button(label='Submit')
33
+ if 'df' not in st.session_state:
34
+ st.session_state.df = pd.read_csv("test.csv")
35
+
36
+ if not st.session_state.filled:
37
+ for s in st.session_state.options:
38
+ st.session_state.logs[s] = []
39
+ if not st.session_state.filled:
40
+ st.session_state.filled = True
41
+ for x in range(10):
42
+ print(x)
43
+ text = st.session_state.df["comment_text"].iloc[x][:128]
44
+ for s in st.session_state.options:
45
+ if s == 'bertweet-base-sentiment-analysis':
46
+ pline = pipeline(task="sentiment-analysis", model="finiteautomata/bertweet-base-sentiment-analysis")
47
+ elif s == 'twitter-roberta-base-sentiment':
48
+ pline = pipeline(task="sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment")
49
+ elif s == 'distilbert-base-uncased-finetuned-sst-2-english':
50
+ pline = pipeline(task="sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
51
+ else:
52
+ model = AutoModelForSequenceClassification.from_pretrained('./model')
53
+ model.eval()
54
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
55
+ encoding = tokenizer(tweet, return_tensors="pt")
56
+ encoding = {k: v.to(model.device) for k,v in encoding.items()}
57
+ predictions = model(**encoding)
58
+ logits = predictions.logits
59
+ sigmoid = torch.nn.Sigmoid()
60
+ probs = sigmoid(logits.squeeze().cpu())
61
+ predictions = np.zeros(probs.shape)
62
+ predictions[np.where(probs >= 0.5)] = 1
63
+ predicted_labels = [st.session_state.id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
64
+ log = []
65
+ if pline:
66
+ predictions = pline(text)
67
+ log = [0] * 4
68
+ log[1] = text
69
+ for p in predictions:
70
+ if s == 'bertweet-base-sentiment-analysis':
71
+ if p['label'] == "POS":
72
+ log[0] = 0
73
+ log[2] = "POSITIVE"
74
+ log[3] = f"{ round(p['score'] * 100, 1)}%"
75
+ elif p['label'] == "NEU":
76
+ log[0] = 2
77
+ log[2] = f"{ p['label'] }"
78
+ log[3] = f"{round(p['score'] * 100, 1)}%"
79
+ else:
80
+ log[2] = "NEG"
81
+ log[0] = 1
82
+ log[3] = f"{round(p['score'] * 100, 1)}%"
83
+ elif s == 'distilbert-base-uncased-finetuned-sst-2-english':
84
+ if p['label'] == "POSITIVE":
85
+ log[0] = 0
86
+ log[2] = "POSITIVE"
87
+ log[3] = (f"{round(p['score'] * 100, 1)}%")
88
+ else:
89
+ log[2] = ("NEGATIVE")
90
+ log[0] = 1
91
+ log[3] = (f"{round(p['score'] * 100, 1)}%")
92
+ elif s == 'twitter-roberta-base-sentiment':
93
+ if p['label'] == "LABEL_2":
94
+ log[0] = 0
95
+ log[2] = ("POSITIVE")
96
+ log[3] = (f"{round(p['score'] * 100, 1)}%")
97
+ elif p['label'] == "LABEL_0":
98
+ log[0] = 1
99
+ log[2] = ("NEGATIVE")
100
+ log[3] = f"{round(p['score'] * 100, 1)}%"
101
+ else:
102
+ log[0] = 2
103
+ log[2] = "NEUTRAL"
104
+ log[3] = f"{round(p['score'] * 100, 1)}%"
105
+ else:
106
+ log = [0] * 6
107
+ log[1] = text
108
+ if max(predictions) == 0:
109
+ log[0] = 0
110
+ log[2] = ("NO TOXICITY")
111
+ log[3] = (f"{100 - round(probs[0] * 100, 1)}%")
112
+ log[4] = ("N/A")
113
+ log[5] = ("N/A")
114
+ else:
115
+ log[0] = 1
116
+ _max = 0
117
+ _max2 = 2
118
+ for i in range(1, len(predictions)):
119
+ if probs[i] > probs[_max]:
120
+ _max = i
121
+ if i > 2 and probs[i] > probs[_max2]:
122
+ _max2 = i
123
+ log[2] = (st.session_state.labels[_max])
124
+ log[3] = (f"{round(probs[_max] * 100, 1)}%")
125
+ log[4] = (st.session_state.labels[_max2])
126
+ log[5] = (f"{round(probs[_max2] * 100, 1)}%")
127
+ st.session_state.logs[s].append(log)
128
 
129
  if submit and tweet:
130
  with st.spinner('Analyzing...'):
 
136
  else:
137
  col1, col2, col3, col4, col5 = st.columns(5)
138
  if box == 'bertweet-base-sentiment-analysis':
139
+ pline = pipeline(task="sentiment-analysis", model="finiteautomata/bertweet-base-sentiment-analysis")
140
  elif box == 'twitter-roberta-base-sentiment':
141
+ pline = pipeline(task="sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment")
142
  elif box == 'distilbert-base-uncased-finetuned-sst-2-english':
143
+ pline = pipeline(task="sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
144
 
145
 
146
  # <--- Unecessary Testing --->
 
157
  predictions = np.zeros(probs.shape)
158
  predictions[np.where(probs >= 0.5)] = 1
159
  # turn predicted id's into actual label names
160
+ st.session_state.id2label = {idx: label for idx, label in enumerate(st.session_state.labels)}
161
+ predicted_labels = [st.session_state.id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
162
  print(predicted_labels)
163
  print(predictions[0])
164
  else:
 
168
  encoding = {k: v.to(model.device) for k,v in encoding.items()}
169
  predictions = model(**encoding)
170
  print(predictions)
171
+ if pline:
172
+ predictions = pline(tweet)
 
173
  col2.header("Judgement")
174
  else:
175
+ col2.header("")
176
  col4.header("Toxicity Type")
177
  col5.header("Probability")
 
178
 
179
  col1.header("Tweet")
180
  col3.header("Probability")
181
 
182
+ if pline:
183
+ log = [0] * 4
184
+ log[1] = tweet
185
+ for p in predictions:
186
+ if box == 'bertweet-base-sentiment-analysis':
187
+ if p['label'] == "POS":
188
+ col1.success(tweet.split("\n")[0][:20])
189
+ log[0] = 0
190
+ col2.success("POS")
191
+ col3.success(f"{ round(p['score'] * 100, 1)}%")
192
+ log[2] = ("POS")
193
+ log[3] = (f"{ round(p['score'] * 100, 1)}%")
194
+ elif p['label'] == "NEU":
195
+ col1.warning(tweet.split("\n")[0][:20])
196
+ log[0] = 2
197
+ col2.warning(f"{ p['label'] }")
198
+ col3.warning(f"{round(p['score'] * 100, 1)}%")
199
+ log[2] = ("NEU")
200
+ log[3] = (f"{round(p['score'] * 100, 1)}%")
201
+ else:
202
+ log[0] = 1
203
+ col1.error(tweet.split("\n")[0][:20])
204
+ col2.error("NEG")
205
+ col3.error(f"{round(p['score'] * 100, 1)}%")
206
+ log[2] = ("NEG")
207
+ log[3] = (f"{round(p['score'] * 100, 1)}%")
208
+ elif box == 'distilbert-base-uncased-finetuned-sst-2-english':
209
+ if p['label'] == "POSITIVE":
210
+ col1.success(tweet.split("\n")[0][:20])
211
+ log[0] = 0
212
+ col2.success("POSITIVE")
213
+ log[2] = "POSITIVE"
214
+ col3.success(f"{round(p['score'] * 100, 1)}%")
215
+ log[3] = f"{round(p['score'] * 100, 1)}%"
216
+ else:
217
+ col2.error("NEGATIVE")
218
+ col1.error(tweet.split("\n")[0][:20])
219
+ log[2] = ("NEGATIVE")
220
+ log[0] = 1
221
+ col3.error(f"{round(p['score'] * 100, 1)}%")
222
+ log[3] = f"{round(p['score'] * 100, 1)}%"
223
+ elif box == 'twitter-roberta-base-sentiment':
224
+ if p['label'] == "LABEL_2":
225
+ log[0] = 0
226
+ col1.success(tweet.split("\n")[0][:20])
227
+ col2.success("POSITIVE")
228
+ col3.success(f"{round(p['score'] * 100, 1)}%")
229
+ log[3] = f"{round(p['score'] * 100, 1)}%"
230
+ log[2] = "POSITIVE"
231
+ elif p['label'] == "LABEL_0":
232
+ log[0] = 1
233
+ col1.error(tweet.split("\n")[0][:20])
234
+ col2.error("NEGATIVE")
235
+ col3.error(f"{round(p['score'] * 100, 1)}%")
236
+ log[3] = f"{round(p['score'] * 100, 1)}%"
237
+ log[2] = "NEGATIVE"
238
+ else:
239
+ log[0] = 2
240
+ col1.warning(tweet.split("\n")[0][:20])
241
+ col2.warning("NEUTRAL")
242
+ col3.warning(f"{round(p['score'] * 100, 1)}%")
243
+ log[3] = f"{round(p['score'] * 100, 1)}%"
244
+ log[2] = "NEUTRAL"
245
+ for a in st.session_state.logs[box][::-1]:
246
+ if a[0] == 0:
247
+ col1.success(a[1].split("\n")[0][:20])
248
+ col2.success(a[2])
249
+ col3.success(a[3])
250
+ elif a[0] == 1:
251
+ col1.error(a[1].split("\n")[0][:20])
252
+ col2.error(a[2])
253
+ col3.error(a[3])
254
+ else:
255
+ col1.warning(a[1].split("\n")[0][:20])
256
+ col2.warning(a[2])
257
+ col3.warning(a[3])
258
+ st.session_state.logs[box].append(log)
259
+ else:
260
+ log = [0] * 6
261
+ log[1] = tweet
262
+ if max(predictions) == 0:
263
+ col1.success(tweet.split("\n")[0][:20])
264
+ col2.success("NO TOXICITY")
265
+ col3.success(f"{100 - round(probs[0] * 100, 1)}%")
266
+ col4.success("N/A")
267
+ col5.success("N/A")
268
  else:
269
+ _max = 0
270
+ _max2 = 2
271
+ for i in range(1, len(predictions)):
272
+ if probs[i] > probs[_max]:
273
+ _max = i
274
+ if i > 2 and probs[i] > probs[_max2]:
275
+ _max2 = i
276
+ col1.error(tweet.split("\n")[0][:20])
277
+ col2.error(st.session_state.labels[_max])
278
+ col3.error(f"{round(probs[_max] * 100, 1)}%")
279
+ col4.error(st.session_state.labels[_max2])
280
+ col5.error(f"{round(probs[_max2] * 100, 1)}%")
281
+ for a in st.session_state.logs[box][::-1]:
282
+ if a[0] == 0:
283
+ col1.success(a[1].split("\n")[0][:20])
284
+ col2.success(a[2])
285
+ col3.success(a[3])
286
+ col4.success(a[4])
287
+ col5.success(a[5])
288
+ elif a[0] == 1:
289
+ col1.error(a[1].split("\n")[0][:20])
290
+ col2.error(a[2])
291
+ col3.error(a[3])
292
+ col4.error(a[4])
293
+ col5.error(a[5])
294
  else:
295
+ col1.warning(a[1].split("\n")[0][:20])
296
+ col2.warning(a[2])
297
+ col3.warning(a[3])
298
+ col4.warning(a[4])
299
+ col5.warning(a[5])
300
+ st.session_state.logs[box].append(log)
 
 
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
  torch
2
  streamlit
3
- transformers
 
 
 
1
  torch
2
  streamlit
3
+ transformers
4
+ numpy
5
+ pandas
test.csv ADDED
Binary file (60.4 MB). View file