liang.zhao commited on
Commit
681a0e3
·
1 Parent(s): 07466c6

update model and config

Browse files
config.json CHANGED
@@ -33,7 +33,7 @@
33
  "rms_norm_eps": 1e-06,
34
  "tie_word_embeddings": false,
35
  "torch_dtype": "bfloat16",
36
- "transformers_version": "4.34.0",
37
  "use_cache": true,
38
  "vocab_size": 65536
39
- }
 
33
  "rms_norm_eps": 1e-06,
34
  "tie_word_embeddings": false,
35
  "torch_dtype": "bfloat16",
36
+ "transformers_version": "4.33.1",
37
  "use_cache": true,
38
  "vocab_size": 65536
39
+ }
configuration_skywork.py CHANGED
@@ -1,13 +1,14 @@
1
  # Copyright (c) SkyworkAI and the HuggingFace Inc. team. All rights reserved.
2
  # This code is built upon Huggingface's transformers repository.
3
 
 
4
  from transformers.configuration_utils import PretrainedConfig
5
  from transformers.utils import logging
6
 
7
 
8
  logger = logging.get_logger(__name__)
9
 
10
- Skywork_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
11
 
12
 
13
  class SkyworkConfig(PretrainedConfig):
@@ -28,15 +29,13 @@ class SkyworkConfig(PretrainedConfig):
28
  initializer_range=0.02,
29
  rms_norm_eps=1e-6,
30
  use_cache=True,
31
- pad_token_id=0,
32
  bos_token_id=1,
33
  eos_token_id=2,
34
  pretraining_tp=1,
35
  tie_word_embeddings=False,
36
- rope_scaling=None,
37
  rope_theta=10000.0,
38
- attention_bias=False,
39
- use_flash_attention=False,
40
  **kwargs,
41
  ):
42
  self.vocab_size = vocab_size
@@ -56,16 +55,9 @@ class SkyworkConfig(PretrainedConfig):
56
  self.rms_norm_eps = rms_norm_eps
57
  self.pretraining_tp = pretraining_tp
58
  self.use_cache = use_cache
59
- self.rope_scaling = rope_scaling
60
  self.rope_theta = rope_theta
61
- self.attention_bias = attention_bias
62
- self.use_flash_attention = use_flash_attention
63
- if self.use_flash_attention:
64
- try:
65
- from flash_attn.flash_attn_interface import flash_attn_varlen_func
66
- from einops import rearrange
67
- except:
68
- raise ValueError("`use_flash_attention` requires Flash Attention 2+ and einops.\nTry `pip install einops` and installing Flash Attention from from https://github.com/Dao-AILab/flash-attention")
69
 
70
  super().__init__(
71
  pad_token_id=pad_token_id,
@@ -74,3 +66,24 @@ class SkyworkConfig(PretrainedConfig):
74
  tie_word_embeddings=tie_word_embeddings,
75
  **kwargs,
76
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Copyright (c) SkyworkAI and the HuggingFace Inc. team. All rights reserved.
2
  # This code is built upon Huggingface's transformers repository.
3
 
4
+
5
  from transformers.configuration_utils import PretrainedConfig
6
  from transformers.utils import logging
7
 
8
 
9
  logger = logging.get_logger(__name__)
10
 
11
+ LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
12
 
13
 
14
  class SkyworkConfig(PretrainedConfig):
 
29
  initializer_range=0.02,
30
  rms_norm_eps=1e-6,
31
  use_cache=True,
32
+ pad_token_id=None,
33
  bos_token_id=1,
34
  eos_token_id=2,
35
  pretraining_tp=1,
36
  tie_word_embeddings=False,
 
37
  rope_theta=10000.0,
38
+ rope_scaling=None,
 
39
  **kwargs,
40
  ):
41
  self.vocab_size = vocab_size
 
55
  self.rms_norm_eps = rms_norm_eps
56
  self.pretraining_tp = pretraining_tp
57
  self.use_cache = use_cache
 
58
  self.rope_theta = rope_theta
59
+ self.rope_scaling = rope_scaling
60
+ self._rope_scaling_validation()
 
 
 
 
 
 
61
 
62
  super().__init__(
63
  pad_token_id=pad_token_id,
 
66
  tie_word_embeddings=tie_word_embeddings,
67
  **kwargs,
68
  )
69
+
70
+ def _rope_scaling_validation(self):
71
+ """
72
+ Validate the `rope_scaling` configuration.
73
+ """
74
+ if self.rope_scaling is None:
75
+ return
76
+
77
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
78
+ raise ValueError(
79
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
80
+ f"got {self.rope_scaling}"
81
+ )
82
+ rope_scaling_type = self.rope_scaling.get("type", None)
83
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
84
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "ntk"]:
85
+ raise ValueError(
86
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
87
+ )
88
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
89
+ raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
generation_config.json CHANGED
@@ -6,5 +6,5 @@
6
  "pad_token_id": 0,
7
  "temperature": 0.6,
8
  "top_p": 0.9,
9
- "transformers_version": "4.34.0"
10
  }
 
6
  "pad_token_id": 0,
7
  "temperature": 0.6,
8
  "top_p": 0.9,
9
+ "transformers_version": "4.33.1"
10
  }
modeling_skywork.py CHANGED
@@ -1,5 +1,6 @@
1
  # Copyright (c) SkyworkAI and the HuggingFace Inc. team. All rights reserved.
2
  # This code is built upon Huggingface's transformers repository.
 
3
  import math
4
  from typing import List, Optional, Tuple, Union
5
 
@@ -12,39 +13,15 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
12
  from transformers.activations import ACT2FN
13
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
14
  from transformers.modeling_utils import PreTrainedModel
15
- from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
16
- from transformers.utils import (
17
- add_start_docstrings,
18
- add_start_docstrings_to_model_forward,
19
- is_flash_attn_available,
20
- logging,
21
- replace_return_docstrings,
22
- )
23
  from .configuration_skywork import SkyworkConfig
24
 
25
 
26
- if is_flash_attn_available():
27
- from flash_attn import flash_attn_func, flash_attn_varlen_func
28
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
29
-
30
-
31
  logger = logging.get_logger(__name__)
32
 
33
  _CONFIG_FOR_DOC = "SkyworkConfig"
34
 
35
 
36
- def _get_unpad_data(padding_mask):
37
- seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
38
- indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
39
- max_seqlen_in_batch = seqlens_in_batch.max().item()
40
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
41
- return (
42
- indices,
43
- cu_seqlens,
44
- max_seqlen_in_batch,
45
- )
46
-
47
-
48
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
49
  def _make_causal_mask(
50
  input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
@@ -95,10 +72,7 @@ class SkyworkRMSNorm(nn.Module):
95
  return self.weight * hidden_states.to(input_dtype)
96
 
97
 
98
- ALL_LAYERNORM_LAYERS.append(SkyworkRMSNorm)
99
-
100
-
101
- class SkyworkRotaryEmbedding(nn.Module):
102
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
103
  super().__init__()
104
 
@@ -120,8 +94,8 @@ class SkyworkRotaryEmbedding(nn.Module):
120
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
121
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
122
  emb = torch.cat((freqs, freqs), dim=-1)
123
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
124
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
125
 
126
  def forward(self, x, seq_len=None):
127
  # x: [bs, num_attention_heads, seq_len, head_size]
@@ -129,8 +103,8 @@ class SkyworkRotaryEmbedding(nn.Module):
129
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
130
 
131
  return (
132
- self.cos_cached[:seq_len].to(dtype=x.dtype),
133
- self.sin_cached[:seq_len].to(dtype=x.dtype),
134
  )
135
 
136
 
@@ -149,8 +123,8 @@ class SkyworkLinearScalingRotaryEmbedding(SkyworkRotaryEmbedding):
149
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
150
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
151
  emb = torch.cat((freqs, freqs), dim=-1)
152
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
153
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
154
 
155
 
156
  class SkyworkDynamicNTKScalingRotaryEmbedding(SkyworkRotaryEmbedding):
@@ -175,30 +149,42 @@ class SkyworkDynamicNTKScalingRotaryEmbedding(SkyworkRotaryEmbedding):
175
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
176
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
177
  emb = torch.cat((freqs, freqs), dim=-1)
178
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
179
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
180
 
181
 
182
- class SkyworkNTKScalingRotaryEmbedding(SkyworkRotaryEmbedding):
183
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
184
- self.scaling_factor = scaling_factor
185
- super().__init__(dim, max_position_embeddings, base, device)
186
-
187
- def _set_cos_sin_cache(self, seq_len, device, dtype):
188
- self.max_seq_len_cached = seq_len
189
 
190
- base = (self.base * self.scaling_factor) ** (self.dim / (self.dim - 2))
191
- inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
 
 
 
 
 
 
192
  self.register_buffer("inv_freq", inv_freq, persistent=False)
193
 
194
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
 
 
 
195
 
 
 
 
196
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
197
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
198
  emb = torch.cat((freqs, freqs), dim=-1)
199
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
200
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
201
 
 
 
 
 
 
 
 
 
202
 
203
  def rotate_half(x):
204
  """Rotates half the hidden dims of the input."""
@@ -207,10 +193,12 @@ def rotate_half(x):
207
  return torch.cat((-x2, x1), dim=-1)
208
 
209
 
210
- # Copied from transformers.models.gpt_neox.modeling_gpt_neox.apply_rotary_pos_emb
211
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
212
- cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
213
- sin = sin[position_ids].unsqueeze(1) #
 
 
 
214
  q_embed = (q * cos) + (rotate_half(q) * sin)
215
  k_embed = (k * cos) + (rotate_half(k) * sin)
216
  return q_embed, k_embed
@@ -281,10 +269,10 @@ class SkyworkAttention(nn.Module):
281
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
282
  f" and `num_heads`: {self.num_heads})."
283
  )
284
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
285
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
286
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
287
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
288
  self._init_rope()
289
 
290
  def _init_rope(self):
@@ -320,7 +308,9 @@ class SkyworkAttention(nn.Module):
320
  )
321
  else:
322
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
323
-
 
 
324
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
325
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
326
 
@@ -332,7 +322,6 @@ class SkyworkAttention(nn.Module):
332
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
333
  output_attentions: bool = False,
334
  use_cache: bool = False,
335
- padding_mask: Optional[torch.LongTensor] = None,
336
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
337
  bsz, q_len, _ = hidden_states.size()
338
 
@@ -375,6 +364,7 @@ class SkyworkAttention(nn.Module):
375
 
376
  past_key_value = (key_states, value_states) if use_cache else None
377
 
 
378
  key_states = repeat_kv(key_states, self.num_key_value_groups)
379
  value_states = repeat_kv(value_states, self.num_key_value_groups)
380
 
@@ -404,7 +394,6 @@ class SkyworkAttention(nn.Module):
404
  )
405
 
406
  attn_output = attn_output.transpose(1, 2).contiguous()
407
-
408
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
409
 
410
  if self.config.pretraining_tp > 1:
@@ -420,193 +409,11 @@ class SkyworkAttention(nn.Module):
420
  return attn_output, attn_weights, past_key_value
421
 
422
 
423
- class SkyworkFlashAttention2(SkyworkAttention):
424
- """
425
- Skywork flash attention module. This module inherits from `SkyworkAttention` as the weights of the module stays
426
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
427
- flash attention and deal with padding tokens in case the input contains any of them.
428
- """
429
-
430
- def forward(
431
- self,
432
- hidden_states: torch.Tensor,
433
- attention_mask: Optional[torch.Tensor] = None,
434
- position_ids: Optional[torch.LongTensor] = None,
435
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
436
- output_attentions: bool = False,
437
- use_cache: bool = False,
438
- padding_mask: Optional[torch.LongTensor] = None,
439
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
440
- # SkyworkFlashAttention2 attention does not support output_attentions
441
- output_attentions = False
442
-
443
- bsz, q_len, _ = hidden_states.size()
444
-
445
- query_states = self.q_proj(hidden_states)
446
- key_states = self.k_proj(hidden_states)
447
- value_states = self.v_proj(hidden_states)
448
-
449
- # Flash attention requires the input to have the shape
450
- # batch_size x seq_length x head_dime x hidden_dim
451
- # therefore we just need to keep the original shape
452
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
453
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
454
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
455
-
456
- kv_seq_len = key_states.shape[-2]
457
- if past_key_value is not None:
458
- kv_seq_len += past_key_value[0].shape[-2]
459
-
460
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
461
-
462
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
463
-
464
- if past_key_value is not None:
465
- # reuse k, v, self_attention
466
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
467
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
468
-
469
- past_key_value = (key_states, value_states) if use_cache else None
470
-
471
- query_states = query_states.transpose(1, 2)
472
- key_states = key_states.transpose(1, 2)
473
- value_states = value_states.transpose(1, 2)
474
-
475
- # TODO: skywork does not have dropout in the config??
476
- # It is recommended to use dropout with FA according to the docs
477
- # when training.
478
- dropout_rate = 0.0 # if not self.training else self.attn_dropout
479
-
480
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
481
- # therefore the input hidden states gets silently casted in float32. Hence, we need
482
- # cast them back in float16 just to be sure everything works as expected.
483
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
484
- # in fp32. (SkyworkRMSNorm handles it correctly)
485
- input_dtype = query_states.dtype
486
- if input_dtype == torch.float32:
487
- logger.warning_once(
488
- "The input hidden states seems to be silently casted in float32, this might be related to"
489
- " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
490
- " float16."
491
- )
492
-
493
- query_states = query_states.to(torch.float16)
494
- key_states = key_states.to(torch.float16)
495
- value_states = value_states.to(torch.float16)
496
-
497
- attn_output = self._flash_attention_forward(
498
- query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate
499
- )
500
-
501
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
502
- attn_output = self.o_proj(attn_output)
503
-
504
- if not output_attentions:
505
- attn_weights = None
506
-
507
- return attn_output, attn_weights, past_key_value
508
-
509
- def _flash_attention_forward(
510
- self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None
511
- ):
512
- """
513
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
514
- first unpad the input, then computes the attention scores and pad the final attention scores.
515
-
516
- Args:
517
- query_states (`torch.Tensor`):
518
- Input query states to be passed to Flash Attention API
519
- key_states (`torch.Tensor`):
520
- Input key states to be passed to Flash Attention API
521
- value_states (`torch.Tensor`):
522
- Input value states to be passed to Flash Attention API
523
- padding_mask (`torch.Tensor`):
524
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
525
- position of padding tokens and 1 for the position of non-padding tokens.
526
- dropout (`int`, *optional*):
527
- Attention dropout
528
- softmax_scale (`float`, *optional*):
529
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
530
- """
531
- # Contains at least one padding token in the sequence
532
- if padding_mask is not None:
533
- batch_size = query_states.shape[0]
534
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
535
- query_states, key_states, value_states, padding_mask, query_length
536
- )
537
-
538
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
539
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
540
-
541
- attn_output_unpad = flash_attn_varlen_func(
542
- query_states,
543
- key_states,
544
- value_states,
545
- cu_seqlens_q=cu_seqlens_q,
546
- cu_seqlens_k=cu_seqlens_k,
547
- max_seqlen_q=max_seqlen_in_batch_q,
548
- max_seqlen_k=max_seqlen_in_batch_k,
549
- dropout_p=dropout,
550
- softmax_scale=softmax_scale,
551
- causal=True,
552
- )
553
-
554
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
555
- else:
556
- attn_output = flash_attn_func(
557
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
558
- )
559
-
560
- return attn_output
561
-
562
- def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length):
563
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
564
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
565
-
566
- key_layer = index_first_axis(
567
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
568
- )
569
- value_layer = index_first_axis(
570
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
571
- )
572
- if query_length == kv_seq_len:
573
- query_layer = index_first_axis(
574
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
575
- )
576
- cu_seqlens_q = cu_seqlens_k
577
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
578
- indices_q = indices_k
579
- elif query_length == 1:
580
- max_seqlen_in_batch_q = 1
581
- cu_seqlens_q = torch.arange(
582
- batch_size + 1, dtype=torch.int32, device=query_layer.device
583
- ) # There is a memcpy here, that is very bad.
584
- indices_q = cu_seqlens_q[:-1]
585
- query_layer = query_layer.squeeze(1)
586
- else:
587
- # The -q_len: slice assumes left padding.
588
- padding_mask = padding_mask[:, -query_length:]
589
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask)
590
-
591
- return (
592
- query_layer,
593
- key_layer,
594
- value_layer,
595
- indices_q,
596
- (cu_seqlens_q, cu_seqlens_k),
597
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
598
- )
599
-
600
-
601
  class SkyworkDecoderLayer(nn.Module):
602
  def __init__(self, config: SkyworkConfig):
603
  super().__init__()
604
  self.hidden_size = config.hidden_size
605
- self.self_attn = (
606
- SkyworkAttention(config=config)
607
- if not getattr(config, "_flash_attn_2_enabled", False)
608
- else SkyworkFlashAttention2(config=config)
609
- )
610
  self.mlp = SkyworkMLP(config)
611
  self.input_layernorm = SkyworkRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
612
  self.post_attention_layernorm = SkyworkRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -619,7 +426,6 @@ class SkyworkDecoderLayer(nn.Module):
619
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
620
  output_attentions: Optional[bool] = False,
621
  use_cache: Optional[bool] = False,
622
- padding_mask: Optional[torch.LongTensor] = None,
623
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
624
  """
625
  Args:
@@ -647,7 +453,6 @@ class SkyworkDecoderLayer(nn.Module):
647
  past_key_value=past_key_value,
648
  output_attentions=output_attentions,
649
  use_cache=use_cache,
650
- padding_mask=padding_mask,
651
  )
652
  hidden_states = residual + hidden_states
653
 
@@ -673,7 +478,6 @@ class SkyworkPreTrainedModel(PreTrainedModel):
673
  supports_gradient_checkpointing = True
674
  _no_split_modules = ["SkyworkDecoderLayer"]
675
  _skip_keys_device_placement = "past_key_values"
676
- _supports_flash_attn_2 = True
677
 
678
  def _init_weights(self, module):
679
  std = self.config.initializer_range
@@ -763,13 +567,13 @@ class SkyworkModel(SkyworkPreTrainedModel):
763
 
764
  # retrieve input_ids and inputs_embeds
765
  if input_ids is not None and inputs_embeds is not None:
766
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
767
  elif input_ids is not None:
768
  batch_size, seq_length = input_ids.shape
769
  elif inputs_embeds is not None:
770
  batch_size, seq_length, _ = inputs_embeds.shape
771
  else:
772
- raise ValueError("You have to specify either input_ids or inputs_embeds")
773
 
774
  seq_length_with_past = seq_length
775
  past_key_values_length = 0
@@ -783,7 +587,9 @@ class SkyworkModel(SkyworkPreTrainedModel):
783
  position_ids = torch.arange(
784
  past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
785
  )
786
- position_ids = position_ids.unsqueeze(0)
 
 
787
 
788
  if inputs_embeds is None:
789
  inputs_embeds = self.embed_tokens(input_ids)
@@ -792,13 +598,6 @@ class SkyworkModel(SkyworkPreTrainedModel):
792
  attention_mask = torch.ones(
793
  (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
794
  )
795
- padding_mask = None
796
- else:
797
- if 0 in attention_mask:
798
- padding_mask = attention_mask
799
- else:
800
- padding_mask = None
801
-
802
  attention_mask = self._prepare_decoder_attention_mask(
803
  attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
804
  )
@@ -828,12 +627,15 @@ class SkyworkModel(SkyworkPreTrainedModel):
828
  def create_custom_forward(module):
829
  def custom_forward(*inputs):
830
  # None for past_key_value
831
- return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask)
832
 
833
  return custom_forward
834
 
835
  layer_outputs = torch.utils.checkpoint.checkpoint(
836
- create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids
 
 
 
837
  )
838
  else:
839
  layer_outputs = decoder_layer(
@@ -843,7 +645,6 @@ class SkyworkModel(SkyworkPreTrainedModel):
843
  past_key_value=past_key_value,
844
  output_attentions=output_attentions,
845
  use_cache=use_cache,
846
- padding_mask=padding_mask,
847
  )
848
 
849
  hidden_states = layer_outputs[0]
@@ -901,7 +702,6 @@ class SkyworkForCausalLM(SkyworkPreTrainedModel):
901
  def get_decoder(self):
902
  return self.model
903
 
904
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
905
  def forward(
906
  self,
907
  input_ids: torch.LongTensor = None,
@@ -915,31 +715,6 @@ class SkyworkForCausalLM(SkyworkPreTrainedModel):
915
  output_hidden_states: Optional[bool] = None,
916
  return_dict: Optional[bool] = None,
917
  ) -> Union[Tuple, CausalLMOutputWithPast]:
918
- r"""
919
- Args:
920
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
921
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
922
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
923
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
924
-
925
- Returns:
926
-
927
- Example:
928
-
929
- ```python
930
- >>> from transformers import AutoTokenizer, SkyworkForCausalLM
931
-
932
- >>> model = SkyworkForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
933
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
934
-
935
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
936
- >>> inputs = tokenizer(prompt, return_tensors="pt")
937
-
938
- >>> # Generate
939
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
940
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
941
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
942
- ```"""
943
 
944
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
945
  output_hidden_states = (
@@ -1033,6 +808,7 @@ class SkyworkForCausalLM(SkyworkPreTrainedModel):
1033
  )
1034
  return reordered_past
1035
 
 
1036
  class SkyworkForSequenceClassification(SkyworkPreTrainedModel):
1037
  def __init__(self, config):
1038
  super().__init__(config)
@@ -1062,12 +838,8 @@ class SkyworkForSequenceClassification(SkyworkPreTrainedModel):
1062
  output_hidden_states: Optional[bool] = None,
1063
  return_dict: Optional[bool] = None,
1064
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1065
- r"""
1066
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1067
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1068
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1069
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1070
- """
1071
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1072
 
1073
  transformer_outputs = self.model(
@@ -1136,4 +908,4 @@ class SkyworkForSequenceClassification(SkyworkPreTrainedModel):
1136
  past_key_values=transformer_outputs.past_key_values,
1137
  hidden_states=transformer_outputs.hidden_states,
1138
  attentions=transformer_outputs.attentions,
1139
- )
 
1
  # Copyright (c) SkyworkAI and the HuggingFace Inc. team. All rights reserved.
2
  # This code is built upon Huggingface's transformers repository.
3
+
4
  import math
5
  from typing import List, Optional, Tuple, Union
6
 
 
13
  from transformers.activations import ACT2FN
14
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
15
  from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.utils import logging
 
 
 
 
 
 
 
17
  from .configuration_skywork import SkyworkConfig
18
 
19
 
 
 
 
 
 
20
  logger = logging.get_logger(__name__)
21
 
22
  _CONFIG_FOR_DOC = "SkyworkConfig"
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
26
  def _make_causal_mask(
27
  input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
 
72
  return self.weight * hidden_states.to(input_dtype)
73
 
74
 
75
+ class SkyworkRotaryEmbedding(torch.nn.Module):
 
 
 
76
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
77
  super().__init__()
78
 
 
94
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
95
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
96
  emb = torch.cat((freqs, freqs), dim=-1)
97
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
98
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
99
 
100
  def forward(self, x, seq_len=None):
101
  # x: [bs, num_attention_heads, seq_len, head_size]
 
103
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
104
 
105
  return (
106
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
107
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
108
  )
109
 
110
 
 
123
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
124
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
125
  emb = torch.cat((freqs, freqs), dim=-1)
126
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
127
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
128
 
129
 
130
  class SkyworkDynamicNTKScalingRotaryEmbedding(SkyworkRotaryEmbedding):
 
149
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
150
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
151
  emb = torch.cat((freqs, freqs), dim=-1)
152
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
153
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
154
 
155
 
 
 
 
 
 
 
 
156
 
157
+ class SkyworkNTKScalingRotaryEmbedding(torch.nn.Module):
158
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=100, device=None):
159
+ super().__init__()
160
+
161
+ self.dim = dim
162
+ self.max_position_embeddings = max_position_embeddings
163
+ self.base = base * scaling_factor
164
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
165
  self.register_buffer("inv_freq", inv_freq, persistent=False)
166
 
167
+ # Build here to make `torch.jit.trace` work.
168
+ self._set_cos_sin_cache(
169
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
170
+ )
171
 
172
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
173
+ self.max_seq_len_cached = seq_len
174
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
175
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
 
176
  emb = torch.cat((freqs, freqs), dim=-1)
177
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
178
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
179
 
180
+ def forward(self, x, seq_len=None):
181
+ if seq_len > self.max_seq_len_cached:
182
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
183
+
184
+ return (
185
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
186
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
187
+ )
188
 
189
  def rotate_half(x):
190
  """Rotates half the hidden dims of the input."""
 
193
  return torch.cat((-x2, x1), dim=-1)
194
 
195
 
 
196
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
197
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
198
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
199
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
200
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
201
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
202
  q_embed = (q * cos) + (rotate_half(q) * sin)
203
  k_embed = (k * cos) + (rotate_half(k) * sin)
204
  return q_embed, k_embed
 
269
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
270
  f" and `num_heads`: {self.num_heads})."
271
  )
272
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
273
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
274
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
275
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
276
  self._init_rope()
277
 
278
  def _init_rope(self):
 
308
  )
309
  else:
310
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
311
+ print('-'*80)
312
+ print(f"USING COSTOM MODELING, scaling_type is {scaling_type}, scaling_factor is {scaling_factor}")
313
+
314
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
315
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
316
 
 
322
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
323
  output_attentions: bool = False,
324
  use_cache: bool = False,
 
325
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
326
  bsz, q_len, _ = hidden_states.size()
327
 
 
364
 
365
  past_key_value = (key_states, value_states) if use_cache else None
366
 
367
+ # repeat k/v heads if n_kv_heads < n_heads
368
  key_states = repeat_kv(key_states, self.num_key_value_groups)
369
  value_states = repeat_kv(value_states, self.num_key_value_groups)
370
 
 
394
  )
395
 
396
  attn_output = attn_output.transpose(1, 2).contiguous()
 
397
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
398
 
399
  if self.config.pretraining_tp > 1:
 
409
  return attn_output, attn_weights, past_key_value
410
 
411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  class SkyworkDecoderLayer(nn.Module):
413
  def __init__(self, config: SkyworkConfig):
414
  super().__init__()
415
  self.hidden_size = config.hidden_size
416
+ self.self_attn = SkyworkAttention(config=config)
 
 
 
 
417
  self.mlp = SkyworkMLP(config)
418
  self.input_layernorm = SkyworkRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
419
  self.post_attention_layernorm = SkyworkRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
426
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
427
  output_attentions: Optional[bool] = False,
428
  use_cache: Optional[bool] = False,
 
429
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
430
  """
431
  Args:
 
453
  past_key_value=past_key_value,
454
  output_attentions=output_attentions,
455
  use_cache=use_cache,
 
456
  )
457
  hidden_states = residual + hidden_states
458
 
 
478
  supports_gradient_checkpointing = True
479
  _no_split_modules = ["SkyworkDecoderLayer"]
480
  _skip_keys_device_placement = "past_key_values"
 
481
 
482
  def _init_weights(self, module):
483
  std = self.config.initializer_range
 
567
 
568
  # retrieve input_ids and inputs_embeds
569
  if input_ids is not None and inputs_embeds is not None:
570
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
571
  elif input_ids is not None:
572
  batch_size, seq_length = input_ids.shape
573
  elif inputs_embeds is not None:
574
  batch_size, seq_length, _ = inputs_embeds.shape
575
  else:
576
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
577
 
578
  seq_length_with_past = seq_length
579
  past_key_values_length = 0
 
587
  position_ids = torch.arange(
588
  past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
589
  )
590
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
591
+ else:
592
+ position_ids = position_ids.view(-1, seq_length).long()
593
 
594
  if inputs_embeds is None:
595
  inputs_embeds = self.embed_tokens(input_ids)
 
598
  attention_mask = torch.ones(
599
  (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
600
  )
 
 
 
 
 
 
 
601
  attention_mask = self._prepare_decoder_attention_mask(
602
  attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
603
  )
 
627
  def create_custom_forward(module):
628
  def custom_forward(*inputs):
629
  # None for past_key_value
630
+ return module(*inputs, past_key_value, output_attentions)
631
 
632
  return custom_forward
633
 
634
  layer_outputs = torch.utils.checkpoint.checkpoint(
635
+ create_custom_forward(decoder_layer),
636
+ hidden_states,
637
+ attention_mask,
638
+ position_ids,
639
  )
640
  else:
641
  layer_outputs = decoder_layer(
 
645
  past_key_value=past_key_value,
646
  output_attentions=output_attentions,
647
  use_cache=use_cache,
 
648
  )
649
 
650
  hidden_states = layer_outputs[0]
 
702
  def get_decoder(self):
703
  return self.model
704
 
 
705
  def forward(
706
  self,
707
  input_ids: torch.LongTensor = None,
 
715
  output_hidden_states: Optional[bool] = None,
716
  return_dict: Optional[bool] = None,
717
  ) -> Union[Tuple, CausalLMOutputWithPast]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
718
 
719
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
720
  output_hidden_states = (
 
808
  )
809
  return reordered_past
810
 
811
+
812
  class SkyworkForSequenceClassification(SkyworkPreTrainedModel):
813
  def __init__(self, config):
814
  super().__init__(config)
 
838
  output_hidden_states: Optional[bool] = None,
839
  return_dict: Optional[bool] = None,
840
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
841
+
842
+
 
 
 
 
843
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
844
 
845
  transformer_outputs = self.model(
 
908
  past_key_values=transformer_outputs.past_key_values,
909
  hidden_states=transformer_outputs.hidden_states,
910
  attentions=transformer_outputs.attentions,
911
+ )
tokenization_skywork.py CHANGED
@@ -1,22 +1,5 @@
1
- # coding=utf-8
2
- # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
 
21
  """Tokenization classes for Skywork."""
22
  import os
 
1
+ # Copyright (c) SkyworkAI and the HuggingFace Inc. team. All rights reserved.
2
+ # This code is built upon Huggingface's transformers repository.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  """Tokenization classes for Skywork."""
5
  import os