Jackmin108
commited on
Commit
·
c1736a8
1
Parent(s):
362ef00
feat: adapter masking wip
Browse filesSigned-off-by: Meow <[email protected]>
- embedding.py +21 -4
- modeling_lora.py +10 -2
- modeling_xlm_roberta.py +7 -6
- xlm_padding.py +9 -1
embedding.py
CHANGED
@@ -40,7 +40,7 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
40 |
if self.type_vocab_size > 0:
|
41 |
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
|
42 |
|
43 |
-
def forward(self, input_ids, position_ids=None, token_type_ids=None, task_type=None):
|
44 |
"""
|
45 |
input_ids: (batch, seqlen)
|
46 |
position_ids: (batch, seqlen)
|
@@ -55,9 +55,25 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
55 |
emb1 = self.word_embeddings(tensor1, task_type=task_type[0])
|
56 |
emb2 = self.word_embeddings(tensor2, task_type=task_type[1])
|
57 |
embeddings = torch.cat((emb1, emb2), dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
else:
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
61 |
|
62 |
if self.max_position_embeddings > 0:
|
63 |
if position_ids is None:
|
@@ -79,7 +95,8 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
79 |
emb2 = emb2 + token_type_embs2
|
80 |
embeddings = torch.cat((emb1, emb2), dim=0)
|
81 |
else:
|
82 |
-
|
|
|
83 |
token_type_embeddings = self.token_type_embeddings(token_type_ids, **lora_kwargs)
|
84 |
embeddings = embeddings + token_type_embeddings
|
85 |
return embeddings
|
|
|
40 |
if self.type_vocab_size > 0:
|
41 |
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
|
42 |
|
43 |
+
def forward(self, input_ids, position_ids=None, token_type_ids=None, task_type=None, adapter_mask=None):
|
44 |
"""
|
45 |
input_ids: (batch, seqlen)
|
46 |
position_ids: (batch, seqlen)
|
|
|
55 |
emb1 = self.word_embeddings(tensor1, task_type=task_type[0])
|
56 |
emb2 = self.word_embeddings(tensor2, task_type=task_type[1])
|
57 |
embeddings = torch.cat((emb1, emb2), dim=0)
|
58 |
+
|
59 |
+
unique_tasks = torch.unique(adapter_mask).tolist()
|
60 |
+
torch_dtype = next(self.word_embeddings.parameters()).dtype
|
61 |
+
embeddings = torch.empty(*input_ids.shape, self.word_embeddings.embedding_dim, dtype=torch_dtype).to(input_ids.device)
|
62 |
+
for task in unique_tasks:
|
63 |
+
indices = (adapter_mask == task).nonzero(as_tuple=True)[0]
|
64 |
+
inp = input_ids[indices]
|
65 |
+
lora_kwargs = {'task_type': task} if task is not None else {}
|
66 |
+
emb = self.word_embeddings(inp, **lora_kwargs)
|
67 |
+
embeddings[indices] = emb
|
68 |
+
|
69 |
+
exit(0)
|
70 |
else:
|
71 |
+
unique_task = torch.unique(adapter_mask)[0]
|
72 |
+
task1_indices = (adapter_mask == unique_task).nonzero(as_tuple=True)[0]
|
73 |
+
input1 = input_ids[task1_indices]
|
74 |
+
lora_kwargs = {'task_type': unique_task} if unique_task is not None else {}
|
75 |
+
embeddings = self.word_embeddings(input1, **lora_kwargs)
|
76 |
+
|
77 |
|
78 |
if self.max_position_embeddings > 0:
|
79 |
if position_ids is None:
|
|
|
95 |
emb2 = emb2 + token_type_embs2
|
96 |
embeddings = torch.cat((emb1, emb2), dim=0)
|
97 |
else:
|
98 |
+
unique_task = torch.unique(adapter_mask)[0]
|
99 |
+
lora_kwargs = {'task_type': unique_task} if unique_task is not None else {}
|
100 |
token_type_embeddings = self.token_type_embeddings(token_type_ids, **lora_kwargs)
|
101 |
embeddings = embeddings + token_type_embeddings
|
102 |
return embeddings
|
modeling_lora.py
CHANGED
@@ -177,7 +177,11 @@ class LoRAParametrization(nn.Module):
|
|
177 |
)
|
178 |
|
179 |
def new_forward(self, input, task_type, residual=False):
|
180 |
-
|
|
|
|
|
|
|
|
|
181 |
if task_idx is not None:
|
182 |
weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
|
183 |
else:
|
@@ -205,7 +209,11 @@ class LoRAParametrization(nn.Module):
|
|
205 |
)
|
206 |
|
207 |
def new_forward(self, input, task_type):
|
208 |
-
|
|
|
|
|
|
|
|
|
209 |
if task_idx is not None:
|
210 |
weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
|
211 |
else:
|
|
|
177 |
)
|
178 |
|
179 |
def new_forward(self, input, task_type, residual=False):
|
180 |
+
if isinstance(task_type, str):
|
181 |
+
task_idx = adaptation_map[task_type] if task_type else None
|
182 |
+
else:
|
183 |
+
task_idx = task_type
|
184 |
+
|
185 |
if task_idx is not None:
|
186 |
weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
|
187 |
else:
|
|
|
209 |
)
|
210 |
|
211 |
def new_forward(self, input, task_type):
|
212 |
+
if isinstance(task_type, str):
|
213 |
+
task_idx = adaptation_map[task_type] if task_type else None
|
214 |
+
else:
|
215 |
+
task_idx = task_type
|
216 |
+
|
217 |
if task_idx is not None:
|
218 |
weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
|
219 |
else:
|
modeling_xlm_roberta.py
CHANGED
@@ -204,7 +204,7 @@ class XLMRobertaEncoder(nn.Module):
|
|
204 |
def gradient_checkpointing(self, value):
|
205 |
self._grad_checkpointing = value
|
206 |
|
207 |
-
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None, task_type=None):
|
208 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
209 |
This means that we only compute the last layer output for these tokens.
|
210 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
@@ -230,10 +230,10 @@ class XLMRobertaEncoder(nn.Module):
|
|
230 |
hidden_states = hidden_states[subset_mask]
|
231 |
else:
|
232 |
batch, seqlen = hidden_states.shape[:2]
|
233 |
-
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
234 |
-
hidden_states, key_padding_mask
|
235 |
)
|
236 |
-
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "task_type": task_type}
|
237 |
if subset_mask is None:
|
238 |
for layer in self.layers:
|
239 |
if self._grad_checkpointing:
|
@@ -649,6 +649,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
649 |
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
650 |
"""
|
651 |
task_type = kwargs.pop('task_type', None)
|
|
|
652 |
if kwargs:
|
653 |
for key, value in kwargs.items():
|
654 |
if value is not None:
|
@@ -662,7 +663,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
662 |
)
|
663 |
|
664 |
hidden_states = self.embeddings(
|
665 |
-
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, task_type=task_type
|
666 |
)
|
667 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
668 |
# BERT puts embedding LayerNorm before embedding dropout.
|
@@ -686,7 +687,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
686 |
subset_mask = None
|
687 |
|
688 |
sequence_output = self.encoder(
|
689 |
-
hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask, task_type=task_type
|
690 |
)
|
691 |
|
692 |
if masked_tokens_mask is None:
|
|
|
204 |
def gradient_checkpointing(self, value):
|
205 |
self._grad_checkpointing = value
|
206 |
|
207 |
+
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None, task_type=None, adapter_mask=None):
|
208 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
209 |
This means that we only compute the last layer output for these tokens.
|
210 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
|
|
230 |
hidden_states = hidden_states[subset_mask]
|
231 |
else:
|
232 |
batch, seqlen = hidden_states.shape[:2]
|
233 |
+
hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = unpad_input(
|
234 |
+
hidden_states, key_padding_mask, adapter_mask
|
235 |
)
|
236 |
+
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "task_type": task_type, "cu_adapter_mask": cu_adapter_mask}
|
237 |
if subset_mask is None:
|
238 |
for layer in self.layers:
|
239 |
if self._grad_checkpointing:
|
|
|
649 |
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
650 |
"""
|
651 |
task_type = kwargs.pop('task_type', None)
|
652 |
+
adapter_mask = kwargs.pop('adapter_mask', None)
|
653 |
if kwargs:
|
654 |
for key, value in kwargs.items():
|
655 |
if value is not None:
|
|
|
663 |
)
|
664 |
|
665 |
hidden_states = self.embeddings(
|
666 |
+
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, task_type=task_type, adapter_mask=adapter_mask
|
667 |
)
|
668 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
669 |
# BERT puts embedding LayerNorm before embedding dropout.
|
|
|
687 |
subset_mask = None
|
688 |
|
689 |
sequence_output = self.encoder(
|
690 |
+
hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask, task_type=task_type, adapter_mask=adapter_mask
|
691 |
)
|
692 |
|
693 |
if masked_tokens_mask is None:
|
xlm_padding.py
CHANGED
@@ -98,7 +98,7 @@ class IndexFirstAxisResidual(torch.autograd.Function):
|
|
98 |
index_first_axis_residual = IndexFirstAxisResidual.apply
|
99 |
|
100 |
|
101 |
-
def unpad_input(hidden_states, attention_mask):
|
102 |
"""
|
103 |
Arguments:
|
104 |
hidden_states: (batch, seqlen, ...)
|
@@ -113,6 +113,13 @@ def unpad_input(hidden_states, attention_mask):
|
|
113 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
114 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
115 |
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
117 |
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
118 |
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
@@ -123,6 +130,7 @@ def unpad_input(hidden_states, attention_mask):
|
|
123 |
indices,
|
124 |
cu_seqlens,
|
125 |
max_seqlen_in_batch,
|
|
|
126 |
)
|
127 |
|
128 |
|
|
|
98 |
index_first_axis_residual = IndexFirstAxisResidual.apply
|
99 |
|
100 |
|
101 |
+
def unpad_input(hidden_states, attention_mask, adapter_mask):
|
102 |
"""
|
103 |
Arguments:
|
104 |
hidden_states: (batch, seqlen, ...)
|
|
|
113 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
114 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
115 |
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
116 |
+
|
117 |
+
cu_adapter_mask = torch.empty(cu_seqlens[-1], dtype=torch.int32)
|
118 |
+
for i in range(len(adapter_mask)):
|
119 |
+
start_idx = cu_seqlens[i]
|
120 |
+
end_idx = cu_seqlens[i + 1]
|
121 |
+
cu_adapter_mask[start_idx:end_idx] = adapter_mask[i]
|
122 |
+
|
123 |
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
124 |
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
125 |
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
|
|
130 |
indices,
|
131 |
cu_seqlens,
|
132 |
max_seqlen_in_batch,
|
133 |
+
cu_adapter_mask,
|
134 |
)
|
135 |
|
136 |
|