Fabian Lang commited on
Commit
f2aeda3
·
1 Parent(s): 354ae5a

Add application file

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[2]:
5
+
6
+
7
+ import gradio as gr
8
+ import torch
9
+
10
+
11
+ # In[3]:
12
+
13
+
14
+ model_ckpt = "langfab/distilbert-base-uncased-finetuned-movie-genre"
15
+
16
+ from transformers import (AutoTokenizer, AutoConfig,
17
+ AutoModelForSequenceClassification)
18
+
19
+ tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
20
+ config = AutoConfig.from_pretrained(model_ckpt)
21
+ model = AutoModelForSequenceClassification.from_pretrained(model_ckpt,config=config)
22
+
23
+
24
+ # In[4]:
25
+
26
+
27
+ id2label = model.config.id2label
28
+
29
+ def predict(plot):
30
+ encoding = tokenizer(plot, padding=True, truncation=True, return_tensors="pt")
31
+ encoding = {k: v.to(model.device) for k,v in encoding.items()}
32
+
33
+ outputs = model(**encoding)
34
+
35
+ logits = outputs.logits
36
+ logits.shape
37
+
38
+ predictions = torch.nn.functional.softmax(logits.squeeze().cpu(), dim=-1)
39
+ predictions
40
+
41
+ return id2label[int(predictions.argmax())]
42
+
43
+ iface = gr.Interface(title = "Movie Plot Genre Predictor", fn=predict, inputs="text", outputs="text")
44
+ iface.launch(share=True)
45
+