Upload modeling_cocom.py
Browse files- modeling_cocom.py +1 -1
modeling_cocom.py
CHANGED
@@ -226,7 +226,7 @@ class COCOM(PreTrainedModel):
|
|
226 |
self.sep = cfg.sep
|
227 |
self.compr_rate = cfg.compr_rate
|
228 |
self.local_rank = os.getenv('LOCAL_RANK', '0')
|
229 |
-
for layer in self.decoder.
|
230 |
layer.attention.self = CustomFlashAttention(
|
231 |
embed_dim=cfg.hidden_size,
|
232 |
num_heads=cfg.num_attention_heads,
|
|
|
226 |
self.sep = cfg.sep
|
227 |
self.compr_rate = cfg.compr_rate
|
228 |
self.local_rank = os.getenv('LOCAL_RANK', '0')
|
229 |
+
for layer in self.decoder.encoder.layer:
|
230 |
layer.attention.self = CustomFlashAttention(
|
231 |
embed_dim=cfg.hidden_size,
|
232 |
num_heads=cfg.num_attention_heads,
|