feat: add crop roi bbox
Browse files- detector/data.py +16 -3
- train.py +7 -0
detector/data.py
CHANGED
@@ -96,11 +96,13 @@ class FontDataset(Dataset):
|
|
96 |
config_path: str = "configs/font.yml",
|
97 |
regression_use_tanh: bool = False,
|
98 |
transforms: bool = False,
|
|
|
99 |
):
|
100 |
self.path = path
|
101 |
self.fonts = load_font_with_exclusion(config_path)
|
102 |
self.regression_use_tanh = regression_use_tanh
|
103 |
self.transforms = transforms
|
|
|
104 |
|
105 |
self.images = [
|
106 |
os.path.join(path, f) for f in os.listdir(path) if f.endswith(".jpg")
|
@@ -146,6 +148,12 @@ class FontDataset(Dataset):
|
|
146 |
with open(label_path, "rb") as f:
|
147 |
label: FontLabel = pickle.load(f)
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
# encode label
|
150 |
label = self.fontlabel2tensor(label, label_path)
|
151 |
|
@@ -188,6 +196,7 @@ class FontDataModule(LightningDataModule):
|
|
188 |
train_transforms: bool = False,
|
189 |
val_transforms: bool = False,
|
190 |
test_transforms: bool = False,
|
|
|
191 |
regression_use_tanh: bool = False,
|
192 |
**kwargs,
|
193 |
):
|
@@ -197,13 +206,17 @@ class FontDataModule(LightningDataModule):
|
|
197 |
self.val_shuffle = val_shuffle
|
198 |
self.test_shuffle = test_shuffle
|
199 |
self.train_dataset = FontDataset(
|
200 |
-
train_path,
|
|
|
|
|
|
|
|
|
201 |
)
|
202 |
self.val_dataset = FontDataset(
|
203 |
-
val_path, config_path, regression_use_tanh, val_transforms
|
204 |
)
|
205 |
self.test_dataset = FontDataset(
|
206 |
-
test_path, config_path, regression_use_tanh, test_transforms
|
207 |
)
|
208 |
|
209 |
def get_train_num_iter(self, num_device: int) -> int:
|
|
|
96 |
config_path: str = "configs/font.yml",
|
97 |
regression_use_tanh: bool = False,
|
98 |
transforms: bool = False,
|
99 |
+
crop_roi_bbox: bool = False,
|
100 |
):
|
101 |
self.path = path
|
102 |
self.fonts = load_font_with_exclusion(config_path)
|
103 |
self.regression_use_tanh = regression_use_tanh
|
104 |
self.transforms = transforms
|
105 |
+
self.crop_roi_bbox = crop_roi_bbox
|
106 |
|
107 |
self.images = [
|
108 |
os.path.join(path, f) for f in os.listdir(path) if f.endswith(".jpg")
|
|
|
148 |
with open(label_path, "rb") as f:
|
149 |
label: FontLabel = pickle.load(f)
|
150 |
|
151 |
+
if self.crop_roi_bbox:
|
152 |
+
left, top, width, height = label.bbox
|
153 |
+
image = TF.crop(image, top, left, height, width)
|
154 |
+
label.image_width = width
|
155 |
+
label.image_height = height
|
156 |
+
|
157 |
# encode label
|
158 |
label = self.fontlabel2tensor(label, label_path)
|
159 |
|
|
|
196 |
train_transforms: bool = False,
|
197 |
val_transforms: bool = False,
|
198 |
test_transforms: bool = False,
|
199 |
+
crop_roi_bbox: bool = False,
|
200 |
regression_use_tanh: bool = False,
|
201 |
**kwargs,
|
202 |
):
|
|
|
206 |
self.val_shuffle = val_shuffle
|
207 |
self.test_shuffle = test_shuffle
|
208 |
self.train_dataset = FontDataset(
|
209 |
+
train_path,
|
210 |
+
config_path,
|
211 |
+
regression_use_tanh,
|
212 |
+
train_transforms,
|
213 |
+
crop_roi_bbox,
|
214 |
)
|
215 |
self.val_dataset = FontDataset(
|
216 |
+
val_path, config_path, regression_use_tanh, val_transforms, crop_roi_bbox
|
217 |
)
|
218 |
self.test_dataset = FontDataset(
|
219 |
+
test_path, config_path, regression_use_tanh, test_transforms, crop_roi_bbox
|
220 |
)
|
221 |
|
222 |
def get_train_num_iter(self, num_device: int) -> int:
|
train.py
CHANGED
@@ -48,6 +48,12 @@ parser.add_argument(
|
|
48 |
action="store_true",
|
49 |
help="Use pretrained model for ResNet (default: False)",
|
50 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
args = parser.parse_args()
|
53 |
|
@@ -85,6 +91,7 @@ data_module = FontDataModule(
|
|
85 |
test_shuffle=False,
|
86 |
regression_use_tanh=regression_use_tanh,
|
87 |
train_transforms=augmentation,
|
|
|
88 |
)
|
89 |
|
90 |
num_iters = data_module.get_train_num_iter(num_device) * num_epochs
|
|
|
48 |
action="store_true",
|
49 |
help="Use pretrained model for ResNet (default: False)",
|
50 |
)
|
51 |
+
parser.add_argument(
|
52 |
+
"-i",
|
53 |
+
"--crop-roi-bbox",
|
54 |
+
action="store_true",
|
55 |
+
help="Crop ROI bounding box (default: False)",
|
56 |
+
)
|
57 |
|
58 |
args = parser.parse_args()
|
59 |
|
|
|
91 |
test_shuffle=False,
|
92 |
regression_use_tanh=regression_use_tanh,
|
93 |
train_transforms=augmentation,
|
94 |
+
crop_roi_bbox=args.crop_roi_bbox,
|
95 |
)
|
96 |
|
97 |
num_iters = data_module.get_train_num_iter(num_device) * num_epochs
|