{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "bf8fb38a",
   "metadata": {},
   "source": [
    "# Data Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9b83dcb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass, field\n",
    "from pathlib import Path\n",
    "\n",
    "import datasets\n",
    "from datasets import Dataset, load_dataset\n",
    "import numpy as np\n",
    "\n",
    "from transformers import BartTokenizer\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "from flax.training.common_utils import shard"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a661a89e",
   "metadata": {},
   "source": [
    "File containing image paths, captions and VQGAN-encoded indices."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0e84e889",
   "metadata": {},
   "outputs": [],
   "source": [
    "datafile = '/data/CC12M/images-encoded-10000.tsv'   # 9999 encoded images from CC12M"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7fdc640b",
   "metadata": {},
   "source": [
    "TODO: generate train/test splits if necessary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cc6789b4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration default-91833df78e844785\n",
      "Reusing dataset csv (/home/pedro/.cache/huggingface/datasets/csv/default-91833df78e844785/0.0.0/e138af468cb14e747fb46a19c787ffcfa5170c821476d20d5304287ce12bbc23)\n"
     ]
    }
   ],
   "source": [
    "dataset = load_dataset('csv', delimiter='\\t', data_files=[datafile])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f3ed4919",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['image_file', 'caption', 'encoding'],\n",
       "        num_rows: 9999\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a70c7354",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['image_file', 'caption', 'encoding'],\n",
       "    num_rows: 9999\n",
       "})"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = dataset[\"train\"]\n",
    "dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a73454cf",
   "metadata": {},
   "source": [
    "We don't really need the `image_file` field for training. We'll drop it during pre-processing because we won't be able to numericalize it to a `jnp.array`, which would be required in JAX."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c0fa992",
   "metadata": {},
   "source": [
    "## Preprocessing"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0e36582",
   "metadata": {},
   "source": [
    "The `encoding` field contains a string representation of the encoded indices. We'll convert them to numbers. We also need to tokenize the captions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d46f6ac5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setting padding=\"max_length\" as we need fixed length inputs for jitted functions\n",
    "max_length = 256   # Read from data_args.max_source_length\n",
    "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')\n",
    "image_bos = 16384   # Max token is 16383 in our VQGAN configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4cac6643",
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_function(examples):\n",
    "    inputs = examples[\"caption\"]\n",
    "#     inputs = [prefix + inp for inp in inputs]   # Do we need this?\n",
    "    model_inputs = tokenizer(\n",
    "        inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"np\"\n",
    "    )\n",
    "\n",
    "    model_inputs[\"labels\"] = [[image_bos] + eval(indices) for indices in examples['encoding']]\n",
    "\n",
    "    return model_inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e6a4cb91",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_workers = 48     # We have 96 processors in the TPU\n",
    "column_names = dataset.column_names\n",
    "input_dataset = dataset.map(preprocess_function,\n",
    "                            remove_columns=column_names,\n",
    "                            batched=True,\n",
    "                            num_proc=48\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a9b1b467",
   "metadata": {},
   "outputs": [],
   "source": [
    "def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):\n",
    "    \"\"\"\n",
    "    Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.\n",
    "    Shuffle batches if `shuffle` is `True`.\n",
    "    \"\"\"\n",
    "    steps_per_epoch = len(dataset) // batch_size\n",
    "\n",
    "    if shuffle:\n",
    "        batch_idx = jax.random.permutation(rng, len(dataset))\n",
    "    else:\n",
    "        batch_idx = jnp.arange(len(dataset))\n",
    "\n",
    "    batch_idx = batch_idx[: steps_per_epoch * batch_size]  # Skip incomplete batch.\n",
    "    batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))\n",
    "\n",
    "    for idx in batch_idx:\n",
    "        batch = dataset[idx]        \n",
    "        batch = {k: jnp.array(v) for k, v in batch.items()}\n",
    "        batch = shard(batch)\n",
    "        yield batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0a628505",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:absl:Starting the local TPU driver.\n",
      "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
      "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Host TPU Interpreter\n"
     ]
    }
   ],
   "source": [
    "rng = jax.random.PRNGKey(23)  # Use training_args.seed\n",
    "batch_size = 64    # Per device\n",
    "super_batch_size = batch_size * jax.device_count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b3a5ce7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "loader = data_loader(rng, input_dataset, batch_size=super_batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "67aa8f9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "superbatch = next(iter(loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7cd99402",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['attention_mask', 'input_ids', 'labels'])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "superbatch.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "652a4a9e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(superbatch[\"labels\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "de7de4e8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(8, 64, 257)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "superbatch[\"labels\"].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6800153b",
   "metadata": {},
   "source": [
    "Any image sequence should begin with `image_bos`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "cfe23a71",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert superbatch[\"labels\"][1][5][0].item() == image_bos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fb899b4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}