tolgacangoz commited on
Commit
82eed0f
·
verified ·
1 Parent(s): 5e1ad73

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. unet/matryoshka.py +27 -68
unet/matryoshka.py CHANGED
@@ -420,6 +420,7 @@ class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin):
420
  self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
421
 
422
  self.scales = None
 
423
 
424
  def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
425
  """
@@ -532,6 +533,7 @@ class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin):
532
 
533
  def get_schedule_shifted(self, alpha_prod, scale_factor=None):
534
  if (scale_factor is not None) and (scale_factor > 1): # rescale noise schedule
 
535
  snr = alpha_prod / (1 - alpha_prod)
536
  scaled_snr = snr / scale_factor
537
  alpha_prod = 1 / (1 + 1 / scaled_snr)
@@ -640,16 +642,16 @@ class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin):
640
  if self.config.thresholding:
641
  if len(model_output) > 1:
642
  pred_original_sample = [
643
- self._threshold_sample(p_o_s * scale) / scale
644
- for p_o_s, scale in zip(pred_original_sample, self.scales)
645
  ]
646
  else:
647
  pred_original_sample = self._threshold_sample(pred_original_sample)
648
  elif self.config.clip_sample:
649
  if len(model_output) > 1:
650
  pred_original_sample = [
651
- (p_o_s * scale).clamp(-self.config.clip_sample_range, self.config.clip_sample_range) / scale
652
- for p_o_s, scale in zip(pred_original_sample, self.scales)
653
  ]
654
  else:
655
  pred_original_sample = pred_original_sample.clamp(
@@ -1440,7 +1442,7 @@ class MatryoshkaTransformerBlock(nn.Module):
1440
  bias=True,
1441
  upcast_attention=upcast_attention,
1442
  pre_only=True,
1443
- processor=MatryoshkaFusedAttnProcessor1_0_or_2_0(),
1444
  )
1445
  self.attn1.fuse_projections()
1446
  del self.attn1.to_q
@@ -1458,7 +1460,7 @@ class MatryoshkaTransformerBlock(nn.Module):
1458
  bias=True,
1459
  upcast_attention=upcast_attention,
1460
  pre_only=True,
1461
- processor=MatryoshkaFusedAttnProcessor1_0_or_2_0(),
1462
  )
1463
  self.attn2.fuse_projections()
1464
  del self.attn2.to_q
@@ -1517,7 +1519,6 @@ class MatryoshkaTransformerBlock(nn.Module):
1517
  # **cross_attention_kwargs,
1518
  )
1519
 
1520
- # attn_output_cond = attn_output_cond.permute(0, 2, 1).contiguous()
1521
  attn_output_cond = self.proj_out(attn_output_cond)
1522
  attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims)
1523
  hidden_states = hidden_states + attn_output_cond
@@ -1535,7 +1536,7 @@ class MatryoshkaTransformerBlock(nn.Module):
1535
  return hidden_states
1536
 
1537
 
1538
- class MatryoshkaFusedAttnProcessor1_0_or_2_0:
1539
  r"""
1540
  Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
1541
  fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
@@ -1548,28 +1549,11 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
1548
  </Tip>
1549
  """
1550
 
1551
- # def __init__(self):
1552
- # if not hasattr(F, "scaled_dot_product_attention"):
1553
- # raise ImportError(
1554
- # "MatryoshkaFusedAttnProcessor2_0 requires PyTorch 2.x, to use it. Please upgrade PyTorch to > 2.x."
1555
- # )
1556
-
1557
- # TODO: They seem to give different results; but nevertheless can I replace this with torch.nn.functional.scaled_dot_product_attention()?
1558
- def attention(self, q, k, v, num_heads, mask=None):
1559
- bs, width, length = q.shape
1560
- ch = width // num_heads
1561
- scale = 1 / torch.sqrt(torch.sqrt(torch.tensor(ch)))
1562
- weight = torch.einsum(
1563
- "bct,bcs->bts",
1564
- (q * scale).reshape(bs * num_heads, ch, length),
1565
- (k * scale).reshape(bs * num_heads, ch, -1),
1566
- ) # More stable with f16 than dividing afterwards
1567
- if mask is not None:
1568
- mask = mask.view(mask.size(0), 1, 1, mask.size(-1)).repeat(1, num_heads, 1, 1).flatten(0, 1)
1569
- weight = weight.masked_fill(mask == 0, float("-inf"))
1570
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
1571
- a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * num_heads, ch, -1))
1572
- return a.reshape(bs, -1, length)
1573
 
1574
  def __call__(
1575
  self,
@@ -1593,26 +1577,12 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
1593
 
1594
  input_ndim = hidden_states.ndim
1595
 
1596
- if input_ndim == 4:
1597
- batch_size, channel, height, width = hidden_states.shape
1598
- # hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1599
-
1600
- # batch_size, sequence_length, _ = (
1601
- # hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1602
- # )
1603
-
1604
- # if attention_mask is not None:
1605
- # attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1606
- # # scaled_dot_product_attention expects attention_mask shape to be
1607
- # # (batch, heads, source_length, target_length)
1608
- # attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1609
-
1610
  if attn.group_norm is not None:
1611
- hidden_states = attn.group_norm(hidden_states) # .transpose(1, 2)).transpose(1, 2)
1612
 
1613
- # Reshape hidden_states to 2D tensor
1614
- hidden_states = hidden_states.view(batch_size, channel, height * width).permute(0, 2, 1).contiguous()
1615
- # Now hidden_states.shape is [batch_size, height * width, channels]
1616
 
1617
  if encoder_hidden_states is None:
1618
  qkv = attn.to_qkv(hidden_states)
@@ -1630,11 +1600,6 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
1630
  split_size = kv.shape[-1] // 2
1631
  key, value = torch.split(kv, split_size, dim=-1)
1632
 
1633
- # if self_attention_output is None:
1634
- # query = query.permute(0, 2, 1)
1635
- # key = key.permute(0, 2, 1)
1636
- # value = value.permute(0, 2, 1)
1637
-
1638
  if attn.norm_q is not None:
1639
  query = attn.norm_q(query)
1640
  if attn.norm_k is not None:
@@ -1659,16 +1624,6 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
1659
  query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1660
  )
1661
 
1662
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
1663
- # TODO: add support for attn.scale when we move to Torch 2.1 if F.scaled_dot_product_attention() is available
1664
- # hidden_states = self.attention(
1665
- # query,
1666
- # key,
1667
- # value,
1668
- # mask=attention_mask,
1669
- # num_heads=attn.heads,
1670
- # )
1671
-
1672
  hidden_states = hidden_states.to(query.dtype)
1673
 
1674
  if self_attention_output is not None:
@@ -1956,7 +1911,7 @@ class MatryoshkaCombinedTimestepTextEmbedding(nn.Module):
1956
  # if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False):
1957
  return temb_micro_conditioning, conditioning_mask, cond_emb
1958
 
1959
- return cond_emb, conditioning_mask, cond_emb
1960
 
1961
 
1962
  @dataclass
@@ -3184,7 +3139,7 @@ class MatryoshkaUNet2DConditionModel(
3184
  encoder_hidden_states=encoder_hidden_states,
3185
  attention_mask=attention_mask,
3186
  cross_attention_kwargs=cross_attention_kwargs,
3187
- encoder_attention_mask=encoder_attention_mask, # cond_mask?
3188
  **additional_residuals,
3189
  )
3190
  else:
@@ -3214,7 +3169,7 @@ class MatryoshkaUNet2DConditionModel(
3214
  encoder_hidden_states=encoder_hidden_states,
3215
  attention_mask=attention_mask,
3216
  cross_attention_kwargs=cross_attention_kwargs,
3217
- encoder_attention_mask=encoder_attention_mask, # cond_mask?
3218
  )
3219
  else:
3220
  sample = self.mid_block(sample, emb)
@@ -3251,7 +3206,7 @@ class MatryoshkaUNet2DConditionModel(
3251
  cross_attention_kwargs=cross_attention_kwargs,
3252
  upsample_size=upsample_size,
3253
  attention_mask=attention_mask,
3254
- encoder_attention_mask=encoder_attention_mask, # cond_mask?
3255
  )
3256
  else:
3257
  sample = upsample_block(
@@ -3699,7 +3654,7 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
3699
  cross_attention_kwargs=cross_attention_kwargs,
3700
  upsample_size=upsample_size,
3701
  attention_mask=attention_mask,
3702
- encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask, # cond_mask?
3703
  )
3704
  else:
3705
  sample = upsample_block(
@@ -3863,6 +3818,8 @@ class MatryoshkaPipeline(
3863
 
3864
  if hasattr(unet, "nest_ratio"):
3865
  scheduler.scales = unet.nest_ratio + [1]
 
 
3866
 
3867
  self.register_modules(
3868
  text_encoder=text_encoder,
@@ -3889,12 +3846,14 @@ class MatryoshkaPipeline(
3889
  ).to(self.device)
3890
  self.config.nesting_level = 1
3891
  self.scheduler.scales = self.unet.nest_ratio + [1]
 
3892
  elif nesting_level == 2:
3893
  self.unet = NestedUNet2DConditionModel.from_pretrained(
3894
  "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_2"
3895
  ).to(self.device)
3896
  self.config.nesting_level = 2
3897
  self.scheduler.scales = self.unet.nest_ratio + [1]
 
3898
  else:
3899
  raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
3900
 
 
420
  self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
421
 
422
  self.scales = None
423
+ self.schedule_shifted_power = 1.0
424
 
425
  def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
426
  """
 
533
 
534
  def get_schedule_shifted(self, alpha_prod, scale_factor=None):
535
  if (scale_factor is not None) and (scale_factor > 1): # rescale noise schedule
536
+ scale_factor = scale_factor ** self.schedule_shifted_power
537
  snr = alpha_prod / (1 - alpha_prod)
538
  scaled_snr = snr / scale_factor
539
  alpha_prod = 1 / (1 + 1 / scaled_snr)
 
642
  if self.config.thresholding:
643
  if len(model_output) > 1:
644
  pred_original_sample = [
645
+ self._threshold_sample(p_o_s)
646
+ for p_o_s in pred_original_sample
647
  ]
648
  else:
649
  pred_original_sample = self._threshold_sample(pred_original_sample)
650
  elif self.config.clip_sample:
651
  if len(model_output) > 1:
652
  pred_original_sample = [
653
+ p_o_s.clamp(-self.config.clip_sample_range, self.config.clip_sample_range)
654
+ for p_o_s in pred_original_sample
655
  ]
656
  else:
657
  pred_original_sample = pred_original_sample.clamp(
 
1442
  bias=True,
1443
  upcast_attention=upcast_attention,
1444
  pre_only=True,
1445
+ processor=MatryoshkaFusedAttnProcessor2_0(),
1446
  )
1447
  self.attn1.fuse_projections()
1448
  del self.attn1.to_q
 
1460
  bias=True,
1461
  upcast_attention=upcast_attention,
1462
  pre_only=True,
1463
+ processor=MatryoshkaFusedAttnProcessor2_0(),
1464
  )
1465
  self.attn2.fuse_projections()
1466
  del self.attn2.to_q
 
1519
  # **cross_attention_kwargs,
1520
  )
1521
 
 
1522
  attn_output_cond = self.proj_out(attn_output_cond)
1523
  attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims)
1524
  hidden_states = hidden_states + attn_output_cond
 
1536
  return hidden_states
1537
 
1538
 
1539
+ class MatryoshkaFusedAttnProcessor2_0:
1540
  r"""
1541
  Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
1542
  fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
 
1549
  </Tip>
1550
  """
1551
 
1552
+ def __init__(self):
1553
+ if not hasattr(F, "scaled_dot_product_attention"):
1554
+ raise ImportError(
1555
+ "MatryoshkaFusedAttnProcessor2_0 requires PyTorch 2.x, to use it. Please upgrade PyTorch to > 2.x."
1556
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1557
 
1558
  def __call__(
1559
  self,
 
1577
 
1578
  input_ndim = hidden_states.ndim
1579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1580
  if attn.group_norm is not None:
1581
+ hidden_states = attn.group_norm(hidden_states)
1582
 
1583
+ if input_ndim == 4:
1584
+ batch_size, channel, height, width = hidden_states.shape
1585
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2).contiguous()
1586
 
1587
  if encoder_hidden_states is None:
1588
  qkv = attn.to_qkv(hidden_states)
 
1600
  split_size = kv.shape[-1] // 2
1601
  key, value = torch.split(kv, split_size, dim=-1)
1602
 
 
 
 
 
 
1603
  if attn.norm_q is not None:
1604
  query = attn.norm_q(query)
1605
  if attn.norm_k is not None:
 
1624
  query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1625
  )
1626
 
 
 
 
 
 
 
 
 
 
 
1627
  hidden_states = hidden_states.to(query.dtype)
1628
 
1629
  if self_attention_output is not None:
 
1911
  # if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False):
1912
  return temb_micro_conditioning, conditioning_mask, cond_emb
1913
 
1914
+ return None, conditioning_mask, cond_emb
1915
 
1916
 
1917
  @dataclass
 
3139
  encoder_hidden_states=encoder_hidden_states,
3140
  attention_mask=attention_mask,
3141
  cross_attention_kwargs=cross_attention_kwargs,
3142
+ encoder_attention_mask=encoder_attention_mask,
3143
  **additional_residuals,
3144
  )
3145
  else:
 
3169
  encoder_hidden_states=encoder_hidden_states,
3170
  attention_mask=attention_mask,
3171
  cross_attention_kwargs=cross_attention_kwargs,
3172
+ encoder_attention_mask=encoder_attention_mask,
3173
  )
3174
  else:
3175
  sample = self.mid_block(sample, emb)
 
3206
  cross_attention_kwargs=cross_attention_kwargs,
3207
  upsample_size=upsample_size,
3208
  attention_mask=attention_mask,
3209
+ encoder_attention_mask=encoder_attention_mask,
3210
  )
3211
  else:
3212
  sample = upsample_block(
 
3654
  cross_attention_kwargs=cross_attention_kwargs,
3655
  upsample_size=upsample_size,
3656
  attention_mask=attention_mask,
3657
+ encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask,
3658
  )
3659
  else:
3660
  sample = upsample_block(
 
3818
 
3819
  if hasattr(unet, "nest_ratio"):
3820
  scheduler.scales = unet.nest_ratio + [1]
3821
+ if nesting_level == 2:
3822
+ scheduler.schedule_shifted_power = 2.0
3823
 
3824
  self.register_modules(
3825
  text_encoder=text_encoder,
 
3846
  ).to(self.device)
3847
  self.config.nesting_level = 1
3848
  self.scheduler.scales = self.unet.nest_ratio + [1]
3849
+ self.scheduler.schedule_shifted_power = 1.0
3850
  elif nesting_level == 2:
3851
  self.unet = NestedUNet2DConditionModel.from_pretrained(
3852
  "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_2"
3853
  ).to(self.device)
3854
  self.config.nesting_level = 2
3855
  self.scheduler.scales = self.unet.nest_ratio + [1]
3856
+ self.scheduler.schedule_shifted_power = 2.0
3857
  else:
3858
  raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
3859