mingdali commited on
Commit
0ef92f7
·
verified ·
1 Parent(s): 4f8bdcf

Update visual.py

Browse files
Files changed (1) hide show
  1. visual.py +3 -17
visual.py CHANGED
@@ -25,13 +25,11 @@ def sliding_window(matrix, window_size, stride):
25
  window_cols = (width - window_size[1]) // stride + 1
26
  images_448 = F.interpolate(matrix, size=window_size, mode='bicubic')
27
  windows = []
28
- # pdb.set_trace()
29
  for i in range(window_rows):
30
  windows_col = []
31
  for j in range(window_cols):
32
  window = matrix[:,:, i*stride:i*stride+window_size[0], j*stride:j*stride+window_size[1]]
33
  windows.append(window)
34
- # windows.append(windows_col)
35
  windows.append(images_448)
36
  images = torch.cat(windows,dim=1)
37
  images = images.reshape(b*5,c,window_size[0], window_size[0])
@@ -145,12 +143,9 @@ class Resampler(nn.Module):
145
  self.ln_kv = norm_layer(embed_dim)
146
 
147
  self.apply(self._init_weights)
148
- # pdb.set_trace()
149
- #self.load_state_dict(torch.load('/cfs/cfs-lugcocyb/mingdali/code/qWen-VL/vl-chat/attn_params.pth'))
150
 
151
  def _init_weights(self, m):
152
- # self.load_state_dict(torch.load('/cfs/cfs-lugcocyb/mingdali/code/qWen-VL/vl-chat/attn_params.pth'))
153
- #pdb.set_trace()
154
  if isinstance(m, nn.Linear):
155
  trunc_normal_(m.weight, std=.02)
156
  if isinstance(m, nn.Linear) and m.bias is not None:
@@ -160,7 +155,6 @@ class Resampler(nn.Module):
160
  nn.init.constant_(m.weight, 1.0)
161
 
162
  def forward(self, x, attn_mask=None):
163
- #pdb.set_trace()
164
  pos_embed = get_abs_pos(self.pos_embed, x.size(1))
165
 
166
  x = self.kv_proj(x)
@@ -401,7 +395,6 @@ class VisionTransformer(nn.Module):
401
  act_layer=act_layer,
402
  norm_layer=norm_layer,
403
  )
404
- # pdb.set_trace()
405
  self.attn_pool = Resampler(
406
  grid_size=int(math.sqrt(n_queries)),
407
  embed_dim=output_dim,
@@ -418,14 +411,10 @@ class VisionTransformer(nn.Module):
418
  )
419
  self.ln_post = norm_layer(output_dim)
420
  self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))
421
- # self.attn_pool2.load_state_dict(torch.load('/cfs/cfs-lugcocyb/mingdali/code/qWen-VL/vl-chat/attn_params.pth'))
422
 
423
- # def initialize_vision_modules(self,lpath):
424
- # self.attn_pool2[0].load_state_dict(torch.load(lpath))
425
 
426
  def forward(self, x: torch.Tensor):
427
- #pdb.set_trace()
428
- #torch.save(self.attn_pool.state_dict(), '/cfs/cfs-lugcocyb/mingdali/code/qWen-VL/vl-chat/attn_params.pth')
429
  x = x.to(
430
  dtype=self.transformer.get_cast_dtype(),
431
  device=self.transformer.get_cast_device(),
@@ -442,7 +431,6 @@ class VisionTransformer(nn.Module):
442
  x = x.permute(1, 0, 2) # NLD -> LND
443
  x = self.transformer(x)
444
  x = x.permute(1, 0, 2) # LND -> NLD
445
- # pdb.set_trace()
446
  src_size = int(math.sqrt(x.shape[1]))
447
  x = x.reshape(x.shape[0]//5,5,-1, x.shape[-1])
448
  x1 = x[:,4,:,:]
@@ -454,7 +442,6 @@ class VisionTransformer(nn.Module):
454
  x1 = self.attn_pool(x1)
455
  x = self.post_pro(x)
456
  x1 = self.post_pro(x1)
457
- # return x1
458
  return torch.cat([x,x1],dim=1)
459
 
460
  def post_pro(self, x):
@@ -465,7 +452,7 @@ class VisionTransformer(nn.Module):
465
 
466
  def encode(self, image_paths: List[str]):
467
  images = []
468
- # pdb.set_trace()
469
  for image_path in image_paths:
470
  try:
471
  if image_path.startswith("http://") or image_path.startswith("https://"):
@@ -474,7 +461,6 @@ class VisionTransformer(nn.Module):
474
  image = self.image_transform(Image.open(image_path).convert("RGB"))
475
  except:
476
  image = torch.zeros((3, 448*2, 448*2))
477
- # pdb.set_trace()
478
  images.append(image)
479
  images = torch.stack(images, dim=0)
480
  windows = sliding_window(images,window_size=(448,448),stride=448)
 
25
  window_cols = (width - window_size[1]) // stride + 1
26
  images_448 = F.interpolate(matrix, size=window_size, mode='bicubic')
27
  windows = []
 
28
  for i in range(window_rows):
29
  windows_col = []
30
  for j in range(window_cols):
31
  window = matrix[:,:, i*stride:i*stride+window_size[0], j*stride:j*stride+window_size[1]]
32
  windows.append(window)
 
33
  windows.append(images_448)
34
  images = torch.cat(windows,dim=1)
35
  images = images.reshape(b*5,c,window_size[0], window_size[0])
 
143
  self.ln_kv = norm_layer(embed_dim)
144
 
145
  self.apply(self._init_weights)
146
+
 
147
 
148
  def _init_weights(self, m):
 
 
149
  if isinstance(m, nn.Linear):
150
  trunc_normal_(m.weight, std=.02)
151
  if isinstance(m, nn.Linear) and m.bias is not None:
 
155
  nn.init.constant_(m.weight, 1.0)
156
 
157
  def forward(self, x, attn_mask=None):
 
158
  pos_embed = get_abs_pos(self.pos_embed, x.size(1))
159
 
160
  x = self.kv_proj(x)
 
395
  act_layer=act_layer,
396
  norm_layer=norm_layer,
397
  )
 
398
  self.attn_pool = Resampler(
399
  grid_size=int(math.sqrt(n_queries)),
400
  embed_dim=output_dim,
 
411
  )
412
  self.ln_post = norm_layer(output_dim)
413
  self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))
 
414
 
 
 
415
 
416
  def forward(self, x: torch.Tensor):
417
+
 
418
  x = x.to(
419
  dtype=self.transformer.get_cast_dtype(),
420
  device=self.transformer.get_cast_device(),
 
431
  x = x.permute(1, 0, 2) # NLD -> LND
432
  x = self.transformer(x)
433
  x = x.permute(1, 0, 2) # LND -> NLD
 
434
  src_size = int(math.sqrt(x.shape[1]))
435
  x = x.reshape(x.shape[0]//5,5,-1, x.shape[-1])
436
  x1 = x[:,4,:,:]
 
442
  x1 = self.attn_pool(x1)
443
  x = self.post_pro(x)
444
  x1 = self.post_pro(x1)
 
445
  return torch.cat([x,x1],dim=1)
446
 
447
  def post_pro(self, x):
 
452
 
453
  def encode(self, image_paths: List[str]):
454
  images = []
455
+
456
  for image_path in image_paths:
457
  try:
458
  if image_path.startswith("http://") or image_path.startswith("https://"):
 
461
  image = self.image_transform(Image.open(image_path).convert("RGB"))
462
  except:
463
  image = torch.zeros((3, 448*2, 448*2))
 
464
  images.append(image)
465
  images = torch.stack(images, dim=0)
466
  windows = sliding_window(images,window_size=(448,448),stride=448)