HYCCC's picture
Update app.py
8554332
import gradio as gr
from transformers import pipeline
import torch
import pandas as pd
from openprompt.plms import load_plm
from openprompt import PromptDataLoader
from openprompt.prompts import ManualVerbalizer
from openprompt.prompts import ManualTemplate
from openprompt.data_utils import InputExample
from openprompt import PromptForClassification
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
def readLMwords():
alldata = pd.read_csv("LoughranMcDonald_MasterDictionary_2020.csv")
positive = list(alldata[alldata["Positive"]!=0]["Word"].str.lower())
negative = list(alldata[alldata["Negative"]!=0]["Word"].str.lower())
uncertainty = list(alldata[alldata["Uncertainty"]!=0]["Word"].str.lower())
return positive,negative,uncertainty
def sentiment_analysis(sentence, model_name):
model_name = "HYCCC/"+model_name
raw_sentences = sentence.strip().split('\n')
template = '{"placeholder":"text_a"} Shares are {"mask"}.'
classes = ['positive', 'neutral', 'negative']
positive,negative,neutral = readLMwords()
label_words = {
"positive": positive,
"neutral": neutral,
"negative": negative,
}
type_dic = {
"HYCCC/RoBERTa_Chinese_AnnualReport_tuned":"roberta",
"HYCCC/RoBERTa_Chinese_FinancialNews_tuned":"roberta",
"HYCCC/RoBERTa_English_AnnualReport_tuned":"roberta",
"HYCCC/RoBERTa_English_FinancialNews_tuned":"roberta",
}
if 'Chinese' in model_name:
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
translated_tokens = model.generate(
**tokenizer(raw_sentences, return_tensors="pt", padding=True)
)
sentences_translated = []
for t in translated_tokens:
sentences_translated.append(tokenizer.decode(t, skip_special_tokens=True))
sentences = sentences_translated
else:
sentences = raw_sentences
testdata = []
for i,sentence in enumerate(sentences):
testdata.append(InputExample(guid=i,text_a=sentence,label=0))
plm, tokenizer, model_config, WrapperClass = load_plm(type_dic[model_name], model_name)
promptTemplate = ManualTemplate(
text = template,
tokenizer = tokenizer,
)
promptVerbalizer = ManualVerbalizer(
classes = classes,
label_words = label_words,
tokenizer = tokenizer,
)
test_dataloader = PromptDataLoader(
dataset = testdata,
tokenizer = tokenizer,
template = promptTemplate,
tokenizer_wrapper_class = WrapperClass,
batch_size = 4,
max_seq_length = 512,
)
prompt_model = PromptForClassification(
plm=plm,
template=promptTemplate,
verbalizer=promptVerbalizer,
freeze_plm=True
)
result = []
for step, inputs in enumerate(test_dataloader):
logits = prompt_model(inputs)
result.extend(torch.argmax(logits, dim=-1))
output = '\n'.join([f"{classes[res]}, {raw_sentences[i]}" for i,res in enumerate(result)])
return str(output)
demo = gr.Interface(fn=sentiment_analysis,
inputs = [gr.TextArea(placeholder="Enter sentence here. If you have multiple sentences, separate them with '\\n'.",
label="Sentence",lines=5,
max_lines = 10),
gr.Radio(choices=["RoBERTa_Chinese_AnnualReport_tuned",
"RoBERTa_Chinese_FinancialNews_tuned",
"RoBERTa_English_AnnualReport_tuned",
"RoBERTa_English_FinancialNews_tuned"],
label="Model Selection")],
outputs=gr.TextArea(label="Sentiment",lines=5, show_copy_button=True, max_lines = 10),
title = "Prompt Learning-Based Disclosure Sentiment Detection"
)
demo.launch()