boris commited on
Commit
231a81a
·
1 Parent(s): cc3e85c

chore: file not needed

Browse files

Former-commit-id: 01a923196312ced88a2f0ca2010e793c26c84855

dalle_mini/vqgan_jax/convert_pt_model_to_jax.py DELETED
@@ -1,109 +0,0 @@
1
- import re
2
-
3
- import jax.numpy as jnp
4
- from flax.traverse_util import flatten_dict, unflatten_dict
5
-
6
- import torch
7
-
8
- from modeling_flax_vqgan import VQModel
9
- from configuration_vqgan import VQGANConfig
10
-
11
-
12
- regex = r"\w+[.]\d+"
13
-
14
-
15
- def rename_key(key):
16
- pats = re.findall(regex, key)
17
- for pat in pats:
18
- key = key.replace(pat, "_".join(pat.split(".")))
19
- return key
20
-
21
-
22
- # Adapted from https://github.com/huggingface/transformers/blob/ff5cdc086be1e0c3e2bbad8e3469b34cffb55a85/src/transformers/modeling_flax_pytorch_utils.py#L61
23
- def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
24
- # convert pytorch tensor to numpy
25
- pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
26
-
27
- random_flax_state_dict = flatten_dict(flax_model.params)
28
- flax_state_dict = {}
29
-
30
- remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
31
- flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
32
- )
33
- add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and (
34
- flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
35
- )
36
-
37
- # Need to change some parameters name to match Flax names so that we don't have to fork any layer
38
- for pt_key, pt_tensor in pt_state_dict.items():
39
- pt_tuple_key = tuple(pt_key.split("."))
40
-
41
- has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix
42
- require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict
43
-
44
- if remove_base_model_prefix and has_base_model_prefix:
45
- pt_tuple_key = pt_tuple_key[1:]
46
- elif add_base_model_prefix and require_base_model_prefix:
47
- pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
48
-
49
- # Correctly rename weight parameters
50
- if (
51
- "norm" in pt_key
52
- and (pt_tuple_key[-1] == "bias")
53
- and (pt_tuple_key[:-1] + ("bias",) in random_flax_state_dict)
54
- ):
55
- pt_tensor = pt_tensor[None, None, None, :]
56
- elif (
57
- "norm" in pt_key
58
- and (pt_tuple_key[-1] == "bias")
59
- and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
60
- ):
61
- pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
62
- pt_tensor = pt_tensor[None, None, None, :]
63
- elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
64
- pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
65
- pt_tensor = pt_tensor[None, None, None, :]
66
- if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
67
- pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
68
- elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict:
69
- # conv layer
70
- pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
71
- pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
72
- elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
73
- # linear layer
74
- pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
75
- pt_tensor = pt_tensor.T
76
- elif pt_tuple_key[-1] == "gamma":
77
- pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
78
- elif pt_tuple_key[-1] == "beta":
79
- pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
80
-
81
- if pt_tuple_key in random_flax_state_dict:
82
- if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape:
83
- raise ValueError(
84
- f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
85
- f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
86
- )
87
-
88
- # also add unexpected weight so that warning is thrown
89
- flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor)
90
-
91
- return unflatten_dict(flax_state_dict)
92
-
93
-
94
- def convert_model(config_path, pt_state_dict_path, save_path):
95
- config = VQGANConfig.from_pretrained(config_path)
96
- model = VQModel(config)
97
-
98
- state_dict = torch.load(pt_state_dict_path, map_location="cpu")["state_dict"]
99
- keys = list(state_dict.keys())
100
- for key in keys:
101
- if key.startswith("loss"):
102
- state_dict.pop(key)
103
- continue
104
- renamed_key = rename_key(key)
105
- state_dict[renamed_key] = state_dict.pop(key)
106
-
107
- state = convert_pytorch_state_dict_to_flax(state_dict, model)
108
- model.params = unflatten_dict(state)
109
- model.save_pretrained(save_path)