papew28 commited on
Commit
3188bd3
·
verified ·
1 Parent(s): e610f8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -28
app.py CHANGED
@@ -1,43 +1,66 @@
1
  import streamlit as st
2
- from transformers import RobertaTokenizer,AutoModelForSequenceClassification
3
  import torch
4
 
5
- state_dict=torch.load("fine_tuned_roberta_comments.bin",map_location=torch.device("cpu"))
6
- tokenizer=RobertaTokenizer.from_pretrained("roberta-base")
7
- model = AutoModelForSequenceClassification.from_pretrained('roberta-base',
8
- problem_type="multi_label_classification",
9
- num_labels=3
10
- )
11
- model.load_state_dict(state_dict)
12
- device = torch.device("cpu")
13
- model.to(device)
14
 
 
 
 
15
 
16
- def main():
17
- st.title("Classification de séquence")
18
 
19
- title = st.text_input("Titre")
20
- post = st.text_area("Post")
21
- comment = st.text_area("Commentaire")
22
 
23
- if st.button("Tester"):
24
- result = get_predictions(title, post, comment)
25
- st.success(result)
26
  @st.cache_data
27
- def get_predictions(title, post, commentaire):
28
- model.eval()
29
- inputs = tokenizer("comment: " + commentaire, return_tensors="pt", padding=True, truncation=True, max_length=512)
30
  input_ids = inputs['input_ids'].to(device)
31
  attention_mask = inputs['attention_mask'].to(device)
32
  with torch.no_grad():
33
- outputs = model(input_ids, attention_mask=attention_mask)
34
  logits = outputs.logits
 
35
  _, preds = torch.max(logits, dim=1)
36
- id2label = {
37
- 0: "neutral",
38
- 1: "with palestine",
39
- 2: "with israel"
40
- }
41
- return id2label[preds.item()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  if __name__ == "__main__":
43
  main()
 
1
  import streamlit as st
2
+ from transformers import RobertaTokenizer, AutoModelForSequenceClassification
3
  import torch
4
 
5
+ # Chargement des modèles
6
+ state_dict_comment = torch.load("fine_tuned_roberta_comment.bin", map_location=torch.device("cpu"))
7
+ state_dict_full = torch.load("fine_tuned_roberta_full.bin", map_location=torch.device("cpu"))
 
 
 
 
 
 
8
 
9
+ tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
10
+ model_comment = AutoModelForSequenceClassification.from_pretrained('roberta-base', problem_type="multi_label_classification", num_labels=3)
11
+ model_comment.load_state_dict(state_dict_comment)
12
 
13
+ model_full = AutoModelForSequenceClassification.from_pretrained('roberta-base', problem_type="multi_label_classification", num_labels=3)
14
+ model_full.load_state_dict(state_dict_full)
15
 
16
+ device = torch.device("cpu")
17
+ model_comment.to(device)
18
+ model_full.to(device)
19
 
 
 
 
20
  @st.cache_data
21
+ def get_predictions_comment(commentaire):
22
+ model_comment.eval()
23
+ inputs = tokenizer(commentaire, return_tensors="pt", padding=True, truncation=True, max_length=512)
24
  input_ids = inputs['input_ids'].to(device)
25
  attention_mask = inputs['attention_mask'].to(device)
26
  with torch.no_grad():
27
+ outputs = model_comment(input_ids, attention_mask=attention_mask)
28
  logits = outputs.logits
29
+ probs = torch.softmax(logits, dim=1)
30
  _, preds = torch.max(logits, dim=1)
31
+ id2label = {0: "neutral", 1: "with palestine", 2: "with israel"}
32
+ return id2label[preds.item()], probs.squeeze().tolist()
33
+
34
+ @st.cache_data
35
+ def get_predictions_full(title, post, commentaire):
36
+ model_full.eval()
37
+ inputs = tokenizer("title of the post: " + title + "\n" + "post: " + post + "\n" + "comment: " + commentaire, return_tensors="pt", padding=True, truncation=True, max_length=512)
38
+ input_ids = inputs['input_ids'].to(device)
39
+ attention_mask = inputs['attention_mask'].to(device)
40
+ with torch.no_grad():
41
+ outputs = model_full(input_ids, attention_mask=attention_mask)
42
+ logits = outputs.logits
43
+ probs = torch.softmax(logits, dim=1)
44
+ _, preds = torch.max(logits, dim=1)
45
+ id2label = {0: "neutral", 1: "with palestine", 2: "with israel"}
46
+ return id2label[preds.item()], probs.squeeze().tolist()
47
+
48
+ def main():
49
+ st.title("Classification de séquence")
50
+ title = st.text_input("Titre")
51
+ post = st.text_area("Post")
52
+ comment = st.text_area("Commentaire")
53
+ if st.button("Tester"):
54
+ if title or post:
55
+ result, probs = get_predictions_full(title, post, comment)
56
+ else:
57
+ result, probs = get_predictions_comment(comment)
58
+ st.success(result)
59
+ st.write("Probabilités:")
60
+ neutral_prob, palestine_prob, israel_prob = probs
61
+ st.slider("Neutre", 0.0, 1.0, neutral_prob, key="neutral_slider")
62
+ st.slider("Avec Palestine", 0.0, 1.0, palestine_prob, key="palestine_slider")
63
+ st.slider("Avec Israël", 0.0, 1.0, israel_prob, key="israel_slider")
64
+
65
  if __name__ == "__main__":
66
  main()