import numpy as np import torch import torch.nn.functional as F from PIL import Image def preprocess(mask_values, pil_img, scale, is_mask): pil_img=Image.fromarray(pil_img) w, h = pil_img.size newW, newH = int(scale * w), int(scale * h) pil_img = pil_img.resize((newW, newH)) img = np.asarray(pil_img) if is_mask: mask = np.zeros((newH, newW), dtype=np.int64) for i, v in enumerate(mask_values): if img.ndim == 2: mask[img == v] = i else: mask[(img == v).all(-1)] = i return mask else: if img.ndim == 2: img = img[np.newaxis, ...] else: img = img.transpose((2, 0, 1)) if (img > 1).any(): img = img / 255.0 return img def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5): net.eval() img = torch.from_numpy(preprocess(None, full_img, scale_factor, is_mask=False)) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img).cpu() if net.n_classes > 1: mask = output.argmax(dim=1) else: mask = torch.sigmoid(output) > out_threshold return mask[0].long().squeeze().numpy() def mask_to_image(mask: np.ndarray, mask_values): if isinstance(mask_values[0], list): out = np.zeros((mask.shape[-2], mask.shape[-1], len(mask_values[0])), dtype=np.uint8) elif mask_values == [0, 1]: out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=bool) else: out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=np.uint8) if mask.ndim == 3: mask = np.argmax(mask, axis=0) for i, v in enumerate(mask_values): out[mask == i] = v return Image.fromarray(out)