angtrim's picture
Update app.py
0a34a56 verified
raw
history blame
8.26 kB
import os
import sys
import pdb
import random
import numpy as np
from PIL import Image
import base64
from io import BytesIO
import torch
from torchvision import transforms
import torchvision.transforms.functional as TF
import gradio as gr
from src.model import make_1step_sched
from src.pix2pix_turbo import Pix2Pix_Turbo
model = Pix2Pix_Turbo("sketch_to_image_stochastic")
ITEMS_NAMES = [ "πŸ’‘ Lamp","πŸ‘œ Bag","πŸ›‹οΈ Sofa","πŸͺ‘ Chair","🏎️ Car","🏍️ Motorbike","🏠 Building"]
MAX_SEED = np.iinfo(np.int32).max
DEFAULT_ITEM_NAME = "πŸ’‘ Lamp"
def pil_image_to_data_uri(img, format='PNG'):
buffered = BytesIO()
img.save(buffered, format=format)
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/{format.lower()};base64,{img_str}"
def run(image, item_name):
print("sketch updated")
if image is None:
ones = Image.new("L", (512, 512), 255)
temp_uri = pil_image_to_data_uri(ones)
return ones, gr.update(link=temp_uri), gr.update(link=temp_uri)
prompt = item_name + " professional 3d model. octane render, highly detailed, volumetric, dramatic lighting"
image = image.convert("RGB")
image_t = TF.to_tensor(image) > 0.5
image_pil = TF.to_pil_image(image_t.to(torch.float32))
with torch.no_grad():
c_t = image_t.unsqueeze(0).cuda().float()
torch.manual_seed(42)
B,C,H,W = c_t.shape
noise = torch.randn((1,4,H//8, W//8), device=c_t.device)
output_image = model(c_t, prompt, deterministic=False, r=0.4, noise_map=noise)
output_pil = TF.to_pil_image(output_image[0].cpu()*0.5+0.5)
input_sketch_uri = pil_image_to_data_uri(Image.fromarray(255-np.array(image)))
output_image_uri = pil_image_to_data_uri(output_pil)
return output_pil, gr.update(link=input_sketch_uri), gr.update(link=output_image_uri)
def update_canvas(use_line, use_eraser):
if use_eraser:
_color = "#ffffff"
brush_size = 20
if use_line:
_color = "#000000"
brush_size = 4
return gr.update(brush_radius=brush_size, brush_color=_color, interactive=True)
def upload_sketch(file):
_img = Image.open(file.name)
_img = _img.convert("L")
return gr.update(value=_img, source="upload", interactive=True)
scripts = """
async () => {
globalThis.theSketchDownloadFunction = () => {
console.log("test")
var link = document.createElement("a");
dataUri = document.getElementById('download_sketch').href
link.setAttribute("href", dataUri)
link.setAttribute("download", "sketch.png")
document.body.appendChild(link); // Required for Firefox
link.click();
document.body.removeChild(link); // Clean up
// also call the output download function
theOutputDownloadFunction();
return false
}
globalThis.theOutputDownloadFunction = () => {
console.log("test output download function")
var link = document.createElement("a");
dataUri = document.getElementById('download_output').href
link.setAttribute("href", dataUri);
link.setAttribute("download", "output.png");
document.body.appendChild(link); // Required for Firefox
link.click();
document.body.removeChild(link); // Clean up
return false
}
globalThis.DELETE_SKETCH_FUNCTION = () => {
console.log("delete sketch function")
var button_del = document.querySelector('#input_image > div.image-container.svelte-p3y7hu > div.svelte-s6ybro > button:nth-child(2)');
// Create a new 'click' event
var event = new MouseEvent('click', {
'view': window,
'bubbles': true,
'cancelable': true
});
button_del.dispatchEvent(event);
}
globalThis.togglePencil = () => {
el_pencil = document.getElementById('my-toggle-pencil');
el_pencil.classList.toggle('clicked');
// simulate a click on the gradio button
btn_gradio = document.querySelector("#cb-line > label > input");
var event = new MouseEvent('click', {
'view': window,
'bubbles': true,
'cancelable': true
});
btn_gradio.dispatchEvent(event);
if (el_pencil.classList.contains('clicked')) {
document.getElementById('my-toggle-eraser').classList.remove('clicked');
document.getElementById('my-div-pencil').style.backgroundColor = "gray";
document.getElementById('my-div-eraser').style.backgroundColor = "white";
}
else {
document.getElementById('my-toggle-eraser').classList.add('clicked');
document.getElementById('my-div-pencil').style.backgroundColor = "white";
document.getElementById('my-div-eraser').style.backgroundColor = "gray";
}
}
globalThis.toggleEraser = () => {
element = document.getElementById('my-toggle-eraser');
element.classList.toggle('clicked');
// simulate a click on the gradio button
btn_gradio = document.querySelector("#cb-eraser > label > input");
var event = new MouseEvent('click', {
'view': window,
'bubbles': true,
'cancelable': true
});
btn_gradio.dispatchEvent(event);
if (element.classList.contains('clicked')) {
document.getElementById('my-toggle-pencil').classList.remove('clicked');
document.getElementById('my-div-pencil').style.backgroundColor = "white";
document.getElementById('my-div-eraser').style.backgroundColor = "gray";
}
else {
document.getElementById('my-toggle-pencil').classList.add('clicked');
document.getElementById('my-div-pencil').style.backgroundColor = "gray";
document.getElementById('my-div-eraser').style.backgroundColor = "white";
}
}
}
"""
with gr.Blocks(css="style.css") as demo:
# these are hidden buttons that are used to trigger the canvas changes
line = gr.Checkbox(label="line", value=False, elem_id="cb-line")
eraser = gr.Checkbox(label="eraser", value=False, elem_id="cb-eraser")
with gr.Row(elem_id="main_row"):
with gr.Column(elem_id="column_input"):
image = gr.Image(
source="canvas", tool="color-sketch", type="pil", image_mode="L",
invert_colors=True, shape=(512, 512), brush_radius=4, height=440, width=440,
brush_color="#000000", interactive=True, show_download_button=False, elem_id="input_image", show_label=False)
download_sketch = gr.Button("Download sketch", scale=1, elem_id="download_sketch")
gr.HTML("""
<div class="button-row">
<div id="my-div-pencil" class="pad2"> <button id="my-toggle-pencil" onclick="return togglePencil(this)"></button> </div>
<div id="my-div-eraser" class="pad2"> <button id="my-toggle-eraser" onclick="return toggleEraser(this)"></button> </div>
<div class="pad2"> <button id="my-button-clear" onclick="return DELETE_SKETCH_FUNCTION(this)"></button> </div>
</div>
""")
with gr.Row():
item = gr.Dropdown(label="What do you want to design? πŸ§‘β€πŸŽ¨ ", choices=ITEMS_NAMES, value=DEFAULT_ITEM_NAME, scale=1)
with gr.Column(elem_id="column_output"):
result = gr.Image(label="Result", height=440, width=440, elem_id="output_image", show_label=False, show_download_button=True)
download_output = gr.Button("Download output", elem_id="download_output")
eraser.change(fn=lambda x: gr.update(value=not x), inputs=[eraser], outputs=[line]).then(update_canvas, [line, eraser], [image])
line.change(fn=lambda x: gr.update(value=not x), inputs=[line], outputs=[eraser]).then(update_canvas, [line, eraser], [image])
demo.load(None,None,None,_js=scripts)
inputs = [image, item]
outputs = [result, download_sketch, download_output]
item.change(fn=run, inputs=inputs, outputs=outputs)
image.change(fn=run, inputs=inputs, outputs=outputs,)
if __name__ == "__main__":
demo.queue().launch(debug=True)