|
from PIL import Image
|
|
import numpy as np
|
|
import torch
|
|
import torchvision.transforms.functional as TF
|
|
|
|
def tensor2pil(image):
|
|
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
|
|
|
|
def pil2tensor(image):
|
|
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
|
|
|
|
def tensor2mask(t: torch.Tensor) -> torch.Tensor:
|
|
size = t.size()
|
|
if (len(size) < 4):
|
|
return t
|
|
if size[3] == 1:
|
|
return t[:,:,:,0]
|
|
elif size[3] == 4:
|
|
|
|
if torch.min(t[:, :, :, 3]).item() != 1.:
|
|
return t[:,:,:,3]
|
|
|
|
return TF.rgb_to_grayscale(t.permute(0,3,1,2), num_output_channels=1)[:,0,:,:]
|
|
|
|
class image_concat_mask:
|
|
def __init__(self):
|
|
pass
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"image1": ("IMAGE",),
|
|
},
|
|
"optional": {
|
|
"image2": ("IMAGE",),
|
|
"mask": ("MASK",),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE", "MASK",)
|
|
RETURN_NAMES = ("image", "mask")
|
|
FUNCTION = "image_concat_mask"
|
|
CATEGORY = "hhy"
|
|
|
|
def image_concat_mask(self, image1, image2=None, mask=None):
|
|
processed_images = []
|
|
masks = []
|
|
|
|
for idx, img1 in enumerate(image1):
|
|
|
|
pil_image1 = tensor2pil(img1)
|
|
|
|
|
|
width1, height1 = pil_image1.size
|
|
|
|
if image2 is not None and idx < len(image2):
|
|
|
|
pil_image2 = tensor2pil(image2[idx])
|
|
width2, height2 = pil_image2.size
|
|
|
|
|
|
new_width2 = int(width2 * (height1 / height2))
|
|
pil_image2 = pil_image2.resize((new_width2, height1), Image.Resampling.LANCZOS)
|
|
else:
|
|
|
|
pil_image2 = Image.new('RGB', (width1, height1), 'white')
|
|
new_width2 = width1
|
|
|
|
|
|
combined_image = Image.new('RGB', (width1 + new_width2, height1))
|
|
|
|
|
|
combined_image.paste(pil_image1, (0, 0))
|
|
combined_image.paste(pil_image2, (width1, 0))
|
|
|
|
|
|
combined_tensor = pil2tensor(combined_image)
|
|
processed_images.append(combined_tensor)
|
|
|
|
|
|
final_mask = torch.zeros((1, height1, width1 + new_width2))
|
|
final_mask[:, :, width1:] = 1.0
|
|
|
|
|
|
if mask is not None and idx < len(mask):
|
|
input_mask = mask[idx]
|
|
|
|
pil_input_mask = tensor2pil(input_mask)
|
|
pil_input_mask = pil_input_mask.resize((new_width2, height1), Image.Resampling.LANCZOS)
|
|
resized_input_mask = pil2tensor(pil_input_mask)
|
|
|
|
|
|
final_mask[:, :, width1:] *= (1.0 - resized_input_mask)
|
|
|
|
masks.append(final_mask)
|
|
|
|
processed_images = torch.cat(processed_images, dim=0)
|
|
masks = torch.cat(masks, dim=0)
|
|
|
|
return (processed_images, masks)
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"image concat mask": image_concat_mask
|
|
}
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"image concat mask": "Image Concat with Mask"
|
|
} |