Spaces:
Running
Running
import os | |
import random | |
from typing import List, Tuple | |
import spaces | |
import gradio as gr | |
import lpips | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torchvision.transforms as transforms | |
from diffusers import StableDiffusionInpaintPipeline | |
from diffusers.utils import load_image | |
from PIL import Image, ImageOps | |
# Constants | |
TARGET_SIZE = (512, 512) | |
DEVICE = torch.device("cuda") | |
LPIPS_MODELS = ['alex', 'vgg', 'squeeze'] | |
MASK_SIZES = {"64x64": 64, "128x128": 128, "256x256": 256} | |
DEFAULT_MASK_SIZE = "256x256" | |
MIN_ITERATIONS = 2 | |
MAX_ITERATIONS = 5 | |
DEFAULT_ITERATIONS = 2 | |
# HTML Content | |
TITLE = """ | |
<h1 style='text-align: center; font-size: 3.2em; margin-bottom: 0.5em; font-family: Arial, sans-serif; margin: 20px;'> | |
How Stable is Stable Diffusion under Recursive InPainting (RIP)?🧟 | |
</h1> | |
""" | |
AUTHORS = """ | |
<body> | |
<div align="center"; style="font-size: 1.4em; margin-bottom: 0.5em;"> | |
Javier Conde<sup>1</sup> | |
Miguel González<sup>1</sup> | |
Gonzalo Martínez<sup>2</sup> | |
Fernando Moral<sup>3</sup> | |
Elena Merino-Gómez<sup>4</sup> | |
Pedro Reviriego<sup>1</sup> | |
</div> | |
<div align="center"; style="font-size: 1.3em; font-style: italic;"> | |
<sup>1</sup>Universidad Politécnica de Madrid, <sup>2</sup>Universidad Carlos III de Madrid, <sup>3</sup>Universidad Antonio de Nebrija, <sup>4</sup>Universidad de Valladolid | |
</div> | |
</body> | |
""" | |
BUTTONS = """ | |
<head> | |
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css"> | |
<style> | |
.button-container { | |
display: flex; | |
justify-content: center; | |
gap: 10px; | |
margin-top: 10px; | |
} | |
.button-container a { | |
display: inline-flex; | |
align-items: center; | |
padding: 10px 20px; | |
border-radius: 30px; | |
border: 1px solid #ccc; | |
text-decoration: none; | |
color: #333 !important; | |
font-size: 16px; | |
text-decoration: none !important; | |
} | |
.button-container a i { | |
margin-right: 8px; | |
} | |
</style> | |
</head> | |
<div class="button-container"> | |
<a href="https://arxiv.org/abs/2407.09549" class="btn btn-outline-primary"> | |
<i class="fa-solid fa-file-pdf"></i> Paper | |
</a> | |
<a href="https://zenodo.org/records/11574941" class="btn btn-outline-secondary"> | |
<i class="fa-regular fa-folder-open"></i> Zenodo | |
</a> | |
</div> | |
""" | |
DESCRIPTION = """ | |
# 🌟 Official Demo: GenAI Evaluation KDD2024 🌟 | |
Welcome to our official demo for our [research paper](https://arxiv.org/abs/2407.09549) presented at the KDD conference workshop on [Evaluation and Trustworthiness of Generative AI Models](https://genai-evaluation-kdd2024.github.io/genai-evalution-kdd2024/). | |
## 🚀 How to Use | |
1. 📤 Upload an image or choose from our examples from the [WikiArt dataset](https://huggingface.co/datasets/huggan/wikiart) used in our paper. | |
2. 🎭 Select the mask size for your image. | |
3. 🔄 Choose the number of iterations (more iterations = longer processing time). | |
4. 🖱️ Click "Submit" and wait for the results! | |
## 📊 Results | |
You'll see the resulting images in the gallery on the right, along with the [LPIPS (Learned Perceptual Image Patch Similarity)](https://github.com/richzhang/PerceptualSimilarity) metric results for each image. | |
""" | |
ARTICLE = """ | |
## **🎨✨To cite our work** | |
```bibtex | |
@misc{conde2024stablestablediffusionrecursive, | |
title={How Stable is Stable Diffusion under Recursive InPainting (RIP)?}, | |
author={Javier Conde and Miguel González and Gonzalo Martínez and Fernando Moral and Elena Merino-Gómez and Pedro Reviriego}, | |
year={2024}, | |
eprint={2407.09549}, | |
archivePrefix={arXiv}, | |
primaryClass={cs.CV}, | |
url={https://arxiv.org/abs/2407.09549}, | |
} | |
``` | |
""" | |
CUSTOM_CSS = """ | |
#centered { | |
display: flex; | |
justify-content: center; | |
width: 60%; | |
margin: 0 auto; | |
} | |
""" | |
def lpips_distance(img1: Image.Image, img2: Image.Image) -> Tuple[float, float, float]: | |
def preprocess(img: Image.Image) -> torch.Tensor: | |
if isinstance(img, torch.Tensor): | |
return img.float() if img.dim() == 3 else img.unsqueeze(0).float() | |
return transforms.ToTensor()(img).unsqueeze(0) | |
tensor_img1, tensor_img2 = map(preprocess, (img1, img2)) | |
resize = transforms.Resize(TARGET_SIZE) | |
tensor_img1, tensor_img2 = map(lambda x: resize(x).to(DEVICE), (tensor_img1, tensor_img2)) | |
loss_fns = {model: lpips.LPIPS(net=model, verbose=False).to(DEVICE) for model in LPIPS_MODELS} | |
with torch.no_grad(): | |
distances = [loss_fns[model](tensor_img1, tensor_img2).item() for model in LPIPS_MODELS] | |
return tuple(distances) | |
def create_square_mask(image: Image.Image, square_size: int = 256) -> Image.Image: | |
img_array = np.array(image) | |
height, width = img_array.shape[:2] | |
mask = np.zeros((height, width), dtype=np.uint8) | |
max_y, max_x = max(0, height - square_size), max(0, width - square_size) | |
start_y, start_x = random.randint(0, max_y), random.randint(0, max_x) | |
end_y, end_x = min(start_y + square_size, height), min(start_x + square_size, width) | |
mask[start_y:end_y, start_x:end_x] = 255 | |
return Image.fromarray(mask) | |
def adjust_size(image: Image.Image) -> Tuple[Image.Image, Image.Image, Image.Image]: | |
mask_image = Image.new("RGB", image.size, (255, 255, 255)) | |
nmask_image = Image.new("RGB", image.size, (0, 0, 0)) | |
new_image = ImageOps.pad(image, TARGET_SIZE, Image.LANCZOS, (255, 255, 255), (0.5, 0.5)) | |
mask_image = ImageOps.pad(mask_image, TARGET_SIZE, Image.LANCZOS, (100, 100, 100), (0.5, 0.5)) | |
nmask_image = ImageOps.pad(nmask_image, TARGET_SIZE, Image.LANCZOS, (100, 100, 100), (0.5, 0.5)) | |
return new_image, mask_image, nmask_image | |
def execute_experiment(image: Image.Image, iterations: int, mask_size: str) -> Tuple[List[Image.Image], pd.DataFrame]: | |
mask_size = MASK_SIZES[mask_size] | |
image = adjust_size(load_image(image))[0] | |
results = [image] | |
lpips_distance_dict = {model: [] for model in LPIPS_MODELS} | |
lpips_distance_dict['iteration'] = [] | |
for iteration in range(iterations): | |
results.append(inpaint_image("", results[-1], create_square_mask(results[-1], square_size=mask_size))) | |
distances = lpips_distance(results[0], results[-1]) | |
for model, distance in zip(LPIPS_MODELS, distances): | |
lpips_distance_dict[model].append(distance) | |
lpips_distance_dict["iteration"].append(iteration + 1) | |
lpips_df = pd.DataFrame(lpips_distance_dict) | |
lpips_df = lpips_df.melt(id_vars="iteration", var_name="model", value_name="lpips") | |
lpips_df["iteration"] = lpips_df["iteration"].astype(str) | |
return results, lpips_df | |
def inpaint_image(prompt: str, image: Image.Image, mask_image: Image.Image) -> Image.Image: | |
pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-2-inpainting", | |
torch_dtype=torch.float16, | |
).to(DEVICE) | |
return pipe(prompt=prompt, image=image, mask_image=mask_image).images[0] | |
def create_gradio_interface(): | |
with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Default(primary_hue="red", secondary_hue="blue")) as demo: | |
gr.Markdown(TITLE) | |
gr.Markdown(AUTHORS) | |
gr.HTML(BUTTONS) | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(): | |
files = gr.Image( | |
elem_id="image_upload", | |
type="pil", | |
height=500, | |
sources=["upload", "clipboard"], | |
label="Upload" | |
) | |
iterations = gr.Slider(MIN_ITERATIONS, MAX_ITERATIONS, value=DEFAULT_ITERATIONS, label="Iterations", step=1) | |
mask_size = gr.Radio(list(MASK_SIZES.keys()), value=DEFAULT_MASK_SIZE, label="Mask Size") | |
submit = gr.Button("Submit") | |
with gr.Column(): | |
gallery = gr.Gallery(label="Generated Images") | |
lineplot = gr.LinePlot( | |
label="LPIPS Distance", | |
x="iteration", | |
y="lpips", | |
color="model", | |
overlay_point=True, | |
width=500, | |
height=500, | |
) | |
submit.click( | |
fn=execute_experiment, | |
inputs=[files, iterations, mask_size], | |
outputs=[gallery, lineplot] | |
) | |
gr.Examples( | |
examples=[ | |
["./examples/example_1.jpg"], | |
["./examples/example_2.jpg"], | |
["./examples/example_3.jpeg"], | |
["./examples/example_4.jpg"], | |
["./examples/example_5.jpg"], | |
["./examples/example_6.jpg"], | |
["./examples/example_7.jpg"], | |
["./examples/example_8.jpg"], | |
], | |
inputs=[files], | |
cache_examples=False, | |
) | |
gr.Markdown(ARTICLE) | |
return demo | |
if __name__ == "__main__": | |
demo = create_gradio_interface() | |
demo.launch() |