Conversion script
May I ask if there is some kind of script you used to convert the model into msgpack format?
If so, could you provide it so we can also convert models like Llama3
I have a ~60 line Python script that does its best to match the weights from the PyTorch version, then confirms output agreement with the Pytorch version, then saves the model in msgpack.
I'll try to get back to you in a day or two with the code itself.
That would be highly appreciated. Thanks!
Hi, here's the conversation script, let me know if you run into any problems:\
from __future__ import annotations
from typing import Any
from jax import Array, tree, dlpack
from torch import Tensor
def _get_key(
key: str, params: dict[str, Any], delimiter: str = "."
) -> tuple[str, Any] | Array | None:
"""Recursively try to find a key in a nested dictionary."""
if params is None:
return None
if delimiter in key:
first, rest = key.split(delimiter, 1)
return _get_key(rest, params[first]) if first in params else None
else:
return params[key] if key in params else (key, params)
def _set_key(key: str, params: dict[str, Any], weight: Array, delimiter: str = ".") -> None:
"""Recursively set a key in a nested dictionary."""
if delimiter in key:
first, rest = key.split(delimiter, 1)
return _set_key(rest, params[first], weight)
else:
params[key] = weight
def match_params(
jax_params: dict[str, Any],
torch_named_parameters: dict[str, Tensor],
dtype=None,
transpose_square_weights: bool = True,
verbose: bool = False,
warn: bool = True,
) -> dict[str, Any]:
"""Match parameters of a PyTorch model into a JAX model.
Args:
jax_params (dict[str, Any]): JAX Flax parameter dictionary.
torch_named_parameters (dict[str, Tensor]): PyTorch model.named_parameters()
dtype (_type_, optional): _description_. Defaults to None. Dtype to cast the weights to.
transpose_square_weights (bool, optional): Whether to transpose square weights. Defaults to True.
Returns:
dict[str, Any]: _description_
"""
new_params = tree.map(lambda x: x, jax_params) # make a shallow copy
for key, val in dict(torch_named_parameters).items():
found = _get_key(key, jax_params)
if isinstance(found, tuple):
# inexact match, but maybe there is just one child weight
if len(found[1]) == 1:
new_suffix_key, weight = list(found[1].items())[0]
key = ".".join(key.split(".")[:-1] + [new_suffix_key])
else:
if warn:
print(f"Could not find {key}")
continue
else:
weight = found
# map the value to Torch
val = dlpack.from_dlpack(val.detach())
if dtype is not None:
val = val.astype(dtype)
# weight matching works as follows:
# 1. check if the weight is square, if so, JAX is likely to have the Torch weight transposed
# 2. check if the Torch weight shape is JAX weight transposed
# 3. check if the Torch weight is the same shape as JAX shape
if weight.ndim == 2 and tuple(weight.shape) == tuple(reversed(weight.shape)):
_set_key(key, new_params, val.T if transpose_square_weights else val)
elif weight.ndim == 2 == val.ndim and weight.shape == tuple(reversed(val.shape)):
_set_key(key, new_params, val.T)
elif weight.shape == val.shape:
_set_key(key, new_params, val)
else:
if warn:
print(f"Could not match {key}")
if verbose:
print(f"key = {key}, val = {val.shape}, found = {weight.shape}")
return new_params
thanks. i will try that out in the following days and let you know :)
Hey sorry, but I am not 100% sure how to kick off the script.
I tried
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", low_cpu_mem_usage=True, device_map="cpu")
### conver the model to JAX
### get the model parameters
params_torch = dict(model.named_parameters())
### JAX model parameters (empty)
params_jax = dict()
### match the parameters
params_jax = match_params(params_jax, params_torch)
but this didn't work. Could you maybe provide me a little hint on how to initialize the jax dict? thanks :)
Oh, I see!
The main problem is that I do not attempt to construct an equivalent JAX model (as far as I can tell that's somewhat involved).
We can match the weights for two models implemented in PyTorch and (for example) Flax.
So something like this should work:
from transformers import AutoModelForCausalLM
from transformers import FlaxMistralForCausalLM
jax_model = FlaxMistralForCausalLM.from_pretrained(flax_model_name)
torch_model = AutoModelForCausalLM.from_pretrained(torch_model_name)
matched_params_jax = match_params(jax_model.params, torch_model.named_parameters()) # we're NOT overwriting jax_model.params
This is a bit of a shameless plug, but I have a Python library for calling PyTorch code from JAX with minimal overhead if that's another path you'd like to explore: https://github.com/rdyro/torch2jax/
The following way of transformation worked, thanks for the help and this nice script
from transformers import AutoModelForCausalLM
from transformers import FlaxAutoModelForCausalLM
from transformers import AutoConfig
from flax.core.frozen_dict import freeze
# Pytorch model
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", low_cpu_mem_usage=True)#, device_map="cpu")
# Same Archtiecture in Flax
config = AutoConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
flax_model = FlaxAutoModelForCausalLM.from_config(config=config)
# Match using the script
matched_params_jax = match_params(flax_model.params, model.named_parameters())
# Set to new params
flax_model.params = freeze(matched_params_jax)
If we check for numerical precision error like suggested in repo it's on the order 1e-7
Will try out later if we can use this also for Llama 3
Update: Unfortunately the script didn't work for LLama3. We obtain
Error is numerical precision level: 7.9338e-02.
The numerical precision level might depend on the dtype.
Can you try with the option of not transposing weights by default, is the precision better? Take a look at this argument for this: transpose_square_weights: bool = False
The numerical precision might also be affected by the way the network is built in PyTorch and JAX. I found that if the weights are mismatched, the precision is usually much worse than ~1e-1
tried it out, that increased precision to ~1e0.
i will try out the llama in the next days qualitywise on inference time and report if there is something strange here.
Hmm, perhaps depending on the dtype, the 7.9338e-02 error is simply due to the compute graph optimization differences.
I don't have any other ideas on how to fix it at the moment, let me know if there's anything else!