NeverlandPeter commited on
Commit
cbf69e8
·
1 Parent(s): 733fac9

state-tuned

Browse files
Files changed (1) hide show
  1. app.py +350 -154
app.py CHANGED
@@ -1,10 +1,10 @@
1
- import os
2
  os.environ["RWKV_JIT_ON"] = '1'
3
  os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
4
  # make sure cuda dir is in the same level as modeling_rwkv.py
5
  from modeling_rwkv import RWKV
6
 
7
- import gc
8
  import gradio as gr
9
  import base64
10
  from io import BytesIO
@@ -20,36 +20,62 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
 
21
  ctx_limit = 2500
22
  gen_limit = 500
 
 
23
  ########################## text rwkv ################################################################
24
  from rwkv.utils import PIPELINE, PIPELINE_ARGS
25
 
26
  title_v6 = "RWKV-x060-World-3B-v2.1-20240417-ctx4096"
27
  model_path_v6 = hf_hub_download(repo_id="BlinkDL/rwkv-6-world", filename=f"{title_v6}.pth")
 
28
  model_v6 = RWKV(model=model_path_v6, strategy='cuda fp16')
29
  pipeline_v6 = PIPELINE(model_v6, "rwkv_vocab_v20230424")
30
 
31
- title = "RWKV-5-World-1B5-v2-20231025-ctx4096"
32
- model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{title}.pth")
33
- model = RWKV(model=model_path, strategy='cuda fp16')
34
- pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def generate_prompt(instruction, input=""):
37
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
38
  input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
39
  if input:
40
- return f"""Instruction: {instruction}
41
-
42
- Input: {input}
43
-
44
- Response:"""
45
  else:
46
- return f"""User: hi
47
-
48
- Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
49
 
50
- User: {instruction}
51
-
52
- Assistant:"""
 
53
 
54
  def evaluate(
55
  ctx,
@@ -71,19 +97,19 @@ def evaluate(
71
  occurrence = {}
72
  state = None
73
  for i in range(int(token_count)):
74
- input_ids = pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token]
75
  out, state = model_v6.forward(tokens=input_ids, state=state)
76
  for n in occurrence:
77
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
78
 
79
- token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
80
  if token in args.token_stop:
81
  break
82
  all_tokens += [token]
83
  for xxx in occurrence:
84
- occurrence[xxx] *= 0.994
85
 
86
- ttt = pipeline.decode([token])
87
  www = 1
88
  if ttt in ' \t0123456789':
89
  www = 0
@@ -94,7 +120,7 @@ def evaluate(
94
  else:
95
  occurrence[token] += www
96
 
97
- tmp = pipeline.decode(all_tokens[out_last:])
98
  if '\ufffd' not in tmp:
99
  out_str += tmp
100
  yield out_str.strip()
@@ -109,96 +135,108 @@ def evaluate(
109
  torch.cuda.empty_cache()
110
  yield out_str.strip()
111
 
112
- examples = [
113
- ["Assistant: How can we craft an engaging story featuring vampires on Mars? Let's think step by step and provide an expert response.", 500, 1, 0.3, 0, 1],
114
- ["Assistant: How can we persuade Elon Musk to follow you on Twitter? Let's think step by step and provide an expert response.", 500, 1, 0.3, 0, 1],
115
- [generate_prompt("東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。"), 500, 1, 0.3, 0, 1],
116
- [generate_prompt("Write a story using the following information.", "A man named Alex chops a tree down."), 500, 1, 0.3, 0, 1],
117
- ["A few light taps upon the pane made her turn to the window. It had begun to snow again.", 500, 1, 0.3, 0, 1],
118
- ['''Edward: I am Edward Elric from Fullmetal Alchemist.
119
-
120
- User: Hello Edward. What have you been up to recently?
121
-
122
- Edward:''', 500, 1, 0.3, 0, 1],
123
- ['''Japanese: 春の初め、桜の花が満開になる頃、小さな町の片隅にある古びた神社の境内は、特別な雰囲気に包まれていた。
124
-
125
- English:''', 500, 1, 0.3, 0, 1],
126
- ["En una pequeña aldea escondida entre las montañas de Andalucía, donde las calles aún conservaban el eco de antiguas leyendas, vivía un joven llamado Alejandro.", 500, 1, 0.3, 0, 1],
127
- ["Dans le cœur battant de Paris, sous le ciel teinté d'un crépuscule d'or et de pourpre, se tenait une petite librairie oubliée par le temps.", 500, 1, 0.3, 0, 1],
128
- ["في تطور مذهل وغير مسبوق، أعلنت السلطات المحلية في العاصمة عن اكتشاف أثري قد يغير مجرى التاريخ كما نعرفه.", 500, 1, 0.3, 0, 1],
129
- ['''“当然可以,大宇宙不会因为这五公斤就不坍缩了。”关一帆说,他还有一个没说出来的想法:也许大宇宙真的会因为相差一个原子的质量而由封闭转为开放。大自然的精巧有时超出想象,比如生命的诞生,就需要各项宇宙参数在几亿亿分之一精度上的精确配合。但程心仍然可以留下她的生态球,因为在那无数文明创造的无数小宇宙中,肯定有相当一部分不响应回归运动的号召,所以,大宇宙最终被夺走的质量至少有几亿吨,甚至可能是几亿亿亿吨。
130
- 但愿大宇宙能够忽略这个误差。
131
- 程心和关一帆进入了飞船,智子最后也进来了。她早就不再穿那身华丽的和服了,她现在身着迷彩服,再次成为一名轻捷精悍的战士,她的身上佩带着许多武器和生存装备,最引人注目的是那把插在背后的武士刀。
132
- “放心,我在,你们就在!”智子对两位人类朋友说。
133
- 聚变发动机启动了,推进器发出幽幽的蓝光,''', 500, 1, 0.3, 0, 1],
134
- ]
 
135
 
136
- ########################## visual rwkv ################################################################
137
- visual_title = 'ViusualRWKV-v5'
138
- rwkv_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_rwkv.pth"
139
- vision_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_visual.pth"
140
- vision_tower_name = 'openai/clip-vit-large-patch14-336'
141
-
142
- model_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=rwkv_remote_path)
143
- visual_rwkv = RWKV(model=model_path, strategy='cuda fp16')
144
-
145
- ##########################################################################
146
- from modeling_vision import VisionEncoder, VisionEncoderConfig
147
- config = VisionEncoderConfig(n_embd=model.args.n_embd,
148
- vision_tower_name=vision_tower_name,
149
- grid_size=-1)
150
- visual_encoder = VisionEncoder(config)
151
- vision_local_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=vision_remote_path)
152
- vision_state_dict = torch.load(vision_local_path, map_location='cpu')
153
- visual_encoder.load_state_dict(vision_state_dict)
154
- image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
155
- visual_encoder = visual_encoder.to(device)
156
- ##########################################################################
157
- def visual_generate_prompt(instruction):
158
- instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
159
- return f"\n{instruction}\n\nAssistant:"
160
 
161
- def generate(
 
 
 
 
 
 
 
 
 
162
  ctx,
163
- image_state,
164
  token_count=200,
165
  temperature=1.0,
166
- top_p=0.1,
167
- presencePenalty = 0.0,
168
- countPenalty = 1.0,
169
  ):
170
- args = PIPELINE_ARGS(temperature = 1.0, top_p = 0.1,
171
- alpha_frequency = 1.0,
172
- alpha_presence = 0.0,
173
- token_ban = [], # ban the generation of some tokens
174
- token_stop = [0, 261]) # stop generation whenever you see any token here
175
- ctx = ctx.strip()
176
  all_tokens = []
177
  out_last = 0
178
  out_str = ''
179
  occurrence = {}
 
180
  for i in range(int(token_count)):
181
- if i == 0:
182
- input_ids = pipeline.encode(ctx)[-ctx_limit:]
183
- out, state = visual_rwkv.forward(tokens=input_ids, state=image_state)
184
- else:
185
- input_ids = [token]
186
- out, state = visual_rwkv.forward(tokens=input_ids, state=state)
187
  for n in occurrence:
188
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
189
 
190
- token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
191
  if token in args.token_stop:
192
  break
193
  all_tokens += [token]
194
  for xxx in occurrence:
195
- occurrence[xxx] *= 0.994
 
 
 
 
 
 
 
196
  if token not in occurrence:
197
- occurrence[token] = 1
198
  else:
199
- occurrence[token] += 1
200
-
201
- tmp = pipeline.decode(all_tokens[out_last:])
202
  if '\ufffd' not in tmp:
203
  out_str += tmp
204
  yield out_str.strip()
@@ -213,65 +251,181 @@ def generate(
213
  torch.cuda.empty_cache()
214
  yield out_str.strip()
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
- ##########################################################################
218
- cur_dir = os.path.dirname(os.path.abspath(__file__))
219
- visual_examples = [
220
- [
221
- f"{cur_dir}/examples_pizza.jpg",
222
- "What are steps to cook it?"
223
- ],
224
- [
225
- f"{cur_dir}/examples_bluejay.jpg",
226
- "what is the name of this bird?",
227
- ],
228
- [
229
- f"{cur_dir}/examples_woman_and_dog.png",
230
- "describe this image",
231
- ],
232
  ]
233
 
 
 
 
 
 
 
 
 
234
 
235
- def pil_image_to_base64(pil_image):
236
- buffered = BytesIO()
237
- pil_image.save(buffered, format="JPEG") # You can change the format as needed (JPEG, PNG, etc.)
238
- # Encodes the image data into base64 format as a bytes object
239
- base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
240
- return base64_image
241
-
242
- image_cache = {}
243
- ln0_weight = model.w['blocks.0.ln0.weight'].to(torch.float32).to(device)
244
- ln0_bias = model.w['blocks.0.ln0.bias'].to(torch.float32).to(device)
245
- def compute_image_state(image):
246
- base64_image = pil_image_to_base64(image)
247
- if base64_image in image_cache:
248
- image_state = image_cache[base64_image]
249
- else:
250
- image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values'].to(device)
251
- image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
252
- # apply layer norm to image feature, very important
253
- image_features = F.layer_norm(image_features,
254
- (image_features.shape[-1],),
255
- weight=ln0_weight,
256
- bias=ln0_bias)
257
- _, image_state = model.forward(embs=image_features, state=None)
258
- image_cache[base64_image] = image_state
259
- return image_state
260
-
261
- def chatbot(image, question):
262
- if image is None:
263
- yield "Please upload an image."
264
- return
265
- image_state = compute_image_state(image)
266
- input_text = visual_generate_prompt(question)
267
- for output in generate(input_text, image_state):
268
- yield output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
 
271
  ##################################################################################################################
272
- with gr.Blocks(title=title) as demo:
273
  gr.HTML(f"<div style=\"text-align: center;\">\n<h1>{title_v6}</h1>\n</div>")
274
- with gr.Tab("Raw Generation"):
 
275
  gr.Markdown(f"This is [RWKV-6 World v2](https://huggingface.co/BlinkDL/rwkv-6-world) - a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM). Supports 100+ world languages and code. And we have [300+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). *** Please try examples first (bottom of page) *** (edit them to use your question). Demo limited to ctxlen {ctx_limit}. (VisualRWKV is using RWKV5 1.5B)")
276
  with gr.Row():
277
  with gr.Column():
@@ -279,8 +433,8 @@ with gr.Blocks(title=title) as demo:
279
  token_count = gr.Slider(10, gen_limit, label="Max Tokens", step=10, value=gen_limit)
280
  temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
281
  top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.3)
282
- presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0)
283
- count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=1)
284
  with gr.Column():
285
  with gr.Row():
286
  submit = gr.Button("Submit", variant="primary")
@@ -290,22 +444,64 @@ with gr.Blocks(title=title) as demo:
290
  submit.click(evaluate, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
291
  clear.click(lambda: None, [], [output])
292
  data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
293
- with gr.Tab("Visual RWKV"):
 
 
294
  with gr.Row():
295
  with gr.Column():
296
- image = gr.Image(type='pil', label="Image")
 
 
 
 
 
297
  with gr.Column():
298
- prompt = gr.Textbox(lines=8, label="Prompt",
299
- value="Render a clear and concise summary of the photo.")
300
  with gr.Row():
301
  submit = gr.Button("Submit", variant="primary")
302
- clear = gr.Button("Clear", variant="secondary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  with gr.Column():
304
- output = gr.Textbox(label="Output", lines=10)
305
- data = gr.Dataset(components=[image, prompt], samples=visual_examples, label="Examples", headers=["Image", "Prompt"])
306
- submit.click(chatbot, [image, prompt], [output])
 
 
 
307
  clear.click(lambda: None, [], [output])
308
- data.click(lambda x: x, [data], [image, prompt])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
  demo.queue(concurrency_count=1, max_size=10)
311
  demo.launch(share=False)
 
1
+ import os, copy
2
  os.environ["RWKV_JIT_ON"] = '1'
3
  os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
4
  # make sure cuda dir is in the same level as modeling_rwkv.py
5
  from modeling_rwkv import RWKV
6
 
7
+ import gc, re
8
  import gradio as gr
9
  import base64
10
  from io import BytesIO
 
20
 
21
  ctx_limit = 2500
22
  gen_limit = 500
23
+ ENABLE_VISUAL = True
24
+
25
  ########################## text rwkv ################################################################
26
  from rwkv.utils import PIPELINE, PIPELINE_ARGS
27
 
28
  title_v6 = "RWKV-x060-World-3B-v2.1-20240417-ctx4096"
29
  model_path_v6 = hf_hub_download(repo_id="BlinkDL/rwkv-6-world", filename=f"{title_v6}.pth")
30
+ # model_path_v6 = '/mnt/e/RWKV-Runner/models/rwkv-final-v6-2.1-3b' # conda activate torch2; cd /mnt/program/_RWKV_/_ref_/_gradio_/RWKV-Gradio-1; python app.py
31
  model_v6 = RWKV(model=model_path_v6, strategy='cuda fp16')
32
  pipeline_v6 = PIPELINE(model_v6, "rwkv_vocab_v20230424")
33
 
34
+ args = model_v6.args
35
+ eng_name = 'rwkv-x060-eng_single_round_qa-3B-20240430-ctx1024'
36
+ chn_name = 'rwkv-x060-chn_single_round_qa-3B-20240505-ctx1024'
37
+
38
+ # state_eng_raw = torch.load(f'/mnt/e/RWKV-Runner/models/{eng_name}.pth')
39
+ # state_chn_raw = torch.load(f'/mnt/e/RWKV-Runner/models/{chn_name}.pth')
40
+
41
+ eng_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{eng_name}.pth")
42
+ chn_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{chn_name}.pth")
43
+ state_eng_raw = torch.load(eng_file)
44
+ state_chn_raw = torch.load(chn_file)
45
+
46
+ state_eng = [None] * args.n_layer * 3
47
+ state_chn = [None] * args.n_layer * 3
48
+ for i in range(args.n_layer):
49
+ dd = model_v6.strategy[i]
50
+ dev = dd.device
51
+ atype = dd.atype
52
+ state_eng[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
53
+ state_chn[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
54
+ state_eng[i*3+1] = state_eng_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
55
+ state_chn[i*3+1] = state_chn_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
56
+ state_eng[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
57
+ state_chn[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
58
+
59
+ penalty_decay = 0.996
60
+
61
+ if ENABLE_VISUAL:
62
+ title = "RWKV-5-World-1B5-v2-20231025-ctx4096"
63
+ model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{title}.pth")
64
+ model = RWKV(model=model_path, strategy='cuda fp16')
65
+ pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
66
 
67
  def generate_prompt(instruction, input=""):
68
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
69
  input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
70
  if input:
71
+ return f"""Instruction: {instruction}\n\nInput: {input}\n\nResponse:"""
 
 
 
 
72
  else:
73
+ return f"""User: hi\n\nAssistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\nUser: {instruction}\n\nAssistant:"""
 
 
74
 
75
+ def qa_prompt(instruction):
76
+ instruction = instruction.strip().replace('\r\n','\n')
77
+ instruction = re.sub(r'\n+', '\n', instruction)
78
+ return f"User: {instruction}\n\nAssistant:"""
79
 
80
  def evaluate(
81
  ctx,
 
97
  occurrence = {}
98
  state = None
99
  for i in range(int(token_count)):
100
+ input_ids = pipeline_v6.encode(ctx)[-ctx_limit:] if i == 0 else [token]
101
  out, state = model_v6.forward(tokens=input_ids, state=state)
102
  for n in occurrence:
103
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
104
 
105
+ token = pipeline_v6.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
106
  if token in args.token_stop:
107
  break
108
  all_tokens += [token]
109
  for xxx in occurrence:
110
+ occurrence[xxx] *= penalty_decay
111
 
112
+ ttt = pipeline_v6.decode([token])
113
  www = 1
114
  if ttt in ' \t0123456789':
115
  www = 0
 
120
  else:
121
  occurrence[token] += www
122
 
123
+ tmp = pipeline_v6.decode(all_tokens[out_last:])
124
  if '\ufffd' not in tmp:
125
  out_str += tmp
126
  yield out_str.strip()
 
135
  torch.cuda.empty_cache()
136
  yield out_str.strip()
137
 
138
+ def evaluate_eng(
139
+ ctx,
140
+ token_count=200,
141
+ temperature=1.0,
142
+ top_p=0.7,
143
+ presencePenalty = 0.1,
144
+ countPenalty = 0.1,
145
+ ):
146
+ args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
147
+ alpha_frequency = countPenalty,
148
+ alpha_presence = presencePenalty,
149
+ token_ban = [], # ban the generation of some tokens
150
+ token_stop = [0]) # stop generation whenever you see any token here
151
+ ctx = qa_prompt(ctx)
152
+ all_tokens = []
153
+ out_last = 0
154
+ out_str = ''
155
+ occurrence = {}
156
+ state = copy.deepcopy(state_eng)
157
+ for i in range(int(token_count)):
158
+ input_ids = pipeline_v6.encode(ctx)[-ctx_limit:] if i == 0 else [token]
159
+ out, state = model_v6.forward(tokens=input_ids, state=state)
160
+ for n in occurrence:
161
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
162
 
163
+ token = pipeline_v6.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
164
+ if token in args.token_stop:
165
+ break
166
+ all_tokens += [token]
167
+ for xxx in occurrence:
168
+ occurrence[xxx] *= penalty_decay
169
+
170
+ ttt = pipeline_v6.decode([token])
171
+ www = 1
172
+ if ttt in ' \t0123456789':
173
+ www = 0
174
+ #elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
175
+ # www = 0.5
176
+ if token not in occurrence:
177
+ occurrence[token] = www
178
+ else:
179
+ occurrence[token] += www
180
+
181
+ tmp = pipeline_v6.decode(all_tokens[out_last:])
182
+ if '\ufffd' not in tmp:
183
+ out_str += tmp
184
+ yield out_str.strip()
185
+ out_last = i + 1
 
186
 
187
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
188
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
189
+ print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
190
+ del out
191
+ del state
192
+ gc.collect()
193
+ torch.cuda.empty_cache()
194
+ yield out_str.strip()
195
+
196
+ def evaluate_chn(
197
  ctx,
 
198
  token_count=200,
199
  temperature=1.0,
200
+ top_p=0.7,
201
+ presencePenalty = 0.1,
202
+ countPenalty = 0.1,
203
  ):
204
+ args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
205
+ alpha_frequency = countPenalty,
206
+ alpha_presence = presencePenalty,
207
+ token_ban = [], # ban the generation of some tokens
208
+ token_stop = [0]) # stop generation whenever you see any token here
209
+ ctx = qa_prompt(ctx)
210
  all_tokens = []
211
  out_last = 0
212
  out_str = ''
213
  occurrence = {}
214
+ state = copy.deepcopy(state_chn)
215
  for i in range(int(token_count)):
216
+ input_ids = pipeline_v6.encode(ctx)[-ctx_limit:] if i == 0 else [token]
217
+ out, state = model_v6.forward(tokens=input_ids, state=state)
 
 
 
 
218
  for n in occurrence:
219
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
220
 
221
+ token = pipeline_v6.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
222
  if token in args.token_stop:
223
  break
224
  all_tokens += [token]
225
  for xxx in occurrence:
226
+ occurrence[xxx] *= penalty_decay
227
+
228
+ ttt = pipeline_v6.decode([token])
229
+ www = 1
230
+ if ttt in ' \t0123456789':
231
+ www = 0
232
+ #elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
233
+ # www = 0.5
234
  if token not in occurrence:
235
+ occurrence[token] = www
236
  else:
237
+ occurrence[token] += www
238
+
239
+ tmp = pipeline_v6.decode(all_tokens[out_last:])
240
  if '\ufffd' not in tmp:
241
  out_str += tmp
242
  yield out_str.strip()
 
251
  torch.cuda.empty_cache()
252
  yield out_str.strip()
253
 
254
+ examples = [
255
+ ["Assistant: How can we craft an engaging story featuring vampires on Mars? Let's think step by step and provide an expert response.", gen_limit, 1, 0.3, 0.5, 0.5],
256
+ ["Assistant: How can we persuade Elon Musk to follow you on Twitter? Let's think step by step and provide an expert response.", gen_limit, 1, 0.3, 0.5, 0.5],
257
+ [generate_prompt("東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。"), gen_limit, 1, 0.3, 0.5, 0.5],
258
+ [generate_prompt("Write a story using the following information.", "A man named Alex chops a tree down."), gen_limit, 1, 0.3, 0.5, 0.5],
259
+ ["A few light taps upon the pane made her turn to the window. It had begun to snow again.", gen_limit, 1, 0.3, 0.5, 0.5],
260
+ ['''Edward: I am Edward Elric from Fullmetal Alchemist.\n\nUser: Hello Edward. What have you been up to recently?\n\nEdward:''', gen_limit, 1, 0.3, 0.5, 0.5],
261
+ [generate_prompt("Write a simple website in HTML. When a user clicks the button, it shows a random joke from a list of 4 jokes."), 500, 1, 0.3, 0.5, 0.5],
262
+ ['''Japanese: 春の初め、桜の花が満開になる頃、小さな町の片隅にある古びた神社の境内は、特別な雰囲気に包まれていた。\n\nEnglish:''', gen_limit, 1, 0.3, 0.5, 0.5],
263
+ ["En una pequeña aldea escondida entre las montañas de Andalucía, donde las calles aún conservaban el eco de antiguas leyendas, vivía un joven llamado Alejandro.", gen_limit, 1, 0.3, 0.5, 0.5],
264
+ ["Dans le cœur battant de Paris, sous le ciel teinté d'un crépuscule d'or et de pourpre, se tenait une petite librairie oubliée par le temps.", gen_limit, 1, 0.3, 0.5, 0.5],
265
+ ["في تطور مذهل وغير مسبوق، أعلنت السلطات المحلية في العاصمة عن اكتشاف أثري قد يغير مجرى التاريخ كما نعرفه.", gen_limit, 1, 0.3, 0.5, 0.5],
266
+ ['''“当然可以,大宇宙不会因为这五公斤就不坍缩了。”关一帆说,他还有一个没说出来的想法:也许大宇宙真的会因为相差一个原子的质量而由封闭转为开放。大自然的精巧有时超出想象,比如生命的诞生,就需要各项宇宙参数在几亿亿分之一精度上的精确配合。但程心仍然可以留下她的生态球,因为在那无数文明创造的无数小宇宙中,肯定有相当一部分不响应回归运动的号召,所以,大宇宙最终被夺走的质量至少有几亿吨,甚至可能是几亿亿亿吨。\n但愿大宇宙能够忽略这个误差。\n程心和关一帆进入了飞船,智子最后也进来了。她早就不再穿那身华丽的和服了,她现在身着迷彩服,再次成为一名轻捷精悍的战士,她的身上佩带着许多武器和生存装备,最引人注目的是那把插在背后的武士刀。\n“放心,我在,你们就在!”智子对两位人类朋友说。\n聚变发动机启动了,推进器发出幽幽的蓝光,''', gen_limit, 1, 0.3, 0.5, 0.5],
267
+ ]
268
 
269
+ examples_eng = [
270
+ ["How can I craft an engaging story featuring vampires on Mars?", gen_limit, 1, 0.2, 0.3, 0.3],
271
+ ["Compare the business models of Apple and Google.", gen_limit, 1, 0.2, 0.3, 0.3],
272
+ ["In JSON format, list the top 5 tourist attractions in Paris.", gen_limit, 1, 0.2, 0.3, 0.3],
273
+ ["Write an outline for a fantasy novel where dreams can alter reality.", gen_limit, 1, 0.2, 0.3, 0.3],
274
+ ["Can fish get thirsty?", gen_limit, 1, 0.2, 0.3, 0.3],
275
+ ["Write a Bash script to check disk usage and send alerts if it's too high.", gen_limit, 1, 0.2, 0.3, 0.3],
276
+ ["Write a simple website in HTML. When a user clicks the button, it shows a random joke from a list of 4 jokes.", gen_limit, 1, 0.2, 0.3, 0.3],
 
 
 
 
 
 
 
277
  ]
278
 
279
+ examples_chn = [
280
+ ["怎样写一个在火星上的吸血鬼的有趣故事?", gen_limit, 1, 0.2, 0.3, 0.3],
281
+ ["比较苹果和谷歌的商业模式。", gen_limit, 1, 0.2, 0.3, 0.3],
282
+ ["鱼会口渴吗?", gen_limit, 1, 0.2, 0.3, 0.3],
283
+ ["以 JSON 格式列举北京的美食。", gen_limit, 1, 0.2, 0.3, 0.3],
284
+ ["编写一个Bash脚本来检查磁盘使用情况,如果使用量过高则发送警报。", gen_limit, 1, 0.2, 0.3, 0.3],
285
+ ["用HTML编写一个简单的网站。当用户点击按钮时,从4个笑话的列表中随机显示一个笑话。", gen_limit, 1, 0.2, 0.3, 0.3],
286
+ ]
287
 
288
+ if ENABLE_VISUAL:
289
+ ########################## visual rwkv ################################################################
290
+ visual_title = 'ViusualRWKV-v5'
291
+ rwkv_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_rwkv.pth"
292
+ vision_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_visual.pth"
293
+ vision_tower_name = 'openai/clip-vit-large-patch14-336'
294
+
295
+ model_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=rwkv_remote_path)
296
+ visual_rwkv = RWKV(model=model_path, strategy='cuda fp16')
297
+
298
+ ##########################################################################
299
+ from modeling_vision import VisionEncoder, VisionEncoderConfig
300
+ config = VisionEncoderConfig(n_embd=model.args.n_embd,
301
+ vision_tower_name=vision_tower_name,
302
+ grid_size=-1)
303
+ visual_encoder = VisionEncoder(config)
304
+ vision_local_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=vision_remote_path)
305
+ vision_state_dict = torch.load(vision_local_path, map_location='cpu')
306
+ visual_encoder.load_state_dict(vision_state_dict)
307
+ image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
308
+ visual_encoder = visual_encoder.to(device)
309
+ ##########################################################################
310
+ def visual_generate_prompt(instruction):
311
+ instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
312
+ return f"\n{instruction}\n\nAssistant:"
313
+
314
+ def generate(
315
+ ctx,
316
+ image_state,
317
+ token_count=200,
318
+ temperature=1.0,
319
+ top_p=0.1,
320
+ presencePenalty = 0.0,
321
+ countPenalty = 1.0,
322
+ ):
323
+ args = PIPELINE_ARGS(temperature = 1.0, top_p = 0.1,
324
+ alpha_frequency = 1.0,
325
+ alpha_presence = 0.0,
326
+ token_ban = [], # ban the generation of some tokens
327
+ token_stop = [0, 261]) # stop generation whenever you see any token here
328
+ ctx = ctx.strip()
329
+ all_tokens = []
330
+ out_last = 0
331
+ out_str = ''
332
+ occurrence = {}
333
+ for i in range(int(token_count)):
334
+ if i == 0:
335
+ input_ids = pipeline.encode(ctx)[-ctx_limit:]
336
+ out, state = visual_rwkv.forward(tokens=input_ids, state=image_state)
337
+ else:
338
+ input_ids = [token]
339
+ out, state = visual_rwkv.forward(tokens=input_ids, state=state)
340
+ for n in occurrence:
341
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
342
+
343
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
344
+ if token in args.token_stop:
345
+ break
346
+ all_tokens += [token]
347
+ for xxx in occurrence:
348
+ occurrence[xxx] *= 0.994
349
+ if token not in occurrence:
350
+ occurrence[token] = 1
351
+ else:
352
+ occurrence[token] += 1
353
+
354
+ tmp = pipeline.decode(all_tokens[out_last:])
355
+ if '\ufffd' not in tmp:
356
+ out_str += tmp
357
+ yield out_str.strip()
358
+ out_last = i + 1
359
+
360
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
361
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
362
+ print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
363
+ del out
364
+ del state
365
+ gc.collect()
366
+ torch.cuda.empty_cache()
367
+ yield out_str.strip()
368
+
369
+
370
+ ##########################################################################
371
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
372
+ visual_examples = [
373
+ [
374
+ f"{cur_dir}/examples_pizza.jpg",
375
+ "What are steps to cook it?"
376
+ ],
377
+ [
378
+ f"{cur_dir}/examples_bluejay.jpg",
379
+ "what is the name of this bird?",
380
+ ],
381
+ [
382
+ f"{cur_dir}/examples_woman_and_dog.png",
383
+ "describe this image",
384
+ ],
385
+ ]
386
+
387
+
388
+ def pil_image_to_base64(pil_image):
389
+ buffered = BytesIO()
390
+ pil_image.save(buffered, format="JPEG") # You can change the format as needed (JPEG, PNG, etc.)
391
+ # Encodes the image data into base64 format as a bytes object
392
+ base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
393
+ return base64_image
394
+
395
+ image_cache = {}
396
+ ln0_weight = model.w['blocks.0.ln0.weight'].to(torch.float32).to(device)
397
+ ln0_bias = model.w['blocks.0.ln0.bias'].to(torch.float32).to(device)
398
+ def compute_image_state(image):
399
+ base64_image = pil_image_to_base64(image)
400
+ if base64_image in image_cache:
401
+ image_state = image_cache[base64_image]
402
+ else:
403
+ image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values'].to(device)
404
+ image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
405
+ # apply layer norm to image feature, very important
406
+ image_features = F.layer_norm(image_features,
407
+ (image_features.shape[-1],),
408
+ weight=ln0_weight,
409
+ bias=ln0_bias)
410
+ _, image_state = model.forward(embs=image_features, state=None)
411
+ image_cache[base64_image] = image_state
412
+ return image_state
413
+
414
+ def chatbot(image, question):
415
+ if image is None:
416
+ yield "Please upload an image."
417
+ return
418
+ image_state = compute_image_state(image)
419
+ input_text = visual_generate_prompt(question)
420
+ for output in generate(input_text, image_state):
421
+ yield output
422
 
423
 
424
  ##################################################################################################################
425
+ with gr.Blocks(title=title_v6) as demo:
426
  gr.HTML(f"<div style=\"text-align: center;\">\n<h1>{title_v6}</h1>\n</div>")
427
+
428
+ with gr.Tab("Base Model (Raw Generation)"):
429
  gr.Markdown(f"This is [RWKV-6 World v2](https://huggingface.co/BlinkDL/rwkv-6-world) - a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM). Supports 100+ world languages and code. And we have [300+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). *** Please try examples first (bottom of page) *** (edit them to use your question). Demo limited to ctxlen {ctx_limit}. (VisualRWKV is using RWKV5 1.5B)")
430
  with gr.Row():
431
  with gr.Column():
 
433
  token_count = gr.Slider(10, gen_limit, label="Max Tokens", step=10, value=gen_limit)
434
  temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
435
  top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.3)
436
+ presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.5)
437
+ count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.5)
438
  with gr.Column():
439
  with gr.Row():
440
  submit = gr.Button("Submit", variant="primary")
 
444
  submit.click(evaluate, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
445
  clear.click(lambda: None, [], [output])
446
  data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
447
+
448
+ with gr.Tab("=== English Q/A ==="):
449
+ gr.Markdown(f"This is [RWKV-6](https://huggingface.co/BlinkDL/rwkv-6-world) state-tuned to [English Q/A](https://huggingface.co/BlinkDL/temp-latest-training-models/blob/main/{eng_name}.pth). RWKV is a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM), and we have [300+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). Demo limited to ctxlen {ctx_limit}.")
450
  with gr.Row():
451
  with gr.Column():
452
+ prompt = gr.Textbox(lines=2, label="Prompt", value="How can I craft an engaging story featuring vampires on Mars?")
453
+ token_count = gr.Slider(10, gen_limit, label="Max Tokens", step=10, value=gen_limit)
454
+ temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
455
+ top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.2)
456
+ presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.3)
457
+ count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.3)
458
  with gr.Column():
 
 
459
  with gr.Row():
460
  submit = gr.Button("Submit", variant="primary")
461
+ clear = gr.Button("Clear", variant="secondary")
462
+ output = gr.Textbox(label="Output", lines=30)
463
+ data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples_eng, samples_per_page=50, label="Examples", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
464
+ submit.click(evaluate_eng, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
465
+ clear.click(lambda: None, [], [output])
466
+ data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
467
+
468
+ with gr.Tab("Chinese Q/A"):
469
+ gr.Markdown(f"This is [RWKV-6](https://huggingface.co/BlinkDL/rwkv-6-world) state-tuned to [Chinese Q/A](https://huggingface.co/BlinkDL/temp-latest-training-models/blob/main/{chn_name}.pth). RWKV is a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM), and we have [300+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). Demo limited to ctxlen {ctx_limit}.")
470
+ with gr.Row():
471
+ with gr.Column():
472
+ prompt = gr.Textbox(lines=2, label="Prompt", value="怎样写一个在火星上的吸血鬼的有趣故事?")
473
+ token_count = gr.Slider(10, gen_limit, label="Max Tokens", step=10, value=gen_limit)
474
+ temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
475
+ top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.2)
476
+ presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.3)
477
+ count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.3)
478
  with gr.Column():
479
+ with gr.Row():
480
+ submit = gr.Button("Submit", variant="primary")
481
+ clear = gr.Button("Clear", variant="secondary")
482
+ output = gr.Textbox(label="Output", lines=30)
483
+ data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples_chn, samples_per_page=50, label="Examples", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
484
+ submit.click(evaluate_chn, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
485
  clear.click(lambda: None, [], [output])
486
+ data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
487
+
488
+ if ENABLE_VISUAL:
489
+ with gr.Tab("Visual RWKV-5 1.5B"):
490
+ with gr.Row():
491
+ with gr.Column():
492
+ image = gr.Image(type='pil', label="Image")
493
+ with gr.Column():
494
+ prompt = gr.Textbox(lines=8, label="Prompt",
495
+ value="Render a clear and concise summary of the photo.")
496
+ with gr.Row():
497
+ submit = gr.Button("Submit", variant="primary")
498
+ clear = gr.Button("Clear", variant="secondary")
499
+ with gr.Column():
500
+ output = gr.Textbox(label="Output", lines=10)
501
+ data = gr.Dataset(components=[image, prompt], samples=visual_examples, label="Examples", headers=["Image", "Prompt"])
502
+ submit.click(chatbot, [image, prompt], [output])
503
+ clear.click(lambda: None, [], [output])
504
+ data.click(lambda x: x, [data], [image, prompt])
505
 
506
  demo.queue(concurrency_count=1, max_size=10)
507
  demo.launch(share=False)