Crystalcareai commited on
Commit
198cba7
·
verified ·
1 Parent(s): 5f967cc

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +14 -26
modeling_gemmoe.py CHANGED
@@ -1112,28 +1112,6 @@ class GemmoeModel(GemmoePreTrainedModel):
1112
  return causal_mask
1113
 
1114
  class GemmoeForCausalLM(GemmoePreTrainedModel):
1115
- r"""
1116
- The Gemmoe Model transformer with a language modeling head on top for causal language modeling (CLM).
1117
-
1118
- Args:
1119
- config (GemmoeConfig): The configuration object for the Gemmoe model.
1120
-
1121
- Example usage:
1122
- ```python
1123
- >>> from transformers import AutoTokenizer, GemmoeForCausalLM
1124
-
1125
- >>> model = GemmoeForCausalLM.from_pretrained("google/gemmoe-7b")
1126
- >>> tokenizer = AutoTokenizer.from_pretrained("google/gemmoe-7b")
1127
-
1128
- >>> prompt = "What is your favorite condiment?"
1129
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1130
-
1131
- >>> # Generate
1132
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1133
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1134
- "What is your favorite condiment?"
1135
- ```
1136
- """
1137
  _tied_weights_keys = ["lm_head.weight"]
1138
 
1139
  def __init__(self, config):
@@ -1237,10 +1215,11 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1237
 
1238
  # Handle unused parameters
1239
  if self.training:
1240
- for expert in self.model.layers[-1].block_sparse_moe.experts:
1241
- for param in expert.parameters():
1242
- if param.requires_grad and param.grad is None:
1243
- param.grad = torch.zeros_like(param)
 
1244
 
1245
  loss = None
1246
  if labels is not None:
@@ -1349,6 +1328,15 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1349
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1350
  )
1351
  return reordered_past
 
 
 
 
 
 
 
 
 
1352
  @add_start_docstrings(
1353
  """
1354
  The Gemmoe Model transformer with a sequence classification head on top (linear layer).
 
1112
  return causal_mask
1113
 
1114
  class GemmoeForCausalLM(GemmoePreTrainedModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1115
  _tied_weights_keys = ["lm_head.weight"]
1116
 
1117
  def __init__(self, config):
 
1215
 
1216
  # Handle unused parameters
1217
  if self.training:
1218
+ for layer in self.model.layers:
1219
+ for expert in layer.block_sparse_moe.experts:
1220
+ for param in expert.parameters():
1221
+ if param.requires_grad and param.grad is None:
1222
+ param.grad = torch.zeros_like(param)
1223
 
1224
  loss = None
1225
  if labels is not None:
 
1328
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1329
  )
1330
  return reordered_past
1331
+
1332
+ @staticmethod
1333
+ def _reorder_cache(past_key_values, beam_idx):
1334
+ reordered_past = ()
1335
+ for layer_past in past_key_values:
1336
+ reordered_past += (
1337
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1338
+ )
1339
+ return reordered_past
1340
  @add_start_docstrings(
1341
  """
1342
  The Gemmoe Model transformer with a sequence classification head on top (linear layer).