Shitao commited on
Commit
92fd472
·
verified ·
1 Parent(s): 9bc9a08

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. gemma_model.py +31 -4
gemma_model.py CHANGED
@@ -54,7 +54,7 @@ from transformers.utils import (
54
  from .gemma_config import CostWiseGemmaConfig
55
  from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2RotaryEmbedding, rotate_half, apply_rotary_pos_emb
56
  from transformers.models.gemma2.modeling_gemma2 import Gemma2MLP, repeat_kv, Gemma2Attention, Gemma2FlashAttention2, Gemma2SdpaAttention, GEMMA2_ATTENTION_CLASSES, Gemma2DecoderLayer, GEMMA2_START_DOCSTRING
57
- from transformers.models.gemma2.modeling_gemma2 import Gemma2PreTrainedModel, GEMMA2_INPUTS_DOCSTRING
58
 
59
  if is_flash_attn_2_available():
60
  from flash_attn import flash_attn_func, flash_attn_varlen_func
@@ -77,6 +77,33 @@ def _get_unpad_data(attention_mask):
77
  max_seqlen_in_batch,
78
  )
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  GEMMA2_ATTENTION_CLASSES = {
82
  "eager": Gemma2Attention,
@@ -213,7 +240,7 @@ def token_compress(compress_ratio,
213
  "The bare Gemma2 Model outputting raw hidden-states without any specific head on top.",
214
  GEMMA2_START_DOCSTRING,
215
  )
216
- class CostWiseGemmaModel(Gemma2PreTrainedModel):
217
  """
218
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`]
219
 
@@ -466,10 +493,10 @@ class CostWiseHead(nn.Module):
466
  return self.linear_head(**kwargs)
467
 
468
 
469
- class CostWiseGemmaForCausalLM(Gemma2PreTrainedModel):
470
  _tied_weights_keys = ["lm_head.weight"]
471
 
472
- def __init__(self, config):
473
  super().__init__(config)
474
  self.model = CostWiseGemmaModel(config)
475
  self.vocab_size = config.vocab_size
 
54
  from .gemma_config import CostWiseGemmaConfig
55
  from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2RotaryEmbedding, rotate_half, apply_rotary_pos_emb
56
  from transformers.models.gemma2.modeling_gemma2 import Gemma2MLP, repeat_kv, Gemma2Attention, Gemma2FlashAttention2, Gemma2SdpaAttention, GEMMA2_ATTENTION_CLASSES, Gemma2DecoderLayer, GEMMA2_START_DOCSTRING
57
+ from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING
58
 
59
  if is_flash_attn_2_available():
60
  from flash_attn import flash_attn_func, flash_attn_varlen_func
 
77
  max_seqlen_in_batch,
78
  )
79
 
80
+ @add_start_docstrings(
81
+ "The bare Gemma2 Model outputting raw hidden-states without any specific head on top.",
82
+ GEMMA2_START_DOCSTRING,
83
+ )
84
+ class CostWiseGemma2PreTrainedModel(PreTrainedModel):
85
+ config_class = CostWiseGemmaConfig
86
+ base_model_prefix = "model"
87
+ supports_gradient_checkpointing = True
88
+ _no_split_modules = ["Gemma2DecoderLayer"]
89
+ _skip_keys_device_placement = ["past_key_values"]
90
+ _supports_flash_attn_2 = True
91
+ _supports_sdpa = True
92
+ _supports_cache_class = False
93
+ _supports_quantized_cache = False
94
+ _supports_static_cache = True
95
+ _is_stateful = True
96
+
97
+ def _init_weights(self, module):
98
+ std = self.config.initializer_range
99
+ if isinstance(module, nn.Linear):
100
+ module.weight.data.normal_(mean=0.0, std=std)
101
+ if module.bias is not None:
102
+ module.bias.data.zero_()
103
+ elif isinstance(module, nn.Embedding):
104
+ module.weight.data.normal_(mean=0.0, std=std)
105
+ if module.padding_idx is not None:
106
+ module.weight.data[module.padding_idx].zero_()
107
 
108
  GEMMA2_ATTENTION_CLASSES = {
109
  "eager": Gemma2Attention,
 
240
  "The bare Gemma2 Model outputting raw hidden-states without any specific head on top.",
241
  GEMMA2_START_DOCSTRING,
242
  )
243
+ class CostWiseGemmaModel(CostWiseGemma2PreTrainedModel):
244
  """
245
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`]
246
 
 
493
  return self.linear_head(**kwargs)
494
 
495
 
496
+ class CostWiseGemmaForCausalLM(CostWiseGemma2PreTrainedModel):
497
  _tied_weights_keys = ["lm_head.weight"]
498
 
499
+ def __init__(self, config: CostWiseGemmaConfig):
500
  super().__init__(config)
501
  self.model = CostWiseGemmaModel(config)
502
  self.vocab_size = config.vocab_size