import gradio as gr
import numpy as np
import os
import json
import subprocess
from PIL import Image
from functools import partial
from datetime import datetime
from sam_inference import get_sam_predictor, sam_seg
from utils import blend_seg, blend_seg_pure
import cv2
import uuid
import torch
import trimesh
from huggingface_hub import snapshot_download
from gradio_model3dcolor import Model3DColor
from gradio_model3dnormal import Model3DNormal
code_dir = snapshot_download("sudo-ai/MeshFormer-API", token=os.environ['HF_TOKEN'])
with open(f'{code_dir}/api.json', 'r') as file:
api_dict = json.load(file)
SEG_CMD = api_dict["SEG_CMD"]
MESH_CMD = api_dict["MESH_CMD"]
STYLE = """
"""
# info (info-circle-fill), cursor (hand-index-thumb), wait (hourglass-split), done (check-circle)
ICONS = {
"info": """""",
"cursor": """""",
"wait": """""",
"done": """""",
}
icons2alert = {
"info": "primary", # blue
"cursor": "info", # light blue
"wait": "secondary", # gray
"done": "success", # green
}
def message(text, icon_type="info"):
return f"""{STYLE}
{ICONS[icon_type]}
{text}
"""
def preprocess(tmp_dir, input_img, idx=None):
if idx is not None:
print("image idx:", int(idx))
input_img = Image.open(input_img[int(idx)]["name"])
input_img.save(f"{tmp_dir}/input.png")
# print(SEG_CMD.format(tmp_dir=tmp_dir))
os.system(SEG_CMD.format(tmp_dir=tmp_dir))
processed_img = Image.open(f"{tmp_dir}/seg.png")
return processed_img.resize((320, 320), Image.Resampling.LANCZOS)
def ply_to_glb(ply_path):
result = subprocess.run(
["python", "ply2glb.py", "--", ply_path],
capture_output=True,
text=True,
)
print("Output of blender script:")
print(result.stdout)
glb_path = ply_path.replace(".ply", ".glb")
return glb_path
def mesh_gen(tmp_dir, simplify, num_inference_steps):
# print(MESH_CMD.format(tmp_dir=tmp_dir, num_inference_steps=num_inference_steps))
os.system(MESH_CMD.format(tmp_dir=tmp_dir, num_inference_steps=num_inference_steps))
mesh = trimesh.load_mesh(f"{tmp_dir}/mesh.ply")
vertex_normals = mesh.vertex_normals
colors = (-vertex_normals + 1) / 2.0
colors = (colors * 255).astype(np.uint8) # Convert to 8-bit color
print(colors.shape)
mesh.visual.vertex_colors = colors[..., [2, 1, 0]] # RGB -> BGR
mesh.export(f"{tmp_dir}/mesh_normal.ply", file_type="ply")
color_path = ply_to_glb(f"{tmp_dir}/mesh.ply")
normal_path = ply_to_glb(f"{tmp_dir}/mesh_normal.ply")
return color_path, normal_path
def create_tmp_dir():
tmp_dir = (
"demo_exp/"
+ datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+ "_"
+ str(uuid.uuid4())[:4]
)
os.makedirs(tmp_dir, exist_ok=True)
print("create tmp_exp_dir", tmp_dir)
return tmp_dir
def vis_seg(checkbox):
if checkbox:
print("Show manual seg windows")
return (
[gr.Image(value=None, visible=True)] * 2
+ [gr.Radio(visible=True)]
+ [[], gr.Checkbox(visible=True)]
)
else:
print("Clear manual seg")
return (
[gr.Image(visible=False)] * 2
+ [gr.Radio(visible=False)]
+ [[], gr.Checkbox(visible=False)]
)
def calc_feat(checkbox, predictor, input_image, idx=None):
if checkbox:
if idx is not None:
print("image idx:", int(idx))
input_image = Image.open(input_image[int(idx)]["name"])
input_image.thumbnail([512, 512], Image.Resampling.LANCZOS)
w, h = input_image.size
print("image size:", w, h)
side_len = np.max((w, h))
seg_in = Image.new(input_image.mode, (side_len, side_len), (255, 255, 255))
seg_in.paste(
input_image, (np.max((0, (h - w) // 2)), np.max((0, (w - h) // 2)))
)
print("Calculating image SAM feature...")
predictor.set_image(np.array(seg_in.convert("RGB")))
torch.cuda.empty_cache()
return gr.Image(value=seg_in, visible=True)
else:
print("Quit manual seg")
raise ValueError("Quit manual seg")
def manual_seg(
predictor,
seg_in,
selected_points,
fg_bg_radio,
tmp_dir,
seg_mask_opt,
evt: gr.SelectData,
):
print("Start segmentation")
selected_points.append(
{"coord": evt.index, "add_del": fg_bg_radio == "+ (add mask)"}
)
input_points = np.array([point["coord"] for point in selected_points])
input_labels = np.array([point["add_del"] for point in selected_points])
out_image = sam_seg(
predictor, np.array(seg_in.convert("RGB")), input_points, input_labels
)
# seg_in.save(f"{tmp_dir}/in.png")
# out_image.save(f"{tmp_dir}/out.png")
if seg_mask_opt:
segmentation = blend_seg_pure(
seg_in.convert("RGB"), out_image, input_points, input_labels
)
else:
segmentation = blend_seg(
seg_in.convert("RGB"), out_image, input_points, input_labels
)
# recenter and rescale
image_arr = np.array(out_image)
ret, mask = cv2.threshold(
np.array(out_image.split()[-1]), 0, 255, cv2.THRESH_BINARY
)
x, y, w, h = cv2.boundingRect(mask)
max_size = max(w, h)
ratio = 0.75
side_len = int(max_size / ratio)
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
center = side_len // 2
padded_image[
center - h // 2 : center - h // 2 + h, center - w // 2 : center - w // 2 + w
] = image_arr[y : y + h, x : x + w]
rgba = Image.fromarray(padded_image)
rgba.save(f"{tmp_dir}/seg.png")
torch.cuda.empty_cache()
return segmentation.resize((380, 380), Image.Resampling.LANCZOS), rgba.resize(
(320, 320), Image.Resampling.LANCZOS
)
custom_theme = gr.themes.Soft(primary_hue="blue").set(
button_secondary_background_fill="*neutral_100",
button_secondary_background_fill_hover="*neutral_200",
)
with gr.Blocks(title="MeshFormer Demo", css="style.css", theme=custom_theme) as demo:
with gr.Row():
gr.Markdown(
"# MeshFormer: High-Quality Mesh Generation with 3D-Guided Reconstruction Model"
)
with gr.Row():
gr.Markdown(
"[Project Page](https://meshformer3d.github.io/) | [arXiv](https://arxiv.org/abs/TBD)"
)
with gr.Row():
gr.Markdown(
"""
"""
)
with gr.Row():
guide_text_i2m = gr.HTML(message("Please input an image!"), visible=True)
tmp_dir_img = gr.State("./demo_exp/placeholder")
tmp_dir_txt = gr.State("./demo_exp/placeholder")
tmp_dir_3t3 = gr.State("./demo_exp/placeholder")
example_folder = os.path.join(os.path.dirname(__file__), "demo_examples")
example_fns = os.listdir(example_folder)
example_fns.sort()
img_examples = [
os.path.join(example_folder, x) for x in example_fns
] # if x.endswith('.png') or x.endswith('.')
with gr.Row(variant="panel"):
with gr.Row():
with gr.Column(scale=8):
input_image = gr.Image(
type="pil",
image_mode="RGBA",
height=320,
label="Input Image",
interactive=True,
)
gr.Examples(
examples=img_examples,
inputs=[input_image],
outputs=[input_image],
cache_examples=False,
label="Image Examples (Click one of the images below to start)",
examples_per_page=27,
)
with gr.Accordion("Options", open=False):
img_simplify = gr.Checkbox(
False, label="simplify the generated mesh", visible=False
)
n_steps_img = gr.Slider(
value=28,
minimum=15,
maximum=100,
step=1,
label="number of inference steps",
)
# manual segmentation
checkbox_manual_seg = gr.Checkbox(False, label="manual segmentation")
with gr.Row():
with gr.Column(scale=1):
seg_in = gr.Image(
type="pil",
image_mode="RGBA",
label="Click to segment",
visible=False,
show_download_button=False,
height=380,
)
with gr.Column(scale=1):
seg_out = gr.Image(
type="pil",
image_mode="RGBA",
label="Segmentation",
interactive=False,
visible=False,
show_download_button=False,
height=380,
elem_id="disp_image",
)
fg_bg_radio = gr.Radio(
["+ (add mask)", "- (remove area)"],
value="+ (add mask)",
info="Select foreground (+) or background (-) point",
label="Point label",
visible=False,
interactive=True,
)
seg_mask_opt = gr.Checkbox(
True,
label="show foreground mask in manual segmentation",
visible=False,
)
# run
img_run_btn = gr.Button(
"Generate", variant="primary", interactive=False
)
with gr.Column(scale=6):
processed_image = gr.Image(
type="pil",
label="Processed Image",
interactive=False,
height=320,
image_mode="RGBA",
elem_id="disp_image",
)
# with gr.Row():
# mesh_output = gr.Model3D(label="Generated Mesh", elem_id="model-3d-out")
mesh_output_normal = Model3DColor(
label="Generated Mesh (normal)",
elem_id="mesh-normal-out",
height=400,
)
mesh_output = Model3DColor(
label="Generated Mesh (color)",
elem_id="mesh-out",
height=400,
)
predictor = gr.State(value=get_sam_predictor())
selected_points = gr.State(value=[])
selected_points_t2i = gr.State(value=[])
disable_checkbox = lambda: gr.Checkbox(value=False)
disable_button = lambda: gr.Button(interactive=False)
enable_button = lambda: gr.Button(interactive=True)
update_guide = lambda GUIDE_TEXT, icon_type="info": gr.HTML(
value=message(GUIDE_TEXT, icon_type)
)
update_md = lambda GUIDE_TEXT: gr.Markdown(value=GUIDE_TEXT)
def is_img_clear(input_image):
if not input_image:
raise ValueError("Input image cleared.")
checkbox_manual_seg.change(
vis_seg,
inputs=[checkbox_manual_seg],
outputs=[seg_in, seg_out, fg_bg_radio, selected_points, seg_mask_opt],
queue=False,
).success(
calc_feat,
inputs=[checkbox_manual_seg, predictor, input_image],
outputs=[seg_in],
).success(
fn=create_tmp_dir, outputs=[tmp_dir_img], queue=False
)
seg_in.select(
manual_seg,
[predictor, seg_in, selected_points, fg_bg_radio, tmp_dir_img, seg_mask_opt],
[seg_out, processed_image],
)
input_image.change(disable_button, outputs=img_run_btn, queue=False).success(
disable_checkbox, outputs=checkbox_manual_seg, queue=False
).success(fn=is_img_clear, inputs=input_image, queue=False).success(
fn=create_tmp_dir, outputs=tmp_dir_img, queue=False
).success(
fn=partial(update_guide, "Preprocessing the image!", "wait"),
outputs=[guide_text_i2m],
queue=False,
).success(
fn=preprocess,
inputs=[tmp_dir_img, input_image],
outputs=[processed_image],
queue=True,
).success(
fn=partial(
update_guide,
"Click Generate to generate mesh! If the input image was not segmented accurately, please adjust it using manual segmentation.",
"cursor",
),
outputs=[guide_text_i2m],
queue=False,
).success(
enable_button, outputs=img_run_btn, queue=False
)
img_run_btn.click(
fn=partial(update_guide, "Generating the mesh!", "wait"),
outputs=[guide_text_i2m],
queue=False,
).success(
fn=mesh_gen,
inputs=[tmp_dir_img, img_simplify, n_steps_img],
outputs=[mesh_output, mesh_output_normal],
queue=True,
).success(
fn=partial(
update_guide,
"Successfully generated the mesh. (It might take a few seconds to load the mesh)",
"done",
),
outputs=[guide_text_i2m],
queue=False,
)
demo.queue().launch(
debug=True, share=False, inline=False, show_api=False, server_name="0.0.0.0"
)