|
from transformers import LlamaForCausalLM, LlamaPreTrainedModel, LlamaModel
|
|
import torch.nn as nn
|
|
from typing import Type, Optional, Tuple
|
|
from .configuration_custom_llama import MyLlamaConfig
|
|
|
|
def apply_to_all_named_modules(module: nn.Module, fn, parent_name: str = ""):
|
|
'''Recursively applies a function to all named modules in a PyTorch module.'''
|
|
|
|
for name, child in module.named_children():
|
|
|
|
full_name = parent_name + ("." if parent_name else "") + name
|
|
|
|
fn(full_name, module, name, child)
|
|
|
|
apply_to_all_named_modules(child, fn, full_name)
|
|
|
|
def print_model_layers(model: nn.Module):
|
|
'''Recursively prints the variable names of all layers in a PyTorch model and their type.'''
|
|
apply_to_all_named_modules(
|
|
model,
|
|
lambda full_name, module, name, child: print(f"{full_name}: {child.__class__.__name__}")
|
|
)
|
|
|
|
def replace_module_by_class_and_name(module: Type[nn.Module],
|
|
target_class: str,
|
|
target_name: str,
|
|
replacement_class: Type[nn.Module],
|
|
other_init_args: Tuple = ()):
|
|
'''
|
|
替换类名为target_class, 实例名target_name为的模块
|
|
'''
|
|
|
|
|
|
def replace_module_by_class_and_name_fn(full_name, module, name, child):
|
|
|
|
|
|
if name == target_name and child.__class__.__name__ == target_class:
|
|
print("Replacing: ", target_class, replacement_class)
|
|
|
|
setattr(module, name, replacement_class(child, *other_init_args))
|
|
|
|
|
|
apply_to_all_named_modules(
|
|
module,
|
|
replace_module_by_class_and_name_fn,
|
|
)
|
|
|
|
class MyLinear(nn.Linear):
|
|
|
|
|
|
def __init__(self, old_linear: nn.Linear, out_features):
|
|
super().__init__(1,1,bias=False)
|
|
self.linear = old_linear
|
|
self.rms_norm=nn.RMSNorm(out_features, eps=1e-6)
|
|
|
|
def forward(self, x):
|
|
return self.rms_norm(self.linear(x))
|
|
|
|
class CustomLlamaModel(LlamaModel):
|
|
config_class = MyLlamaConfig
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
replace_module_by_class_and_name(self.layers, 'Linear', 'q_proj', MyLinear, (2048,))
|
|
replace_module_by_class_and_name(self.layers, 'Linear', 'k_proj', MyLinear, (512,))
|
|
|
|
self.post_init()
|
|
|
|
def apply_custom_modifications(self):
|
|
def replace_module_by_class_and_name(module: nn.Module,
|
|
target_class: str,
|
|
target_name: str,
|
|
replacement_class: Type[nn.Module],
|
|
other_init_args: Tuple = ()):
|
|
def replace_module_by_class_and_name_fn(full_name, module, name, child):
|
|
if name == target_name and child.__class__.__name__ == target_class:
|
|
setattr(module, name, replacement_class(child, *other_init_args))
|
|
apply_to_all_named_modules(module, replace_module_by_class_and_name_fn)
|
|
|
|
class CustomLlamaForCausalLM(LlamaForCausalLM):
|
|
config_class = MyLlamaConfig
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.model = CustomLlamaModel(config)
|
|
self.post_init()
|
|
|
|
def save_checkpoint(self, dir):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|