Text Generation
Transformers
Safetensors
English
stripedhyena
custom_code
Zymrael commited on
Commit
521ac0e
·
1 Parent(s): 8711fb6

chore: update gradient checkpointing

Browse files
Files changed (1) hide show
  1. model.py +14 -4
model.py CHANGED
@@ -311,8 +311,8 @@ class StripedHyena(nn.Module):
311
  self.embedding_layer = VocabParallelEmbedding(config)
312
  self.norm = RMSNorm(config) if config.get("final_norm", True) else None
313
  self.unembed = self.emb if config.tie_embeddings else VocabParallelEmbedding(config)
314
- self.scratchpad = None
315
-
316
  if config.get("use_flashfft", "False"):
317
  raise NotImplementedError("Please use standalone SH code for other custom kernels")
318
  else:
@@ -349,8 +349,18 @@ class StripedHyena(nn.Module):
349
  if type(padding_mask) == torch.Tensor:
350
  x = x * padding_mask[..., None]
351
 
352
- for _, block in enumerate(self.blocks):
353
- x, _ = block(x, inference_params=None, padding_mask=padding_mask)
 
 
 
 
 
 
 
 
 
 
354
  return x, None
355
 
356
  def initialize_inference_params(self):
 
311
  self.embedding_layer = VocabParallelEmbedding(config)
312
  self.norm = RMSNorm(config) if config.get("final_norm", True) else None
313
  self.unembed = self.emb if config.tie_embeddings else VocabParallelEmbedding(config)
314
+ self.gradient_checkpointing = False
315
+
316
  if config.get("use_flashfft", "False"):
317
  raise NotImplementedError("Please use standalone SH code for other custom kernels")
318
  else:
 
349
  if type(padding_mask) == torch.Tensor:
350
  x = x * padding_mask[..., None]
351
 
352
+ for block_idx, block in enumerate(self.blocks):
353
+ if self.gradient_checkpointing and self.training:
354
+ def create_custom_forward(module):
355
+ def custom_forward(*inputs):
356
+ # None for past_key_value
357
+ return module(*inputs, inference_params=None, padding_mask=padding_mask)
358
+
359
+ return custom_forward
360
+
361
+ x, _ = checkpoint(create_custom_forward(block), x, use_reentrant=False)
362
+ else:
363
+ x, _ = block(x, inference_params=None, padding_mask=padding_mask)
364
  return x, None
365
 
366
  def initialize_inference_params(self):