|
import torch |
|
from peft import PeftModel |
|
import os |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig |
|
import argparse |
|
from utils import get_logger |
|
import json |
|
|
|
logger = get_logger("merge", "info") |
|
|
|
def smart_tokenizer_and_embedding_resize(tokenizer, model, custom_tokens_path=None): |
|
"""Resize tokenizer and embedding to accommodate new tokens.""" |
|
special_tokens_dict = { |
|
"pad_token": "[PAD]", |
|
"eos_token": "</s>", |
|
"bos_token": "<s>", |
|
"unk_token": "<unk>" |
|
} |
|
|
|
|
|
custom_tokens = [] |
|
if custom_tokens_path is not None: |
|
with open(custom_tokens_path, 'r') as file: |
|
custom_tokens = [line.strip() for line in file.readlines()] |
|
|
|
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) |
|
if custom_tokens: |
|
num_added_toks += tokenizer.add_tokens(custom_tokens, special_tokens=True) |
|
|
|
model.resize_token_embeddings(len(tokenizer)) |
|
logger.info(f"Resized tokenizer and model embeddings. Added {num_added_toks} tokens.") |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("-bm", "--base_model", type=str, default="meta-llama/Llama-2-7b-chat-hf", help="Base model name or path") |
|
parser.add_argument("-lm", "--lora_model", type=str, required=True, help="Path to the Lora model directory") |
|
parser.add_argument("-o", "--output", type=str, required=True, help="Output directory for the merged model") |
|
parser.add_argument("--custom_tokens", type=str, default=None, help="Path to a file containing custom tokens") |
|
args = parser.parse_args() |
|
|
|
if not os.path.exists(args.lora_model): |
|
raise FileNotFoundError(f"LoRA model directory {args.lora_model} not found.") |
|
|
|
os.makedirs(args.output, exist_ok=True) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(args.base_model) |
|
tokenizer = AutoTokenizer.from_pretrained(args.base_model) |
|
|
|
|
|
smart_tokenizer_and_embedding_resize(tokenizer, model, args.custom_tokens) |
|
|
|
|
|
logger.info("Loading and merging the LoRA model...") |
|
lora_model = PeftModel.from_pretrained(model, args.lora_model, merge_with_base=True) |
|
|
|
|
|
lora_model.save_pretrained(args.output) |
|
tokenizer.save_pretrained(args.output) |
|
logger.info(f"Merged model saved to {args.output}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|