StarFox7 commited on
Commit
95381c3
Β·
1 Parent(s): be9bc52

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +328 -0
app.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run codes."""
2
+ # pylint: disable=line-too-long, broad-exception-caught, invalid-name, missing-function-docstring, too-many-instance-attributes, missing-class-docstring
3
+ # ruff: noqa: E501
4
+ import os
5
+ import platform
6
+ import random
7
+ import time
8
+ from dataclasses import asdict, dataclass
9
+ from pathlib import Path
10
+
11
+ # from types import SimpleNamespace
12
+ import gradio as gr
13
+ import psutil
14
+ from about_time import about_time
15
+ from ctransformers import AutoModelForCausalLM
16
+ from dl_hf_model import dl_hf_model
17
+ from loguru import logger
18
+
19
+ filename_list = [
20
+ "Llama-2-ko-7B-chat-ggml-q4_0.bin"
21
+ ]
22
+
23
+ url = "https://huggingface.co/StarFox7/Llama-2-ko-7B-chat-ggml/blob/main/Llama-2-ko-7B-chat-ggml-q4_0.bin"
24
+
25
+ prompt_template = "Q: {q}. A: "
26
+
27
+ _ = [elm for elm in prompt_template.splitlines() if elm.strip()]
28
+ stop_string = [elm.split(":")[0] + ":" for elm in _][-2]
29
+
30
+ logger.debug(f"{stop_string=} not used")
31
+
32
+ _ = psutil.cpu_count(logical=False) - 1
33
+ cpu_count: int = int(_) if _ else 1
34
+ logger.debug(f"{cpu_count=}")
35
+
36
+ LLM = None
37
+
38
+ try:
39
+ model_loc, file_size = dl_hf_model(url)
40
+ except Exception as exc_:
41
+ logger.error(exc_)
42
+ raise SystemExit(1) from exc_
43
+
44
+ LLM = AutoModelForCausalLM.from_pretrained(
45
+ model_loc,
46
+ model_type="llama",
47
+ # threads=cpu_count,
48
+ )
49
+
50
+ logger.info(f"done load llm {model_loc=} {file_size=}G")
51
+
52
+ os.environ["TZ"] = "Asia/Seoul"
53
+ try:
54
+ time.tzset() # type: ignore # pylint: disable=no-member
55
+ except Exception:
56
+ # Windows
57
+ logger.warning("Windows, cant run time.tzset()")
58
+
59
+ _ = """
60
+ ns = SimpleNamespace(
61
+ response="",
62
+ generator=(_ for _ in []),
63
+ )
64
+ # """
65
+
66
+ @dataclass
67
+ class GenerationConfig:
68
+ temperature: float = 0.7
69
+ top_k: int = 50
70
+ top_p: float = 0.9
71
+ repetition_penalty: float = 1.0
72
+ max_new_tokens: int = 1024
73
+ seed: int = 42
74
+ reset: bool = False
75
+ stream: bool = True
76
+ # threads: int = cpu_count
77
+ # stop: list[str] = field(default_factory=lambda: [stop_string])
78
+
79
+
80
+ def generate(
81
+ question: str,
82
+ llm=LLM,
83
+ config: GenerationConfig = GenerationConfig(),
84
+ ):
85
+ """Run model inference, will return a Generator if streaming is true."""
86
+ # _ = prompt_template.format(question=question)
87
+ # print(_)
88
+
89
+ prompt = prompt_template.format(question=question)
90
+
91
+ return llm(
92
+ prompt,
93
+ **asdict(config),
94
+ )
95
+
96
+
97
+ logger.debug(f"{asdict(GenerationConfig())=}")
98
+
99
+
100
+ def user(user_message, history):
101
+ # return user_message, history + [[user_message, None]]
102
+ history.append([user_message, None])
103
+ return user_message, history # keep user_message
104
+
105
+
106
+ def user1(user_message, history):
107
+ # return user_message, history + [[user_message, None]]
108
+ history.append([user_message, None])
109
+ return "", history # clear user_message
110
+
111
+
112
+ def bot_(history):
113
+ user_message = history[-1][0]
114
+ resp = random.choice(["How are you?", "I love you", "I'm very hungry"])
115
+ bot_message = user_message + ": " + resp
116
+ history[-1][1] = ""
117
+ for character in bot_message:
118
+ history[-1][1] += character
119
+ time.sleep(0.02)
120
+ yield history
121
+
122
+ history[-1][1] = resp
123
+ yield history
124
+
125
+
126
+ def bot(history):
127
+ user_message = history[-1][0]
128
+ response = []
129
+
130
+ logger.debug(f"{user_message=}")
131
+
132
+ with about_time() as atime: # type: ignore
133
+ flag = 1
134
+ prefix = ""
135
+ then = time.time()
136
+
137
+ logger.debug("about to generate")
138
+
139
+ config = GenerationConfig(reset=True)
140
+ for elm in generate(user_message, config=config):
141
+ if flag == 1:
142
+ logger.debug("in the loop")
143
+ prefix = f"({time.time() - then:.2f}s) "
144
+ flag = 0
145
+ print(prefix, end="", flush=True)
146
+ logger.debug(f"{prefix=}")
147
+ print(elm, end="", flush=True)
148
+ # logger.debug(f"{elm}")
149
+
150
+ response.append(elm)
151
+ history[-1][1] = prefix + "".join(response)
152
+ yield history
153
+
154
+ _ = (
155
+ f"(time elapsed: {atime.duration_human}, " # type: ignore
156
+ f"{atime.duration/len(''.join(response)):.2f}s/char)" # type: ignore
157
+ )
158
+
159
+ history[-1][1] = "".join(response) + f"\n{_}"
160
+ yield history
161
+
162
+
163
+ def predict_api(prompt):
164
+ logger.debug(f"{prompt=}")
165
+ try:
166
+ # user_prompt = prompt
167
+ config = GenerationConfig(
168
+ temperature=0.2,
169
+ top_k=10,
170
+ top_p=0.9,
171
+ repetition_penalty=1.0,
172
+ max_new_tokens=512, # adjust as needed
173
+ seed=42,
174
+ reset=True, # reset history (cache)
175
+ stream=False,
176
+ # threads=cpu_count,
177
+ # stop=prompt_prefix[1:2],
178
+ )
179
+
180
+ response = generate(
181
+ prompt,
182
+ config=config,
183
+ )
184
+
185
+ logger.debug(f"api: {response=}")
186
+ except Exception as exc:
187
+ logger.error(exc)
188
+ response = f"{exc=}"
189
+ # bot = {"inputs": [response]}
190
+ # bot = [(prompt, response)]
191
+
192
+ return response
193
+
194
+
195
+ css = """
196
+ .importantButton {
197
+ background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
198
+ border: none !important;
199
+ }
200
+ .importantButton:hover {
201
+ background: linear-gradient(45deg, #ff00e0,#8500ff, #6e00ff) !important;
202
+ border: none !important;
203
+ }
204
+ .disclaimer {font-variant-caps: all-small-caps; font-size: xx-small;}
205
+ .xsmall {font-size: x-small;}
206
+ """
207
+
208
+ examples_list = [
209
+ ["μΈμƒμ΄λž€ λ­˜κΉŒμš”?"],
210
+ ]
211
+
212
+ logger.info("start block")
213
+
214
+ with gr.Blocks(
215
+ title=f"{Path(model_loc).name}",
216
+ theme=gr.themes.Soft(text_size="sm", spacing_size="sm"),
217
+ css=css,
218
+ ) as block:
219
+ # buff_var = gr.State("")
220
+ with gr.Accordion("🎈 Info", open=False):
221
+ # gr.HTML(
222
+ # """<center><a href="https://huggingface.co/spaces/mikeee/mpt-30b-chat?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate"></a> and spin a CPU UPGRADE to avoid the queue</center>"""
223
+ # )
224
+ gr.Markdown(
225
+ f"""<h5><center>{Path(model_loc).name}</center></h4>
226
+ Most examples are meant for another model.
227
+ You probably should try to test
228
+ some related prompts.""",
229
+ elem_classes="xsmall",
230
+ )
231
+
232
+ # chatbot = gr.Chatbot().style(height=700) # 500
233
+ chatbot = gr.Chatbot(height=500)
234
+
235
+ # buff = gr.Textbox(show_label=False, visible=True)
236
+
237
+ with gr.Row():
238
+ with gr.Column(scale=5):
239
+ msg = gr.Textbox(
240
+ label="Chat Message Box",
241
+ placeholder="Ask me anything (press Shift+Enter or click Submit to send)",
242
+ show_label=False,
243
+ # container=False,
244
+ lines=6,
245
+ max_lines=30,
246
+ show_copy_button=True,
247
+ # ).style(container=False)
248
+ )
249
+ with gr.Column(scale=1, min_width=50):
250
+ with gr.Row():
251
+ submit = gr.Button("Submit", elem_classes="xsmall")
252
+ stop = gr.Button("Stop", visible=True)
253
+ clear = gr.Button("Clear History", visible=True)
254
+ with gr.Row(visible=False):
255
+ with gr.Accordion("Advanced Options:", open=False):
256
+ with gr.Row():
257
+ with gr.Column(scale=2):
258
+ system = gr.Textbox(
259
+ label="System Prompt",
260
+ value=prompt_template,
261
+ show_label=False,
262
+ container=False,
263
+ # ).style(container=False)
264
+ )
265
+ with gr.Column():
266
+ with gr.Row():
267
+ change = gr.Button("Change System Prompt")
268
+ reset = gr.Button("Reset System Prompt")
269
+
270
+ with gr.Accordion("Example Inputs", open=True):
271
+ examples = gr.Examples(
272
+ examples=examples_list,
273
+ inputs=[msg],
274
+ examples_per_page=40,
275
+ )
276
+
277
+ # with gr.Row():
278
+ with gr.Accordion("Disclaimer", open=False):
279
+ _ = Path(model_loc).name
280
+ gr.Markdown(
281
+ f"Disclaimer: {_} can produce factually incorrect output, and should not be relied on to produce "
282
+ "factually accurate information. {_} was trained on various public datasets; while great efforts "
283
+ "have been taken to clean the pretraining data, it is possible that this model could generate lewd, "
284
+ "biased, or otherwise offensive outputs.",
285
+ elem_classes=["disclaimer"],
286
+ )
287
+
288
+ msg_submit_event = msg.submit(
289
+ # fn=conversation.user_turn,
290
+ fn=user,
291
+ inputs=[msg, chatbot],
292
+ outputs=[msg, chatbot],
293
+ queue=True,
294
+ show_progress="full",
295
+ # api_name=None,
296
+ ).then(bot, chatbot, chatbot, queue=True)
297
+ submit_click_event = submit.click(
298
+ fn=user1, # clear msg
299
+ inputs=[msg, chatbot],
300
+ outputs=[msg, chatbot],
301
+ queue=True,
302
+ show_progress="full",
303
+ ).then(bot, chatbot, chatbot, queue=True)
304
+ stop.click(
305
+ fn=None,
306
+ inputs=None,
307
+ outputs=None,
308
+ cancels=[msg_submit_event, submit_click_event],
309
+ queue=False,
310
+ )
311
+ clear.click(lambda: None, None, chatbot, queue=False)
312
+
313
+ with gr.Accordion("For Chat/Translation API", open=False, visible=False):
314
+ input_text = gr.Text()
315
+ api_btn = gr.Button("Go", variant="primary")
316
+ out_text = gr.Text()
317
+
318
+ api_btn.click(
319
+ predict_api,
320
+ input_text,
321
+ out_text,
322
+ api_name="api",
323
+ )
324
+
325
+ concurrency_count = 1
326
+ logger.info(f"{concurrency_count=}")
327
+
328
+ block.queue(concurrency_count=concurrency_count, max_size=5).launch(debug=True)