|
def sentiment_analysis(sentence, template, model_name, positive, neutral, negative): |
|
model_name = model_name |
|
template = template.replace('[SENTENCE]', '{"placeholder":"text_a"}') |
|
template = template.replace('[MASK]', '{"mask"}') |
|
classes = ['positive', 'neutral', 'negative'] |
|
|
|
label_words = { |
|
"positive": positive.split(" "), |
|
"neutral": neutral.split(" "), |
|
"negative": negative.split(" "), |
|
} |
|
print(label_words) |
|
type_dic = { |
|
"CCCC/ARCH_tuned_bert":"bert", |
|
"bert-base-uncased":"bert", |
|
"roberta-base":"roberta", |
|
"yiyanghkust/finbert-pretrain":"bert", |
|
"facebook/opt-125m":"opt", |
|
"facebook/opt-350m":"opt", |
|
} |
|
|
|
testdata = [InputExample(guid=0,text_a=sentence,label=0)] |
|
|
|
plm, tokenizer, model_config, WrapperClass = load_plm(type_dic[model_name], model_name) |
|
|
|
promptTemplate = ManualTemplate( |
|
text = template, |
|
tokenizer = tokenizer, |
|
) |
|
promptVerbalizer = ManualVerbalizer( |
|
classes = classes, |
|
label_words = label_words, |
|
tokenizer = tokenizer, |
|
) |
|
test_dataloader = PromptDataLoader( |
|
dataset = testdata, |
|
tokenizer = tokenizer, |
|
template = promptTemplate, |
|
tokenizer_wrapper_class = WrapperClass, |
|
batch_size = 1, |
|
max_seq_length = 512, |
|
) |
|
prompt_model = PromptForClassification( |
|
plm=plm, |
|
template=promptTemplate, |
|
verbalizer=promptVerbalizer, |
|
freeze_plm=False |
|
) |
|
for step, inputs in enumerate(test_dataloader): |
|
logits = prompt_model(inputs) |
|
|
|
|
|
return classes[torch.argmax(logits, dim=-1)[0]] |
|
|
|
|
|
demo = gr.Interface(fn=sentiment_analysis, |
|
inputs = [gr.Textbox(placeholder="Enter sentence here.",label="sentence"), |
|
gr.Textbox(placeholder="Your template must have a [SENTENCE] token and a [MASK] token.",label="template"), |
|
gr.Radio(choices=["ARCH_tuned_robert","FNCH_tuned_robert","AREN_tuned_robert","FNEN_tuned_robert","bert-base-uncased"], label="model choics"), |
|
gr.Textbox(placeholder="Separate words with Spaces.",label="positive label words"), |
|
gr.Textbox(placeholder="Separate words with Spaces.",label="neutral label words"), |
|
gr.Textbox(placeholder="Separate words with Spaces.",label="negative label words") |
|
], |
|
outputs="text", |
|
) |
|
|
|
demo.launch() |