gyrojeff commited on
Commit
855e240
·
1 Parent(s): 416c7bb

feat: add crop roi bbox

Browse files
Files changed (2) hide show
  1. detector/data.py +16 -3
  2. train.py +7 -0
detector/data.py CHANGED
@@ -96,11 +96,13 @@ class FontDataset(Dataset):
96
  config_path: str = "configs/font.yml",
97
  regression_use_tanh: bool = False,
98
  transforms: bool = False,
 
99
  ):
100
  self.path = path
101
  self.fonts = load_font_with_exclusion(config_path)
102
  self.regression_use_tanh = regression_use_tanh
103
  self.transforms = transforms
 
104
 
105
  self.images = [
106
  os.path.join(path, f) for f in os.listdir(path) if f.endswith(".jpg")
@@ -146,6 +148,12 @@ class FontDataset(Dataset):
146
  with open(label_path, "rb") as f:
147
  label: FontLabel = pickle.load(f)
148
 
 
 
 
 
 
 
149
  # encode label
150
  label = self.fontlabel2tensor(label, label_path)
151
 
@@ -188,6 +196,7 @@ class FontDataModule(LightningDataModule):
188
  train_transforms: bool = False,
189
  val_transforms: bool = False,
190
  test_transforms: bool = False,
 
191
  regression_use_tanh: bool = False,
192
  **kwargs,
193
  ):
@@ -197,13 +206,17 @@ class FontDataModule(LightningDataModule):
197
  self.val_shuffle = val_shuffle
198
  self.test_shuffle = test_shuffle
199
  self.train_dataset = FontDataset(
200
- train_path, config_path, regression_use_tanh, train_transforms
 
 
 
 
201
  )
202
  self.val_dataset = FontDataset(
203
- val_path, config_path, regression_use_tanh, val_transforms
204
  )
205
  self.test_dataset = FontDataset(
206
- test_path, config_path, regression_use_tanh, test_transforms
207
  )
208
 
209
  def get_train_num_iter(self, num_device: int) -> int:
 
96
  config_path: str = "configs/font.yml",
97
  regression_use_tanh: bool = False,
98
  transforms: bool = False,
99
+ crop_roi_bbox: bool = False,
100
  ):
101
  self.path = path
102
  self.fonts = load_font_with_exclusion(config_path)
103
  self.regression_use_tanh = regression_use_tanh
104
  self.transforms = transforms
105
+ self.crop_roi_bbox = crop_roi_bbox
106
 
107
  self.images = [
108
  os.path.join(path, f) for f in os.listdir(path) if f.endswith(".jpg")
 
148
  with open(label_path, "rb") as f:
149
  label: FontLabel = pickle.load(f)
150
 
151
+ if self.crop_roi_bbox:
152
+ left, top, width, height = label.bbox
153
+ image = TF.crop(image, top, left, height, width)
154
+ label.image_width = width
155
+ label.image_height = height
156
+
157
  # encode label
158
  label = self.fontlabel2tensor(label, label_path)
159
 
 
196
  train_transforms: bool = False,
197
  val_transforms: bool = False,
198
  test_transforms: bool = False,
199
+ crop_roi_bbox: bool = False,
200
  regression_use_tanh: bool = False,
201
  **kwargs,
202
  ):
 
206
  self.val_shuffle = val_shuffle
207
  self.test_shuffle = test_shuffle
208
  self.train_dataset = FontDataset(
209
+ train_path,
210
+ config_path,
211
+ regression_use_tanh,
212
+ train_transforms,
213
+ crop_roi_bbox,
214
  )
215
  self.val_dataset = FontDataset(
216
+ val_path, config_path, regression_use_tanh, val_transforms, crop_roi_bbox
217
  )
218
  self.test_dataset = FontDataset(
219
+ test_path, config_path, regression_use_tanh, test_transforms, crop_roi_bbox
220
  )
221
 
222
  def get_train_num_iter(self, num_device: int) -> int:
train.py CHANGED
@@ -48,6 +48,12 @@ parser.add_argument(
48
  action="store_true",
49
  help="Use pretrained model for ResNet (default: False)",
50
  )
 
 
 
 
 
 
51
 
52
  args = parser.parse_args()
53
 
@@ -85,6 +91,7 @@ data_module = FontDataModule(
85
  test_shuffle=False,
86
  regression_use_tanh=regression_use_tanh,
87
  train_transforms=augmentation,
 
88
  )
89
 
90
  num_iters = data_module.get_train_num_iter(num_device) * num_epochs
 
48
  action="store_true",
49
  help="Use pretrained model for ResNet (default: False)",
50
  )
51
+ parser.add_argument(
52
+ "-i",
53
+ "--crop-roi-bbox",
54
+ action="store_true",
55
+ help="Crop ROI bounding box (default: False)",
56
+ )
57
 
58
  args = parser.parse_args()
59
 
 
91
  test_shuffle=False,
92
  regression_use_tanh=regression_use_tanh,
93
  train_transforms=augmentation,
94
+ crop_roi_bbox=args.crop_roi_bbox,
95
  )
96
 
97
  num_iters = data_module.get_train_num_iter(num_device) * num_epochs