Jackmin108 commited on
Commit
4ee2970
·
1 Parent(s): d2c9d06

Signed-off-by: Meow <[email protected]>

Files changed (2) hide show
  1. mlp.py +1 -1
  2. xlm_padding.py +1 -7
mlp.py CHANGED
@@ -74,7 +74,7 @@ class Mlp(nn.Module):
74
  task_out = self.fc2(task_tensor, task_id=task_id)
75
  out[task_indices] = task_out
76
  else:
77
- out = self.fc1(y)
78
 
79
  return out if not self.return_residual else (out, x)
80
 
 
74
  task_out = self.fc2(task_tensor, task_id=task_id)
75
  out[task_indices] = task_out
76
  else:
77
+ out = self.fc2(y)
78
 
79
  return out if not self.return_residual else (out, x)
80
 
xlm_padding.py CHANGED
@@ -114,13 +114,7 @@ def unpad_input(hidden_states, attention_mask, adapter_mask=None):
114
  max_seqlen_in_batch = seqlens_in_batch.max().item()
115
  cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
116
 
117
- cu_adapter_mask = None
118
- if adapter_mask:
119
- cu_adapter_mask = torch.empty(cu_seqlens[-1], dtype=torch.int32)
120
- for i in range(len(adapter_mask)):
121
- start_idx = cu_seqlens[i]
122
- end_idx = cu_seqlens[i + 1]
123
- cu_adapter_mask[start_idx:end_idx] = adapter_mask[i]
124
 
125
  # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
126
  # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
 
114
  max_seqlen_in_batch = seqlens_in_batch.max().item()
115
  cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
116
 
117
+ cu_adapter_mask = torch.repeat_interleave(adapter_mask, cu_seqlens[1:] - cu_seqlens[:-1]) if adapter_mask is not None else None
 
 
 
 
 
 
118
 
119
  # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
120
  # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim