feat: double dataset
Browse files- detector/data.py +37 -14
- train.py +11 -0
detector/data.py
CHANGED
@@ -11,7 +11,7 @@ import torch
|
|
11 |
import torchvision.transforms as transforms
|
12 |
import torchvision.transforms.functional as TF
|
13 |
from typing import List, Dict, Tuple
|
14 |
-
from torch.utils.data import Dataset, DataLoader
|
15 |
from pytorch_lightning import LightningDataModule
|
16 |
from PIL import Image
|
17 |
|
@@ -262,9 +262,9 @@ class FontDataModule(LightningDataModule):
|
|
262 |
def __init__(
|
263 |
self,
|
264 |
config_path: str = "configs/font.yml",
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
train_shuffle: bool = True,
|
269 |
val_shuffle: bool = False,
|
270 |
test_shuffle: bool = False,
|
@@ -280,18 +280,41 @@ class FontDataModule(LightningDataModule):
|
|
280 |
self.train_shuffle = train_shuffle
|
281 |
self.val_shuffle = val_shuffle
|
282 |
self.test_shuffle = test_shuffle
|
283 |
-
self.train_dataset =
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
|
|
|
|
|
|
|
|
|
|
289 |
)
|
290 |
-
self.val_dataset =
|
291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
)
|
293 |
-
self.test_dataset =
|
294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
)
|
296 |
|
297 |
def get_train_num_iter(self, num_device: int) -> int:
|
|
|
11 |
import torchvision.transforms as transforms
|
12 |
import torchvision.transforms.functional as TF
|
13 |
from typing import List, Dict, Tuple
|
14 |
+
from torch.utils.data import Dataset, DataLoader, ConcatDataset
|
15 |
from pytorch_lightning import LightningDataModule
|
16 |
from PIL import Image
|
17 |
|
|
|
262 |
def __init__(
|
263 |
self,
|
264 |
config_path: str = "configs/font.yml",
|
265 |
+
train_paths: List[str] = ["./dataset/font_img/train"],
|
266 |
+
val_paths: List[str] = ["./dataset/font_img/val"],
|
267 |
+
test_paths: List[str] = ["./dataset/font_img/test"],
|
268 |
train_shuffle: bool = True,
|
269 |
val_shuffle: bool = False,
|
270 |
test_shuffle: bool = False,
|
|
|
280 |
self.train_shuffle = train_shuffle
|
281 |
self.val_shuffle = val_shuffle
|
282 |
self.test_shuffle = test_shuffle
|
283 |
+
self.train_dataset = ConcatDataset(
|
284 |
+
[
|
285 |
+
FontDataset(
|
286 |
+
train_path,
|
287 |
+
config_path,
|
288 |
+
regression_use_tanh,
|
289 |
+
train_transforms,
|
290 |
+
crop_roi_bbox,
|
291 |
+
)
|
292 |
+
for train_path in train_paths
|
293 |
+
]
|
294 |
)
|
295 |
+
self.val_dataset = ConcatDataset(
|
296 |
+
[
|
297 |
+
FontDataset(
|
298 |
+
val_path,
|
299 |
+
config_path,
|
300 |
+
regression_use_tanh,
|
301 |
+
val_transforms,
|
302 |
+
crop_roi_bbox,
|
303 |
+
)
|
304 |
+
for val_path in val_paths
|
305 |
+
]
|
306 |
)
|
307 |
+
self.test_dataset = ConcatDataset(
|
308 |
+
[
|
309 |
+
FontDataset(
|
310 |
+
test_path,
|
311 |
+
config_path,
|
312 |
+
regression_use_tanh,
|
313 |
+
test_transforms,
|
314 |
+
crop_roi_bbox,
|
315 |
+
)
|
316 |
+
for test_path in test_paths
|
317 |
+
]
|
318 |
)
|
319 |
|
320 |
def get_train_num_iter(self, num_device: int) -> int:
|
train.py
CHANGED
@@ -69,6 +69,14 @@ parser.add_argument(
|
|
69 |
default=0.0001,
|
70 |
help="Learning rate (default: 0.0001)",
|
71 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
args = parser.parse_args()
|
74 |
|
@@ -97,6 +105,9 @@ log_every_n_steps = 100
|
|
97 |
num_device = len(devices)
|
98 |
|
99 |
data_module = FontDataModule(
|
|
|
|
|
|
|
100 |
batch_size=single_batch_size,
|
101 |
num_workers=single_device_num_workers,
|
102 |
pin_memory=True,
|
|
|
69 |
default=0.0001,
|
70 |
help="Learning rate (default: 0.0001)",
|
71 |
)
|
72 |
+
parser.add_argument(
|
73 |
+
"-s",
|
74 |
+
"--datasets",
|
75 |
+
nargs="*",
|
76 |
+
type=str,
|
77 |
+
default=["./dataset/font_img"],
|
78 |
+
help="Datasets paths, seperated by space (default: ['./dataset/font_img'])",
|
79 |
+
)
|
80 |
|
81 |
args = parser.parse_args()
|
82 |
|
|
|
105 |
num_device = len(devices)
|
106 |
|
107 |
data_module = FontDataModule(
|
108 |
+
train_paths=[os.path.join(path, "train") for path in args.datasets],
|
109 |
+
val_paths=[os.path.join(path, "val") for path in args.datasets],
|
110 |
+
test_paths=[os.path.join(path, "test") for path in args.datasets],
|
111 |
batch_size=single_batch_size,
|
112 |
num_workers=single_device_num_workers,
|
113 |
pin_memory=True,
|