proposed patch for the Tensor size mismatch error

#15
by claverAI - opened
Files changed (1) hide show
  1. modeling_phi.py +1 -0
modeling_phi.py CHANGED
@@ -315,6 +315,7 @@ class CrossAttention(nn.Module):
315
  dtype=scores.dtype,
316
  device=scores.device,
317
  )
 
318
  padding_mask.masked_fill_(key_padding_mask, 0.0)
319
  scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
320
 
 
315
  dtype=scores.dtype,
316
  device=scores.device,
317
  )
318
+ key_padding_mask = key_padding_mask[:, :seqlen_k]
319
  padding_mask.masked_fill_(key_padding_mask, 0.0)
320
  scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
321