donjuanplatinum commited on
Commit
c37fce3
·
1 Parent(s): fd3502c

Upload run_demo.py

Browse files
Files changed (1) hide show
  1. run_demo.py +192 -0
run_demo.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy
4
+ import torch
5
+ import random
6
+ import gradio as gr
7
+
8
+ from transformers import AutoTokenizer, AutoModel
9
+
10
+ def get_model():
11
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
12
+ model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True).to('cpu')
13
+ # 如需实现多显卡模型加载,请将上面一行注释并启用一下两行,"num_gpus"调整为自己需求的显卡数量 / To enable Multiple GPUs model loading, please uncomment the line above and enable the following two lines. Adjust "num_gpus" to the desired number of graphics cards.
14
+ # from gpus import load_model_on_gpus
15
+ # model = load_model_on_gpus("THUDM/codegeex2-6b", num_gpus=2)
16
+ model = model.eval()
17
+ return tokenizer, model
18
+
19
+ tokenizer, model = get_model()
20
+
21
+ examples = []
22
+ with open(os.path.join(os.path.split(os.path.realpath(__file__))[0], "example_inputs.jsonl"), "r", encoding="utf-8") as f:
23
+ for line in f:
24
+ examples.append(list(json.loads(line).values()))
25
+
26
+
27
+ LANGUAGE_TAG = {
28
+ "Abap" : "* language: Abap",
29
+ "ActionScript" : "// language: ActionScript",
30
+ "Ada" : "-- language: Ada",
31
+ "Agda" : "-- language: Agda",
32
+ "ANTLR" : "// language: ANTLR",
33
+ "AppleScript" : "-- language: AppleScript",
34
+ "Assembly" : "; language: Assembly",
35
+ "Augeas" : "// language: Augeas",
36
+ "AWK" : "// language: AWK",
37
+ "Basic" : "' language: Basic",
38
+ "C" : "// language: C",
39
+ "C#" : "// language: C#",
40
+ "C++" : "// language: C++",
41
+ "CMake" : "# language: CMake",
42
+ "Cobol" : "// language: Cobol",
43
+ "CSS" : "/* language: CSS */",
44
+ "CUDA" : "// language: Cuda",
45
+ "Dart" : "// language: Dart",
46
+ "Delphi" : "{language: Delphi}",
47
+ "Dockerfile" : "# language: Dockerfile",
48
+ "Elixir" : "# language: Elixir",
49
+ "Erlang" : f"% language: Erlang",
50
+ "Excel" : "' language: Excel",
51
+ "F#" : "// language: F#",
52
+ "Fortran" : "!language: Fortran",
53
+ "GDScript" : "# language: GDScript",
54
+ "GLSL" : "// language: GLSL",
55
+ "Go" : "// language: Go",
56
+ "Groovy" : "// language: Groovy",
57
+ "Haskell" : "-- language: Haskell",
58
+ "HTML" : "<!--language: HTML-->",
59
+ "Isabelle" : "(*language: Isabelle*)",
60
+ "Java" : "// language: Java",
61
+ "JavaScript" : "// language: JavaScript",
62
+ "Julia" : "# language: Julia",
63
+ "Kotlin" : "// language: Kotlin",
64
+ "Lean" : "-- language: Lean",
65
+ "Lisp" : "; language: Lisp",
66
+ "Lua" : "// language: Lua",
67
+ "Markdown" : "<!--language: Markdown-->",
68
+ "Matlab" : f"% language: Matlab",
69
+ "Objective-C" : "// language: Objective-C",
70
+ "Objective-C++": "// language: Objective-C++",
71
+ "Pascal" : "// language: Pascal",
72
+ "Perl" : "# language: Perl",
73
+ "PHP" : "// language: PHP",
74
+ "PowerShell" : "# language: PowerShell",
75
+ "Prolog" : f"% language: Prolog",
76
+ "Python" : "# language: Python",
77
+ "R" : "# language: R",
78
+ "Racket" : "; language: Racket",
79
+ "RMarkdown" : "# language: RMarkdown",
80
+ "Ruby" : "# language: Ruby",
81
+ "Rust" : "// language: Rust",
82
+ "Scala" : "// language: Scala",
83
+ "Scheme" : "; language: Scheme",
84
+ "Shell" : "# language: Shell",
85
+ "Solidity" : "// language: Solidity",
86
+ "SPARQL" : "# language: SPARQL",
87
+ "SQL" : "-- language: SQL",
88
+ "Swift" : "// language: swift",
89
+ "TeX" : f"% language: TeX",
90
+ "Thrift" : "/* language: Thrift */",
91
+ "TypeScript" : "// language: TypeScript",
92
+ "Vue" : "<!--language: Vue-->",
93
+ "Verilog" : "// language: Verilog",
94
+ "Visual Basic" : "' language: Visual Basic",
95
+ }
96
+
97
+
98
+ def set_random_seed(seed):
99
+ """Set random seed for reproducability."""
100
+ random.seed(seed)
101
+ numpy.random.seed(seed)
102
+ torch.manual_seed(seed)
103
+
104
+
105
+ def main():
106
+ def predict(
107
+ prompt,
108
+ lang,
109
+ seed,
110
+ out_seq_length,
111
+ temperature,
112
+ top_k,
113
+ top_p,
114
+ ):
115
+ set_random_seed(seed)
116
+ if lang != "None":
117
+ prompt = LANGUAGE_TAG[lang] + "\n" + prompt
118
+
119
+ inputs = tokenizer([prompt], return_tensors="pt")
120
+ inputs = inputs.to(model.device)
121
+ outputs = model.generate(**inputs,
122
+ max_length=inputs['input_ids'].shape[-1] + out_seq_length,
123
+ do_sample=True,
124
+ top_p=top_p,
125
+ top_k=top_k,
126
+ temperature=temperature,
127
+ pad_token_id=2,
128
+ eos_token_id=2)
129
+ response = tokenizer.decode(outputs[0])
130
+
131
+ return response
132
+
133
+ with gr.Blocks(title="CodeGeeX2 DEMO") as demo:
134
+ gr.Markdown(
135
+ """
136
+ <p align="center">
137
+ <img src="https://raw.githubusercontent.com/THUDM/CodeGeeX2/main/resources/codegeex_logo.png">
138
+ </p>
139
+ """)
140
+ gr.Markdown(
141
+ """
142
+ <p align="center">
143
+ 🏠 <a href="https://codegeex.cn" target="_blank">Homepage</a>|💻 <a href="https://github.com/THUDM/CodeGeeX2" target="_blank">GitHub</a>|🛠 Tools <a href="https://marketplace.visualstudio.com/items?itemName=aminer.codegeex" target="_blank">VS Code</a>, <a href="https://plugins.jetbrains.com/plugin/20587-codegeex" target="_blank">Jetbrains</a>|🤗 <a href="https://huggingface.co/THUDM/codegeex2-6b" target="_blank">HF Repo</a>|📄 <a href="https://arxiv.org/abs/2303.17568" target="_blank">Paper</a>
144
+ </p>
145
+ """)
146
+ gr.Markdown(
147
+ """
148
+ This is the DEMO for CodeGeeX2. Please note that:
149
+ * CodeGeeX2 is a base model, which is not instruction-tuned for chatting. It can do tasks like code completion/translation/explaination. To try the instruction-tuned version in CodeGeeX plugins ([VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex), [Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex)).
150
+ * Programming languages can be controled by adding `language tag`, e.g., `# language: Python`. The format should be respected to ensure performance, full list can be found [here](https://github.com/THUDM/CodeGeeX2/blob/main/evaluation/utils.py#L14).
151
+ * Write comments under the format of the selected programming language to achieve better results, see examples below.
152
+ """)
153
+
154
+ with gr.Row():
155
+ with gr.Column():
156
+ prompt = gr.Textbox(lines=13, placeholder='Please enter the description or select an example input below.',label='Input')
157
+ with gr.Row():
158
+ gen = gr.Button("Generate")
159
+ clr = gr.Button("Clear")
160
+
161
+ outputs = gr.Textbox(lines=15, label='Output')
162
+
163
+ gr.Markdown(
164
+ """
165
+ Generation Parameter
166
+ """)
167
+
168
+ with gr.Row():
169
+ with gr.Row():
170
+ seed = gr.Slider(maximum=10000, value=8888, step=1, label='Seed')
171
+ with gr.Row():
172
+ out_seq_length = gr.Slider(maximum=8192, value=128, minimum=1, step=1, label='Output Sequence Length')
173
+ temperature = gr.Slider(maximum=1, value=0.2, minimum=0, label='Temperature')
174
+ with gr.Row():
175
+ top_k = gr.Slider(maximum=100, value=0, minimum=0, step=1, label='Top K')
176
+ top_p = gr.Slider(maximum=1, value=0.95, minimum=0, label='Top P')
177
+ with gr.Row():
178
+ lang = gr.Radio(
179
+ choices=["None"] + list(LANGUAGE_TAG.keys()), value='None', label='Programming Language')
180
+ inputs = [prompt, lang, seed, out_seq_length, temperature, top_k, top_p]
181
+ gen.click(fn=predict, inputs=inputs, outputs=outputs)
182
+ clr.click(fn=lambda value: gr.update(value=""), inputs=clr, outputs=prompt)
183
+
184
+ gr_examples = gr.Examples(examples=examples, inputs=[prompt, lang],
185
+ label="Example Inputs (Click to insert an examplet it into the input box)",
186
+ examples_per_page=20)
187
+
188
+ demo.launch(share=True)
189
+
190
+ if __name__ == '__main__':
191
+ with torch.no_grad():
192
+ main()