qubvel-hf's picture
qubvel-hf HF staff
Add links
415bb30
raw
history blame
13.4 kB
import albumentations as A
import base64
import cv2
import gradio as gr
import inspect
import io
import numpy as np
from copy import deepcopy
from functools import wraps
from PIL import Image, ImageDraw
from typing import get_type_hints
from utils import is_not_supported_transform
HEADER = f"""
<div align="center">
<p>
<img src="https://avatars.githubusercontent.com/u/57894582?s=200&v=4" alt="A" width="50" height="50" style="display:inline;">
<span style="font-size: 30px; vertical-align: bottom;"> lbumentations Demo ({A.__version__})</span>
</p>
<p style="margin-top: -15px;">
<a href="https://albumentations.ai/docs/" target="_blank" style="color: grey;">Documentation</a>
&nbsp;
<a href="https://github.com/albumentations-team/albumentations" target="_blank" style="color: grey;">GitHub Repository</a>
</p>
</div>
"""
DEFAULT_TRANSFORM = "Rotate"
DEFAULT_IMAGE_PATH = "images/doctor.webp"
DEFAULT_IMAGE = np.array(Image.open(DEFAULT_IMAGE_PATH))
DEFAULT_IMAGE_HEIGHT = DEFAULT_IMAGE.shape[0]
DEFAULT_IMAGE_WIDTH = DEFAULT_IMAGE.shape[1]
DEFAULT_BOXES = [
[265, 121, 326, 177], # Mask
[192, 169, 401, 395], # Coverall
]
mask_keypoints = [[270, 123], [320, 130], [270, 151], [321, 158]]
pocket_keypoints = [[226, 379], [272, 386], [307, 388], [364, 380]]
arm_keypoints = [[215, 194], [372, 192], [214, 322], [378, 330]]
DEFAULT_KEYPOINTS = mask_keypoints + pocket_keypoints + arm_keypoints
BASE64_DEFAULT_MASKS = [
{
"label": "Coverall",
# light green color
"color": (144, 238, 144),
"mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAF+ElEQVR4nO3dwXLjNhBFUSg1///LziLj1Iwt26KkFhuvz9kkWVhFAJcAqXGStQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACe5+3sCyCUsq745+wLSKCsz4T1DMr6RFiUENZT2LI+EhYlhPWnt7t3nvt/MtSvsy+gkcfaeFuXJ11HBJPx7r+s7piPP3o0m/9zFP729tdfjv/gnT8dy1G41npeEc7Dd3astR7q6uOP2rT+4wb7mMLBGfkckildyyw8HMa1HWr8pK7pz1hF55YnrdE315dVHZmTb9IcPLVr7Oi/36oOTMpPe97Q+R16FL7wzW3sqThw2DdkdfOs3JTowDmeN+gbN6tbp+XJHxdk1pAPnIE3TczNnzdrmteaNeJjj1Y3zMyRD5w00WuNGu/RR/afpubZn5dlymjvexH8Znbu+cApk73WlLE+8v3C1Rm69wNnTPdaM0ba6hcOJkz4WhPG2SqrtSZM+Vr5o2yX1Vr5k75W+hhbZrVW+rSvlT3Ctlmt7Hlfa0X/anLnrnpf3DPkhtV86Zpf3sNyw+ouvKzYsPqvW/8rfERsWJxLWOeJ3rKERQlhnSh5y0oNK3nNtpAa1h6C8w8NK3jFNpEZlq5OlxnWNnLvgMiwcpdrH5FhbST2HkgMK3axdpIY1lZS7wJhUSIwrM32gM0u91aBYdFBXlihO8Bu4sLSVQ9xYe0n81ZICytzlTaUFtaOIm8GYVEiLKzIm39LYWHRhbAoIawGEg9wYVEiK6zEW39TWWHtKvCGEBYlhEUJYbWQdxYKixLCokRUWHkHyr6iwqIPYVEiKSwnYSNJYdGIsCghLEoIixLCokRQWF4KOwkKa2txd4WwmkgrS1iUEFYXYVuWsCiRE1bYHb+7nLBoRViUEBYlhEUJYVFCWG1kvdYKixIxYWXd7/tLCUtXzaSEFeBy9gU8VUhYNqxuMsLSVTsZYdFORFgZG1bGKN5FhEU/wqJEQlhZZ0iIhLBoSFiUEBYlhEUJYVEiIaysP70NkRAWDQmLEsLqI+qLXmFRQliUEBYlhEWJX2dfwK4ua4U9bj+XsA66/P0P0vqCo/CQy8dv+X3r/wVhHXElI2VdJ6wDiiOKalRYlBDWo6L2mecRVhtZhQrrUb5wuEpYlBBWF1knYUZYYWsSISKsM8vyiHVdRlivYWM8ICQsa95NSFinleUk/EJKWDQjrCbSDvOYsM5ZGCfhV2LCohdhPcKG9SVh3e5TRk/sKu0RKyis1y+N/eobOWG9nK6+ExTWa7esN119y79XeBdV/URYd5DVz4R1wNtlqepGUa+5+6551DKstaIe3ulEWJQQFiWERQlhUSIqrLx3q31FhUUfwqJEVljOwjaywqINYXUQuNOGhRW4QpsKC4su0sKyZTWRFtaWZe14zT+JCytylTYUuQyb/cJf5Brk7Vhrt5Xa62pvFRnWVmu107UekBlW6mptJDQsZZ0tNaxtpN4BqeN6fzW8/PH3LaUuQOq4PuqaVuz8OwopEXvHfNRyywqe/eCh/aVjV9Fz7z8KcpborIR1jvCo1hLWGQZk5a3wBCO6EtbLzehKWNQQFiWERQlhUUJYlBAWJYRFCWG9Wsc/Di8gLEoMCWvINtHIkLB4NWFRQliUENbLzXjeExYlhEUJYVFiRlgzHmtamREWLycsSgjr9UYczMKihLAoISxKCIsSwqKEsCghLEoI6wQTvsgSFiWERQlhUUJYZxjwkCUsSgiLEiPCGnDytDMirH7yU58QVv4qNjQhLE4gLEoIixLCooSwzhH/QjEhrMuQ/31NKxPCktYJBs14s9MnfOZn7FhrLdvWaw0KS1qvNG+qu5yI4TMfPryrmqSVPfWjjsLfnIgvMHWOO+xa0XM/ccdaK3xRO5gaFsWERQlhUWJqWB0e3qNNDauD6LiFRYmhYUVvFi3MDEtX5UaGpat6I8Oi3r8KSpCuwVpGmQAAAABJRU5ErkJggg==",
},
{
"label": "Mask",
# light blue color
"color": (173, 216, 230),
"mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAB4ElEQVR4nO3csQ6CMBSG0avv/864OFhoobW9UeM5i4ML+fOFkmCMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIbcPn0BX257ftppkMEqtuY35uplqYN2VhFhsU5m2rnIKsJmXYy00xFWRBjuin1KvV2F6c7dP30Bv2ugwT8krPcp64SwJmzSahJWYbQUZbUIqzD8QK6sBmEVdLKKsEghLFIIa5LDs05YpBAWKYQ1y1lYJSxSCIsUwiKFsF55XlpGWLP83q9KWKQQ1qt37j6OzyphkUJYpBBWwZP4KsIqKWsRYZFCWDtuWWsIixTC2nPLWkJYB8paQVhHY2XpsMosdV0vaozXZpsG/+s3x0BN9bQM1sdOJ45pmauXpU4VadlqgLEubRF2AgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGCZBwKLGEVAl/J/AAAAAElFTkSuQmCC",
},
]
# Get all the transforms from the albumentations library
transforms_map = {
name: cls
for name, cls in vars(A).items()
if (
inspect.isclass(cls)
and issubclass(cls, (A.DualTransform, A.ImageOnlyTransform))
and not is_not_supported_transform(cls)
)
}
transforms_map.pop("DualTransform", None)
transforms_map.pop("ImageOnlyTransform", None)
transforms_map.pop("ReferenceBasedTransform", None)
transforms_keys = list(sorted(transforms_map.keys()))
# Decode the masks
for mask in BASE64_DEFAULT_MASKS:
mask["mask"] = np.array(
Image.open(io.BytesIO(base64.b64decode(mask["mask"]))).convert("L")
)
def run_with_retry(compose):
@wraps(compose)
def wrapper(*args, **kwargs):
processors = deepcopy(compose.processors)
for _ in range(4):
try:
result = compose(*args, **kwargs)
break
except NotImplementedError as e:
print(f"Caught NotImplementedError: {e}")
if "bbox" in str(e):
kwargs.pop("bboxes", None)
kwargs.pop("category_id", None)
compose.processors.pop("bboxes")
if "keypoint" in str(e):
kwargs.pop("keypoints", None)
compose.processors.pop("keypoints")
if "mask" in str(e):
kwargs.pop("mask", None)
except Exception as e:
compose.processors = processors
raise e
compose.processors = processors
return result
return wrapper
def draw_boxes(image, boxes, color=(255, 0, 0), thickness=2) -> np.ndarray:
"""Draw boxes with PIL."""
pil_image = Image.fromarray(image)
draw = ImageDraw.Draw(pil_image)
for box in boxes:
x_min, y_min, x_max, y_max = box
draw.rectangle([x_min, y_min, x_max, y_max], outline=color, width=thickness)
return np.array(pil_image)
def draw_keypoints(image, keypoints, color=(255, 0, 0), radius=2):
"""Draw keypoints with PIL."""
pil_image = Image.fromarray(image)
draw = ImageDraw.Draw(pil_image)
for keypoint in keypoints:
x, y = keypoint
draw.ellipse([x - radius, y - radius, x + radius, y + radius], fill=color)
return np.array(pil_image)
def get_rgb_mask(masks):
"""Get the RGB mask from the binary mask."""
rgb_mask = np.zeros((DEFAULT_IMAGE_HEIGHT, DEFAULT_IMAGE_WIDTH, 3), dtype=np.uint8)
for data in masks:
mask = data["mask"]
rgb_mask[mask > 0] = np.array(data["color"])
return rgb_mask
def draw_mask(image, mask):
"""Draw the mask on the image."""
image_with_mask = cv2.addWeighted(image, 0.5, mask, 0.5, 0)
return image_with_mask
def draw_not_implemented_image(image: np.ndarray, annotation_type: str):
"""Draw the image with a text. In the middle."""
pil_image = Image.fromarray(image)
draw = ImageDraw.Draw(pil_image)
# align in the centerm, and make bigger font
text = f'Transform NOT working with "{annotation_type.upper()}" annotations.'
length = draw.textlength(text)
draw.text(
(DEFAULT_IMAGE_WIDTH // 2 - length // 2, DEFAULT_IMAGE_HEIGHT // 2),
text,
fill=(255, 0, 0),
align="center",
)
return np.array(pil_image)
def get_formatted_signature(function_or_class, indentation=4):
signature = inspect.signature(function_or_class)
type_hints = get_type_hints(function_or_class)
args = []
for param in signature.parameters.values():
if param.name == "p":
str_param = "p=1.0,"
elif param.default == inspect.Parameter.empty:
str_param = f"{param.name}=,"
else:
if isinstance(param.default, str):
str_param = f'{param.name}="{param.default}",'
else:
str_param = f"{param.name}={param.default},"
annotation = type_hints.get(param.name, param.annotation)
if isinstance(param.annotation, type):
str_param += f" # {param.annotation.__name__}"
else:
str_annotation = str(annotation).replace("typing.", "")
str_param += f" # {str_annotation}"
str_param = "\n" + " " * indentation + str_param
args.append(str_param)
result = "(" + "".join(args) + "\n" + " " * (indentation - 4) + ")"
return result
def get_formatted_transform(transform_number):
transform_name = transforms_keys[transform_number]
transform = transforms_map[transform_name]
return f"A.{transform.__name__}{get_formatted_signature(transform)}"
def get_formatted_transform_docs(transform_number):
transform_name = transforms_keys[transform_number]
transform = transforms_map[transform_name]
return transform.__doc__.strip("\n")
def update_augmented_images(image, code):
augmentation = eval(code)
compose = A.Compose(
[augmentation],
bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_id"]),
keypoint_params=A.KeypointParams(format="xy"),
)
compose = run_with_retry(compose) # to prevent NotImplementedError
keypoints = DEFAULT_KEYPOINTS
bboxes = DEFAULT_BOXES
mask = get_rgb_mask(BASE64_DEFAULT_MASKS)
augmented = compose(
image=image,
not_implemented_image=image.copy(),
mask=mask,
keypoints=keypoints,
bboxes=bboxes,
category_id=range(len(bboxes)),
)
image = augmented["image"]
mask = augmented.get("mask", None)
bboxes = augmented.get("bboxes", None)
keypoints = augmented.get("keypoints", None)
# Draw the augmented images (or replace by placeholder if not implemented)
if mask is not None:
image_with_mask = draw_mask(image.copy(), mask)
else:
image_with_mask = draw_not_implemented_image(image.copy(), "mask")
if bboxes is not None:
image_with_bboxes = draw_boxes(image.copy(), bboxes)
else:
image_with_bboxes = draw_not_implemented_image(image.copy(), "boxes")
if keypoints is not None:
image_with_keypoints = draw_keypoints(image.copy(), keypoints)
else:
image_with_keypoints = draw_not_implemented_image(image.copy(), "keypoints")
return [
(image_with_mask, "Mask"),
(image_with_bboxes, "Boxes"),
(image_with_keypoints, "Keypoints"),
]
def update_image_info(image):
h, w = image.shape[:2]
dtype = image.dtype
max_, min_ = image.max(), image.min()
return f"Image info:\n\t - shape: {h}x{w}\n\t - dtype: {dtype}\n\t - min/max: {min_}/{max_}"
def update_code_and_docs(select):
code = get_formatted_transform(select)
docs = get_formatted_transform_docs(select)
return code, docs
with gr.Blocks() as demo:
gr.Markdown(HEADER)
with gr.Row():
with gr.Column():
with gr.Group():
select = gr.Dropdown(
label="Select a transformation",
choices=transforms_keys,
value=DEFAULT_TRANSFORM,
type="index",
interactive=True,
)
with gr.Accordion("Documentation (click to expand)", open=False):
docs = gr.TextArea(
get_formatted_transform_docs(
transforms_keys.index(DEFAULT_TRANSFORM)
),
show_label=False,
interactive=False,
)
code = gr.Code(
label="Code",
language="python",
value=get_formatted_transform(
transforms_keys.index(DEFAULT_TRANSFORM)
),
interactive=True,
lines=5,
)
info = gr.TextArea(
value=f"Image size: {DEFAULT_IMAGE_HEIGHT} x {DEFAULT_IMAGE_WIDTH} (height x width)",
show_label=False,
lines=1,
max_lines=1,
)
button = gr.Button("Apply!")
image = gr.Image(
value=DEFAULT_IMAGE_PATH,
type="numpy",
height=500,
width=300,
sources=[],
)
with gr.Row():
augmented_image = gr.Gallery(
value=update_augmented_images(DEFAULT_IMAGE, "A.NoOp()"),
rows=1,
columns=3,
show_label=False,
)
select.change(fn=update_code_and_docs, inputs=[select], outputs=[code, docs])
button.click(
fn=update_augmented_images, inputs=[image, code], outputs=[augmented_image]
)
if __name__ == "__main__":
demo.launch()