Pedro Cuenca commited on
Commit
adfe05e
·
1 Parent(s): cb2ac60

Add dalle_mini directory module.

Browse files

It hosts a copy of VQGAN-JAX.


Former-commit-id: b859c49e7e9d8728c93882ce11ffdb137630de33

app/dalle_mini/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.0.1"
app/dalle_mini/dataset.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ An image-caption dataset dataloader.
3
+ Luke Melas-Kyriazi, 2021
4
+ """
5
+ import warnings
6
+ from typing import Optional, Callable
7
+ from pathlib import Path
8
+ import numpy as np
9
+ import torch
10
+ import pandas as pd
11
+ from torch.utils.data import Dataset
12
+ from torchvision.datasets.folder import default_loader
13
+ from PIL import ImageFile
14
+ from PIL.Image import DecompressionBombWarning
15
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
16
+ warnings.filterwarnings("ignore", category=UserWarning)
17
+ warnings.filterwarnings("ignore", category=DecompressionBombWarning)
18
+
19
+
20
+ class CaptionDataset(Dataset):
21
+ """
22
+ A PyTorch Dataset class for (image, texts) tasks. Note that this dataset
23
+ returns the raw text rather than tokens. This is done on purpose, because
24
+ it's easy to tokenize a batch of text after loading it from this dataset.
25
+ """
26
+
27
+ def __init__(self, *, images_root: str, captions_path: str, text_transform: Optional[Callable] = None,
28
+ image_transform: Optional[Callable] = None, image_transform_type: str = 'torchvision',
29
+ include_captions: bool = True):
30
+ """
31
+ :param images_root: folder where images are stored
32
+ :param captions_path: path to csv that maps image filenames to captions
33
+ :param image_transform: image transform pipeline
34
+ :param text_transform: image transform pipeline
35
+ :param image_transform_type: image transform type, either `torchvision` or `albumentations`
36
+ :param include_captions: Returns a dictionary with `image`, `text` if `true`; otherwise returns just the images.
37
+ """
38
+
39
+ # Base path for images
40
+ self.images_root = Path(images_root)
41
+
42
+ # Load captions as DataFrame
43
+ self.captions = pd.read_csv(captions_path, delimiter='\t', header=0)
44
+ self.captions['image_file'] = self.captions['image_file'].astype(str)
45
+
46
+ # PyTorch transformation pipeline for the image (normalizing, etc.)
47
+ self.text_transform = text_transform
48
+ self.image_transform = image_transform
49
+ self.image_transform_type = image_transform_type.lower()
50
+ assert self.image_transform_type in ['torchvision', 'albumentations']
51
+
52
+ # Total number of datapoints
53
+ self.size = len(self.captions)
54
+
55
+ # Return image+captions or just images
56
+ self.include_captions = include_captions
57
+
58
+ def verify_that_all_images_exist(self):
59
+ for image_file in self.captions['image_file']:
60
+ p = self.images_root / image_file
61
+ if not p.is_file():
62
+ print(f'file does not exist: {p}')
63
+
64
+ def _get_raw_image(self, i):
65
+ image_file = self.captions.iloc[i]['image_file']
66
+ image_path = self.images_root / image_file
67
+ image = default_loader(image_path)
68
+ return image
69
+
70
+ def _get_raw_text(self, i):
71
+ return self.captions.iloc[i]['caption']
72
+
73
+ def __getitem__(self, i):
74
+ image = self._get_raw_image(i)
75
+ caption = self._get_raw_text(i)
76
+ if self.image_transform is not None:
77
+ if self.image_transform_type == 'torchvision':
78
+ image = self.image_transform(image)
79
+ elif self.image_transform_type == 'albumentations':
80
+ image = self.image_transform(image=np.array(image))['image']
81
+ else:
82
+ raise NotImplementedError(f"{self.image_transform_type=}")
83
+ return {'image': image, 'text': caption} if self.include_captions else image
84
+
85
+ def __len__(self):
86
+ return self.size
87
+
88
+
89
+ if __name__ == "__main__":
90
+ import albumentations as A
91
+ from albumentations.pytorch import ToTensorV2
92
+ from transformers import AutoTokenizer
93
+
94
+ # Paths
95
+ images_root = './images'
96
+ captions_path = './images-list-clean.tsv'
97
+
98
+ # Create transforms
99
+ tokenizer = AutoTokenizer.from_pretrained('distilroberta-base')
100
+ def tokenize(text):
101
+ return tokenizer(text, max_length=32, truncation=True, return_tensors='pt', padding='max_length')
102
+ image_transform = A.Compose([
103
+ A.Resize(256, 256), A.CenterCrop(256, 256),
104
+ A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ToTensorV2()])
105
+
106
+ # Create dataset
107
+ dataset = CaptionDataset(
108
+ images_root=images_root,
109
+ captions_path=captions_path,
110
+ image_transform=image_transform,
111
+ text_transform=tokenize,
112
+ image_transform_type='albumentations')
113
+
114
+ # Create dataloader
115
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=2)
116
+ batch = next(iter(dataloader))
117
+ print({k: (v.shape if isinstance(v, torch.Tensor) else v) for k, v in batch.items()})
118
+
119
+ # # (Optional) Check that all the images exist
120
+ # dataset = CaptionDataset(images_root=images_root, captions_path=captions_path)
121
+ # dataset.verify_that_all_images_exist()
122
+ # print('Done')
app/dalle_mini/vqgan_jax/__init__.py ADDED
File without changes
app/dalle_mini/vqgan_jax/configuration_vqgan.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class VQGANConfig(PretrainedConfig):
7
+ def __init__(
8
+ self,
9
+ ch: int = 128,
10
+ out_ch: int = 3,
11
+ in_channels: int = 3,
12
+ num_res_blocks: int = 2,
13
+ resolution: int = 256,
14
+ z_channels: int = 256,
15
+ ch_mult: Tuple = (1, 1, 2, 2, 4),
16
+ attn_resolutions: int = (16,),
17
+ n_embed: int = 1024,
18
+ embed_dim: int = 256,
19
+ dropout: float = 0.0,
20
+ double_z: bool = False,
21
+ resamp_with_conv: bool = True,
22
+ give_pre_end: bool = False,
23
+ **kwargs,
24
+ ):
25
+ super().__init__(**kwargs)
26
+ self.ch = ch
27
+ self.out_ch = out_ch
28
+ self.in_channels = in_channels
29
+ self.num_res_blocks = num_res_blocks
30
+ self.resolution = resolution
31
+ self.z_channels = z_channels
32
+ self.ch_mult = list(ch_mult)
33
+ self.attn_resolutions = list(attn_resolutions)
34
+ self.n_embed = n_embed
35
+ self.embed_dim = embed_dim
36
+ self.dropout = dropout
37
+ self.double_z = double_z
38
+ self.resamp_with_conv = resamp_with_conv
39
+ self.give_pre_end = give_pre_end
40
+ self.num_resolutions = len(ch_mult)
app/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import jax.numpy as jnp
4
+ from flax.traverse_util import flatten_dict, unflatten_dict
5
+
6
+ import torch
7
+
8
+ from modeling_flax_vqgan import VQModel
9
+ from configuration_vqgan import VQGANConfig
10
+
11
+
12
+ regex = r"\w+[.]\d+"
13
+
14
+
15
+ def rename_key(key):
16
+ pats = re.findall(regex, key)
17
+ for pat in pats:
18
+ key = key.replace(pat, "_".join(pat.split(".")))
19
+ return key
20
+
21
+
22
+ # Adapted from https://github.com/huggingface/transformers/blob/ff5cdc086be1e0c3e2bbad8e3469b34cffb55a85/src/transformers/modeling_flax_pytorch_utils.py#L61
23
+ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
24
+ # convert pytorch tensor to numpy
25
+ pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
26
+
27
+ random_flax_state_dict = flatten_dict(flax_model.params)
28
+ flax_state_dict = {}
29
+
30
+ remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
31
+ flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
32
+ )
33
+ add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and (
34
+ flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
35
+ )
36
+
37
+ # Need to change some parameters name to match Flax names so that we don't have to fork any layer
38
+ for pt_key, pt_tensor in pt_state_dict.items():
39
+ pt_tuple_key = tuple(pt_key.split("."))
40
+
41
+ has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix
42
+ require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict
43
+
44
+ if remove_base_model_prefix and has_base_model_prefix:
45
+ pt_tuple_key = pt_tuple_key[1:]
46
+ elif add_base_model_prefix and require_base_model_prefix:
47
+ pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
48
+
49
+ # Correctly rename weight parameters
50
+ if (
51
+ "norm" in pt_key
52
+ and (pt_tuple_key[-1] == "bias")
53
+ and (pt_tuple_key[:-1] + ("bias",) in random_flax_state_dict)
54
+ ):
55
+ pt_tensor = pt_tensor[None, None, None, :]
56
+ elif (
57
+ "norm" in pt_key
58
+ and (pt_tuple_key[-1] == "bias")
59
+ and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
60
+ ):
61
+ pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
62
+ pt_tensor = pt_tensor[None, None, None, :]
63
+ elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
64
+ pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
65
+ pt_tensor = pt_tensor[None, None, None, :]
66
+ if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
67
+ pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
68
+ elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict:
69
+ # conv layer
70
+ pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
71
+ pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
72
+ elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
73
+ # linear layer
74
+ pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
75
+ pt_tensor = pt_tensor.T
76
+ elif pt_tuple_key[-1] == "gamma":
77
+ pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
78
+ elif pt_tuple_key[-1] == "beta":
79
+ pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
80
+
81
+ if pt_tuple_key in random_flax_state_dict:
82
+ if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape:
83
+ raise ValueError(
84
+ f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
85
+ f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
86
+ )
87
+
88
+ # also add unexpected weight so that warning is thrown
89
+ flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor)
90
+
91
+ return unflatten_dict(flax_state_dict)
92
+
93
+
94
+ def convert_model(config_path, pt_state_dict_path, save_path):
95
+ config = VQGANConfig.from_pretrained(config_path)
96
+ model = VQModel(config)
97
+
98
+ state_dict = torch.load(pt_state_dict_path, map_location="cpu")["state_dict"]
99
+ keys = list(state_dict.keys())
100
+ for key in keys:
101
+ if key.startswith("loss"):
102
+ state_dict.pop(key)
103
+ continue
104
+ renamed_key = rename_key(key)
105
+ state_dict[renamed_key] = state_dict.pop(key)
106
+
107
+ state = convert_pytorch_state_dict_to_flax(state_dict, model)
108
+ model.params = unflatten_dict(state)
109
+ model.save_pretrained(save_path)
app/dalle_mini/vqgan_jax/modeling_flax_vqgan.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers
2
+
3
+ from functools import partial
4
+ from typing import Tuple
5
+ import math
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import numpy as np
10
+ import flax.linen as nn
11
+ from flax.core.frozen_dict import FrozenDict
12
+
13
+ from transformers.modeling_flax_utils import FlaxPreTrainedModel
14
+
15
+ from .configuration_vqgan import VQGANConfig
16
+
17
+
18
+ class Upsample(nn.Module):
19
+ in_channels: int
20
+ with_conv: bool
21
+ dtype: jnp.dtype = jnp.float32
22
+
23
+ def setup(self):
24
+ if self.with_conv:
25
+ self.conv = nn.Conv(
26
+ self.in_channels,
27
+ kernel_size=(3, 3),
28
+ strides=(1, 1),
29
+ padding=((1, 1), (1, 1)),
30
+ dtype=self.dtype,
31
+ )
32
+
33
+ def __call__(self, hidden_states):
34
+ batch, height, width, channels = hidden_states.shape
35
+ hidden_states = jax.image.resize(
36
+ hidden_states,
37
+ shape=(batch, height * 2, width * 2, channels),
38
+ method="nearest",
39
+ )
40
+ if self.with_conv:
41
+ hidden_states = self.conv(hidden_states)
42
+ return hidden_states
43
+
44
+
45
+ class Downsample(nn.Module):
46
+ in_channels: int
47
+ with_conv: bool
48
+ dtype: jnp.dtype = jnp.float32
49
+
50
+ def setup(self):
51
+ if self.with_conv:
52
+ self.conv = nn.Conv(
53
+ self.in_channels,
54
+ kernel_size=(3, 3),
55
+ strides=(2, 2),
56
+ padding="VALID",
57
+ dtype=self.dtype,
58
+ )
59
+
60
+ def __call__(self, hidden_states):
61
+ if self.with_conv:
62
+ pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
63
+ hidden_states = jnp.pad(hidden_states, pad_width=pad)
64
+ hidden_states = self.conv(hidden_states)
65
+ else:
66
+ hidden_states = nn.avg_pool(hidden_states, window_shape=(2, 2), strides=(2, 2), padding="VALID")
67
+ return hidden_states
68
+
69
+
70
+ class ResnetBlock(nn.Module):
71
+ in_channels: int
72
+ out_channels: int = None
73
+ use_conv_shortcut: bool = False
74
+ temb_channels: int = 512
75
+ dropout_prob: float = 0.0
76
+ dtype: jnp.dtype = jnp.float32
77
+
78
+ def setup(self):
79
+ self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels
80
+
81
+ self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
82
+ self.conv1 = nn.Conv(
83
+ self.out_channels_,
84
+ kernel_size=(3, 3),
85
+ strides=(1, 1),
86
+ padding=((1, 1), (1, 1)),
87
+ dtype=self.dtype,
88
+ )
89
+
90
+ if self.temb_channels:
91
+ self.temb_proj = nn.Dense(self.out_channels_, dtype=self.dtype)
92
+
93
+ self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
94
+ self.dropout = nn.Dropout(self.dropout_prob)
95
+ self.conv2 = nn.Conv(
96
+ self.out_channels_,
97
+ kernel_size=(3, 3),
98
+ strides=(1, 1),
99
+ padding=((1, 1), (1, 1)),
100
+ dtype=self.dtype,
101
+ )
102
+
103
+ if self.in_channels != self.out_channels_:
104
+ if self.use_conv_shortcut:
105
+ self.conv_shortcut = nn.Conv(
106
+ self.out_channels_,
107
+ kernel_size=(3, 3),
108
+ strides=(1, 1),
109
+ padding=((1, 1), (1, 1)),
110
+ dtype=self.dtype,
111
+ )
112
+ else:
113
+ self.nin_shortcut = nn.Conv(
114
+ self.out_channels_,
115
+ kernel_size=(1, 1),
116
+ strides=(1, 1),
117
+ padding="VALID",
118
+ dtype=self.dtype,
119
+ )
120
+
121
+ def __call__(self, hidden_states, temb=None, deterministic: bool = True):
122
+ residual = hidden_states
123
+ hidden_states = self.norm1(hidden_states)
124
+ hidden_states = nn.swish(hidden_states)
125
+ hidden_states = self.conv1(hidden_states)
126
+
127
+ if temb is not None:
128
+ hidden_states = hidden_states + self.temb_proj(nn.swish(temb))[:, :, None, None] # TODO: check shapes
129
+
130
+ hidden_states = self.norm2(hidden_states)
131
+ hidden_states = nn.swish(hidden_states)
132
+ hidden_states = self.dropout(hidden_states, deterministic)
133
+ hidden_states = self.conv2(hidden_states)
134
+
135
+ if self.in_channels != self.out_channels_:
136
+ if self.use_conv_shortcut:
137
+ residual = self.conv_shortcut(residual)
138
+ else:
139
+ residual = self.nin_shortcut(residual)
140
+
141
+ return hidden_states + residual
142
+
143
+
144
+ class AttnBlock(nn.Module):
145
+ in_channels: int
146
+ dtype: jnp.dtype = jnp.float32
147
+
148
+ def setup(self):
149
+ conv = partial(
150
+ nn.Conv, self.in_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype
151
+ )
152
+
153
+ self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-6)
154
+ self.q, self.k, self.v = conv(), conv(), conv()
155
+ self.proj_out = conv()
156
+
157
+ def __call__(self, hidden_states):
158
+ residual = hidden_states
159
+ hidden_states = self.norm(hidden_states)
160
+
161
+ query = self.q(hidden_states)
162
+ key = self.k(hidden_states)
163
+ value = self.v(hidden_states)
164
+
165
+ # compute attentions
166
+ batch, height, width, channels = query.shape
167
+ query = query.reshape((batch, height * width, channels))
168
+ key = key.reshape((batch, height * width, channels))
169
+ attn_weights = jnp.einsum("...qc,...kc->...qk", query, key)
170
+ attn_weights = attn_weights * (int(channels) ** -0.5)
171
+ attn_weights = nn.softmax(attn_weights, axis=2)
172
+
173
+ ## attend to values
174
+ value = value.reshape((batch, height * width, channels))
175
+ hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights)
176
+ hidden_states = hidden_states.reshape((batch, height, width, channels))
177
+
178
+ hidden_states = self.proj_out(hidden_states)
179
+ hidden_states = hidden_states + residual
180
+ return hidden_states
181
+
182
+
183
+ class UpsamplingBlock(nn.Module):
184
+ config: VQGANConfig
185
+ curr_res: int
186
+ block_idx: int
187
+ dtype: jnp.dtype = jnp.float32
188
+
189
+ def setup(self):
190
+ if self.block_idx == self.config.num_resolutions - 1:
191
+ block_in = self.config.ch * self.config.ch_mult[-1]
192
+ else:
193
+ block_in = self.config.ch * self.config.ch_mult[self.block_idx + 1]
194
+
195
+ block_out = self.config.ch * self.config.ch_mult[self.block_idx]
196
+ self.temb_ch = 0
197
+
198
+ res_blocks = []
199
+ attn_blocks = []
200
+ for _ in range(self.config.num_res_blocks + 1):
201
+ res_blocks.append(
202
+ ResnetBlock(
203
+ block_in, block_out, temb_channels=self.temb_ch, dropout_prob=self.config.dropout, dtype=self.dtype
204
+ )
205
+ )
206
+ block_in = block_out
207
+ if self.curr_res in self.config.attn_resolutions:
208
+ attn_blocks.append(AttnBlock(block_in, dtype=self.dtype))
209
+
210
+ self.block = res_blocks
211
+ self.attn = attn_blocks
212
+
213
+ self.upsample = None
214
+ if self.block_idx != 0:
215
+ self.upsample = Upsample(block_in, self.config.resamp_with_conv, dtype=self.dtype)
216
+
217
+ def __call__(self, hidden_states, temb=None, deterministic: bool = True):
218
+ for res_block in self.block:
219
+ hidden_states = res_block(hidden_states, temb, deterministic=deterministic)
220
+ for attn_block in self.attn:
221
+ hidden_states = attn_block(hidden_states)
222
+
223
+ if self.upsample is not None:
224
+ hidden_states = self.upsample(hidden_states)
225
+
226
+ return hidden_states
227
+
228
+
229
+ class DownsamplingBlock(nn.Module):
230
+ config: VQGANConfig
231
+ curr_res: int
232
+ block_idx: int
233
+ dtype: jnp.dtype = jnp.float32
234
+
235
+ def setup(self):
236
+ in_ch_mult = (1,) + tuple(self.config.ch_mult)
237
+ block_in = self.config.ch * in_ch_mult[self.block_idx]
238
+ block_out = self.config.ch * self.config.ch_mult[self.block_idx]
239
+ self.temb_ch = 0
240
+
241
+ res_blocks = []
242
+ attn_blocks = []
243
+ for _ in range(self.config.num_res_blocks):
244
+ res_blocks.append(
245
+ ResnetBlock(
246
+ block_in, block_out, temb_channels=self.temb_ch, dropout_prob=self.config.dropout, dtype=self.dtype
247
+ )
248
+ )
249
+ block_in = block_out
250
+ if self.curr_res in self.config.attn_resolutions:
251
+ attn_blocks.append(AttnBlock(block_in, dtype=self.dtype))
252
+
253
+ self.block = res_blocks
254
+ self.attn = attn_blocks
255
+
256
+ self.downsample = None
257
+ if self.block_idx != self.config.num_resolutions - 1:
258
+ self.downsample = Downsample(block_in, self.config.resamp_with_conv, dtype=self.dtype)
259
+
260
+ def __call__(self, hidden_states, temb=None, deterministic: bool = True):
261
+ for res_block in self.block:
262
+ hidden_states = res_block(hidden_states, temb, deterministic=deterministic)
263
+ for attn_block in self.attn:
264
+ hidden_states = attn_block(hidden_states)
265
+
266
+ if self.downsample is not None:
267
+ hidden_states = self.downsample(hidden_states)
268
+
269
+ return hidden_states
270
+
271
+
272
+ class MidBlock(nn.Module):
273
+ in_channels: int
274
+ temb_channels: int
275
+ dropout: float
276
+ dtype: jnp.dtype = jnp.float32
277
+
278
+ def setup(self):
279
+ self.block_1 = ResnetBlock(
280
+ self.in_channels,
281
+ self.in_channels,
282
+ temb_channels=self.temb_channels,
283
+ dropout_prob=self.dropout,
284
+ dtype=self.dtype,
285
+ )
286
+ self.attn_1 = AttnBlock(self.in_channels, dtype=self.dtype)
287
+ self.block_2 = ResnetBlock(
288
+ self.in_channels,
289
+ self.in_channels,
290
+ temb_channels=self.temb_channels,
291
+ dropout_prob=self.dropout,
292
+ dtype=self.dtype,
293
+ )
294
+
295
+ def __call__(self, hidden_states, temb=None, deterministic: bool = True):
296
+ hidden_states = self.block_1(hidden_states, temb, deterministic=deterministic)
297
+ hidden_states = self.attn_1(hidden_states)
298
+ hidden_states = self.block_2(hidden_states, temb, deterministic=deterministic)
299
+ return hidden_states
300
+
301
+
302
+ class Encoder(nn.Module):
303
+ config: VQGANConfig
304
+ dtype: jnp.dtype = jnp.float32
305
+
306
+ def setup(self):
307
+ self.temb_ch = 0
308
+
309
+ # downsampling
310
+ self.conv_in = nn.Conv(
311
+ self.config.ch,
312
+ kernel_size=(3, 3),
313
+ strides=(1, 1),
314
+ padding=((1, 1), (1, 1)),
315
+ dtype=self.dtype,
316
+ )
317
+
318
+ curr_res = self.config.resolution
319
+ downsample_blocks = []
320
+ for i_level in range(self.config.num_resolutions):
321
+ downsample_blocks.append(DownsamplingBlock(self.config, curr_res, block_idx=i_level, dtype=self.dtype))
322
+
323
+ if i_level != self.config.num_resolutions - 1:
324
+ curr_res = curr_res // 2
325
+ self.down = downsample_blocks
326
+
327
+ # middle
328
+ mid_channels = self.config.ch * self.config.ch_mult[-1]
329
+ self.mid = MidBlock(mid_channels, self.temb_ch, self.config.dropout, dtype=self.dtype)
330
+
331
+ # end
332
+ self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
333
+ self.conv_out = nn.Conv(
334
+ 2 * self.config.z_channels if self.config.double_z else self.config.z_channels,
335
+ kernel_size=(3, 3),
336
+ strides=(1, 1),
337
+ padding=((1, 1), (1, 1)),
338
+ dtype=self.dtype,
339
+ )
340
+
341
+ def __call__(self, pixel_values, deterministic: bool = True):
342
+ # timestep embedding
343
+ temb = None
344
+
345
+ # downsampling
346
+ hidden_states = self.conv_in(pixel_values)
347
+ for block in self.down:
348
+ hidden_states = block(hidden_states, temb, deterministic=deterministic)
349
+
350
+ # middle
351
+ hidden_states = self.mid(hidden_states, temb, deterministic=deterministic)
352
+
353
+ # end
354
+ hidden_states = self.norm_out(hidden_states)
355
+ hidden_states = nn.swish(hidden_states)
356
+ hidden_states = self.conv_out(hidden_states)
357
+
358
+ return hidden_states
359
+
360
+
361
+ class Decoder(nn.Module):
362
+ config: VQGANConfig
363
+ dtype: jnp.dtype = jnp.float32
364
+
365
+ def setup(self):
366
+ self.temb_ch = 0
367
+
368
+ # compute in_ch_mult, block_in and curr_res at lowest res
369
+ block_in = self.config.ch * self.config.ch_mult[self.config.num_resolutions - 1]
370
+ curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1)
371
+ self.z_shape = (1, self.config.z_channels, curr_res, curr_res)
372
+ print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
373
+
374
+ # z to block_in
375
+ self.conv_in = nn.Conv(
376
+ block_in,
377
+ kernel_size=(3, 3),
378
+ strides=(1, 1),
379
+ padding=((1, 1), (1, 1)),
380
+ dtype=self.dtype,
381
+ )
382
+
383
+ # middle
384
+ self.mid = MidBlock(block_in, self.temb_ch, self.config.dropout, dtype=self.dtype)
385
+
386
+ # upsampling
387
+ upsample_blocks = []
388
+ for i_level in reversed(range(self.config.num_resolutions)):
389
+ upsample_blocks.append(UpsamplingBlock(self.config, curr_res, block_idx=i_level, dtype=self.dtype))
390
+ if i_level != 0:
391
+ curr_res = curr_res * 2
392
+ self.up = list(reversed(upsample_blocks)) # reverse to get consistent order
393
+
394
+ # end
395
+ self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
396
+ self.conv_out = nn.Conv(
397
+ self.config.out_ch,
398
+ kernel_size=(3, 3),
399
+ strides=(1, 1),
400
+ padding=((1, 1), (1, 1)),
401
+ dtype=self.dtype,
402
+ )
403
+
404
+ def __call__(self, hidden_states, deterministic: bool = True):
405
+ # timestep embedding
406
+ temb = None
407
+
408
+ # z to block_in
409
+ hidden_states = self.conv_in(hidden_states)
410
+
411
+ # middle
412
+ hidden_states = self.mid(hidden_states, temb, deterministic=deterministic)
413
+
414
+ # upsampling
415
+ for block in reversed(self.up):
416
+ hidden_states = block(hidden_states, temb, deterministic=deterministic)
417
+
418
+ # end
419
+ if self.config.give_pre_end:
420
+ return hidden_states
421
+
422
+ hidden_states = self.norm_out(hidden_states)
423
+ hidden_states = nn.swish(hidden_states)
424
+ hidden_states = self.conv_out(hidden_states)
425
+
426
+ return hidden_states
427
+
428
+
429
+ class VectorQuantizer(nn.Module):
430
+ """
431
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
432
+ ____________________________________________
433
+ Discretization bottleneck part of the VQ-VAE.
434
+ Inputs:
435
+ - n_e : number of embeddings
436
+ - e_dim : dimension of embedding
437
+ - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
438
+ _____________________________________________
439
+ """
440
+
441
+ config: VQGANConfig
442
+ dtype: jnp.dtype = jnp.float32
443
+
444
+ def setup(self):
445
+ self.embedding = nn.Embed(self.config.n_embed, self.config.embed_dim, dtype=self.dtype) # TODO: init
446
+
447
+ def __call__(self, hidden_states):
448
+ """
449
+ Inputs the output of the encoder network z and maps it to a discrete
450
+ one-hot vector that is the index of the closest embedding vector e_j
451
+ z (continuous) -> z_q (discrete)
452
+ z.shape = (batch, channel, height, width)
453
+ quantization pipeline:
454
+ 1. get encoder input (B,C,H,W)
455
+ 2. flatten input to (B*H*W,C)
456
+ """
457
+ # flatten
458
+ hidden_states_flattended = hidden_states.reshape((-1, self.config.embed_dim))
459
+
460
+ # dummy op to init the weights, so we can access them below
461
+ self.embedding(jnp.ones((1, 1), dtype="i4"))
462
+
463
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
464
+ emb_weights = self.variables["params"]["embedding"]["embedding"]
465
+ distance = (
466
+ jnp.sum(hidden_states_flattended ** 2, axis=1, keepdims=True)
467
+ + jnp.sum(emb_weights ** 2, axis=1)
468
+ - 2 * jnp.dot(hidden_states_flattended, emb_weights.T)
469
+ )
470
+
471
+ # get quantized latent vectors
472
+ min_encoding_indices = jnp.argmin(distance, axis=1)
473
+ z_q = self.embedding(min_encoding_indices).reshape(hidden_states.shape)
474
+
475
+ # reshape to (batch, num_tokens)
476
+ min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1)
477
+
478
+ # compute the codebook_loss (q_loss) outside the model
479
+ # here we return the embeddings and indices
480
+ return z_q, min_encoding_indices
481
+
482
+ def get_codebook_entry(self, indices, shape=None):
483
+ # indices are expected to be of shape (batch, num_tokens)
484
+ # get quantized latent vectors
485
+ batch, num_tokens = indices.shape
486
+ z_q = self.embedding(indices)
487
+ z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1)
488
+ return z_q
489
+
490
+
491
+ class VQModule(nn.Module):
492
+ config: VQGANConfig
493
+ dtype: jnp.dtype = jnp.float32
494
+
495
+ def setup(self):
496
+ self.encoder = Encoder(self.config, dtype=self.dtype)
497
+ self.decoder = Decoder(self.config, dtype=self.dtype)
498
+ self.quantize = VectorQuantizer(self.config, dtype=self.dtype)
499
+ self.quant_conv = nn.Conv(
500
+ self.config.embed_dim,
501
+ kernel_size=(1, 1),
502
+ strides=(1, 1),
503
+ padding="VALID",
504
+ dtype=self.dtype,
505
+ )
506
+ self.post_quant_conv = nn.Conv(
507
+ self.config.z_channels,
508
+ kernel_size=(1, 1),
509
+ strides=(1, 1),
510
+ padding="VALID",
511
+ dtype=self.dtype,
512
+ )
513
+
514
+ def encode(self, pixel_values, deterministic: bool = True):
515
+ hidden_states = self.encoder(pixel_values, deterministic=deterministic)
516
+ hidden_states = self.quant_conv(hidden_states)
517
+ quant_states, indices = self.quantize(hidden_states)
518
+ return quant_states, indices
519
+
520
+ def decode(self, hidden_states, deterministic: bool = True):
521
+ hidden_states = self.post_quant_conv(hidden_states)
522
+ hidden_states = self.decoder(hidden_states, deterministic=deterministic)
523
+ return hidden_states
524
+
525
+ def decode_code(self, code_b):
526
+ hidden_states = self.quantize.get_codebook_entry(code_b)
527
+ hidden_states = self.decode(hidden_states)
528
+ return hidden_states
529
+
530
+ def __call__(self, pixel_values, deterministic: bool = True):
531
+ quant_states, indices = self.encode(pixel_values, deterministic)
532
+ hidden_states = self.decode(quant_states, deterministic)
533
+ return hidden_states, indices
534
+
535
+
536
+ class VQGANPreTrainedModel(FlaxPreTrainedModel):
537
+ """
538
+ An abstract class to handle weights initialization and a simple interface
539
+ for downloading and loading pretrained models.
540
+ """
541
+
542
+ config_class = VQGANConfig
543
+ base_model_prefix = "model"
544
+ module_class: nn.Module = None
545
+
546
+ def __init__(
547
+ self,
548
+ config: VQGANConfig,
549
+ input_shape: Tuple = (1, 256, 256, 3),
550
+ seed: int = 0,
551
+ dtype: jnp.dtype = jnp.float32,
552
+ **kwargs,
553
+ ):
554
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
555
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
556
+
557
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
558
+ # init input tensors
559
+ pixel_values = jnp.zeros(input_shape, dtype=jnp.float32)
560
+ params_rng, dropout_rng = jax.random.split(rng)
561
+ rngs = {"params": params_rng, "dropout": dropout_rng}
562
+
563
+ return self.module.init(rngs, pixel_values)["params"]
564
+
565
+ def encode(self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False):
566
+ # Handle any PRNG if needed
567
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
568
+
569
+ return self.module.apply(
570
+ {"params": params or self.params}, jnp.array(pixel_values), not train, rngs=rngs, method=self.module.encode
571
+ )
572
+
573
+ def decode(self, hidden_states, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False):
574
+ # Handle any PRNG if needed
575
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
576
+
577
+ return self.module.apply(
578
+ {"params": params or self.params},
579
+ jnp.array(hidden_states),
580
+ not train,
581
+ rngs=rngs,
582
+ method=self.module.decode,
583
+ )
584
+
585
+ def decode_code(self, indices, params: dict = None):
586
+ return self.module.apply(
587
+ {"params": params or self.params}, jnp.array(indices, dtype="i4"), method=self.module.decode_code
588
+ )
589
+
590
+ def __call__(
591
+ self,
592
+ pixel_values,
593
+ params: dict = None,
594
+ dropout_rng: jax.random.PRNGKey = None,
595
+ train: bool = False,
596
+ ):
597
+ # Handle any PRNG if needed
598
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
599
+
600
+ return self.module.apply(
601
+ {"params": params or self.params},
602
+ jnp.array(pixel_values),
603
+ not train,
604
+ rngs=rngs,
605
+ )
606
+
607
+
608
+ class VQModel(VQGANPreTrainedModel):
609
+ module_class = VQModule