custom_llama / modeling_custom_llama.py
yuntaozh's picture
Upload custom_llama
ab192f6
from transformers import LlamaForCausalLM, LlamaPreTrainedModel, LlamaModel
import torch.nn as nn
from typing import Type, Optional, Tuple
from .configuration_custom_llama import MyLlamaConfig
def apply_to_all_named_modules(module: nn.Module, fn, parent_name: str = ""):
'''Recursively applies a function to all named modules in a PyTorch module.'''
# Recurse through children with their instance names
for name, child in module.named_children():
# Construct the full name path for the current module
full_name = parent_name + ("." if parent_name else "") + name
# Apply the function to the current module
fn(full_name, module, name, child)
# Recurse into the child module
apply_to_all_named_modules(child, fn, full_name)
def print_model_layers(model: nn.Module):
'''Recursively prints the variable names of all layers in a PyTorch model and their type.'''
apply_to_all_named_modules(
model,
lambda full_name, module, name, child: print(f"{full_name}: {child.__class__.__name__}")
)
def replace_module_by_class_and_name(module: Type[nn.Module],
target_class: str,
target_name: str,
replacement_class: Type[nn.Module],
other_init_args: Tuple = ()):
'''
替换类名为target_class, 实例名target_name为的模块
'''
# Lambda function used to replace the target module with the replacement module
def replace_module_by_class_and_name_fn(full_name, module, name, child):
# print(f"{full_name}: {child.__class__.__name__}")
# If the current module is of the target class, replace it
if name == target_name and child.__class__.__name__ == target_class:
print("Replacing: ", target_class, replacement_class)
# 用原本的attention层初始化
setattr(module, name, replacement_class(child, *other_init_args))
# Recursively apply the replacement function to all named modules
apply_to_all_named_modules(
module,
replace_module_by_class_and_name_fn,
)
class MyLinear(nn.Linear):
# 之所以要继承nn.Linear是因为tuners\ia3\model.py中的_create_new_module定死了只能接受target_base_layer为torch.nn.Linear
# 设置为(1,1)和out_features是为了(几乎)不增加参数量,因为虽然我们继承了nn.Linear,但根本不会去用其中的参数
def __init__(self, old_linear: nn.Linear, out_features):
super().__init__(1,1,bias=False)
self.linear = old_linear
self.rms_norm=nn.RMSNorm(out_features, eps=1e-6)
def forward(self, x):
return self.rms_norm(self.linear(x))
class CustomLlamaModel(LlamaModel):
config_class = MyLlamaConfig
def __init__(self, config):
super().__init__(config)
# Replace 'q_proj' and 'k_proj' layers with 'MyLinear'
replace_module_by_class_and_name(self.layers, 'Linear', 'q_proj', MyLinear, (2048,))
replace_module_by_class_and_name(self.layers, 'Linear', 'k_proj', MyLinear, (512,))
# Initialize weights and apply final processing
self.post_init()
def apply_custom_modifications(self):
def replace_module_by_class_and_name(module: nn.Module,
target_class: str,
target_name: str,
replacement_class: Type[nn.Module],
other_init_args: Tuple = ()):
def replace_module_by_class_and_name_fn(full_name, module, name, child):
if name == target_name and child.__class__.__name__ == target_class:
setattr(module, name, replacement_class(child, *other_init_args))
apply_to_all_named_modules(module, replace_module_by_class_and_name_fn)
class CustomLlamaForCausalLM(LlamaForCausalLM):
config_class = MyLlamaConfig
def __init__(self, config):
super().__init__(config)
self.model = CustomLlamaModel(config)
self.post_init()
def save_checkpoint(self, dir):
# to bypass the code line 2291 in transformers.trainer
pass