BlinkDL commited on
Commit
e4a58ac
·
verified ·
1 Parent(s): 81f3630

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -368
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import os, copy
 
2
  os.environ["RWKV_JIT_ON"] = '1'
3
  os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
4
- # make sure cuda dir is in the same level as modeling_rwkv.py
5
- from modeling_rwkv import RWKV
6
 
7
  import gc, re
8
  import gradio as gr
@@ -18,53 +19,22 @@ nvmlInit()
18
  gpu_h = nvmlDeviceGetHandleByIndex(0)
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
 
21
- ctx_limit = 2500
22
- gen_limit = 500
23
- gen_limit_long = 800
24
- ENABLE_VISUAL = False
25
 
26
  ########################## text rwkv ################################################################
27
  from rwkv.utils import PIPELINE, PIPELINE_ARGS
28
 
29
- title_v6 = "RWKV-x060-World-3B-v2.1-20240417-ctx4096"
30
- model_path_v6 = hf_hub_download(repo_id="BlinkDL/rwkv-6-world", filename=f"{title_v6}.pth")
31
- # model_path_v6 = '/mnt/e/RWKV-Runner/models/rwkv-final-v6-2.1-3b' # conda activate torch2; cd /mnt/program/_RWKV_/_ref_/_gradio_/RWKV-Gradio-1; python app.py
32
  model_v6 = RWKV(model=model_path_v6, strategy='cuda fp16')
33
  pipeline_v6 = PIPELINE(model_v6, "rwkv_vocab_v20230424")
34
 
35
  args = model_v6.args
36
- eng_name = 'rwkv-x060-eng_single_round_qa-3B-20240516-ctx2048'
37
- chn_name = 'rwkv-x060-chn_single_round_qa-3B-20240516-ctx2048'
38
-
39
- # state_eng_raw = torch.load(f'/mnt/e/RWKV-Runner/models/{eng_name}.pth', map_location=torch.device('cpu'))
40
- # state_chn_raw = torch.load(f'/mnt/e/RWKV-Runner/models/{chn_name}.pth', map_location=torch.device('cpu'))
41
-
42
- eng_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{eng_name}.pth")
43
- chn_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{chn_name}.pth")
44
- state_eng_raw = torch.load(eng_file, map_location=torch.device('cpu'))
45
- state_chn_raw = torch.load(chn_file, map_location=torch.device('cpu'))
46
-
47
- state_eng = [None] * args.n_layer * 3
48
- state_chn = [None] * args.n_layer * 3
49
- for i in range(args.n_layer):
50
- dd = model_v6.strategy[i]
51
- dev = dd.device
52
- atype = dd.atype
53
- state_eng[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
54
- state_chn[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
55
- state_eng[i*3+1] = state_eng_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
56
- state_chn[i*3+1] = state_chn_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
57
- state_eng[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
58
- state_chn[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
59
 
60
  penalty_decay = 0.996
61
 
62
- if ENABLE_VISUAL:
63
- title = "RWKV-5-World-1B5-v2-20231025-ctx4096"
64
- model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{title}.pth")
65
- model = RWKV(model=model_path, strategy='cuda fp16')
66
- pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
67
-
68
  def generate_prompt(instruction, input=""):
69
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
70
  input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
@@ -98,122 +68,7 @@ def evaluate(
98
  occurrence = {}
99
  state = None
100
  for i in range(int(token_count)):
101
- input_ids = pipeline_v6.encode(ctx)[-ctx_limit:] if i == 0 else [token]
102
- out, state = model_v6.forward(tokens=input_ids, state=state)
103
- for n in occurrence:
104
- out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
105
-
106
- token = pipeline_v6.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
107
- if token in args.token_stop:
108
- break
109
- all_tokens += [token]
110
- for xxx in occurrence:
111
- occurrence[xxx] *= penalty_decay
112
-
113
- ttt = pipeline_v6.decode([token])
114
- www = 1
115
- if ttt in ' \t0123456789':
116
- www = 0
117
- #elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
118
- # www = 0.5
119
- if token not in occurrence:
120
- occurrence[token] = www
121
- else:
122
- occurrence[token] += www
123
-
124
- tmp = pipeline_v6.decode(all_tokens[out_last:])
125
- if '\ufffd' not in tmp:
126
- out_str += tmp
127
- yield out_str.strip()
128
- out_last = i + 1
129
-
130
- gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
131
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
132
- print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
133
- del out
134
- del state
135
- gc.collect()
136
- torch.cuda.empty_cache()
137
- yield out_str.strip()
138
 
139
- def evaluate_eng(
140
- ctx,
141
- token_count=200,
142
- temperature=1.0,
143
- top_p=0.7,
144
- presencePenalty = 0.1,
145
- countPenalty = 0.1,
146
- ):
147
- args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
148
- alpha_frequency = countPenalty,
149
- alpha_presence = presencePenalty,
150
- token_ban = [], # ban the generation of some tokens
151
- token_stop = [0]) # stop generation whenever you see any token here
152
- ctx = qa_prompt(ctx)
153
- all_tokens = []
154
- out_last = 0
155
- out_str = ''
156
- occurrence = {}
157
- state = copy.deepcopy(state_eng)
158
- for i in range(int(token_count)):
159
- input_ids = pipeline_v6.encode(ctx)[-ctx_limit:] if i == 0 else [token]
160
- out, state = model_v6.forward(tokens=input_ids, state=state)
161
- for n in occurrence:
162
- out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
163
-
164
- token = pipeline_v6.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
165
- if token in args.token_stop:
166
- break
167
- all_tokens += [token]
168
- for xxx in occurrence:
169
- occurrence[xxx] *= penalty_decay
170
-
171
- ttt = pipeline_v6.decode([token])
172
- www = 1
173
- if ttt in ' \t0123456789':
174
- www = 0
175
- #elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
176
- # www = 0.5
177
- if token not in occurrence:
178
- occurrence[token] = www
179
- else:
180
- occurrence[token] += www
181
-
182
- tmp = pipeline_v6.decode(all_tokens[out_last:])
183
- if '\ufffd' not in tmp:
184
- out_str += tmp
185
- yield out_str.strip()
186
- out_last = i + 1
187
-
188
- gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
189
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
190
- print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
191
- del out
192
- del state
193
- gc.collect()
194
- torch.cuda.empty_cache()
195
- yield out_str.strip()
196
-
197
- def evaluate_chn(
198
- ctx,
199
- token_count=200,
200
- temperature=1.0,
201
- top_p=0.7,
202
- presencePenalty = 0.1,
203
- countPenalty = 0.1,
204
- ):
205
- args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
206
- alpha_frequency = countPenalty,
207
- alpha_presence = presencePenalty,
208
- token_ban = [], # ban the generation of some tokens
209
- token_stop = [0]) # stop generation whenever you see any token here
210
- ctx = qa_prompt(ctx)
211
- all_tokens = []
212
- out_last = 0
213
- out_str = ''
214
- occurrence = {}
215
- state = copy.deepcopy(state_chn)
216
- for i in range(int(token_count)):
217
  input_ids = pipeline_v6.encode(ctx)[-ctx_limit:] if i == 0 else [token]
218
  out, state = model_v6.forward(tokens=input_ids, state=state)
219
  for n in occurrence:
@@ -267,167 +122,12 @@ examples = [
267
  ['''“当然可以,大宇宙不会因为这五公斤就不坍缩了。”关一帆说,他还有一个没说出来的想法:也许大宇宙真的会因为相差一个原子的质量而由封闭转为开放。大自然的精巧有时超出想象,比如生命的诞生,就需要各项宇宙参数在几亿亿分之一精度上的精确配合。但程心仍然可以留下她的生态球,因为在那无数文明创造的无数小宇宙中,肯定有相当一部分不响应回归运动的号召,所以,大宇宙最终被夺走的质量至少有几亿吨,甚至可能是几亿亿亿吨。\n但愿大宇宙能够忽略这个误差。\n程心和关一帆进入了飞船,智子最后也进来了。她早就不再穿那身华丽的和服了,她现在身着迷彩服,再次成为一名轻捷精悍的战士,她的身上佩带着许多武器和生存装备,最引人注目的是那把插在背后的武士刀。\n“放心,我在,你们就在!”智子对两位人类朋友说。\n聚变发动机启动了,推进器发出幽幽的蓝光,''', gen_limit, 1, 0.3, 0.5, 0.5],
268
  ]
269
 
270
- examples_eng = [
271
- ["How can I craft an engaging story featuring vampires on Mars?", gen_limit_long, 1, 0.2, 0.3, 0.3],
272
- ["Compare the business models of Apple and Google.", gen_limit_long, 1, 0.2, 0.3, 0.3],
273
- ["In JSON format, list the top 5 tourist attractions in Paris.", gen_limit_long, 1, 0.2, 0.3, 0.3],
274
- ["Write an outline for a fantasy novel where dreams can alter reality.", gen_limit_long, 1, 0.2, 0.3, 0.3],
275
- ["Can fish get thirsty?", gen_limit_long, 1, 0.2, 0.3, 0.3],
276
- ["Write a Bash script to check disk usage and send alerts if it's too high.", gen_limit_long, 1, 0.2, 0.3, 0.3],
277
- ["Write a simple webpage. When a user clicks the button, it shows a random joke from a list of 4 jokes.", gen_limit_long, 1, 0.2, 0.3, 0.3],
278
- ]
279
-
280
- examples_chn = [
281
- ["怎样写一个在火星上的吸血鬼的有趣故事?", gen_limit_long, 1, 0.2, 0.3, 0.3],
282
- ["比较苹果和谷歌的商业模式。", gen_limit_long, 1, 0.2, 0.3, 0.3],
283
- ["鱼会口渴吗?", gen_limit_long, 1, 0.2, 0.3, 0.3],
284
- ["以 JSON 格式列举北京的美食。", gen_limit_long, 1, 0.2, 0.3, 0.3],
285
- ["编写一个Bash脚本来检查磁盘使用情况,如果使用量过高则发送警报。", gen_limit_long, 1, 0.2, 0.3, 0.3],
286
- ["用HTML编写一个简单的网站。当用户点击按钮时,从4个笑话的列表中随机显示一个笑话。", gen_limit_long, 1, 0.2, 0.3, 0.3],
287
- ]
288
-
289
- if ENABLE_VISUAL:
290
- ########################## visual rwkv ################################################################
291
- visual_title = 'ViusualRWKV-v5'
292
- rwkv_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_rwkv.pth"
293
- vision_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_visual.pth"
294
- vision_tower_name = 'openai/clip-vit-large-patch14-336'
295
-
296
- model_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=rwkv_remote_path)
297
- visual_rwkv = RWKV(model=model_path, strategy='cuda fp16')
298
-
299
- ##########################################################################
300
- from modeling_vision import VisionEncoder, VisionEncoderConfig
301
- config = VisionEncoderConfig(n_embd=model.args.n_embd,
302
- vision_tower_name=vision_tower_name,
303
- grid_size=-1)
304
- visual_encoder = VisionEncoder(config)
305
- vision_local_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=vision_remote_path)
306
- vision_state_dict = torch.load(vision_local_path, map_location='cpu')
307
- visual_encoder.load_state_dict(vision_state_dict)
308
- image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
309
- visual_encoder = visual_encoder.to(device)
310
- ##########################################################################
311
- def visual_generate_prompt(instruction):
312
- instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
313
- return f"\n{instruction}\n\nAssistant:"
314
-
315
- def generate(
316
- ctx,
317
- image_state,
318
- token_count=200,
319
- temperature=1.0,
320
- top_p=0.1,
321
- presencePenalty = 0.0,
322
- countPenalty = 1.0,
323
- ):
324
- args = PIPELINE_ARGS(temperature = 1.0, top_p = 0.1,
325
- alpha_frequency = 1.0,
326
- alpha_presence = 0.0,
327
- token_ban = [], # ban the generation of some tokens
328
- token_stop = [0, 261]) # stop generation whenever you see any token here
329
- ctx = ctx.strip()
330
- all_tokens = []
331
- out_last = 0
332
- out_str = ''
333
- occurrence = {}
334
- for i in range(int(token_count)):
335
- if i == 0:
336
- input_ids = pipeline.encode(ctx)[-ctx_limit:]
337
- out, state = visual_rwkv.forward(tokens=input_ids, state=image_state)
338
- else:
339
- input_ids = [token]
340
- out, state = visual_rwkv.forward(tokens=input_ids, state=state)
341
- for n in occurrence:
342
- out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
343
-
344
- token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
345
- if token in args.token_stop:
346
- break
347
- all_tokens += [token]
348
- for xxx in occurrence:
349
- occurrence[xxx] *= 0.994
350
- if token not in occurrence:
351
- occurrence[token] = 1
352
- else:
353
- occurrence[token] += 1
354
-
355
- tmp = pipeline.decode(all_tokens[out_last:])
356
- if '\ufffd' not in tmp:
357
- out_str += tmp
358
- yield out_str.strip()
359
- out_last = i + 1
360
-
361
- gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
362
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
363
- print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
364
- del out
365
- del state
366
- gc.collect()
367
- torch.cuda.empty_cache()
368
- yield out_str.strip()
369
-
370
-
371
- ##########################################################################
372
- cur_dir = os.path.dirname(os.path.abspath(__file__))
373
- visual_examples = [
374
- [
375
- f"{cur_dir}/examples_pizza.jpg",
376
- "What are steps to cook it?"
377
- ],
378
- [
379
- f"{cur_dir}/examples_bluejay.jpg",
380
- "what is the name of this bird?",
381
- ],
382
- [
383
- f"{cur_dir}/examples_woman_and_dog.png",
384
- "describe this image",
385
- ],
386
- ]
387
-
388
-
389
- def pil_image_to_base64(pil_image):
390
- buffered = BytesIO()
391
- pil_image.save(buffered, format="JPEG") # You can change the format as needed (JPEG, PNG, etc.)
392
- # Encodes the image data into base64 format as a bytes object
393
- base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
394
- return base64_image
395
-
396
- image_cache = {}
397
- ln0_weight = model.w['blocks.0.ln0.weight'].to(torch.float32).to(device)
398
- ln0_bias = model.w['blocks.0.ln0.bias'].to(torch.float32).to(device)
399
- def compute_image_state(image):
400
- base64_image = pil_image_to_base64(image)
401
- if base64_image in image_cache:
402
- image_state = image_cache[base64_image]
403
- else:
404
- image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values'].to(device)
405
- image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
406
- # apply layer norm to image feature, very important
407
- image_features = F.layer_norm(image_features,
408
- (image_features.shape[-1],),
409
- weight=ln0_weight,
410
- bias=ln0_bias)
411
- _, image_state = model.forward(embs=image_features, state=None)
412
- image_cache[base64_image] = image_state
413
- return image_state
414
-
415
- def chatbot(image, question):
416
- if image is None:
417
- yield "Please upload an image."
418
- return
419
- image_state = compute_image_state(image)
420
- input_text = visual_generate_prompt(question)
421
- for output in generate(input_text, image_state):
422
- yield output
423
-
424
-
425
  ##################################################################################################################
426
  with gr.Blocks(title=title_v6) as demo:
427
  gr.HTML(f"<div style=\"text-align: center;\">\n<h1>{title_v6}</h1>\n</div>")
428
 
429
  with gr.Tab("=== Base Model (Raw Generation) ==="):
430
- gr.Markdown(f"This is [RWKV-6 World v2](https://huggingface.co/BlinkDL/rwkv-6-world) - a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM). Supports 100+ world languages and code. Check [300+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). *** Can try examples (bottom of page) *** (can edit them). Demo limited to ctxlen {ctx_limit}.")
431
  with gr.Row():
432
  with gr.Column():
433
  prompt = gr.Textbox(lines=2, label="Prompt", value="Assistant: How can we craft an engaging story featuring vampires on Mars? Let's think step by step and provide an expert response.")
@@ -446,63 +146,5 @@ with gr.Blocks(title=title_v6) as demo:
446
  clear.click(lambda: None, [], [output])
447
  data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
448
 
449
- with gr.Tab("=== English Q/A ==="):
450
- gr.Markdown(f"This is [RWKV-6](https://huggingface.co/BlinkDL/rwkv-6-world) state-tuned to [English Q/A](https://huggingface.co/BlinkDL/temp-latest-training-models/blob/main/{eng_name}.pth). RWKV is a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM), and we have [300+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). Demo limited to ctxlen {ctx_limit}.")
451
- with gr.Row():
452
- with gr.Column():
453
- prompt = gr.Textbox(lines=2, label="Prompt", value="How can I craft an engaging story featuring vampires on Mars?")
454
- token_count = gr.Slider(10, gen_limit_long, label="Max Tokens", step=10, value=gen_limit_long)
455
- temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
456
- top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.2)
457
- presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.3)
458
- count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.3)
459
- with gr.Column():
460
- with gr.Row():
461
- submit = gr.Button("Submit", variant="primary")
462
- clear = gr.Button("Clear", variant="secondary")
463
- output = gr.Textbox(label="Output", lines=30)
464
- data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples_eng, samples_per_page=50, label="Examples", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
465
- submit.click(evaluate_eng, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
466
- clear.click(lambda: None, [], [output])
467
- data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
468
-
469
- with gr.Tab("=== Chinese Q/A ==="):
470
- gr.Markdown(f"This is [RWKV-6](https://huggingface.co/BlinkDL/rwkv-6-world) state-tuned to [Chinese Q/A](https://huggingface.co/BlinkDL/temp-latest-training-models/blob/main/{chn_name}.pth). RWKV is a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM), and we have [300+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). Demo limited to ctxlen {ctx_limit}.")
471
- with gr.Row():
472
- with gr.Column():
473
- prompt = gr.Textbox(lines=2, label="Prompt", value="怎样写一个在火星上的吸血鬼的有趣故事?")
474
- token_count = gr.Slider(10, gen_limit_long, label="Max Tokens", step=10, value=gen_limit_long)
475
- temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
476
- top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.2)
477
- presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.3)
478
- count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.3)
479
- with gr.Column():
480
- with gr.Row():
481
- submit = gr.Button("Submit", variant="primary")
482
- clear = gr.Button("Clear", variant="secondary")
483
- output = gr.Textbox(label="Output", lines=30)
484
- data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples_chn, samples_per_page=50, label="Examples", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
485
- submit.click(evaluate_chn, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
486
- clear.click(lambda: None, [], [output])
487
- data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
488
-
489
- if ENABLE_VISUAL:
490
- with gr.Tab("Visual RWKV-5 1.5B"):
491
- with gr.Row():
492
- with gr.Column():
493
- image = gr.Image(type='pil', label="Image")
494
- with gr.Column():
495
- prompt = gr.Textbox(lines=8, label="Prompt",
496
- value="Render a clear and concise summary of the photo.")
497
- with gr.Row():
498
- submit = gr.Button("Submit", variant="primary")
499
- clear = gr.Button("Clear", variant="secondary")
500
- with gr.Column():
501
- output = gr.Textbox(label="Output", lines=10)
502
- data = gr.Dataset(components=[image, prompt], samples=visual_examples, label="Examples", headers=["Image", "Prompt"])
503
- submit.click(chatbot, [image, prompt], [output])
504
- clear.click(lambda: None, [], [output])
505
- data.click(lambda x: x, [data], [image, prompt])
506
-
507
  demo.queue(concurrency_count=1, max_size=10)
508
- demo.launch(share=False)
 
1
  import os, copy
2
+ os.environ["RWKV_V7_ON"] = '1'
3
  os.environ["RWKV_JIT_ON"] = '1'
4
  os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
5
+
6
+ from rwkv.model import RWKV
7
 
8
  import gc, re
9
  import gradio as gr
 
19
  gpu_h = nvmlDeviceGetHandleByIndex(0)
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
+ ctx_limit = 4096
23
+ gen_limit = 1000
 
 
24
 
25
  ########################## text rwkv ################################################################
26
  from rwkv.utils import PIPELINE, PIPELINE_ARGS
27
 
28
+ title_v6 = "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096"
29
+ model_path_v6 = hf_hub_download(repo_id="BlinkDL/rwkv-7-world", filename=f"{title_v6}.pth")
30
+ # model_path_v6 = f'/mnt/e/RWKV-Runner/models/{title_v6}' # conda activate torch2; cd /mnt/program/git-public/RWKV-Gradio-1; python app.py
31
  model_v6 = RWKV(model=model_path_v6, strategy='cuda fp16')
32
  pipeline_v6 = PIPELINE(model_v6, "rwkv_vocab_v20230424")
33
 
34
  args = model_v6.args
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  penalty_decay = 0.996
37
 
 
 
 
 
 
 
38
  def generate_prompt(instruction, input=""):
39
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
40
  input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
 
68
  occurrence = {}
69
  state = None
70
  for i in range(int(token_count)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  input_ids = pipeline_v6.encode(ctx)[-ctx_limit:] if i == 0 else [token]
73
  out, state = model_v6.forward(tokens=input_ids, state=state)
74
  for n in occurrence:
 
122
  ['''“当然可以,大宇宙不会因为这五公斤就不坍缩了。”关一帆说,他还有一个没说出来的想法:也许大宇宙真的会因为相差一个原子的质量而由封闭转为开放。大自然的精巧有时超出想象,比如生命的诞生,就需要各项宇宙参数在几亿亿分之一精度上的精确配合。但程心仍然可以留下她的生态球,因为在那无数文明创造的无数小宇宙中,肯定有相当一部分不响应回归运动的号召,所以,大宇宙最终被夺走的质量至少有几亿吨,甚至可能是几亿亿亿吨。\n但愿大宇宙能够忽略这个误差。\n程心和关一帆进入了飞船,智子最后也进来了。她早就不再穿那身华丽的和服了,她现在身着迷彩服,再次成为一名轻捷精悍的战士,她的身上佩带着许多武器和生存装备,最引人注目的是那把插在背后的武士刀。\n“放心,我在,你们就在!”智子对两位人类朋友说。\n聚变发动机启动了,推进器发出幽幽的蓝光,''', gen_limit, 1, 0.3, 0.5, 0.5],
123
  ]
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  ##################################################################################################################
126
  with gr.Blocks(title=title_v6) as demo:
127
  gr.HTML(f"<div style=\"text-align: center;\">\n<h1>{title_v6}</h1>\n</div>")
128
 
129
  with gr.Tab("=== Base Model (Raw Generation) ==="):
130
+ gr.Markdown(f"This is [RWKV-7 World v2.8](https://huggingface.co/BlinkDL/rwkv-7-world) 0.1B (L12-D768) - a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM). Supports 100+ world languages and code. Check [400+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). *** Can try examples (bottom of page) *** (can edit them). Demo limited to ctxlen {ctx_limit}.")
131
  with gr.Row():
132
  with gr.Column():
133
  prompt = gr.Textbox(lines=2, label="Prompt", value="Assistant: How can we craft an engaging story featuring vampires on Mars? Let's think step by step and provide an expert response.")
 
146
  clear.click(lambda: None, [], [output])
147
  data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  demo.queue(concurrency_count=1, max_size=10)
150
+ demo.launch(share=False)