Spaces:
Build error
Build error
dillonlaird
commited on
Commit
·
d9697ef
1
Parent(s):
199c85f
added postprocessing
Browse files- app/sam/postprocess.py +21 -0
app/sam/postprocess.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import numpy.typing as npt
|
5 |
+
|
6 |
+
from torch import Tensor
|
7 |
+
from kornia.morphology import erosion, dilation
|
8 |
+
|
9 |
+
|
10 |
+
def clean_mask_torch(mask: Tensor) -> Tensor:
|
11 |
+
kernel = torch.ones(2, 2).to(mask.device)
|
12 |
+
if len(mask.shape) == 2:
|
13 |
+
mask = mask[None, None, :, :]
|
14 |
+
if mask.dtype == torch.bool:
|
15 |
+
mask = mask.int()
|
16 |
+
return dilation(erosion(mask, kernel), kernel)
|
17 |
+
|
18 |
+
|
19 |
+
def clean_mask_np(mask: npt.NDArray) -> npt.NDArray:
|
20 |
+
kernel = np.ones((2, 2), np.uint8)
|
21 |
+
return cv2.dilate(cv2.erode(mask, kernel, iterations=1), kernel, iterations=1)
|