File size: 2,743 Bytes
3a71eee
dd6d97a
da4e37a
 
31f4e47
 
 
da4e37a
31f4e47
dd6d97a
 
da4e37a
 
 
 
 
dd6d97a
da4e37a
 
 
 
 
 
 
 
 
 
 
 
 
 
dd6d97a
da4e37a
dd6d97a
da4e37a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
from transformers import pipeline
import torch
from openprompt.plms import load_plm
from openprompt import PromptDataLoader
from openprompt.prompts import ManualVerbalizer
from openprompt.prompts import ManualTemplate
from openprompt.data_utils import InputExample
from openprompt import PromptForClassification


def sentiment_analysis(sentence, template, positive, neutral, negative):
    model_name = "CCCC/ARCH_tuned_bert" 
    template = template.replace('[SENTENCE]', '{"placeholder":"text_a"}')
    template = template.replace('[MASK]', '{"mask"}')
    classes = ['positive', 'neutral', 'negative']

    label_words = {
        "positive": positive.split(" "),
        "neutral": neutral.split(" "),
        "negative": negative.split(" "),
    }
    print(label_words)
    type_dic = {
        "CCCC/ARCH_tuned_bert":"bert",
        "bert-base-uncased":"bert",
        "roberta-base":"roberta",
        "yiyanghkust/finbert-pretrain":"bert",
        "facebook/opt-125m":"opt",
        "facebook/opt-350m":"opt",
    }

    testdata = [InputExample(guid=0,text_a=sentence,label=0)]

    plm, tokenizer, model_config, WrapperClass = load_plm(type_dic[model_name], model_name)

    promptTemplate = ManualTemplate(
        text = template,
        tokenizer = tokenizer,
    )
    promptVerbalizer = ManualVerbalizer(
        classes = classes,
        label_words = label_words,
        tokenizer = tokenizer,
    )
    test_dataloader = PromptDataLoader(
        dataset = testdata,
        tokenizer = tokenizer, 
        template = promptTemplate, 
        tokenizer_wrapper_class = WrapperClass,
        batch_size = 1,
        max_seq_length = 512,
    )
    prompt_model = PromptForClassification(
        plm=plm,
        template=promptTemplate, 
        verbalizer=promptVerbalizer, 
        freeze_plm=False #whether or not to freeze the pretrained language model
    )
    for step, inputs in enumerate(test_dataloader):
        logits = prompt_model(inputs)
        
        
    return classes[torch.argmax(logits, dim=-1)[0]]


demo = gr.Interface(fn=sentiment_analysis,
                    inputs = [gr.Textbox(placeholder="Enter sentence here.",label="sentence"),
                            gr.Textbox(placeholder="Your template must have a [SENTENCE] token and a [MASK] token.",label="template"),
                            gr.Textbox(placeholder="Separate words with Spaces.",label="positive"),
                            gr.Textbox(placeholder="Separate words with Spaces.",label="neutral"),
                            gr.Textbox(placeholder="Separate words with Spaces.",label="negative")
                            ],
                    outputs="text")
    
demo.launch(server_port=8080)