File size: 3,984 Bytes
c97a8b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f43f7c7
 
 
c97a8b1
 
 
 
 
 
 
 
 
 
 
 
 
f43f7c7
c97a8b1
 
 
 
 
f43f7c7
c97a8b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f43f7c7
c97a8b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
import cv2
import numpy as np
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from mediapipe.python._framework_bindings import image as image_module
_Image = image_module.Image
from mediapipe.python._framework_bindings import image_frame
_ImageFormat = image_frame.ImageFormat

import torch
from diffusers import StableDiffusionPipeline, StableDiffusionControlNetInpaintPipeline, ControlNetModel
from PIL import Image
from compel import Compel

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Constants for colors
BG_COLOR = (0, 0, 0, 255)  # gray with full opacity
MASK_COLOR = (255, 255, 255, 255)  # white with full opacity

# Create the options that will be used for ImageSegmenter
base_options = python.BaseOptions(model_asset_path='emirhan.tflite')
options = vision.ImageSegmenterOptions(base_options=base_options,
                                       output_category_mask=True)

# Initialize ControlNet inpainting pipeline
controlnet = ControlNetModel.from_pretrained(
    'lllyasviel/control_v11p_sd15_inpaint',
    torch_dtype=torch.float16,
).to(device)

pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
    'runwayml/stable-diffusion-v1-5',
    controlnet=controlnet,
    torch_dtype=torch.float16,
).to(device)

# Function to segment hair and generate mask
def segment_hair(image):
    rgba_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGBA)
    rgba_image[:, :, 3] = 0  # Set alpha channel to empty

    # Create MP Image object from numpy array
    mp_image = _Image(image_format=_ImageFormat.SRGBA, data=rgba_image)

    # Create the image segmenter
    with vision.ImageSegmenter.create_from_options(options) as segmenter:
        # Retrieve the masks for the segmented image
        segmentation_result = segmenter.segment(mp_image)
        category_mask = segmentation_result.category_mask

        # Generate solid color images for showing the output segmentation mask.
        image_data = mp_image.numpy_view()
        fg_image = np.zeros(image_data.shape, dtype=np.uint8)
        fg_image[:] = MASK_COLOR
        bg_image = np.zeros(image_data.shape, dtype=np.uint8)
        bg_image[:] = BG_COLOR

        condition = np.stack((category_mask.numpy_view(),) * 4, axis=-1) > 0.2
        output_image = np.where(condition, fg_image, bg_image)

        return output_image  # Return the RGBA mask

# Function to inpaint the hair area using ControlNet
def inpaint_hair(image, prompt):
    # Segment hair to get the mask
    mask = segment_hair(image)
    
    # Convert to PIL image for the inpainting pipeline
    image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    mask_pil = Image.fromarray(cv2.cvtColor(mask, cv2.COLOR_RGBA2RGB))

    # Prepare the inpainting condition
    image_np = np.array(image_pil).astype(np.float32) / 255.0
    mask_np = np.array(mask_pil.convert("L")).astype(np.float32) / 255.0
    image_np[mask_np > 0.5] = -1.0  # Set as masked pixel
    inpaint_condition = torch.from_numpy(np.expand_dims(image_np, 0).transpose(0, 3, 1, 2)).to(device)

    # Generate inpainted image
    generator = torch.Generator("cuda").manual_seed(42)
    output = pipe(
        prompt=prompt,
        image=image_pil,
        mask_image=mask_pil,
        control_image=inpaint_condition,
        num_inference_steps=50,
        guidance_scale=7.5,
        generator=generator
    ).images[0]

    return np.array(output)

# Gradio interface
iface = gr.Interface(
    fn=inpaint_hair,
    inputs=[
        gr.Image(type="numpy"),
        gr.Textbox(label="Prompt", placeholder="Describe the desired inpainting result...")
    ],
    outputs=gr.Image(type="numpy"),
    title="Hair Inpainting with ControlNet",
    description="Upload an image, and provide a prompt to inpaint the hair area using ControlNet.",
    examples=[["example.jpeg", "dreadlocks"]]
)

if __name__ == "__main__":
    iface.launch()