Spaces:
Sleeping
Sleeping
import argparse | |
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image | |
import constants | |
import utils | |
from ldm.util import instantiate_from_config | |
from omegaconf import OmegaConf | |
from zipfile import ZipFile | |
import os | |
import requests | |
import shutil | |
def download_model(url): | |
os.makedirs("models", exist_ok=True) | |
local_filename = url.split('/')[-1] | |
with requests.get(url, stream=True) as r: | |
with open(os.path.join("models", local_filename), 'wb') as file: | |
shutil.copyfileobj(r.raw, file) | |
with ZipFile("models/gqa_inpaint.zip", 'r') as zObject: | |
zObject.extractall(path="models/") | |
os.remove("models/gqa_inpaint.zip") | |
MODEL = None | |
def inference(image: np.ndarray, instruction: str, center_crop: bool): | |
if not instruction.lower().startswith("remove the"): | |
raise gr.Error("Instruction should start with 'Remove the' !") | |
image = Image.fromarray(image) | |
cropped_image, image = utils.preprocess_image(image, center_crop=center_crop) | |
output_image = MODEL.inpaint(image, instruction, num_steps=int(os.environ["NUM_STEPS"]), device="cuda", return_pil=True, seed=int(os.environ["SEED"])) | |
return cropped_image, output_image | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--config", | |
type=str, | |
default="configs/latent-diffusion/gqa-inpaint-ldm-vq-f8-256x256.yaml", | |
help="Path of the model config file", | |
) | |
parser.add_argument( | |
"--checkpoint", | |
type=str, | |
default="models/gqa_inpaint/ldm/model.ckpt", | |
help="Path of the model checkpoint file", | |
) | |
args = parser.parse_args() | |
print("## Downloading the model file") | |
download_model("https://huggingface.co/abyildirim/inst-inpaint-models/resolve/main/gqa_inpaint.zip") | |
print("## Download is completed") | |
print("## Running the demo") | |
parsed_config = OmegaConf.load(args.config) | |
MODEL = instantiate_from_config(parsed_config["model"]) | |
model_state_dict = torch.load(args.checkpoint, map_location="cpu")["state_dict"] | |
MODEL.load_state_dict(model_state_dict) | |
MODEL.eval() | |
MODEL.to("cuda") | |
sample_image, sample_instruction, sample_step = constants.EXAMPLES[3] | |
gr.Interface( | |
fn=inference, | |
inputs=[ | |
gr.Image(type="numpy", value=sample_image, label="Source Image").style( | |
height=256 | |
), | |
gr.Textbox( | |
label="Instruction", | |
lines=1, | |
value=sample_instruction, | |
), | |
gr.Checkbox(value=True, label="Center Crop", interactive=False), | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Cropped Image").style(height=256), | |
gr.Image(type="pil", label="Output Image").style(height=256), | |
], | |
allow_flagging="never", | |
examples=constants.EXAMPLES, | |
cache_examples=True, | |
title=constants.TITLE, | |
description=constants.DESCRIPTION, | |
).launch() | |