Spaces:
Sleeping
Sleeping
abyildirim
commited on
Commit
·
94a0cd2
1
Parent(s):
005f2dd
initial commit
Browse files- .gitignore +1 -0
- app.py +58 -0
- constants.py +25 -0
- examples/birds.png +0 -0
- examples/bus-tree.jpg +0 -0
- examples/cat-car.jpg +0 -0
- examples/clock.png +0 -0
- examples/cups.webp +0 -0
- examples/kite-boy.png +0 -0
- examples/men.png +0 -0
- examples/tree.png +0 -0
- examples/woman-fantasy.jpg +0 -0
- examples/woman.png +0 -0
- requirements.txt +13 -0
- utils.py +78 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
app.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
import constants
|
7 |
+
import utils
|
8 |
+
|
9 |
+
PREDICTOR = None
|
10 |
+
|
11 |
+
|
12 |
+
def inference(image: np.ndarray, text: str, center_crop: bool):
|
13 |
+
num_steps = 10
|
14 |
+
if not text.lower().startswith("remove the"):
|
15 |
+
raise gr.Error("Instruction should start with 'Remove the' !")
|
16 |
+
|
17 |
+
image = Image.fromarray(image)
|
18 |
+
cropped_image, image = utils.preprocess_image(image, center_crop=center_crop)
|
19 |
+
|
20 |
+
utils.seed_everything()
|
21 |
+
prediction = PREDICTOR.predict(image, text, num_steps)
|
22 |
+
|
23 |
+
print("Num steps:", num_steps)
|
24 |
+
|
25 |
+
return cropped_image, prediction
|
26 |
+
|
27 |
+
|
28 |
+
if __name__ == "__main__":
|
29 |
+
utils.setup_environment()
|
30 |
+
|
31 |
+
if not PREDICTOR:
|
32 |
+
PREDICTOR = utils.get_predictor()
|
33 |
+
|
34 |
+
sample_image, sample_instruction, sample_step = constants.EXAMPLES[3]
|
35 |
+
|
36 |
+
gr.Interface(
|
37 |
+
fn=inference,
|
38 |
+
inputs=[
|
39 |
+
gr.Image(type="numpy", value=sample_image, label="Source Image").style(
|
40 |
+
height=256
|
41 |
+
),
|
42 |
+
gr.Textbox(
|
43 |
+
label="Instruction",
|
44 |
+
lines=1,
|
45 |
+
value=sample_instruction,
|
46 |
+
),
|
47 |
+
gr.Checkbox(value=True, label="Center Crop", interactive=False),
|
48 |
+
],
|
49 |
+
outputs=[
|
50 |
+
gr.Image(type="pil", label="Cropped Image").style(height=256),
|
51 |
+
gr.Image(type="pil", label="Output Image").style(height=256),
|
52 |
+
],
|
53 |
+
allow_flagging="never",
|
54 |
+
examples=constants.EXAMPLES,
|
55 |
+
cache_examples=True,
|
56 |
+
title=constants.TITLE,
|
57 |
+
description=constants.DESCRIPTION,
|
58 |
+
).launch()
|
constants.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TITLE = "Inst-Inpaint: Instructing to Remove Objects with Diffusion Models"
|
2 |
+
|
3 |
+
DESCRIPTION = """
|
4 |
+
<p style='text-align: center'>
|
5 |
+
<a href='http://instinpaint.abyildirim.com' target='_blank'>Project Page</a> |
|
6 |
+
<a href='https://arxiv.org/abs/2304.03246' target='_blank'>Paper</a> |
|
7 |
+
<a href='https://github.com/abyildirim/inst-inpaint' target='_blank'>GitHub Repo</a> |
|
8 |
+
</p>
|
9 |
+
<p style='text-align: center'>
|
10 |
+
This demo demonstrates the Inst-Inpaint's abilities for instruction-based image inpainting.
|
11 |
+
</p>
|
12 |
+
"""
|
13 |
+
|
14 |
+
EXAMPLES = [
|
15 |
+
["examples/kite-boy.png", "Remove the colorful kite", True],
|
16 |
+
["examples/cat-car.jpg", "Remove the car", True],
|
17 |
+
["examples/bus-tree.jpg", "Remove the bus", True],
|
18 |
+
["examples/cups.webp", "Remove the cup at the left", True],
|
19 |
+
["examples/woman-fantasy.jpg", "Remove the woman", True],
|
20 |
+
["examples/clock.png", "Remove the round clock at the center", True],
|
21 |
+
["examples/woman.png", "Remove the woman at the left", True],
|
22 |
+
["examples/men.png", "Remove the man at the right", True],
|
23 |
+
["examples/tree.png", "Remove the tree", True],
|
24 |
+
["examples/birds.png", "Remove the bird at the right of the bird", True]
|
25 |
+
]
|
examples/birds.png
ADDED
examples/bus-tree.jpg
ADDED
examples/cat-car.jpg
ADDED
examples/clock.png
ADDED
examples/cups.webp
ADDED
examples/kite-boy.png
ADDED
examples/men.png
ADDED
examples/tree.png
ADDED
examples/woman-fantasy.jpg
ADDED
examples/woman.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-f https://download.pytorch.org/whl/torch_stable.html
|
2 |
+
git+https://github.com/openai/CLIP.git
|
3 |
+
torch==1.13.1+cpu
|
4 |
+
torchvision==0.14.1+cpu
|
5 |
+
pytorch-lightning==1.6.5
|
6 |
+
taming-transformers-rom1504==0.0.6
|
7 |
+
einops==0.6.0
|
8 |
+
kornia==0.6.11
|
9 |
+
transformers==4.27.4
|
10 |
+
dill==0.3.6
|
11 |
+
gradio==3.24.1
|
12 |
+
gdown==4.7.1
|
13 |
+
torchmetrics==0.11.4
|
utils.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import tarfile
|
5 |
+
from typing import Tuple
|
6 |
+
|
7 |
+
import dill
|
8 |
+
import gdown
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from PIL import Image
|
12 |
+
from torchvision.transforms import ToTensor
|
13 |
+
|
14 |
+
logger = logging.getLogger(__file__)
|
15 |
+
|
16 |
+
to_tensor = ToTensor()
|
17 |
+
|
18 |
+
|
19 |
+
def preprocess_image(
|
20 |
+
image: Image, resize_shape: Tuple[int, int] = (256, 256), center_crop=True
|
21 |
+
):
|
22 |
+
processed_image = image
|
23 |
+
|
24 |
+
if center_crop:
|
25 |
+
width, height = image.size
|
26 |
+
crop_size = min(width, height)
|
27 |
+
|
28 |
+
left = (width - crop_size) // 2
|
29 |
+
top = (height - crop_size) // 2
|
30 |
+
right = (width + crop_size) // 2
|
31 |
+
bottom = (height + crop_size) // 2
|
32 |
+
|
33 |
+
processed_image = image.crop((left, top, right, bottom))
|
34 |
+
|
35 |
+
processed_image = processed_image.resize(resize_shape)
|
36 |
+
|
37 |
+
image = to_tensor(processed_image)
|
38 |
+
image = image.unsqueeze(0) * 2 - 1
|
39 |
+
|
40 |
+
return processed_image, image
|
41 |
+
|
42 |
+
|
43 |
+
def download_artifacts(output_path: str):
|
44 |
+
logger.error("Downloading the model artifacts...")
|
45 |
+
if not os.path.exists(output_path):
|
46 |
+
gdown.download(id=os.environ["GDRIVE_ID"], output=output_path, quiet=True)
|
47 |
+
|
48 |
+
|
49 |
+
def extract_artifacts(path: str):
|
50 |
+
logger.error("Extracting the model artifacts...")
|
51 |
+
if not os.path.exists("model.pkl"):
|
52 |
+
with tarfile.open(path) as tar:
|
53 |
+
tar.extractall()
|
54 |
+
|
55 |
+
|
56 |
+
def setup_environment():
|
57 |
+
os.environ["PYTHONPATH"] = os.getcwd()
|
58 |
+
|
59 |
+
artifacts_path = "artifacts.tar.gz"
|
60 |
+
|
61 |
+
download_artifacts(output_path=artifacts_path)
|
62 |
+
|
63 |
+
extract_artifacts(path=artifacts_path)
|
64 |
+
|
65 |
+
|
66 |
+
def get_predictor():
|
67 |
+
logger.error("Loading the predictor...")
|
68 |
+
with open("model.pkl", "rb") as fp:
|
69 |
+
return dill.load(fp)
|
70 |
+
|
71 |
+
|
72 |
+
def seed_everything(seed: int = 0):
|
73 |
+
random.seed(seed)
|
74 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
75 |
+
np.random.seed(seed)
|
76 |
+
torch.manual_seed(seed)
|
77 |
+
torch.cuda.manual_seed(seed)
|
78 |
+
torch.backends.cudnn.deterministic = True
|