sachin commited on
Commit
16d5d78
·
1 Parent(s): 8783046

Got data batch running locally

Browse files
Files changed (4) hide show
  1. clip_config.json +7 -1
  2. requirements.txt +8 -0
  3. src/data.py +51 -52
  4. src/tokenizer.py +21 -0
clip_config.json CHANGED
@@ -1 +1,7 @@
1
- {"cls_token": true, "n_projection_layers": 3, "embed_dims": 512, "vision_model": "edgenext_small", "text_model": "microsoft/xtremedistil-l6-h256-uncased"}
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": true,
3
+ "n_projection_layers": 3,
4
+ "embed_dims": 512,
5
+ "vision_model": "edgenext_small",
6
+ "text_model": "microsoft/xtremedistil-l6-h256-uncased"
7
+ }
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ datasets==2.18.0
2
+ Pillow==10.2.0
3
+ pydantic==2.6.4
4
+ Requests==2.31.0
5
+ timm==0.9.16
6
+ torch==2.2.2
7
+ torchvision==0.17.2
8
+ transformers==4.39.2
src/data.py CHANGED
@@ -1,59 +1,44 @@
1
- import io
2
  import multiprocessing as mp
3
- from typing import Optional, Union
 
4
 
5
  import datasets
6
  from PIL import Image
7
- import requests
8
  import torch
9
  from torch.utils.data import Dataset, DataLoader
10
  from torchvision import transforms
11
- from transformers import AutoTokenizer
12
 
13
  from src import config
 
14
 
15
 
16
- class Tokenizer:
17
- def __init__(self, model_name: str, max_len: int) -> None:
18
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
19
- self.max_len = max_len
20
 
21
- def __call__(self, x: Union[str, list[str]]) -> dict[str, torch.LongTensor]:
22
- return self.tokenizer(
23
- x, max_length=self.max_len, truncation=True, padding=True, return_tensors="pt"
24
- )
25
 
26
- def decode(self, x: dict[str, torch.LongTensor]) -> list[str]:
27
- return [
28
- self.tokenizer.decode(sentence[:sentence_len])
29
- for sentence, sentence_len in zip(x["input_ids"], x["attention_mask"].sum(axis=-1))
30
- ]
31
-
32
-
33
- def _get_image_and_caption(item: dict[str, str]) -> Optional[tuple[Image.Image, str]]:
34
- image_url = item["url"]
35
- caption = item["caption"]
36
- try:
37
- response = requests.get(image_url, timeout=1)
38
- response.raise_for_status() # Raise HTTPError for bad responses (4xx and 5xx)
39
- image = Image.open(io.BytesIO(response.content))
40
- return image, caption
41
- except (requests.RequestException, IOError):
42
- return None
43
 
44
 
45
  class CollateFn:
46
- def __init__(self, tokenizer: Tokenizer, transform: transforms.Compose):
47
  self.tokenizer = tokenizer
48
  self.transform = transform
49
 
50
- def __call__(
51
- self, batch: list[Optional[tuple[str, torch.FloatTensor]]]
52
- ) -> tuple[dict[str, torch.LongTensor], torch.FloatTensor]:
53
- filtered_batch = [data for data in map(_get_image_and_caption, batch) if data is not None]
54
- x, y = zip(*filtered_batch)
55
- tokenized_text = self.tokenizer(list(x))
56
- return tokenized_text, torch.stack([self.transform(image) for image in y])
 
57
 
58
 
59
  def _get_dataloaders(
@@ -65,7 +50,7 @@ def _get_dataloaders(
65
  common_params = {
66
  "batch_size": training_config.batch_size,
67
  "pin_memory": True,
68
- "num_workers": mp.cpu_count(),
69
  "collate_fn": collate_fn,
70
  }
71
  train_loader = DataLoader(
@@ -85,25 +70,39 @@ def _get_dataloaders(
85
 
86
  def get_dataset(
87
  transform: transforms.Compose,
88
- tokenizer: Tokenizer,
89
  hyper_parameters: config.TrainerConfig,
90
- num_workers: int,
91
  ) -> tuple[DataLoader, DataLoader]:
92
- dataset = datasets.load_dataset(
93
- hyper_parameters.data_config.dataset, split="train", streaming=True
94
- )
95
- full_dataset = dataset.shuffle(
96
- seed=42, buffer_size=hyper_parameters.data_config.buffer_size
97
- ).take(hyper_parameters.data_config.data_len)
98
- train_dataset = full_dataset.take(hyper_parameters.data_config.train_len)
99
- valid_dataset = full_dataset.skip(hyper_parameters.data_config.train_len)
100
-
101
  collate_fn = CollateFn(tokenizer, transform)
102
 
103
  return _get_dataloaders(
104
- train_ds=train_dataset,
105
- valid_ds=valid_dataset,
106
  training_config=hyper_parameters,
107
  collate_fn=collate_fn,
108
- num_workers=num_workers,
109
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import multiprocessing as mp
2
+ import pathlib
3
+ from typing import Any
4
 
5
  import datasets
6
  from PIL import Image
 
7
  import torch
8
  from torch.utils.data import Dataset, DataLoader
9
  from torchvision import transforms
 
10
 
11
  from src import config
12
+ from src import tokenizer as tk
13
 
14
 
15
+ class CaptionDatset(Dataset):
16
+ def __init__(self, dataset: datasets.Dataset, img_path: pathlib.Path) -> None:
17
+ self.dataset = dataset
18
+ self.img_path = img_path
19
 
20
+ def __len__(self) -> int:
21
+ return len(self.dataset)
 
 
22
 
23
+ def __getitem__(self, idx: int) -> dict[str, Any]:
24
+ item = self.dataset[idx]
25
+ image = Image.open(self.img_path / item["url"].rsplit("/", 1)[-1]).convert("RGB")
26
+ return {"image": image, "caption": item["short_caption"]}
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  class CollateFn:
30
+ def __init__(self, tokenizer: tk.Tokenizer, transform: transforms.Compose):
31
  self.tokenizer = tokenizer
32
  self.transform = transform
33
 
34
+ def __call__(self, batch: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
35
+ stacked_images = torch.stack([self.transform(item["image"]) for item in batch])
36
+ tokenized_text = self.tokenizer([item["caption"] for item in batch])
37
+
38
+ return {
39
+ "image": stacked_images,
40
+ **tokenized_text,
41
+ }
42
 
43
 
44
  def _get_dataloaders(
 
50
  common_params = {
51
  "batch_size": training_config.batch_size,
52
  "pin_memory": True,
53
+ "num_workers": mp.cpu_count() // 3,
54
  "collate_fn": collate_fn,
55
  }
56
  train_loader = DataLoader(
 
70
 
71
  def get_dataset(
72
  transform: transforms.Compose,
73
+ tokenizer: tk.Tokenizer,
74
  hyper_parameters: config.TrainerConfig,
 
75
  ) -> tuple[DataLoader, DataLoader]:
76
+ dataset: datasets.Dataset = datasets.load_dataset(
77
+ hyper_parameters._data_config.dataset, split="train"
78
+ ) # type: ignore
79
+ train_test_dataset = dataset.train_test_split(seed=42, test_size=0.1)
80
+ train_ds = CaptionDatset(train_test_dataset["train"], config.IMAGE_DOWNLOAD_PATH)
81
+ valid_ds = CaptionDatset(train_test_dataset["test"], config.IMAGE_DOWNLOAD_PATH)
 
 
 
82
  collate_fn = CollateFn(tokenizer, transform)
83
 
84
  return _get_dataloaders(
85
+ train_ds=train_ds,
86
+ valid_ds=valid_ds,
87
  training_config=hyper_parameters,
88
  collate_fn=collate_fn,
 
89
  )
90
+
91
+
92
+ if __name__ == "__main__":
93
+ # do not want to do these imports in general
94
+ import os
95
+ from tqdm.auto import tqdm
96
+
97
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
98
+ hyper_parameters = config.TrainerConfig()
99
+ transform = transforms.Compose([transforms.Resize((128, 128)), transforms.ToTensor()])
100
+ tokenizer = tk.Tokenizer(
101
+ hyper_parameters._model_config.text_model, hyper_parameters._model_config.max_len
102
+ )
103
+ train_dl, valid_dl = get_dataset(transform, tokenizer, hyper_parameters)
104
+
105
+ for batch in tqdm(train_dl):
106
+ continue
107
+
108
+ print("hellow")
src/tokenizer.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import torch
4
+ from transformers import AutoTokenizer
5
+
6
+
7
+ class Tokenizer:
8
+ def __init__(self, model_name: str, max_len: int) -> None:
9
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ self.max_len = max_len
11
+
12
+ def __call__(self, x: Union[str, list[str]]) -> dict[str, torch.LongTensor]:
13
+ return self.tokenizer(
14
+ x, max_length=self.max_len, truncation=True, padding=True, return_tensors="pt"
15
+ ) # type: ignore
16
+
17
+ def decode(self, x: dict[str, torch.LongTensor]) -> list[str]:
18
+ return [
19
+ self.tokenizer.decode(sentence[:sentence_len])
20
+ for sentence, sentence_len in zip(x["input_ids"], x["attention_mask"].sum(axis=-1))
21
+ ]