import gradio as gr from transformers import pipeline import torch from openprompt.plms import load_plm from openprompt import PromptDataLoader from openprompt.prompts import ManualVerbalizer from openprompt.prompts import ManualTemplate from openprompt.data_utils import InputExample from openprompt import PromptForClassification def sentiment_analysis(sentence, template, positive, neutral, negative): model_name = "CCCC/ARCH_tuned_bert" template = template.replace('[SENTENCE]', '{"placeholder":"text_a"}') template = template.replace('[MASK]', '{"mask"}') classes = ['positive', 'neutral', 'negative'] label_words = { "positive": positive.split(" "), "neutral": neutral.split(" "), "negative": negative.split(" "), } print(label_words) type_dic = { "CCCC/ARCH_tuned_bert":"bert", "bert-base-uncased":"bert", "roberta-base":"roberta", "yiyanghkust/finbert-pretrain":"bert", "facebook/opt-125m":"opt", "facebook/opt-350m":"opt", } testdata = [InputExample(guid=0,text_a=sentence,label=0)] plm, tokenizer, model_config, WrapperClass = load_plm(type_dic[model_name], model_name) promptTemplate = ManualTemplate( text = template, tokenizer = tokenizer, ) promptVerbalizer = ManualVerbalizer( classes = classes, label_words = label_words, tokenizer = tokenizer, ) test_dataloader = PromptDataLoader( dataset = testdata, tokenizer = tokenizer, template = promptTemplate, tokenizer_wrapper_class = WrapperClass, batch_size = 1, max_seq_length = 512, ) prompt_model = PromptForClassification( plm=plm, template=promptTemplate, verbalizer=promptVerbalizer, freeze_plm=False #whether or not to freeze the pretrained language model ) for step, inputs in enumerate(test_dataloader): logits = prompt_model(inputs) return classes[torch.argmax(logits, dim=-1)[0]] demo = gr.Interface(fn=sentiment_analysis, inputs = [gr.Textbox(placeholder="Enter sentence here.",label="sentence"), gr.Textbox(placeholder="Your template must have a [SENTENCE] token and a [MASK] token.",label="template"), gr.Textbox(placeholder="Separate words with Spaces.",label="positive"), gr.Textbox(placeholder="Separate words with Spaces.",label="neutral"), gr.Textbox(placeholder="Separate words with Spaces.",label="negative") ], outputs="text") demo.launch(server_port=8080)