{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "ewer-Q-0w2xA"
},
"source": [
"# Installation"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "NpsF9ipLLl2s",
"outputId": "10bf54aa-b89d-4e42-9777-bc97b00a5f32"
},
"outputs": [],
"source": [
"#!pip install git+https://github.com/huggingface/transformers/\n",
"#!pip install git+https://github.com/google/flax"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "M1wVkrpjU6zO"
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/home/tmabraham/vqgan-jax\n"
]
}
],
"source": [
"%cd ../../vqgan-jax"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t47CH1H_IOT8"
},
"source": [
"# Custom BART Model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "9jQnM6S2vCpn"
},
"outputs": [],
"source": [
"# TODO: set those args in a config file\n",
"OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n",
"OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n",
"BOS_TOKEN_ID = 16384\n",
"BASE_MODEL = 'facebook/bart-large'"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "_eEaJVxAKpV5"
},
"outputs": [],
"source": [
"import jax\n",
"import flax.linen as nn\n",
"\n",
"from transformers.models.bart.modeling_flax_bart import *\n",
"from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n",
"\n",
"class CustomFlaxBartModule(FlaxBartModule):\n",
" def setup(self):\n",
" # we keep shared to easily load pre-trained weights\n",
" self.shared = nn.Embed(\n",
" self.config.vocab_size,\n",
" self.config.d_model,\n",
" embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
" dtype=self.dtype,\n",
" )\n",
" # a separate embedding is used for the decoder\n",
" self.decoder_embed = nn.Embed(\n",
" OUTPUT_VOCAB_SIZE,\n",
" self.config.d_model,\n",
" embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
" dtype=self.dtype,\n",
" )\n",
" self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n",
"\n",
" # the decoder has a different config\n",
" decoder_config = BartConfig(self.config.to_dict())\n",
" decoder_config.max_position_embeddings = OUTPUT_LENGTH\n",
" decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n",
" self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n",
"\n",
"class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n",
" def setup(self):\n",
" self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n",
" self.lm_head = nn.Dense(\n",
" OUTPUT_VOCAB_SIZE,\n",
" use_bias=False,\n",
" dtype=self.dtype,\n",
" kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
" )\n",
" self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n",
"\n",
"class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n",
" module_class = CustomFlaxBartForConditionalGenerationModule"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mtmabraham\u001b[0m (use `wandb login --relogin` to force relogin)\n"
]
},
{
"data": {
"text/html": [
"\n",
" Tracking run with wandb version 0.10.33
\n",
" Syncing run serene-resonance-1 to Weights & Biases (Documentation).
\n",
" Project page: https://wandb.ai/tmabraham/vqgan-jax
\n",
" Run page: https://wandb.ai/tmabraham/vqgan-jax/runs/1cm35ims
\n",
" Run data is saved locally in /home/tmabraham/vqgan-jax/wandb/run-20210715_030616-1cm35ims
\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-1ef8yxby:v1, 1674.97MB. 2 files... Done. 0:0:0\n"
]
}
],
"source": [
"import wandb\n",
"run = wandb.init()\n",
"artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-1ef8yxby:v1', type='bart_model')\n",
"artifact_dir = artifact.download()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "_6-XKK40oEfP",
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/tmabraham/dalle-mini/src/transformers/src/transformers/models/bart/configuration_bart.py:180: UserWarning: Please make sure the config includes `forced_bos_token_id=16384` in future versions.The config can simply be saved and uploaded again to be fixed.\n",
" warnings.warn(\n",
"INFO:absl:Starting the local TPU driver.\n",
"INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
"INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: TPU Interpreter Host\n"
]
}
],
"source": [
"# create our model and initialize it randomly\n",
"model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Jz032w73nHEf",
"outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49"
},
"outputs": [
{
"data": {
"text/plain": [
"(1, 16385)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# we verify that the shape has not been modified\n",
"model.params['final_logits_bias'].shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zLl24Ez5t7x1"
},
"source": [
"## Inference"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "XLLA2NK3uDQr"
},
"outputs": [],
"source": [
"tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "P32mJJSbrU1F"
},
"outputs": [],
"source": [
"input_ids_test = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([[ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
" 2]], dtype=int32)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"input_ids_test"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "C7cHbIHruELT"
},
"outputs": [],
"source": [
"greedy_output = model.generate(input_ids_test, max_length=257)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "jYugh9cOuwc9",
"outputId": "19c3a2ee-e7bc-4f1f-9c86-06bd7337b537"
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([[16384, 16384, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
" 10042]], dtype=int32)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"greedy_output[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# VGAN Jax"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"import io\n",
"\n",
"import requests\n",
"from PIL import Image\n",
"import numpy as np\n",
"\n",
"import torch\n",
"import torchvision.transforms as T\n",
"import torchvision.transforms.functional as TF\n",
"from torchvision.transforms import InterpolationMode"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"from modeling_flax_vqgan import VQModel"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"def custom_to_pil(x):\n",
" x = np.clip(x, 0., 1.)\n",
" x = (255*x).astype(np.uint8)\n",
" x = Image.fromarray(x)\n",
" if not x.mode == \"RGB\":\n",
" x = x.convert(\"RGB\")\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Jz032w73nHEf",
"outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
]
}
],
"source": [
"model = VQModel.from_pretrained(\"valhalla/vqgan-imagenet-f16-1024\")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"def get_images(indices, model):\n",
" indices = indices[:, 1:]\n",
" model.decode_code(indices)\n",
" return indices"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAEACAIAAAD9XIvPAAAAF0lEQVR4nGP4//8/EwMDwygexaN45GEA7ucE/J1FRrMAAAAASUVORK5CYII=\n",
"text/plain": [
""
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"custom_to_pil(np.asarray(get_images(greedy_output[0], model)[0]))"
]
}
],
"metadata": {
"accelerator": "TPU",
"colab": {
"collapsed_sections": [],
"machine_shape": "hm",
"name": "CustomBARTv4b-model-generate.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 1
}