Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +14 -26
modeling_gemmoe.py
CHANGED
@@ -1112,28 +1112,6 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
1112 |
return causal_mask
|
1113 |
|
1114 |
class GemmoeForCausalLM(GemmoePreTrainedModel):
|
1115 |
-
r"""
|
1116 |
-
The Gemmoe Model transformer with a language modeling head on top for causal language modeling (CLM).
|
1117 |
-
|
1118 |
-
Args:
|
1119 |
-
config (GemmoeConfig): The configuration object for the Gemmoe model.
|
1120 |
-
|
1121 |
-
Example usage:
|
1122 |
-
```python
|
1123 |
-
>>> from transformers import AutoTokenizer, GemmoeForCausalLM
|
1124 |
-
|
1125 |
-
>>> model = GemmoeForCausalLM.from_pretrained("google/gemmoe-7b")
|
1126 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemmoe-7b")
|
1127 |
-
|
1128 |
-
>>> prompt = "What is your favorite condiment?"
|
1129 |
-
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
1130 |
-
|
1131 |
-
>>> # Generate
|
1132 |
-
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1133 |
-
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1134 |
-
"What is your favorite condiment?"
|
1135 |
-
```
|
1136 |
-
"""
|
1137 |
_tied_weights_keys = ["lm_head.weight"]
|
1138 |
|
1139 |
def __init__(self, config):
|
@@ -1237,10 +1215,11 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1237 |
|
1238 |
# Handle unused parameters
|
1239 |
if self.training:
|
1240 |
-
for
|
1241 |
-
for
|
1242 |
-
|
1243 |
-
param.grad
|
|
|
1244 |
|
1245 |
loss = None
|
1246 |
if labels is not None:
|
@@ -1349,6 +1328,15 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1349 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1350 |
)
|
1351 |
return reordered_past
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1352 |
@add_start_docstrings(
|
1353 |
"""
|
1354 |
The Gemmoe Model transformer with a sequence classification head on top (linear layer).
|
|
|
1112 |
return causal_mask
|
1113 |
|
1114 |
class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1115 |
_tied_weights_keys = ["lm_head.weight"]
|
1116 |
|
1117 |
def __init__(self, config):
|
|
|
1215 |
|
1216 |
# Handle unused parameters
|
1217 |
if self.training:
|
1218 |
+
for layer in self.model.layers:
|
1219 |
+
for expert in layer.block_sparse_moe.experts:
|
1220 |
+
for param in expert.parameters():
|
1221 |
+
if param.requires_grad and param.grad is None:
|
1222 |
+
param.grad = torch.zeros_like(param)
|
1223 |
|
1224 |
loss = None
|
1225 |
if labels is not None:
|
|
|
1328 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1329 |
)
|
1330 |
return reordered_past
|
1331 |
+
|
1332 |
+
@staticmethod
|
1333 |
+
def _reorder_cache(past_key_values, beam_idx):
|
1334 |
+
reordered_past = ()
|
1335 |
+
for layer_past in past_key_values:
|
1336 |
+
reordered_past += (
|
1337 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1338 |
+
)
|
1339 |
+
return reordered_past
|
1340 |
@add_start_docstrings(
|
1341 |
"""
|
1342 |
The Gemmoe Model transformer with a sequence classification head on top (linear layer).
|