Update modeling_internlm_xcomposer2.py

#14
by yuhangzang - opened
Files changed (1) hide show
  1. modeling_internlm_xcomposer2.py +81 -44
modeling_internlm_xcomposer2.py CHANGED
@@ -287,69 +287,93 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
287
  }
288
  return inputs, wrap_im_mask, temp_len
289
 
290
- def interleav_wrap(self, img_list, text_list):
291
- wrap_embeds_list, wrap_atts_list = [], []
292
- wrap_target_list, wrap_im_mask_list = [], []
 
293
 
294
- for image, text in zip(img_list, text_list):
295
- img_embeds, atts_img, img_target = self.img2emb(image)
296
- text = text[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  parts = text.split('<ImageHere>')
298
- wrap_tokens, wrap_embeds, wrap_atts, wrap_im_mask = [], [], [], []
 
299
  temp_len = 0
300
- image_nums, im_len = img_embeds.shape[:2]
301
  need_bos = True
302
  for idx, part in enumerate(parts):
303
  if len(part) > 0:
304
- part_tokens = self.tokenizer(
305
- part,
306
- return_tensors='pt',
307
- padding='longest',
308
- add_special_tokens=need_bos).to(self.device)
309
  if need_bos:
310
  need_bos = False
311
  wrap_tokens.append(part_tokens.input_ids)
312
- part_embeds = self.model.tok_embeddings(
313
- part_tokens.input_ids)
314
  wrap_embeds.append(part_embeds)
315
- wrap_atts.append(part_tokens.attention_mask)
316
- wrap_im_mask.append(
317
- torch.zeros(part_embeds.shape[:2]).to(self.device))
318
-
319
  temp_len += part_embeds.shape[1]
320
- if idx < image_nums:
321
- wrap_tokens.append(img_target[idx].unsqueeze(0))
322
- wrap_embeds.append(img_embeds[idx].unsqueeze(0))
323
- wrap_atts.append(atts_img[idx].unsqueeze(0))
324
- wrap_im_mask.append(
325
- torch.ones_like(atts_img[idx].unsqueeze(0)))
326
-
327
- temp_len += im_len
328
  if temp_len > self.max_length:
329
  break
330
-
331
  wrap_tokens = torch.cat(wrap_tokens, dim=1)
332
  wrap_embeds = torch.cat(wrap_embeds, dim=1)
333
- wrap_atts = torch.cat(wrap_atts, dim=1)
334
  wrap_im_mask = torch.cat(wrap_im_mask, dim=1)
335
 
336
  wrap_target = self.mask_human_targets(wrap_tokens).to(self.device)
337
 
338
- wrap_embeds = wrap_embeds[:, :self.max_length].to(self.device)
339
- wrap_atts = wrap_atts[:, :self.max_length].to(self.device)
340
- wrap_target = wrap_target[:, :self.max_length].to(self.device)
341
- wrap_im_mask = wrap_im_mask[:, :self.max_length].to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
- wrap_embeds_list.append(wrap_embeds)
344
- wrap_atts_list.append(wrap_atts)
345
- wrap_target_list.append(wrap_target)
346
- wrap_im_mask_list.append(wrap_im_mask)
347
 
348
- wrap_embeds = torch.cat(wrap_embeds_list)
349
- wrap_atts = torch.cat(wrap_atts_list)
350
- wrap_target = torch.cat(wrap_target_list)
351
- wrap_im_mask = torch.cat(wrap_im_mask_list)
352
- return wrap_embeds, wrap_atts, wrap_target, wrap_im_mask
353
 
354
  def mask_human_targets(self, input_ids, pure=False):
355
  target_batch = []
@@ -416,9 +440,22 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
416
  text = samples['text_input']
417
  # encode image
418
  if has_img:
419
- image = samples['image']
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  to_regress_embeds, attention_mask, targets, im_mask = self.interleav_wrap(
421
- image, text)
422
  else:
423
  to_regress_tokens, targets = self.text2emb(
424
  text, add_special_tokens=True)
 
287
  }
288
  return inputs, wrap_im_mask, temp_len
289
 
290
+ def interleav_wrap(self, img_list, text_list, image_nums):
291
+ temp_embeds = []
292
+ temp_im_mask = []
293
+ temp_tars = []
294
 
295
+ # encode_image
296
+ img_embeds, img_split = self.vit(img_list, self.plora_glb_GN, self.plora_sub_GN)
297
+ img_embeds = self.vision_proj(img_embeds)
298
+
299
+ text_list = text_list[0]
300
+ for idx, text in enumerate(text_list):
301
+ image_num = image_nums[idx]
302
+ im_id = int(np.sum(image_nums[:idx]))
303
+ images = []
304
+ for i in range(image_nums[idx]):
305
+ st = int(np.sum(img_split[:im_id + i]))
306
+ sp = img_split[im_id + i]
307
+ temp_img = img_embeds[:, st:st+sp]
308
+ images.append(temp_img)
309
+ atts_img = torch.ones((len(images), images[0].shape[1]), dtype=torch.long).to(self.device)
310
+ img_target = torch.ones(
311
+ (len(images), images[0].shape[1]), dtype=torch.long).to(
312
+ self.device) * -100
313
+
314
+ if image_num == 1 and text.find('<ImageHere>') == -1:
315
+ text = '<ImageHere>' + text
316
  parts = text.split('<ImageHere>')
317
+
318
+ wrap_tokens, wrap_embeds, wrap_im_mask = [], [], []
319
  temp_len = 0
 
320
  need_bos = True
321
  for idx, part in enumerate(parts):
322
  if len(part) > 0:
323
+ part_tokens = self.tokenizer(part, return_tensors='pt', padding='longest',
324
+ add_special_tokens=need_bos).to(self.device)
 
 
 
325
  if need_bos:
326
  need_bos = False
327
  wrap_tokens.append(part_tokens.input_ids)
328
+ part_embeds = self.model.tok_embeddings(part_tokens.input_ids)
 
329
  wrap_embeds.append(part_embeds)
330
+ wrap_im_mask.append(torch.zeros(part_embeds.shape[:2]).to(self.device))
 
 
 
331
  temp_len += part_embeds.shape[1]
332
+ if idx < image_num:
333
+ wrap_embeds.append(images[idx])
334
+ wrap_token = torch.ones(images[idx].shape[:2], dtype=torch.long).to(self.device) * -100
335
+ wrap_tokens.append(wrap_token)
336
+ wrap_im_mask.append(torch.ones(images[idx].shape[:2]).to(self.device))
337
+ temp_len += images[idx].shape[1]
 
 
338
  if temp_len > self.max_length:
339
  break
 
340
  wrap_tokens = torch.cat(wrap_tokens, dim=1)
341
  wrap_embeds = torch.cat(wrap_embeds, dim=1)
 
342
  wrap_im_mask = torch.cat(wrap_im_mask, dim=1)
343
 
344
  wrap_target = self.mask_human_targets(wrap_tokens).to(self.device)
345
 
346
+ temp_embeds.append(wrap_embeds)
347
+ temp_im_mask.append(wrap_im_mask)
348
+ temp_tars.append(wrap_target)
349
+
350
+ temp_max_len = np.max([i.shape[1] for i in temp_embeds])
351
+ temp_max_len = min(temp_max_len, self.max_length)
352
+
353
+ final_input, final_atts, final_tars, final_mask = [], [], [], []
354
+ pad = torch.ones([1, 1]) * self.tokenizer.pad_token_id
355
+ pad = pad.long().to(self.device)
356
+ pad_emb = self.model.tok_embeddings(pad)
357
+
358
+ for idx in range(len(temp_embeds)):
359
+ temp_len = temp_embeds[idx].shape[1]
360
+ if temp_len >= temp_max_len:
361
+ final_input.append(temp_embeds[idx][:, :temp_max_len])
362
+ final_atts.append(torch.ones(1, temp_max_len).to(wrap_target.dtype).to(self.device))
363
+ final_tars.append(temp_tars[idx][:, :temp_max_len])
364
+ final_mask.append(temp_im_mask[idx][:, :temp_max_len])
365
+ else:
366
+ final_input.append(torch.cat([temp_embeds[idx], pad_emb.repeat(1, temp_max_len-temp_len, 1)], dim=1))
367
+ final_atts.append(torch.cat([torch.ones(1, temp_len), torch.zeros(1, temp_max_len-temp_len)], dim=1).to(wrap_target.dtype).to(self.device))
368
+ final_tars.append(torch.cat([temp_tars[idx], (torch.ones(1, temp_max_len-temp_len)*-100).to(wrap_target.dtype).to(self.device)], dim=1))
369
+ final_mask.append(torch.cat([temp_im_mask[idx], (torch.zeros(1, temp_max_len-temp_len)).to(wrap_target.dtype).to(self.device)], dim=1))
370
 
371
+ inputs_embeds = torch.cat(final_input, dim=0)
372
+ attention_mask = torch.cat(final_atts, dim=0)
373
+ targets = torch.cat(final_tars, dim=0)
374
+ im_mask = torch.cat(final_mask, dim=0)
375
 
376
+ return inputs_embeds, attention_mask, targets, im_mask
 
 
 
 
377
 
378
  def mask_human_targets(self, input_ids, pure=False):
379
  target_batch = []
 
440
  text = samples['text_input']
441
  # encode image
442
  if has_img:
443
+ image = samples['image'][0]
444
+ bs = len(samples['text_input'][0])
445
+ image_nums = []
446
+ temp_image = []
447
+ for im in image:
448
+ if type(im) is list:
449
+ image_nums.append(len(im))
450
+ temp_image.extend(im)
451
+ else:
452
+ image_nums.append(1)
453
+ temp_image.append(im)
454
+ image = temp_image
455
+ assert type(image) is list and len(image_nums) == bs
456
+
457
  to_regress_embeds, attention_mask, targets, im_mask = self.interleav_wrap(
458
+ image, text, image_nums)
459
  else:
460
  to_regress_tokens, targets = self.text2emb(
461
  text, add_special_tokens=True)