tomofi's picture
Add application file
# Copyright (c) OpenMMLab. All rights reserved.
import cv2
import numpy as np
import pyclipper
from mmdet.core import BitmapMasks
from mmdet.datasets.builder import PIPELINES
from shapely.geometry import Polygon
from . import BaseTextDetTargets
class DBNetTargets(BaseTextDetTargets):
"""Generate gt shrunk text, gt threshold map, and their effective region
masks to learn DBNet: Real-time Scene Text Detection with Differentiable
Binarization []. This was partially adapted
shrink_ratio (float): The area shrunk ratio between text
kernels and their text masks.
thr_min (float): The minimum value of the threshold map.
thr_max (float): The maximum value of the threshold map.
min_short_size (int): The minimum size of polygon below which
the polygon is invalid.
def __init__(self,
self.shrink_ratio = shrink_ratio
self.thr_min = thr_min
self.thr_max = thr_max
self.min_short_size = min_short_size
def find_invalid(self, results):
"""Find invalid polygons.
results (dict): The dict containing gt_mask.
ignore_tags (list[bool]): The indicators for ignoring polygons.
texts = results['gt_masks'].masks
ignore_tags = [False] * len(texts)
for idx, text in enumerate(texts):
if self.invalid_polygon(text[0]):
ignore_tags[idx] = True
return ignore_tags
def invalid_polygon(self, poly):
"""Judge the input polygon is invalid or not. It is invalid if its area
smaller than 1 or the shorter side of its minimum bounding box smaller
than min_short_size.
poly (ndarray): The polygon boundary point sequence.
True/False (bool): Whether the polygon is invalid.
area = self.polygon_area(poly)
if abs(area) < 1:
return True
short_size = min(self.polygon_size(poly))
if short_size < self.min_short_size:
return True
return False
def ignore_texts(self, results, ignore_tags):
"""Ignore gt masks and gt_labels while padding gt_masks_ignore in
results given ignore_tags.
results (dict): Result for one image.
ignore_tags (list[int]): Indicate whether to ignore its
corresponding ground truth text.
results (dict): Results after filtering.
flag_len = len(ignore_tags)
assert flag_len == len(results['gt_masks'].masks)
assert flag_len == len(results['gt_labels'])
results['gt_masks_ignore'].masks += [
mask for i, mask in enumerate(results['gt_masks'].masks)
if ignore_tags[i]
results['gt_masks'].masks = [
mask for i, mask in enumerate(results['gt_masks'].masks)
if not ignore_tags[i]
results['gt_labels'] = np.array([
mask for i, mask in enumerate(results['gt_labels'])
if not ignore_tags[i]
new_ignore_tags = [ignore for ignore in ignore_tags if not ignore]
return results, new_ignore_tags
def generate_thr_map(self, img_size, polygons):
"""Generate threshold map.
img_size (tuple(int)): The image size (h,w)
polygons (list(ndarray)): The polygon list.
thr_map (ndarray): The generated threshold map.
thr_mask (ndarray): The effective mask of threshold map.
thr_map = np.zeros(img_size, dtype=np.float32)
thr_mask = np.zeros(img_size, dtype=np.uint8)
for polygon in polygons:
self.draw_border_map(polygon[0], thr_map, mask=thr_mask)
thr_map = thr_map * (self.thr_max - self.thr_min) + self.thr_min
return thr_map, thr_mask
def draw_border_map(self, polygon, canvas, mask):
"""Generate threshold map for one polygon.
polygon(ndarray): The polygon boundary ndarray.
canvas(ndarray): The generated threshold map.
mask(ndarray): The generated threshold mask.
polygon = polygon.reshape(-1, 2)
assert polygon.ndim == 2
assert polygon.shape[1] == 2
polygon_shape = Polygon(polygon)
distance = (
polygon_shape.area * (1 - np.power(self.shrink_ratio, 2)) /
subject = [tuple(p) for p in polygon]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND,
padded_polygon = padding.Execute(distance)
if len(padded_polygon) > 0:
padded_polygon = np.array(padded_polygon[0])
print(f'padding {polygon} with {distance} gets {padded_polygon}')
padded_polygon = polygon.copy().astype(np.int32)
x_min = padded_polygon[:, 0].min()
x_max = padded_polygon[:, 0].max()
y_min = padded_polygon[:, 1].min()
y_max = padded_polygon[:, 1].max()
width = x_max - x_min + 1
height = y_max - y_min + 1
polygon[:, 0] = polygon[:, 0] - x_min
polygon[:, 1] = polygon[:, 1] - y_min
xs = np.broadcast_to(
np.linspace(0, width - 1, num=width).reshape(1, width),
(height, width))
ys = np.broadcast_to(
np.linspace(0, height - 1, num=height).reshape(height, 1),
(height, width))
distance_map = np.zeros((polygon.shape[0], height, width),
for i in range(polygon.shape[0]):
j = (i + 1) % polygon.shape[0]
absolute_distance = self.point2line(xs, ys, polygon[i], polygon[j])
distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
distance_map = distance_map.min(axis=0)
x_min_valid = min(max(0, x_min), canvas.shape[1] - 1)
x_max_valid = min(max(0, x_max), canvas.shape[1] - 1)
y_min_valid = min(max(0, y_min), canvas.shape[0] - 1)
y_max_valid = min(max(0, y_max), canvas.shape[0] - 1)
if x_min_valid - x_min >= width or y_min_valid - y_min >= height:
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
canvas[y_min_valid:y_max_valid + 1,
x_min_valid:x_max_valid + 1] = np.fmax(
1 - distance_map[y_min_valid - y_min:y_max_valid - y_max +
height, x_min_valid - x_min:x_max_valid -
x_max + width],
canvas[y_min_valid:y_max_valid + 1,
x_min_valid:x_max_valid + 1])
def generate_targets(self, results):
"""Generate the gt targets for DBNet.
results (dict): The input result dictionary.
results (dict): The output result dictionary.
assert isinstance(results, dict)
if 'bbox_fields' in results:
ignore_tags = self.find_invalid(results)
results, ignore_tags = self.ignore_texts(results, ignore_tags)
h, w, _ = results['img_shape']
polygons = results['gt_masks'].masks
# generate gt_shrink_kernel
gt_shrink, ignore_tags = self.generate_kernels((h, w),
results, ignore_tags = self.ignore_texts(results, ignore_tags)
# genenrate gt_shrink_mask
polygons_ignore = results['gt_masks_ignore'].masks
gt_shrink_mask = self.generate_effective_mask((h, w), polygons_ignore)
# generate gt_threshold and gt_threshold_mask
polygons = results['gt_masks'].masks
gt_thr, gt_thr_mask = self.generate_thr_map((h, w), polygons)
results['mask_fields'].clear() # rm gt_masks encoded by polygons
results.pop('gt_labels', None)
results.pop('gt_masks', None)
results.pop('gt_bboxes', None)
results.pop('gt_bboxes_ignore', None)
mapping = {
'gt_shrink': gt_shrink,
'gt_shrink_mask': gt_shrink_mask,
'gt_thr': gt_thr,
'gt_thr_mask': gt_thr_mask
for key, value in mapping.items():
value = value if isinstance(value, list) else [value]
results[key] = BitmapMasks(value, h, w)
return results