papew28 commited on
Commit
f1d5671
·
verified ·
1 Parent(s): 4a37028

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -42
app.py CHANGED
@@ -1,43 +1,43 @@
1
- import streamlit as st
2
- from transformers import RobertaTokenizer,AutoModelForSequenceClassification
3
- import torch
4
-
5
- state_dict=torch.load("models/fine_tuned_roberta.bin",map_location=torch.device("cpu"))
6
- tokenizer=RobertaTokenizer.from_pretrained("roberta-base")
7
- model = AutoModelForSequenceClassification.from_pretrained('roberta-base',
8
- problem_type="multi_label_classification",
9
- num_labels=3
10
- )
11
- model.load_state_dict(state_dict)
12
- device = torch.device("cpu")
13
- model.to(device)
14
-
15
-
16
- def main():
17
- st.title("Classification de séquence")
18
-
19
- title = st.text_input("Titre")
20
- post = st.text_area("Post")
21
- comment = st.text_area("Commentaire")
22
-
23
- if st.button("Tester"):
24
- result = get_predictions(title, post, comment)
25
- st.success(result)
26
- @st.cache_data
27
- def get_predictions(title, post, commentaire):
28
- model.eval()
29
- inputs = tokenizer("title of the post: " + title + "\n" + "post: " + post + "\n" + "comment: " + commentaire, return_tensors="pt", padding=True, truncation=True, max_length=512)
30
- input_ids = inputs['input_ids'].to(device)
31
- attention_mask = inputs['attention_mask'].to(device)
32
- with torch.no_grad():
33
- outputs = model(input_ids, attention_mask=attention_mask)
34
- logits = outputs.logits
35
- _, preds = torch.max(logits, dim=1)
36
- id2label = {
37
- 0: "neutral",
38
- 1: "with palestine",
39
- 2: "with israel"
40
- }
41
- return id2label[preds.item()]
42
- if __name__ == "__main__":
43
  main()
 
1
+ import streamlit as st
2
+ from transformers import RobertaTokenizer,AutoModelForSequenceClassification
3
+ import torch
4
+
5
+ state_dict=torch.load("fine_tuned_roberta.bin",map_location=torch.device("cpu"))
6
+ tokenizer=RobertaTokenizer.from_pretrained("roberta-base")
7
+ model = AutoModelForSequenceClassification.from_pretrained('roberta-base',
8
+ problem_type="multi_label_classification",
9
+ num_labels=3
10
+ )
11
+ model.load_state_dict(state_dict)
12
+ device = torch.device("cpu")
13
+ model.to(device)
14
+
15
+
16
+ def main():
17
+ st.title("Classification de séquence")
18
+
19
+ title = st.text_input("Titre")
20
+ post = st.text_area("Post")
21
+ comment = st.text_area("Commentaire")
22
+
23
+ if st.button("Tester"):
24
+ result = get_predictions(title, post, comment)
25
+ st.success(result)
26
+ @st.cache_data
27
+ def get_predictions(title, post, commentaire):
28
+ model.eval()
29
+ inputs = tokenizer("title of the post: " + title + "\n" + "post: " + post + "\n" + "comment: " + commentaire, return_tensors="pt", padding=True, truncation=True, max_length=512)
30
+ input_ids = inputs['input_ids'].to(device)
31
+ attention_mask = inputs['attention_mask'].to(device)
32
+ with torch.no_grad():
33
+ outputs = model(input_ids, attention_mask=attention_mask)
34
+ logits = outputs.logits
35
+ _, preds = torch.max(logits, dim=1)
36
+ id2label = {
37
+ 0: "neutral",
38
+ 1: "with palestine",
39
+ 2: "with israel"
40
+ }
41
+ return id2label[preds.item()]
42
+ if __name__ == "__main__":
43
  main()