MykolaL commited on
Commit
83f5a68
·
verified ·
1 Parent(s): b0bac6d

Upload EVPRefer_warp

Browse files
Files changed (2) hide show
  1. model.py +3 -2
  2. model.safetensors +2 -2
model.py CHANGED
@@ -286,7 +286,8 @@ class EVPRefer(nn.Module):
286
 
287
  self.classifier = SimpleDecoding(dims=neck_dim)
288
 
289
- self.gamma = nn.Parameter(torch.ones(token_embed_dim) * 1e-4)
 
290
  self.aggregation = InverseMultiAttentiveFeatureRefinement([320,680,1320,1280])
291
  self.clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
292
  self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
@@ -310,7 +311,7 @@ class EVPRefer(nn.Module):
310
  latents = latents / 4.7164
311
 
312
  l_feats = self.clip_model(input_ids=input_ids).last_hidden_state
313
- c_crossattn = self.text_adapter(latents, l_feats, self.gamma) # NOTE: here the c_crossattn should be expand_dim as latents
314
  t = torch.ones((img.shape[0],), device=img.device).long()
315
  outs = self.unet(latents, t, c_crossattn=[c_crossattn])
316
 
 
286
 
287
  self.classifier = SimpleDecoding(dims=neck_dim)
288
 
289
+ self.my_gamma = nn.Parameter(torch.ones(token_embed_dim) * 1e-4)
290
+
291
  self.aggregation = InverseMultiAttentiveFeatureRefinement([320,680,1320,1280])
292
  self.clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
293
  self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
 
311
  latents = latents / 4.7164
312
 
313
  l_feats = self.clip_model(input_ids=input_ids).last_hidden_state
314
+ c_crossattn = self.text_adapter(latents, l_feats, self.my_gamma) # NOTE: here the c_crossattn should be expand_dim as latents
315
  t = torch.ones((img.shape[0],), device=img.device).long()
316
  outs = self.unet(latents, t, c_crossattn=[c_crossattn])
317
 
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8cb0158d61ae68cf82b2a478c35521123c8f7e89af87888349c97363392824a0
3
- size 4317953152
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc513a54bae634771a5a1ad6f86e8cb390600a01a94ae45b1eafe5d2325d9eb8
3
+ size 4317953160