update modeling
Browse files- modeling_gptpangu.py +1 -1
modeling_gptpangu.py
CHANGED
@@ -460,7 +460,7 @@ class GPTPanguForCausalLM(GPTPanguPreTrainedModel):
|
|
460 |
|
461 |
if attention_mask is not None and position_ids is None:
|
462 |
# create position_ids on the fly for batch generation
|
463 |
-
position_ids = attention_mask.
|
464 |
position_ids.masked_fill_(attention_mask == 0, 1)
|
465 |
if past:
|
466 |
position_ids = position_ids[:, -1].unsqueeze(-1)
|
|
|
460 |
|
461 |
if attention_mask is not None and position_ids is None:
|
462 |
# create position_ids on the fly for batch generation
|
463 |
+
position_ids = attention_mask.int().cumsum(-1).long() - 1
|
464 |
position_ids.masked_fill_(attention_mask == 0, 1)
|
465 |
if past:
|
466 |
position_ids = position_ids[:, -1].unsqueeze(-1)
|