ai-hairstyler / app.py
emirhanbilgic's picture
Update app.py
f43f7c7 verified
raw
history blame
3.98 kB
import gradio as gr
import cv2
import numpy as np
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from mediapipe.python._framework_bindings import image as image_module
_Image = image_module.Image
from mediapipe.python._framework_bindings import image_frame
_ImageFormat = image_frame.ImageFormat
import torch
from diffusers import StableDiffusionPipeline, StableDiffusionControlNetInpaintPipeline, ControlNetModel
from PIL import Image
from compel import Compel
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Constants for colors
BG_COLOR = (0, 0, 0, 255) # gray with full opacity
MASK_COLOR = (255, 255, 255, 255) # white with full opacity
# Create the options that will be used for ImageSegmenter
base_options = python.BaseOptions(model_asset_path='emirhan.tflite')
options = vision.ImageSegmenterOptions(base_options=base_options,
output_category_mask=True)
# Initialize ControlNet inpainting pipeline
controlnet = ControlNetModel.from_pretrained(
'lllyasviel/control_v11p_sd15_inpaint',
torch_dtype=torch.float16,
).to(device)
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
'runwayml/stable-diffusion-v1-5',
controlnet=controlnet,
torch_dtype=torch.float16,
).to(device)
# Function to segment hair and generate mask
def segment_hair(image):
rgba_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGBA)
rgba_image[:, :, 3] = 0 # Set alpha channel to empty
# Create MP Image object from numpy array
mp_image = _Image(image_format=_ImageFormat.SRGBA, data=rgba_image)
# Create the image segmenter
with vision.ImageSegmenter.create_from_options(options) as segmenter:
# Retrieve the masks for the segmented image
segmentation_result = segmenter.segment(mp_image)
category_mask = segmentation_result.category_mask
# Generate solid color images for showing the output segmentation mask.
image_data = mp_image.numpy_view()
fg_image = np.zeros(image_data.shape, dtype=np.uint8)
fg_image[:] = MASK_COLOR
bg_image = np.zeros(image_data.shape, dtype=np.uint8)
bg_image[:] = BG_COLOR
condition = np.stack((category_mask.numpy_view(),) * 4, axis=-1) > 0.2
output_image = np.where(condition, fg_image, bg_image)
return output_image # Return the RGBA mask
# Function to inpaint the hair area using ControlNet
def inpaint_hair(image, prompt):
# Segment hair to get the mask
mask = segment_hair(image)
# Convert to PIL image for the inpainting pipeline
image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
mask_pil = Image.fromarray(cv2.cvtColor(mask, cv2.COLOR_RGBA2RGB))
# Prepare the inpainting condition
image_np = np.array(image_pil).astype(np.float32) / 255.0
mask_np = np.array(mask_pil.convert("L")).astype(np.float32) / 255.0
image_np[mask_np > 0.5] = -1.0 # Set as masked pixel
inpaint_condition = torch.from_numpy(np.expand_dims(image_np, 0).transpose(0, 3, 1, 2)).to(device)
# Generate inpainted image
generator = torch.Generator("cuda").manual_seed(42)
output = pipe(
prompt=prompt,
image=image_pil,
mask_image=mask_pil,
control_image=inpaint_condition,
num_inference_steps=50,
guidance_scale=7.5,
generator=generator
).images[0]
return np.array(output)
# Gradio interface
iface = gr.Interface(
fn=inpaint_hair,
inputs=[
gr.Image(type="numpy"),
gr.Textbox(label="Prompt", placeholder="Describe the desired inpainting result...")
],
outputs=gr.Image(type="numpy"),
title="Hair Inpainting with ControlNet",
description="Upload an image, and provide a prompt to inpaint the hair area using ControlNet.",
examples=[["example.jpeg", "dreadlocks"]]
)
if __name__ == "__main__":
iface.launch()