cccc's picture
Update app.py
da4e37a
raw
history blame
2.74 kB
import gradio as gr
from transformers import pipeline
import torch
from openprompt.plms import load_plm
from openprompt import PromptDataLoader
from openprompt.prompts import ManualVerbalizer
from openprompt.prompts import ManualTemplate
from openprompt.data_utils import InputExample
from openprompt import PromptForClassification
def sentiment_analysis(sentence, template, positive, neutral, negative):
model_name = "CCCC/ARCH_tuned_bert"
template = template.replace('[SENTENCE]', '{"placeholder":"text_a"}')
template = template.replace('[MASK]', '{"mask"}')
classes = ['positive', 'neutral', 'negative']
label_words = {
"positive": positive.split(" "),
"neutral": neutral.split(" "),
"negative": negative.split(" "),
}
print(label_words)
type_dic = {
"CCCC/ARCH_tuned_bert":"bert",
"bert-base-uncased":"bert",
"roberta-base":"roberta",
"yiyanghkust/finbert-pretrain":"bert",
"facebook/opt-125m":"opt",
"facebook/opt-350m":"opt",
}
testdata = [InputExample(guid=0,text_a=sentence,label=0)]
plm, tokenizer, model_config, WrapperClass = load_plm(type_dic[model_name], model_name)
promptTemplate = ManualTemplate(
text = template,
tokenizer = tokenizer,
)
promptVerbalizer = ManualVerbalizer(
classes = classes,
label_words = label_words,
tokenizer = tokenizer,
)
test_dataloader = PromptDataLoader(
dataset = testdata,
tokenizer = tokenizer,
template = promptTemplate,
tokenizer_wrapper_class = WrapperClass,
batch_size = 1,
max_seq_length = 512,
)
prompt_model = PromptForClassification(
plm=plm,
template=promptTemplate,
verbalizer=promptVerbalizer,
freeze_plm=False #whether or not to freeze the pretrained language model
)
for step, inputs in enumerate(test_dataloader):
logits = prompt_model(inputs)
return classes[torch.argmax(logits, dim=-1)[0]]
demo = gr.Interface(fn=sentiment_analysis,
inputs = [gr.Textbox(placeholder="Enter sentence here.",label="sentence"),
gr.Textbox(placeholder="Your template must have a [SENTENCE] token and a [MASK] token.",label="template"),
gr.Textbox(placeholder="Separate words with Spaces.",label="positive"),
gr.Textbox(placeholder="Separate words with Spaces.",label="neutral"),
gr.Textbox(placeholder="Separate words with Spaces.",label="negative")
],
outputs="text")
demo.launch(server_port=8080)