Jaehan's picture
Create app.py
f24c629
raw
history blame
900 Bytes
from transformers import BartForSequenceClassification, BartTokenizer
import gradio as grad
model_name = "facebook/bart-large-mnli"
bart_tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForSequenceClassification.from_pretrained(model_name)
def classify(text, label):
token_ids = bart_tokenizer.encode(text, label, return_tensors="pt")
token_logits = model(token_ids)[0]
entail_contra_token_logits = token_logits[:, [0, 2]]
probabilities = entail_contra_token_logits.softmax(dim=1)
response = probabilities[:, 1].item() * 100
return response
in_text = grad.Textbox(lines=1, label="English", placeholder="Text to be classified")
in_labels = grad.Textbox(lines=1, label="Label", placeholder="Input a label")
out = grad.Textbox(lines=1, label="Probability of label being true is ")
grad.Interface(classify, inputs=[in_text, in_labels], outputs=[out]).launch()