yuntaozh commited on
Commit
ab192f6
·
1 Parent(s): 11b63b5

Upload custom_llama

Browse files
__init__.py ADDED
File without changes
config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:763a8005c449434c842c55dd816386dd424f819cc65111c33e121ec5198bfa45
3
+ size 1107
configuration_custom_llama.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from transformers.models.llama.configuration_llama import LlamaConfig
2
+
3
+ class MyLlamaConfig(LlamaConfig):
4
+ model_type = "custom_llama"
5
+ def __init__(self, **kwargs):
6
+ super().__init__(**kwargs)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5a338759f688af91eb687b7facdd90e7668acc77f4d99a49ca31e81ce3433ee
3
+ size 4943444632
modeling_custom_llama.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LlamaForCausalLM, LlamaPreTrainedModel, LlamaModel
2
+ import torch.nn as nn
3
+ from typing import Type, Optional, Tuple
4
+ from .configuration_custom_llama import MyLlamaConfig
5
+
6
+ def apply_to_all_named_modules(module: nn.Module, fn, parent_name: str = ""):
7
+ '''Recursively applies a function to all named modules in a PyTorch module.'''
8
+ # Recurse through children with their instance names
9
+ for name, child in module.named_children():
10
+ # Construct the full name path for the current module
11
+ full_name = parent_name + ("." if parent_name else "") + name
12
+ # Apply the function to the current module
13
+ fn(full_name, module, name, child)
14
+ # Recurse into the child module
15
+ apply_to_all_named_modules(child, fn, full_name)
16
+
17
+ def print_model_layers(model: nn.Module):
18
+ '''Recursively prints the variable names of all layers in a PyTorch model and their type.'''
19
+ apply_to_all_named_modules(
20
+ model,
21
+ lambda full_name, module, name, child: print(f"{full_name}: {child.__class__.__name__}")
22
+ )
23
+
24
+ def replace_module_by_class_and_name(module: Type[nn.Module],
25
+ target_class: str,
26
+ target_name: str,
27
+ replacement_class: Type[nn.Module],
28
+ other_init_args: Tuple = ()):
29
+ '''
30
+ 替换类名为target_class, 实例名target_name为的模块
31
+ '''
32
+
33
+ # Lambda function used to replace the target module with the replacement module
34
+ def replace_module_by_class_and_name_fn(full_name, module, name, child):
35
+ # print(f"{full_name}: {child.__class__.__name__}")
36
+ # If the current module is of the target class, replace it
37
+ if name == target_name and child.__class__.__name__ == target_class:
38
+ print("Replacing: ", target_class, replacement_class)
39
+ # 用原本的attention层初始化
40
+ setattr(module, name, replacement_class(child, *other_init_args))
41
+
42
+ # Recursively apply the replacement function to all named modules
43
+ apply_to_all_named_modules(
44
+ module,
45
+ replace_module_by_class_and_name_fn,
46
+ )
47
+
48
+ class MyLinear(nn.Linear):
49
+ # 之所以要继承nn.Linear是因为tuners\ia3\model.py中的_create_new_module定死了只能接受target_base_layer为torch.nn.Linear
50
+ # 设置为(1,1)和out_features是为了(几乎)不增加参数量,因为虽然我们继承了nn.Linear,但根本不会去用其中的参数
51
+ def __init__(self, old_linear: nn.Linear, out_features):
52
+ super().__init__(1,1,bias=False)
53
+ self.linear = old_linear
54
+ self.rms_norm=nn.RMSNorm(out_features, eps=1e-6)
55
+
56
+ def forward(self, x):
57
+ return self.rms_norm(self.linear(x))
58
+
59
+ class CustomLlamaModel(LlamaModel):
60
+ config_class = MyLlamaConfig
61
+
62
+ def __init__(self, config):
63
+ super().__init__(config)
64
+ # Replace 'q_proj' and 'k_proj' layers with 'MyLinear'
65
+ replace_module_by_class_and_name(self.layers, 'Linear', 'q_proj', MyLinear, (2048,))
66
+ replace_module_by_class_and_name(self.layers, 'Linear', 'k_proj', MyLinear, (512,))
67
+ # Initialize weights and apply final processing
68
+ self.post_init()
69
+
70
+ def apply_custom_modifications(self):
71
+ def replace_module_by_class_and_name(module: nn.Module,
72
+ target_class: str,
73
+ target_name: str,
74
+ replacement_class: Type[nn.Module],
75
+ other_init_args: Tuple = ()):
76
+ def replace_module_by_class_and_name_fn(full_name, module, name, child):
77
+ if name == target_name and child.__class__.__name__ == target_class:
78
+ setattr(module, name, replacement_class(child, *other_init_args))
79
+ apply_to_all_named_modules(module, replace_module_by_class_and_name_fn)
80
+
81
+ class CustomLlamaForCausalLM(LlamaForCausalLM):
82
+ config_class = MyLlamaConfig
83
+
84
+ def __init__(self, config):
85
+ super().__init__(config)
86
+ self.model = CustomLlamaModel(config)
87
+ self.post_init()
88
+
89
+ def save_checkpoint(self, dir):
90
+ # to bypass the code line 2291 in transformers.trainer
91
+ pass
92
+
93
+
94
+
95
+
96
+
special_tokens_map.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc2e013b7545f183ef03e079a3c91c6f364fa37e4068c512d7dd843e59024535
3
+ size 301
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b9e4e7fb171f92fd137b777cc2714bf87d11576700a1dcd7a399e7bbe39537b
3
+ size 17209920
tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8004530facf809ac432114de2a4dcc65fcb632da5ec16d666091aeb6a2ee444a
3
+ size 50500