gyrojeff commited on
Commit
bc0f7fc
·
1 Parent(s): 0693434

feat: double dataset

Browse files
Files changed (2) hide show
  1. detector/data.py +37 -14
  2. train.py +11 -0
detector/data.py CHANGED
@@ -11,7 +11,7 @@ import torch
11
  import torchvision.transforms as transforms
12
  import torchvision.transforms.functional as TF
13
  from typing import List, Dict, Tuple
14
- from torch.utils.data import Dataset, DataLoader
15
  from pytorch_lightning import LightningDataModule
16
  from PIL import Image
17
 
@@ -262,9 +262,9 @@ class FontDataModule(LightningDataModule):
262
  def __init__(
263
  self,
264
  config_path: str = "configs/font.yml",
265
- train_path: str = "./dataset/font_img/train",
266
- val_path: str = "./dataset/font_img/val",
267
- test_path: str = "./dataset/font_img/test",
268
  train_shuffle: bool = True,
269
  val_shuffle: bool = False,
270
  test_shuffle: bool = False,
@@ -280,18 +280,41 @@ class FontDataModule(LightningDataModule):
280
  self.train_shuffle = train_shuffle
281
  self.val_shuffle = val_shuffle
282
  self.test_shuffle = test_shuffle
283
- self.train_dataset = FontDataset(
284
- train_path,
285
- config_path,
286
- regression_use_tanh,
287
- train_transforms,
288
- crop_roi_bbox,
 
 
 
 
 
289
  )
290
- self.val_dataset = FontDataset(
291
- val_path, config_path, regression_use_tanh, val_transforms, crop_roi_bbox
 
 
 
 
 
 
 
 
 
292
  )
293
- self.test_dataset = FontDataset(
294
- test_path, config_path, regression_use_tanh, test_transforms, crop_roi_bbox
 
 
 
 
 
 
 
 
 
295
  )
296
 
297
  def get_train_num_iter(self, num_device: int) -> int:
 
11
  import torchvision.transforms as transforms
12
  import torchvision.transforms.functional as TF
13
  from typing import List, Dict, Tuple
14
+ from torch.utils.data import Dataset, DataLoader, ConcatDataset
15
  from pytorch_lightning import LightningDataModule
16
  from PIL import Image
17
 
 
262
  def __init__(
263
  self,
264
  config_path: str = "configs/font.yml",
265
+ train_paths: List[str] = ["./dataset/font_img/train"],
266
+ val_paths: List[str] = ["./dataset/font_img/val"],
267
+ test_paths: List[str] = ["./dataset/font_img/test"],
268
  train_shuffle: bool = True,
269
  val_shuffle: bool = False,
270
  test_shuffle: bool = False,
 
280
  self.train_shuffle = train_shuffle
281
  self.val_shuffle = val_shuffle
282
  self.test_shuffle = test_shuffle
283
+ self.train_dataset = ConcatDataset(
284
+ [
285
+ FontDataset(
286
+ train_path,
287
+ config_path,
288
+ regression_use_tanh,
289
+ train_transforms,
290
+ crop_roi_bbox,
291
+ )
292
+ for train_path in train_paths
293
+ ]
294
  )
295
+ self.val_dataset = ConcatDataset(
296
+ [
297
+ FontDataset(
298
+ val_path,
299
+ config_path,
300
+ regression_use_tanh,
301
+ val_transforms,
302
+ crop_roi_bbox,
303
+ )
304
+ for val_path in val_paths
305
+ ]
306
  )
307
+ self.test_dataset = ConcatDataset(
308
+ [
309
+ FontDataset(
310
+ test_path,
311
+ config_path,
312
+ regression_use_tanh,
313
+ test_transforms,
314
+ crop_roi_bbox,
315
+ )
316
+ for test_path in test_paths
317
+ ]
318
  )
319
 
320
  def get_train_num_iter(self, num_device: int) -> int:
train.py CHANGED
@@ -69,6 +69,14 @@ parser.add_argument(
69
  default=0.0001,
70
  help="Learning rate (default: 0.0001)",
71
  )
 
 
 
 
 
 
 
 
72
 
73
  args = parser.parse_args()
74
 
@@ -97,6 +105,9 @@ log_every_n_steps = 100
97
  num_device = len(devices)
98
 
99
  data_module = FontDataModule(
 
 
 
100
  batch_size=single_batch_size,
101
  num_workers=single_device_num_workers,
102
  pin_memory=True,
 
69
  default=0.0001,
70
  help="Learning rate (default: 0.0001)",
71
  )
72
+ parser.add_argument(
73
+ "-s",
74
+ "--datasets",
75
+ nargs="*",
76
+ type=str,
77
+ default=["./dataset/font_img"],
78
+ help="Datasets paths, seperated by space (default: ['./dataset/font_img'])",
79
+ )
80
 
81
  args = parser.parse_args()
82
 
 
105
  num_device = len(devices)
106
 
107
  data_module = FontDataModule(
108
+ train_paths=[os.path.join(path, "train") for path in args.datasets],
109
+ val_paths=[os.path.join(path, "val") for path in args.datasets],
110
+ test_paths=[os.path.join(path, "test") for path in args.datasets],
111
  batch_size=single_batch_size,
112
  num_workers=single_device_num_workers,
113
  pin_memory=True,