cccc's picture
Update app.py
df2e6a9
raw
history blame
2.63 kB
import gradio as gr
def sentiment_analysis(sentence, template, model_name, positive, neutral, negative):
model_name = model_name
template = template.replace('[SENTENCE]', '{"placeholder":"text_a"}')
template = template.replace('[MASK]', '{"mask"}')
classes = ['positive', 'neutral', 'negative']
label_words = {
"positive": positive.split(" "),
"neutral": neutral.split(" "),
"negative": negative.split(" "),
}
print(label_words)
type_dic = {
"CCCC/ARCH_tuned_bert":"bert",
"bert-base-uncased":"bert",
"roberta-base":"roberta",
"yiyanghkust/finbert-pretrain":"bert",
"facebook/opt-125m":"opt",
"facebook/opt-350m":"opt",
}
testdata = [InputExample(guid=0,text_a=sentence,label=0)]
plm, tokenizer, model_config, WrapperClass = load_plm(type_dic[model_name], model_name)
promptTemplate = ManualTemplate(
text = template,
tokenizer = tokenizer,
)
promptVerbalizer = ManualVerbalizer(
classes = classes,
label_words = label_words,
tokenizer = tokenizer,
)
test_dataloader = PromptDataLoader(
dataset = testdata,
tokenizer = tokenizer,
template = promptTemplate,
tokenizer_wrapper_class = WrapperClass,
batch_size = 1,
max_seq_length = 512,
)
prompt_model = PromptForClassification(
plm=plm,
template=promptTemplate,
verbalizer=promptVerbalizer,
freeze_plm=False #whether or not to freeze the pretrained language model
)
for step, inputs in enumerate(test_dataloader):
logits = prompt_model(inputs)
return classes[torch.argmax(logits, dim=-1)[0]]
demo = gr.Interface(fn=sentiment_analysis,
inputs = [gr.Textbox(placeholder="Enter sentence here.",label="sentence"),
gr.Textbox(placeholder="Your template must have a [SENTENCE] token and a [MASK] token.",label="template"),
gr.Radio(choices=["ARCH_tuned_robert","FNCH_tuned_robert","AREN_tuned_robert","FNEN_tuned_robert","bert-base-uncased"], label="model choics"),
gr.Textbox(placeholder="Separate words with Spaces.",label="positive label words"),
gr.Textbox(placeholder="Separate words with Spaces.",label="neutral label words"),
gr.Textbox(placeholder="Separate words with Spaces.",label="negative label words")
],
outputs="text",
)
demo.launch()