File size: 3,349 Bytes
35352c6 16d5d78 35352c6 16d5d78 35352c6 16d5d78 35352c6 16d5d78 35352c6 16d5d78 35352c6 16d5d78 35352c6 16d5d78 35352c6 16d5d78 35352c6 16d5d78 35352c6 16d5d78 35352c6 16d5d78 35352c6 16d5d78 a8c8fe0 16d5d78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import multiprocessing as mp
import pathlib
from typing import Any
import datasets
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from src import config
from src import tokenizer as tk
class CaptionDatset(Dataset):
def __init__(self, dataset: datasets.Dataset, img_path: pathlib.Path) -> None:
self.dataset = dataset
self.img_path = img_path
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, idx: int) -> dict[str, Any]:
item = self.dataset[idx]
image = Image.open(self.img_path / item["url"].rsplit("/", 1)[-1]).convert("RGB")
return {"image": image, "caption": item["short_caption"]}
class CollateFn:
def __init__(self, tokenizer: tk.Tokenizer, transform: transforms.Compose):
self.tokenizer = tokenizer
self.transform = transform
def __call__(self, batch: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
stacked_images = torch.stack([self.transform(item["image"]) for item in batch])
tokenized_text = self.tokenizer([item["caption"] for item in batch])
return {
"image": stacked_images,
**tokenized_text,
}
def _get_dataloaders(
train_ds: Dataset,
valid_ds: Dataset,
training_config: config.TrainerConfig,
collate_fn: CollateFn,
) -> tuple[DataLoader, DataLoader]:
common_params = {
"batch_size": training_config.batch_size,
"pin_memory": True,
"num_workers": mp.cpu_count() // 3,
"collate_fn": collate_fn,
}
train_loader = DataLoader(
train_ds,
shuffle=True,
drop_last=True,
**common_params,
)
valid_loader = DataLoader(
valid_ds,
shuffle=False,
drop_last=False,
**common_params,
)
return train_loader, valid_loader
def get_dataset(
transform: transforms.Compose,
tokenizer: tk.Tokenizer,
hyper_parameters: config.TrainerConfig,
) -> tuple[DataLoader, DataLoader]:
dataset: datasets.Dataset = datasets.load_dataset(
hyper_parameters._data_config.dataset, split="train"
) # type: ignore
train_test_dataset = dataset.train_test_split(seed=42, test_size=0.1)
train_ds = CaptionDatset(train_test_dataset["train"], config.IMAGE_DOWNLOAD_PATH)
valid_ds = CaptionDatset(train_test_dataset["test"], config.IMAGE_DOWNLOAD_PATH)
collate_fn = CollateFn(tokenizer, transform)
return _get_dataloaders(
train_ds=train_ds,
valid_ds=valid_ds,
training_config=hyper_parameters,
collate_fn=collate_fn,
)
if __name__ == "__main__":
# do not want to do these imports in general
import os
from tqdm.auto import tqdm
os.environ["TOKENIZERS_PARALLELISM"] = "false"
hyper_parameters = config.TrainerConfig()
transform = transforms.Compose([transforms.Resize((128, 128)), transforms.ToTensor()])
tokenizer = tk.Tokenizer(
hyper_parameters._model_config.text_model, hyper_parameters._model_config.max_len
)
train_dl, valid_dl = get_dataset(transform, tokenizer, hyper_parameters)
batch = next(iter(train_dl))
print({k: v.shape for k, v in batch.items()}) # torch.Size([1, 3, 128, 128])
for batch in tqdm(train_dl):
continue
|