File size: 2,237 Bytes
583456e
 
 
 
 
 
b92a792
583456e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e0710e
 
583456e
 
 
 
 
 
b92a792
583456e
 
b92a792
583456e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Meta Platforms, Inc. All Rights Reserved

from typing import Tuple
import numpy as np
import torch
from .clip import load as clip_load
from detectron2.utils.comm import get_local_rank, synchronize


def expand_box(
    x1: float,
    y1: float,
    x2: float,
    y2: float,
    expand_ratio: float = 1.0,
    max_h: int = None,
    max_w: int = None,
):
    cx = 0.5 * (x1 + x2)
    cy = 0.5 * (y1 + y2)
    w = x2 - x1
    h = y2 - y1
    w = w * expand_ratio
    h = h * expand_ratio
    box = [cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h]
    if max_h is not None:
        box[1] = max(0, box[1])
        box[3] = min(max_h - 1, box[3])
    if max_w is not None:
        box[0] = max(0, box[0])
        box[2] = min(max_w - 1, box[2])
    return [int(b) for b in box]


def mask2box(mask: torch.Tensor):
    # use naive way
    row = torch.nonzero(mask.sum(dim=0))[:, 0]
    if len(row) == 0:
        return None
    x1 = row.min()
    x2 = row.max()
    col = np.nonzero(mask.sum(dim=1))[:, 0]
    y1 = col.min()
    y2 = col.max()
    return x1, y1, x2 + 1, y2 + 1


def crop_with_mask(
    image: torch.Tensor,
    mask: torch.Tensor,
    bbox: torch.Tensor,
    fill: Tuple[float, float, float] = (0, 0, 0),
    expand_ratio: float = 1.0,
):
    l, t, r, b = expand_box(*bbox, expand_ratio)
    _, h, w = image.shape
    l = max(l, 0)
    t = max(t, 0)
    r = min(r, w)
    b = min(b, h)
    new_image = torch.cat(
        [image.new_full((1, b - t, r - l), fill_value=val) for val in fill]
    )
    mask_bool = mask.bool()
    return image[:, t:b, l:r] * mask[None, t:b, l:r] + (~ mask_bool[None, t:b, l:r]) * new_image, mask[None, t:b, l:r]


def build_clip_model(model: str, mask_prompt_depth: int = 0, frozen: bool = True):
    rank = get_local_rank()
    if rank == 0:
        # download on rank 0 only
        model, _ = clip_load(model, mask_prompt_depth=mask_prompt_depth, device="cpu")
    synchronize()
    if rank != 0:
        model, _ = clip_load(model, mask_prompt_depth=mask_prompt_depth, device="cpu")
    synchronize()
    if frozen:
        for param in model.parameters():
            param.requires_grad = False
    return model