{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "ewer-Q-0w2xA" }, "source": [ "# Installation" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "NpsF9ipLLl2s", "outputId": "10bf54aa-b89d-4e42-9777-bc97b00a5f32" }, "outputs": [], "source": [ "#!pip install git+https://github.com/huggingface/transformers/\n", "#!pip install git+https://github.com/google/flax" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "M1wVkrpjU6zO" }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/home/tmabraham/vqgan-jax\n" ] } ], "source": [ "%cd ../../vqgan-jax" ] }, { "cell_type": "markdown", "metadata": { "id": "t47CH1H_IOT8" }, "source": [ "# Custom BART Model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "9jQnM6S2vCpn" }, "outputs": [], "source": [ "# TODO: set those args in a config file\n", "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n", "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n", "BOS_TOKEN_ID = 16384\n", "BASE_MODEL = 'facebook/bart-large'" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "_eEaJVxAKpV5" }, "outputs": [], "source": [ "import jax\n", "import flax.linen as nn\n", "\n", "from transformers.models.bart.modeling_flax_bart import *\n", "from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n", "\n", "class CustomFlaxBartModule(FlaxBartModule):\n", " def setup(self):\n", " # we keep shared to easily load pre-trained weights\n", " self.shared = nn.Embed(\n", " self.config.vocab_size,\n", " self.config.d_model,\n", " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n", " dtype=self.dtype,\n", " )\n", " # a separate embedding is used for the decoder\n", " self.decoder_embed = nn.Embed(\n", " OUTPUT_VOCAB_SIZE,\n", " self.config.d_model,\n", " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n", " dtype=self.dtype,\n", " )\n", " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n", "\n", " # the decoder has a different config\n", " decoder_config = BartConfig(self.config.to_dict())\n", " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n", " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n", " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n", "\n", "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n", " def setup(self):\n", " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n", " self.lm_head = nn.Dense(\n", " OUTPUT_VOCAB_SIZE,\n", " use_bias=False,\n", " dtype=self.dtype,\n", " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n", " )\n", " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n", "\n", "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n", " module_class = CustomFlaxBartForConditionalGenerationModule" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mtmabraham\u001b[0m (use `wandb login --relogin` to force relogin)\n" ] }, { "data": { "text/html": [ "\n", " Tracking run with wandb version 0.10.33
\n", " Syncing run serene-resonance-1 to Weights & Biases (Documentation).
\n", " Project page: https://wandb.ai/tmabraham/vqgan-jax
\n", " Run page: https://wandb.ai/tmabraham/vqgan-jax/runs/1cm35ims
\n", " Run data is saved locally in /home/tmabraham/vqgan-jax/wandb/run-20210715_030616-1cm35ims

\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-1ef8yxby:v1, 1674.97MB. 2 files... Done. 0:0:0\n" ] } ], "source": [ "import wandb\n", "run = wandb.init()\n", "artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-1ef8yxby:v1', type='bart_model')\n", "artifact_dir = artifact.download()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "_6-XKK40oEfP", "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/tmabraham/dalle-mini/src/transformers/src/transformers/models/bart/configuration_bart.py:180: UserWarning: Please make sure the config includes `forced_bos_token_id=16384` in future versions.The config can simply be saved and uploaded again to be fixed.\n", " warnings.warn(\n", "INFO:absl:Starting the local TPU driver.\n", "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n", "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: TPU Interpreter Host\n" ] } ], "source": [ "# create our model and initialize it randomly\n", "model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Jz032w73nHEf", "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49" }, "outputs": [ { "data": { "text/plain": [ "(1, 16385)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# we verify that the shape has not been modified\n", "model.params['final_logits_bias'].shape" ] }, { "cell_type": "markdown", "metadata": { "id": "zLl24Ez5t7x1" }, "source": [ "## Inference" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "XLLA2NK3uDQr" }, "outputs": [], "source": [ "tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "P32mJJSbrU1F" }, "outputs": [], "source": [ "input_ids_test = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n", " 2]], dtype=int32)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "input_ids_test" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "C7cHbIHruELT" }, "outputs": [], "source": [ "greedy_output = model.generate(input_ids_test, max_length=257)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "jYugh9cOuwc9", "outputId": "19c3a2ee-e7bc-4f1f-9c86-06bd7337b537" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[16384, 16384, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n", " 10042]], dtype=int32)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "greedy_output[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# VGAN Jax" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "import io\n", "\n", "import requests\n", "from PIL import Image\n", "import numpy as np\n", "\n", "import torch\n", "import torchvision.transforms as T\n", "import torchvision.transforms.functional as TF\n", "from torchvision.transforms import InterpolationMode" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "from modeling_flax_vqgan import VQModel" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def custom_to_pil(x):\n", " x = np.clip(x, 0., 1.)\n", " x = (255*x).astype(np.uint8)\n", " x = Image.fromarray(x)\n", " if not x.mode == \"RGB\":\n", " x = x.convert(\"RGB\")\n", " return x" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Jz032w73nHEf", "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n" ] } ], "source": [ "model = VQModel.from_pretrained(\"valhalla/vqgan-imagenet-f16-1024\")" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "def get_images(indices, model):\n", " indices = indices[:, 1:]\n", " model.decode_code(indices)\n", " return indices" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAEACAIAAAD9XIvPAAAAF0lEQVR4nGP4//8/EwMDwygexaN45GEA7ucE/J1FRrMAAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "custom_to_pil(np.asarray(get_images(greedy_output[0], model)[0]))" ] } ], "metadata": { "accelerator": "TPU", "colab": { "collapsed_sections": [], "machine_shape": "hm", "name": "CustomBARTv4b-model-generate.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 1 }