hyx21 commited on
Commit
f4a3ba4
·
verified ·
1 Parent(s): 98ea6c0

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_minicpm.py +5 -0
  2. modeling_minicpm.py +4 -4
configuration_minicpm.py CHANGED
@@ -174,6 +174,11 @@ class MiniCPMConfig(PretrainedConfig):
174
  tie_word_embeddings=tie_word_embeddings,
175
  **kwargs,
176
  )
 
 
 
 
 
177
 
178
  def _rope_scaling_validation(self):
179
  """
 
174
  tie_word_embeddings=tie_word_embeddings,
175
  **kwargs,
176
  )
177
+ try:
178
+ import flash_attn
179
+ self._attn_implementation = "flash_attention_2"
180
+ except:
181
+ pass
182
 
183
  def _rope_scaling_validation(self):
184
  """
modeling_minicpm.py CHANGED
@@ -51,10 +51,11 @@ from transformers.utils.import_utils import is_torch_fx_available
51
  from .configuration_minicpm import MiniCPMConfig
52
  import re
53
 
54
-
55
- if is_flash_attn_2_available():
56
  from flash_attn import flash_attn_func, flash_attn_varlen_func
57
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
 
58
 
59
 
60
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
@@ -125,7 +126,7 @@ ALL_LAYERNORM_LAYERS.append(MiniCPMRMSNorm)
125
 
126
 
127
  class MiniCPMRotaryEmbedding(nn.Module):
128
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device="cuda"):
129
  super().__init__()
130
 
131
  self.dim = dim
@@ -763,7 +764,6 @@ class MiniCPMDecoderLayer(nn.Module):
763
  def __init__(self, config: MiniCPMConfig, layer_idx: int):
764
  super().__init__()
765
  self.hidden_size = config.hidden_size
766
-
767
  self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
768
 
769
  self.mlp = MiniCPMMLP(config)
 
51
  from .configuration_minicpm import MiniCPMConfig
52
  import re
53
 
54
+ try:
 
55
  from flash_attn import flash_attn_func, flash_attn_varlen_func
56
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
57
+ except:
58
+ pass
59
 
60
 
61
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
 
126
 
127
 
128
  class MiniCPMRotaryEmbedding(nn.Module):
129
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
130
  super().__init__()
131
 
132
  self.dim = dim
 
764
  def __init__(self, config: MiniCPMConfig, layer_idx: int):
765
  super().__init__()
766
  self.hidden_size = config.hidden_size
 
767
  self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
768
 
769
  self.mlp = MiniCPMMLP(config)