TerminatorPower commited on
Commit
6582310
·
verified ·
1 Parent(s): ed9687a

Create predict.py

Browse files
Files changed (1) hide show
  1. predict.py +39 -0
predict.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+
4
+ # Load the model and tokenizer
5
+ model_name = "TerminatorPower/bert-news-classif-turkish"
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
8
+ model.eval()
9
+
10
+ # Load the reverse label mapping
11
+ reverse_label_mapping = {
12
+ 0: "label_0",
13
+ 1: "label_1",
14
+ 2: "label_2",
15
+ 3: "label_3",
16
+ 4: "label_4",
17
+ 5: "label_5",
18
+ 6: "label_6",
19
+ 7: "label_7",
20
+ 8: "label_8",
21
+ 9: "label_9",
22
+ 10: "label_10",
23
+ 11: "label_11",
24
+ 12: "siyaset" # Example: Map index 12 back to "siyaset"
25
+ }
26
+
27
+ def predict(text):
28
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
29
+ inputs = {key: value.to("cuda" if torch.cuda.is_available() else "cpu") for key, value in inputs.items()}
30
+ model.to(inputs["input_ids"].device)
31
+ with torch.no_grad():
32
+ outputs = model(**inputs)
33
+ predictions = torch.argmax(outputs.logits, dim=1)
34
+ predicted_label = reverse_label_mapping[predictions.item()]
35
+ return predicted_label
36
+
37
+ if __name__ == "__main__":
38
+ text = input()
39
+ print(f"Predicted label: {predict(text)}")