chore: update gradient checkpointing
Browse files
model.py
CHANGED
@@ -311,8 +311,8 @@ class StripedHyena(nn.Module):
|
|
311 |
self.embedding_layer = VocabParallelEmbedding(config)
|
312 |
self.norm = RMSNorm(config) if config.get("final_norm", True) else None
|
313 |
self.unembed = self.emb if config.tie_embeddings else VocabParallelEmbedding(config)
|
314 |
-
self.
|
315 |
-
|
316 |
if config.get("use_flashfft", "False"):
|
317 |
raise NotImplementedError("Please use standalone SH code for other custom kernels")
|
318 |
else:
|
@@ -349,8 +349,18 @@ class StripedHyena(nn.Module):
|
|
349 |
if type(padding_mask) == torch.Tensor:
|
350 |
x = x * padding_mask[..., None]
|
351 |
|
352 |
-
for
|
353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
return x, None
|
355 |
|
356 |
def initialize_inference_params(self):
|
|
|
311 |
self.embedding_layer = VocabParallelEmbedding(config)
|
312 |
self.norm = RMSNorm(config) if config.get("final_norm", True) else None
|
313 |
self.unembed = self.emb if config.tie_embeddings else VocabParallelEmbedding(config)
|
314 |
+
self.gradient_checkpointing = False
|
315 |
+
|
316 |
if config.get("use_flashfft", "False"):
|
317 |
raise NotImplementedError("Please use standalone SH code for other custom kernels")
|
318 |
else:
|
|
|
349 |
if type(padding_mask) == torch.Tensor:
|
350 |
x = x * padding_mask[..., None]
|
351 |
|
352 |
+
for block_idx, block in enumerate(self.blocks):
|
353 |
+
if self.gradient_checkpointing and self.training:
|
354 |
+
def create_custom_forward(module):
|
355 |
+
def custom_forward(*inputs):
|
356 |
+
# None for past_key_value
|
357 |
+
return module(*inputs, inference_params=None, padding_mask=padding_mask)
|
358 |
+
|
359 |
+
return custom_forward
|
360 |
+
|
361 |
+
x, _ = checkpoint(create_custom_forward(block), x, use_reentrant=False)
|
362 |
+
else:
|
363 |
+
x, _ = block(x, inference_params=None, padding_mask=padding_mask)
|
364 |
return x, None
|
365 |
|
366 |
def initialize_inference_params(self):
|