cccc commited on
Commit
da4e37a
·
1 Parent(s): 5513697

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -75
app.py CHANGED
@@ -1,86 +1,76 @@
1
  import gradio as gr
2
  from transformers import pipeline
 
 
3
  from openprompt import PromptDataLoader
4
  from openprompt.prompts import ManualVerbalizer
5
  from openprompt.prompts import ManualTemplate
 
6
  from openprompt import PromptForClassification
7
 
8
- unmasker = pipeline('fill-mask', model="CCCC/ARCH_tuned_bert") #'bert-base-uncased')
9
 
10
- def fill_mask(text):
11
- unmasked = unmasker(text)
12
- output = {}
13
- for unmask in unmasked:
14
- output[unmask["token_str"]] = unmask["score"]
15
- return output
16
 
17
- examples = [["Hello I'm a [MASK] model."], ["[MASK] is my favourite sports."]]
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- css = """
20
- footer {display:none !important}
21
- .output-markdown{display:none !important}
22
- .gr-button-primary {
23
- z-index: 14;
24
- height: 43px;
25
- width: 130px;
26
- left: 0px;
27
- top: 0px;
28
- padding: 0px;
29
- cursor: pointer !important;
30
- background: none rgb(17, 20, 45) !important;
31
- border: none !important;
32
- text-align: center !important;
33
- font-family: Poppins !important;
34
- font-size: 14px !important;
35
- font-weight: 500 !important;
36
- color: rgb(255, 255, 255) !important;
37
- line-height: 1 !important;
38
- border-radius: 12px !important;
39
- transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
40
- box-shadow: none !important;
41
- }
42
- .gr-button-primary:hover{
43
- z-index: 14;
44
- height: 43px;
45
- width: 130px;
46
- left: 0px;
47
- top: 0px;
48
- padding: 0px;
49
- cursor: pointer !important;
50
- background: none rgb(37, 56, 133) !important;
51
- border: none !important;
52
- text-align: center !important;
53
- font-family: Poppins !important;
54
- font-size: 14px !important;
55
- font-weight: 500 !important;
56
- color: rgb(255, 255, 255) !important;
57
- line-height: 1 !important;
58
- border-radius: 12px !important;
59
- transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
60
- box-shadow: rgb(0 0 0 / 23%) 0px 1px 7px 0px !important;
61
- }
62
- .hover\:bg-orange-50:hover {
63
- --tw-bg-opacity: 1 !important;
64
- background-color: rgb(229,225,255) !important;
65
- }
66
- .to-orange-200 {
67
- --tw-gradient-to: rgb(37 56 133 / 37%) !important;
68
- }
69
- .from-orange-400 {
70
- --tw-gradient-from: rgb(17, 20, 45) !important;
71
- --tw-gradient-to: rgb(255 150 51 / 0);
72
- --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important;
73
- }
74
- .group-hover\:from-orange-500{
75
- --tw-gradient-from:rgb(17, 20, 45) !important;
76
- --tw-gradient-to: rgb(37 56 133 / 37%);
77
- --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important;
78
- }
79
- .group:hover .group-hover\:text-orange-500{
80
- --tw-text-opacity: 1 !important;
81
- color:rgb(37 56 133 / var(--tw-text-opacity)) !important;
82
- }
83
- """
84
 
85
- demo = gr.Interface(fn=fill_mask, inputs=gr.Textbox(lines=1, label="Input"), outputs=gr.Label(label="Output"),title="Fill Mask | Data Science Dojo", theme="light", examples=examples, css=css)
86
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from transformers import pipeline
3
+ import torch
4
+ from openprompt.plms import load_plm
5
  from openprompt import PromptDataLoader
6
  from openprompt.prompts import ManualVerbalizer
7
  from openprompt.prompts import ManualTemplate
8
+ from openprompt.data_utils import InputExample
9
  from openprompt import PromptForClassification
10
 
 
11
 
12
+ def sentiment_analysis(sentence, template, positive, neutral, negative):
13
+ model_name = "CCCC/ARCH_tuned_bert"
14
+ template = template.replace('[SENTENCE]', '{"placeholder":"text_a"}')
15
+ template = template.replace('[MASK]', '{"mask"}')
16
+ classes = ['positive', 'neutral', 'negative']
 
17
 
18
+ label_words = {
19
+ "positive": positive.split(" "),
20
+ "neutral": neutral.split(" "),
21
+ "negative": negative.split(" "),
22
+ }
23
+ print(label_words)
24
+ type_dic = {
25
+ "CCCC/ARCH_tuned_bert":"bert",
26
+ "bert-base-uncased":"bert",
27
+ "roberta-base":"roberta",
28
+ "yiyanghkust/finbert-pretrain":"bert",
29
+ "facebook/opt-125m":"opt",
30
+ "facebook/opt-350m":"opt",
31
+ }
32
 
33
+ testdata = [InputExample(guid=0,text_a=sentence,label=0)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ plm, tokenizer, model_config, WrapperClass = load_plm(type_dic[model_name], model_name)
36
+
37
+ promptTemplate = ManualTemplate(
38
+ text = template,
39
+ tokenizer = tokenizer,
40
+ )
41
+ promptVerbalizer = ManualVerbalizer(
42
+ classes = classes,
43
+ label_words = label_words,
44
+ tokenizer = tokenizer,
45
+ )
46
+ test_dataloader = PromptDataLoader(
47
+ dataset = testdata,
48
+ tokenizer = tokenizer,
49
+ template = promptTemplate,
50
+ tokenizer_wrapper_class = WrapperClass,
51
+ batch_size = 1,
52
+ max_seq_length = 512,
53
+ )
54
+ prompt_model = PromptForClassification(
55
+ plm=plm,
56
+ template=promptTemplate,
57
+ verbalizer=promptVerbalizer,
58
+ freeze_plm=False #whether or not to freeze the pretrained language model
59
+ )
60
+ for step, inputs in enumerate(test_dataloader):
61
+ logits = prompt_model(inputs)
62
+
63
+
64
+ return classes[torch.argmax(logits, dim=-1)[0]]
65
+
66
+
67
+ demo = gr.Interface(fn=sentiment_analysis,
68
+ inputs = [gr.Textbox(placeholder="Enter sentence here.",label="sentence"),
69
+ gr.Textbox(placeholder="Your template must have a [SENTENCE] token and a [MASK] token.",label="template"),
70
+ gr.Textbox(placeholder="Separate words with Spaces.",label="positive"),
71
+ gr.Textbox(placeholder="Separate words with Spaces.",label="neutral"),
72
+ gr.Textbox(placeholder="Separate words with Spaces.",label="negative")
73
+ ],
74
+ outputs="text")
75
+
76
+ demo.launch(server_port=8080)