Spaces:
Build error
Build error
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
def preprocess(mask_values, pil_img, scale, is_mask): | |
pil_img=Image.fromarray(pil_img) | |
w, h = pil_img.size | |
newW, newH = int(scale * w), int(scale * h) | |
pil_img = pil_img.resize((newW, newH)) | |
img = np.asarray(pil_img) | |
if is_mask: | |
mask = np.zeros((newH, newW), dtype=np.int64) | |
for i, v in enumerate(mask_values): | |
if img.ndim == 2: | |
mask[img == v] = i | |
else: | |
mask[(img == v).all(-1)] = i | |
return mask | |
else: | |
if img.ndim == 2: | |
img = img[np.newaxis, ...] | |
else: | |
img = img.transpose((2, 0, 1)) | |
if (img > 1).any(): | |
img = img / 255.0 | |
return img | |
def predict_img(net, | |
full_img, | |
device, | |
scale_factor=1, | |
out_threshold=0.5): | |
net.eval() | |
img = torch.from_numpy(preprocess(None, full_img, scale_factor, is_mask=False)) | |
img = img.unsqueeze(0) | |
img = img.to(device=device, dtype=torch.float32) | |
with torch.no_grad(): | |
output = net(img).cpu() | |
if net.n_classes > 1: | |
mask = output.argmax(dim=1) | |
else: | |
mask = torch.sigmoid(output) > out_threshold | |
return mask[0].long().squeeze().numpy() | |
def mask_to_image(mask: np.ndarray, mask_values): | |
if isinstance(mask_values[0], list): | |
out = np.zeros((mask.shape[-2], mask.shape[-1], len(mask_values[0])), dtype=np.uint8) | |
elif mask_values == [0, 1]: | |
out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=bool) | |
else: | |
out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=np.uint8) | |
if mask.ndim == 3: | |
mask = np.argmax(mask, axis=0) | |
for i, v in enumerate(mask_values): | |
out[mask == i] = v | |
return Image.fromarray(out) | |