michael-guenther commited on
Commit
1c61b96
·
1 Parent(s): 95b4916

support activation checkpointing

Browse files
Files changed (1) hide show
  1. modeling_xlm_roberta.py +46 -5
modeling_xlm_roberta.py CHANGED
@@ -17,6 +17,7 @@ from functools import partial
17
  import torch
18
  import torch.nn as nn
19
  import torch.nn.functional as F
 
20
  from einops import rearrange
21
  from transformers import PretrainedConfig
22
  from transformers.modeling_utils import PreTrainedModel
@@ -42,7 +43,6 @@ from .embedding import XLMRobertaEmbeddings
42
  from .mha import MHA
43
  from .mlp import FusedMLP, Mlp
44
 
45
- # from flash_attn.utils.pretrained import state_dict_from_pretrained
46
 
47
  try:
48
  from flash_attn.ops.fused_dense import FusedDense
@@ -166,6 +166,15 @@ class XLMRobertaEncoder(nn.Module):
166
  self.layers = nn.ModuleList(
167
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
168
  )
 
 
 
 
 
 
 
 
 
169
 
170
  def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
171
  """If subset_mask is not None, we only want output for the subset of the sequence.
@@ -177,7 +186,15 @@ class XLMRobertaEncoder(nn.Module):
177
  {"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
178
  )
179
  for layer in self.layers:
180
- hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
 
 
 
 
 
 
 
181
  if subset_mask is not None:
182
  hidden_states = hidden_states[subset_mask]
183
  else:
@@ -188,11 +205,27 @@ class XLMRobertaEncoder(nn.Module):
188
  mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
189
  if subset_mask is None:
190
  for layer in self.layers:
191
- hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
 
 
 
 
 
 
 
192
  hidden_states = pad_input(hidden_states, indices, batch, seqlen)
193
  else:
194
  for layer in self.layers[:-1]:
195
- hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
 
 
 
 
 
 
 
196
  if key_padding_mask is not None:
197
  subset_idx = torch.nonzero(
198
  subset_mask[key_padding_mask], as_tuple=False
@@ -218,7 +251,15 @@ class XLMRobertaEncoder(nn.Module):
218
  "cu_seqlens_k": cu_seqlens,
219
  "max_seqlen_k": max_seqlen_in_batch,
220
  }
221
- hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
 
 
 
 
 
 
 
 
222
  return hidden_states
223
 
224
 
 
17
  import torch
18
  import torch.nn as nn
19
  import torch.nn.functional as F
20
+ import torch.utils.checkpoint
21
  from einops import rearrange
22
  from transformers import PretrainedConfig
23
  from transformers.modeling_utils import PreTrainedModel
 
43
  from .mha import MHA
44
  from .mlp import FusedMLP, Mlp
45
 
 
46
 
47
  try:
48
  from flash_attn.ops.fused_dense import FusedDense
 
166
  self.layers = nn.ModuleList(
167
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
168
  )
169
+ self._grad_checkpointing = False
170
+
171
+ @property
172
+ def gradient_checkpointing(self):
173
+ return self._grad_checkpointing
174
+
175
+ @gradient_checkpointing.setter
176
+ def gradient_checkpointing(self, value):
177
+ self._grad_checkpointing = value
178
 
179
  def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
180
  """If subset_mask is not None, we only want output for the subset of the sequence.
 
186
  {"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
187
  )
188
  for layer in self.layers:
189
+ if self._grad_checkpointing:
190
+ hidden_states = torch.utils.checkpoint.checkpoint(
191
+ layer,
192
+ hidden_states,
193
+ use_reentrant=False,
194
+ mixer_kwargs=mixer_kwargs
195
+ )
196
+ else:
197
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
198
  if subset_mask is not None:
199
  hidden_states = hidden_states[subset_mask]
200
  else:
 
205
  mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
206
  if subset_mask is None:
207
  for layer in self.layers:
208
+ if self._grad_checkpointing:
209
+ hidden_states = torch.utils.checkpoint.checkpoint(
210
+ layer,
211
+ hidden_states,
212
+ use_reentrant=False,
213
+ mixer_kwargs=mixer_kwargs
214
+ )
215
+ else:
216
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
217
  hidden_states = pad_input(hidden_states, indices, batch, seqlen)
218
  else:
219
  for layer in self.layers[:-1]:
220
+ if self._grad_checkpointing:
221
+ hidden_states = torch.utils.checkpoint.checkpoint(
222
+ layer,
223
+ hidden_states,
224
+ use_reentrant=False,
225
+ mixer_kwargs=mixer_kwargs
226
+ )
227
+ else:
228
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
229
  if key_padding_mask is not None:
230
  subset_idx = torch.nonzero(
231
  subset_mask[key_padding_mask], as_tuple=False
 
251
  "cu_seqlens_k": cu_seqlens,
252
  "max_seqlen_k": max_seqlen_in_batch,
253
  }
254
+ if self._grad_checkpointing:
255
+ torch.utils.checkpoint.checkpoint(
256
+ self.layers[-1],
257
+ hidden_states_subset,
258
+ use_reentrant=False,
259
+ mixer_kwargs=mixer_kwargs
260
+ )
261
+ else:
262
+ hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
263
  return hidden_states
264
 
265