File size: 3,349 Bytes
35352c6
16d5d78
 
35352c6
 
 
 
 
 
 
 
16d5d78
35352c6
 
16d5d78
 
 
 
35352c6
16d5d78
 
35352c6
16d5d78
 
 
 
35352c6
 
 
16d5d78
35352c6
 
 
16d5d78
 
 
 
 
 
 
 
35352c6
 
 
 
 
 
 
 
 
 
 
16d5d78
35352c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16d5d78
35352c6
 
16d5d78
 
 
 
 
 
35352c6
 
 
16d5d78
 
35352c6
 
 
16d5d78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8c8fe0
 
 
16d5d78
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import multiprocessing as mp
import pathlib
from typing import Any

import datasets
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from src import config
from src import tokenizer as tk


class CaptionDatset(Dataset):
    def __init__(self, dataset: datasets.Dataset, img_path: pathlib.Path) -> None:
        self.dataset = dataset
        self.img_path = img_path

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, idx: int) -> dict[str, Any]:
        item = self.dataset[idx]
        image = Image.open(self.img_path / item["url"].rsplit("/", 1)[-1]).convert("RGB")
        return {"image": image, "caption": item["short_caption"]}


class CollateFn:
    def __init__(self, tokenizer: tk.Tokenizer, transform: transforms.Compose):
        self.tokenizer = tokenizer
        self.transform = transform

    def __call__(self, batch: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
        stacked_images = torch.stack([self.transform(item["image"]) for item in batch])
        tokenized_text = self.tokenizer([item["caption"] for item in batch])

        return {
            "image": stacked_images,
            **tokenized_text,
        }


def _get_dataloaders(
    train_ds: Dataset,
    valid_ds: Dataset,
    training_config: config.TrainerConfig,
    collate_fn: CollateFn,
) -> tuple[DataLoader, DataLoader]:
    common_params = {
        "batch_size": training_config.batch_size,
        "pin_memory": True,
        "num_workers": mp.cpu_count() // 3,
        "collate_fn": collate_fn,
    }
    train_loader = DataLoader(
        train_ds,
        shuffle=True,
        drop_last=True,
        **common_params,
    )
    valid_loader = DataLoader(
        valid_ds,
        shuffle=False,
        drop_last=False,
        **common_params,
    )
    return train_loader, valid_loader


def get_dataset(
    transform: transforms.Compose,
    tokenizer: tk.Tokenizer,
    hyper_parameters: config.TrainerConfig,
) -> tuple[DataLoader, DataLoader]:
    dataset: datasets.Dataset = datasets.load_dataset(
        hyper_parameters._data_config.dataset, split="train"
    )  # type: ignore
    train_test_dataset = dataset.train_test_split(seed=42, test_size=0.1)
    train_ds = CaptionDatset(train_test_dataset["train"], config.IMAGE_DOWNLOAD_PATH)
    valid_ds = CaptionDatset(train_test_dataset["test"], config.IMAGE_DOWNLOAD_PATH)
    collate_fn = CollateFn(tokenizer, transform)

    return _get_dataloaders(
        train_ds=train_ds,
        valid_ds=valid_ds,
        training_config=hyper_parameters,
        collate_fn=collate_fn,
    )


if __name__ == "__main__":
    # do not want to do these imports in general
    import os
    from tqdm.auto import tqdm

    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    hyper_parameters = config.TrainerConfig()
    transform = transforms.Compose([transforms.Resize((128, 128)), transforms.ToTensor()])
    tokenizer = tk.Tokenizer(
        hyper_parameters._model_config.text_model, hyper_parameters._model_config.max_len
    )
    train_dl, valid_dl = get_dataset(transform, tokenizer, hyper_parameters)

    batch = next(iter(train_dl))
    print({k: v.shape for k, v in batch.items()})  # torch.Size([1, 3, 128, 128])

    for batch in tqdm(train_dl):
        continue