File size: 2,743 Bytes
3a71eee dd6d97a da4e37a 31f4e47 da4e37a 31f4e47 dd6d97a da4e37a dd6d97a da4e37a dd6d97a da4e37a dd6d97a da4e37a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import gradio as gr
from transformers import pipeline
import torch
from openprompt.plms import load_plm
from openprompt import PromptDataLoader
from openprompt.prompts import ManualVerbalizer
from openprompt.prompts import ManualTemplate
from openprompt.data_utils import InputExample
from openprompt import PromptForClassification
def sentiment_analysis(sentence, template, positive, neutral, negative):
model_name = "CCCC/ARCH_tuned_bert"
template = template.replace('[SENTENCE]', '{"placeholder":"text_a"}')
template = template.replace('[MASK]', '{"mask"}')
classes = ['positive', 'neutral', 'negative']
label_words = {
"positive": positive.split(" "),
"neutral": neutral.split(" "),
"negative": negative.split(" "),
}
print(label_words)
type_dic = {
"CCCC/ARCH_tuned_bert":"bert",
"bert-base-uncased":"bert",
"roberta-base":"roberta",
"yiyanghkust/finbert-pretrain":"bert",
"facebook/opt-125m":"opt",
"facebook/opt-350m":"opt",
}
testdata = [InputExample(guid=0,text_a=sentence,label=0)]
plm, tokenizer, model_config, WrapperClass = load_plm(type_dic[model_name], model_name)
promptTemplate = ManualTemplate(
text = template,
tokenizer = tokenizer,
)
promptVerbalizer = ManualVerbalizer(
classes = classes,
label_words = label_words,
tokenizer = tokenizer,
)
test_dataloader = PromptDataLoader(
dataset = testdata,
tokenizer = tokenizer,
template = promptTemplate,
tokenizer_wrapper_class = WrapperClass,
batch_size = 1,
max_seq_length = 512,
)
prompt_model = PromptForClassification(
plm=plm,
template=promptTemplate,
verbalizer=promptVerbalizer,
freeze_plm=False #whether or not to freeze the pretrained language model
)
for step, inputs in enumerate(test_dataloader):
logits = prompt_model(inputs)
return classes[torch.argmax(logits, dim=-1)[0]]
demo = gr.Interface(fn=sentiment_analysis,
inputs = [gr.Textbox(placeholder="Enter sentence here.",label="sentence"),
gr.Textbox(placeholder="Your template must have a [SENTENCE] token and a [MASK] token.",label="template"),
gr.Textbox(placeholder="Separate words with Spaces.",label="positive"),
gr.Textbox(placeholder="Separate words with Spaces.",label="neutral"),
gr.Textbox(placeholder="Separate words with Spaces.",label="negative")
],
outputs="text")
demo.launch(server_port=8080)
|