michael-guenther
commited on
Commit
·
1c61b96
1
Parent(s):
95b4916
support activation checkpointing
Browse files- modeling_xlm_roberta.py +46 -5
modeling_xlm_roberta.py
CHANGED
@@ -17,6 +17,7 @@ from functools import partial
|
|
17 |
import torch
|
18 |
import torch.nn as nn
|
19 |
import torch.nn.functional as F
|
|
|
20 |
from einops import rearrange
|
21 |
from transformers import PretrainedConfig
|
22 |
from transformers.modeling_utils import PreTrainedModel
|
@@ -42,7 +43,6 @@ from .embedding import XLMRobertaEmbeddings
|
|
42 |
from .mha import MHA
|
43 |
from .mlp import FusedMLP, Mlp
|
44 |
|
45 |
-
# from flash_attn.utils.pretrained import state_dict_from_pretrained
|
46 |
|
47 |
try:
|
48 |
from flash_attn.ops.fused_dense import FusedDense
|
@@ -166,6 +166,15 @@ class XLMRobertaEncoder(nn.Module):
|
|
166 |
self.layers = nn.ModuleList(
|
167 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
168 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
171 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
@@ -177,7 +186,15 @@ class XLMRobertaEncoder(nn.Module):
|
|
177 |
{"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
|
178 |
)
|
179 |
for layer in self.layers:
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
if subset_mask is not None:
|
182 |
hidden_states = hidden_states[subset_mask]
|
183 |
else:
|
@@ -188,11 +205,27 @@ class XLMRobertaEncoder(nn.Module):
|
|
188 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
189 |
if subset_mask is None:
|
190 |
for layer in self.layers:
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
193 |
else:
|
194 |
for layer in self.layers[:-1]:
|
195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
if key_padding_mask is not None:
|
197 |
subset_idx = torch.nonzero(
|
198 |
subset_mask[key_padding_mask], as_tuple=False
|
@@ -218,7 +251,15 @@ class XLMRobertaEncoder(nn.Module):
|
|
218 |
"cu_seqlens_k": cu_seqlens,
|
219 |
"max_seqlen_k": max_seqlen_in_batch,
|
220 |
}
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
return hidden_states
|
223 |
|
224 |
|
|
|
17 |
import torch
|
18 |
import torch.nn as nn
|
19 |
import torch.nn.functional as F
|
20 |
+
import torch.utils.checkpoint
|
21 |
from einops import rearrange
|
22 |
from transformers import PretrainedConfig
|
23 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
43 |
from .mha import MHA
|
44 |
from .mlp import FusedMLP, Mlp
|
45 |
|
|
|
46 |
|
47 |
try:
|
48 |
from flash_attn.ops.fused_dense import FusedDense
|
|
|
166 |
self.layers = nn.ModuleList(
|
167 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
168 |
)
|
169 |
+
self._grad_checkpointing = False
|
170 |
+
|
171 |
+
@property
|
172 |
+
def gradient_checkpointing(self):
|
173 |
+
return self._grad_checkpointing
|
174 |
+
|
175 |
+
@gradient_checkpointing.setter
|
176 |
+
def gradient_checkpointing(self, value):
|
177 |
+
self._grad_checkpointing = value
|
178 |
|
179 |
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
180 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
|
|
186 |
{"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
|
187 |
)
|
188 |
for layer in self.layers:
|
189 |
+
if self._grad_checkpointing:
|
190 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
191 |
+
layer,
|
192 |
+
hidden_states,
|
193 |
+
use_reentrant=False,
|
194 |
+
mixer_kwargs=mixer_kwargs
|
195 |
+
)
|
196 |
+
else:
|
197 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
198 |
if subset_mask is not None:
|
199 |
hidden_states = hidden_states[subset_mask]
|
200 |
else:
|
|
|
205 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
206 |
if subset_mask is None:
|
207 |
for layer in self.layers:
|
208 |
+
if self._grad_checkpointing:
|
209 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
210 |
+
layer,
|
211 |
+
hidden_states,
|
212 |
+
use_reentrant=False,
|
213 |
+
mixer_kwargs=mixer_kwargs
|
214 |
+
)
|
215 |
+
else:
|
216 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
217 |
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
218 |
else:
|
219 |
for layer in self.layers[:-1]:
|
220 |
+
if self._grad_checkpointing:
|
221 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
222 |
+
layer,
|
223 |
+
hidden_states,
|
224 |
+
use_reentrant=False,
|
225 |
+
mixer_kwargs=mixer_kwargs
|
226 |
+
)
|
227 |
+
else:
|
228 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
229 |
if key_padding_mask is not None:
|
230 |
subset_idx = torch.nonzero(
|
231 |
subset_mask[key_padding_mask], as_tuple=False
|
|
|
251 |
"cu_seqlens_k": cu_seqlens,
|
252 |
"max_seqlen_k": max_seqlen_in_batch,
|
253 |
}
|
254 |
+
if self._grad_checkpointing:
|
255 |
+
torch.utils.checkpoint.checkpoint(
|
256 |
+
self.layers[-1],
|
257 |
+
hidden_states_subset,
|
258 |
+
use_reentrant=False,
|
259 |
+
mixer_kwargs=mixer_kwargs
|
260 |
+
)
|
261 |
+
else:
|
262 |
+
hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
|
263 |
return hidden_states
|
264 |
|
265 |
|