osliusarenko commited on
Commit
d65cf4e
·
1 Parent(s): a212615

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +10 -15
README.md CHANGED
@@ -17,24 +17,19 @@ This is a baseline RoBERTa-base model for the delicate text detection task.
17
  Here's a short usage example with the torch library in a binary classification task:
18
 
19
  ```python
20
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
21
- import torch
22
 
23
- tokenizer = AutoTokenizer.from_pretrained("grammarly/detexd-roberta")
24
- model = AutoModelForSequenceClassification.from_pretrained("grammarly/detexd-roberta")
25
- model.eval()
26
 
27
- def predict_binary_score(text: str, break_class_ix=3):
28
- with torch.no_grad():
29
- # get multiclass probability scores
30
- logits = model(**tokenizer(text, return_tensors='pt'))[0]
31
- probs = torch.nn.functional.softmax(logits, dim=-1)
32
 
33
- # convert to a binary prediction by summing the probability scores
34
- # for the higher-index classes, as defined by break_class_ix
35
- bin_score = probs[..., break_class_ix:].sum(dim=-1)
36
-
37
- return bin_score.item()
38
 
39
  def predict_delicate(text: str, threshold=0.72496545):
40
  return predict_binary_score(text) > threshold
 
17
  Here's a short usage example with the torch library in a binary classification task:
18
 
19
  ```python
20
+ from transformers import pipeline
 
21
 
22
+ classifier = pipeline("text-classification", model="grammarly/detexd-roberta-base")
 
 
23
 
24
+ def predict_binary_score(text: str):
25
+ # get multiclass probability scores
26
+ scores = classifier(text, top_k=None)
 
 
27
 
28
+ # convert to a single score by summing the probability scores
29
+ # for the higher-index classes
30
+ return sum(score['score']
31
+ for score in scores
32
+ if score['label'] in ('LABEL_3', 'LABEL_4', 'LABEL_5'))
33
 
34
  def predict_delicate(text: str, threshold=0.72496545):
35
  return predict_binary_score(text) > threshold