Spaces:
Sleeping
Sleeping
File size: 2,746 Bytes
df2e6a9 e917745 df2e6a9 314019b da4e37a dd6d97a da4e37a e917745 da4e37a dd6d97a da4e37a dd6d97a da4e37a 98dd801 7496119 da4e37a 314019b da4e37a 314019b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import gradio as gr
import torch
from openprompt.plms import load_plm
from openprompt import PromptDataLoader
from openprompt.prompts import ManualVerbalizer
from openprompt.prompts import ManualTemplate
from openprompt.data_utils import InputExample
from openprompt import PromptForClassification
def sentiment_analysis(sentence, template, model_name, positive, neutral, negative):
model_name = model_name
template = template.replace('[SENTENCE]', '{"placeholder":"text_a"}')
template = template.replace('[MASK]', '{"mask"}')
classes = ['positive', 'neutral', 'negative']
label_words = {
"positive": positive.split(" "),
"neutral": neutral.split(" "),
"negative": negative.split(" "),
}
type_dic = {
"bert-base-uncased":"bert",
"roberta-base":"roberta",
"yiyanghkust/finbert-pretrain":"bert",
}
testdata = [InputExample(guid=0,text_a=sentence,label=0)]
plm, tokenizer, model_config, WrapperClass = load_plm(type_dic[model_name], model_name)
promptTemplate = ManualTemplate(
text = template,
tokenizer = tokenizer,
)
promptVerbalizer = ManualVerbalizer(
classes = classes,
label_words = label_words,
tokenizer = tokenizer,
)
test_dataloader = PromptDataLoader(
dataset = testdata,
tokenizer = tokenizer,
template = promptTemplate,
tokenizer_wrapper_class = WrapperClass,
batch_size = 1,
max_seq_length = 512,
)
prompt_model = PromptForClassification(
plm=plm,
template=promptTemplate,
verbalizer=promptVerbalizer,
freeze_plm=False #whether or not to freeze the pretrained language model
)
for step, inputs in enumerate(test_dataloader):
logits = prompt_model(inputs)
return classes[torch.argmax(logits, dim=-1)[0]]
demo = gr.Interface(fn=sentiment_analysis,
inputs = [gr.Textbox(placeholder="Enter sentence here.",label="sentence"),
gr.Textbox(placeholder="Your template must have a [SENTENCE] token and a [MASK] token.",label="template"),
gr.Radio(choices=["roberta-base","bert-base-uncased","yiyanghkust/finbert-pretrain"], label="model choics"),
gr.Textbox(placeholder="Separate words with Spaces.",label="positive label words"),
gr.Textbox(placeholder="Separate words with Spaces.",label="neutral label words"),
gr.Textbox(placeholder="Separate words with Spaces.",label="negative label words")
],
outputs="text",
)
demo.launch() |