import gradio as gr import numpy as np import os import json import subprocess from PIL import Image from functools import partial from datetime import datetime from sam_inference import get_sam_predictor, sam_seg from utils import blend_seg, blend_seg_pure import cv2 import uuid import torch import trimesh from huggingface_hub import snapshot_download from gradio_model3dcolor import Model3DColor # from gradio_model3dnormal import Model3DNormal code_dir = snapshot_download("sudo-ai/MeshFormer-API", token=os.environ['HF_TOKEN']) with open(f'{code_dir}/api.json', 'r') as file: api_dict = json.load(file) SEG_CMD = api_dict["SEG_CMD"] MESH_CMD = api_dict["MESH_CMD"] STYLE = """ """ # info (info-circle-fill), cursor (hand-index-thumb), wait (hourglass-split), done (check-circle) ICONS = { "info": """ """, "cursor": """ """, "wait": """ """, "done": """ """, } icons2alert = { "info": "primary", # blue "cursor": "info", # light blue "wait": "secondary", # gray "done": "success", # green } def message(text, icon_type="info"): return f"""{STYLE} """ def preprocess(tmp_dir, input_img, idx=None): if idx is not None: print("image idx:", int(idx)) input_img = Image.open(input_img[int(idx)]["name"]) input_img.save(f"{tmp_dir}/input.png") # print(SEG_CMD.format(tmp_dir=tmp_dir)) os.system(SEG_CMD.format(tmp_dir=tmp_dir)) processed_img = Image.open(f"{tmp_dir}/seg.png") return processed_img.resize((320, 320), Image.Resampling.LANCZOS) def ply_to_glb(ply_path): result = subprocess.run( ["python", "ply2glb.py", "--", ply_path], capture_output=True, text=True, ) print("Output of blender script:") print(result.stdout) glb_path = ply_path.replace(".ply", ".glb") return glb_path def mesh_gen(tmp_dir, simplify, num_inference_steps): # print(MESH_CMD.format(tmp_dir=tmp_dir, num_inference_steps=num_inference_steps)) os.system(MESH_CMD.format(tmp_dir=tmp_dir, num_inference_steps=num_inference_steps)) mesh = trimesh.load_mesh(f"{tmp_dir}/mesh.ply") vertex_normals = mesh.vertex_normals colors = (-vertex_normals + 1) / 2.0 colors = (colors * 255).astype(np.uint8) # Convert to 8-bit color print(colors.shape) mesh.visual.vertex_colors = colors[..., [0, 2, 1]] # RGB -> RBG mesh.export(f"{tmp_dir}/mesh_normal.ply", file_type="ply") color_path = ply_to_glb(f"{tmp_dir}/mesh.ply") normal_path = ply_to_glb(f"{tmp_dir}/mesh_normal.ply") return color_path, normal_path def create_tmp_dir(): tmp_dir = ( "demo_exp/" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + "_" + str(uuid.uuid4())[:4] ) os.makedirs(tmp_dir, exist_ok=True) print("create tmp_exp_dir", tmp_dir) return tmp_dir def vis_seg(checkbox): if checkbox: print("Show manual seg windows") return ( [gr.Image(value=None, visible=True)] * 2 + [gr.Radio(visible=True)] + [[], gr.Checkbox(visible=True)] ) else: print("Clear manual seg") return ( [gr.Image(visible=False)] * 2 + [gr.Radio(visible=False)] + [[], gr.Checkbox(visible=False)] ) def calc_feat(checkbox, predictor, input_image, idx=None): if checkbox: if idx is not None: print("image idx:", int(idx)) input_image = Image.open(input_image[int(idx)]["name"]) input_image.thumbnail([512, 512], Image.Resampling.LANCZOS) w, h = input_image.size print("image size:", w, h) side_len = np.max((w, h)) seg_in = Image.new(input_image.mode, (side_len, side_len), (255, 255, 255)) seg_in.paste( input_image, (np.max((0, (h - w) // 2)), np.max((0, (w - h) // 2))) ) print("Calculating image SAM feature...") predictor.set_image(np.array(seg_in.convert("RGB"))) torch.cuda.empty_cache() return gr.Image(value=seg_in, visible=True) else: print("Quit manual seg") raise ValueError("Quit manual seg") def manual_seg( predictor, seg_in, selected_points, fg_bg_radio, tmp_dir, seg_mask_opt, evt: gr.SelectData, ): print("Start segmentation") selected_points.append( {"coord": evt.index, "add_del": fg_bg_radio == "+ (add mask)"} ) input_points = np.array([point["coord"] for point in selected_points]) input_labels = np.array([point["add_del"] for point in selected_points]) out_image = sam_seg( predictor, np.array(seg_in.convert("RGB")), input_points, input_labels ) # seg_in.save(f"{tmp_dir}/in.png") # out_image.save(f"{tmp_dir}/out.png") if seg_mask_opt: segmentation = blend_seg_pure( seg_in.convert("RGB"), out_image, input_points, input_labels ) else: segmentation = blend_seg( seg_in.convert("RGB"), out_image, input_points, input_labels ) # recenter and rescale image_arr = np.array(out_image) ret, mask = cv2.threshold( np.array(out_image.split()[-1]), 0, 255, cv2.THRESH_BINARY ) x, y, w, h = cv2.boundingRect(mask) max_size = max(w, h) ratio = 0.75 side_len = int(max_size / ratio) padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) center = side_len // 2 padded_image[ center - h // 2 : center - h // 2 + h, center - w // 2 : center - w // 2 + w ] = image_arr[y : y + h, x : x + w] rgba = Image.fromarray(padded_image) rgba.save(f"{tmp_dir}/seg.png") torch.cuda.empty_cache() return segmentation.resize((380, 380), Image.Resampling.LANCZOS), rgba.resize( (320, 320), Image.Resampling.LANCZOS ) custom_theme = gr.themes.Soft(primary_hue="blue").set( button_secondary_background_fill="*neutral_100", button_secondary_background_fill_hover="*neutral_200", ) with gr.Blocks(title="MeshFormer Demo", css="style.css", theme=custom_theme) as demo: with gr.Row(): gr.Markdown( "# MeshFormer: High-Quality Mesh Generation with 3D-Guided Reconstruction Model" ) with gr.Row(): gr.Markdown( "[Project Page](https://meshformer3d.github.io/) | [arXiv](https://arxiv.org/abs/TBD)" ) with gr.Row(): gr.Markdown( """
Check out Hillbot (sudoAI) for more details and advanced features.
""" ) with gr.Row(): guide_text_i2m = gr.HTML(message("Please input an image!"), visible=True) tmp_dir_img = gr.State("./demo_exp/placeholder") tmp_dir_txt = gr.State("./demo_exp/placeholder") tmp_dir_3t3 = gr.State("./demo_exp/placeholder") example_folder = os.path.join(os.path.dirname(__file__), "demo_examples") example_fns = os.listdir(example_folder) example_fns.sort() img_examples = [ os.path.join(example_folder, x) for x in example_fns ] # if x.endswith('.png') or x.endswith('.') with gr.Row(variant="panel"): with gr.Row(): with gr.Column(scale=8): input_image = gr.Image( type="pil", image_mode="RGBA", height=320, label="Input Image", interactive=True, ) gr.Examples( examples=img_examples, inputs=[input_image], outputs=[input_image], cache_examples=False, label="Image Examples (Click one of the images below to start)", examples_per_page=27, ) with gr.Accordion("Options", open=False): img_simplify = gr.Checkbox( False, label="simplify the generated mesh", visible=False ) n_steps_img = gr.Slider( value=28, minimum=15, maximum=100, step=1, label="number of inference steps", ) # manual segmentation checkbox_manual_seg = gr.Checkbox(False, label="manual segmentation") with gr.Row(): with gr.Column(scale=1): seg_in = gr.Image( type="pil", image_mode="RGBA", label="Click to segment", visible=False, show_download_button=False, height=380, ) with gr.Column(scale=1): seg_out = gr.Image( type="pil", image_mode="RGBA", label="Segmentation", interactive=False, visible=False, show_download_button=False, height=380, elem_id="disp_image", ) fg_bg_radio = gr.Radio( ["+ (add mask)", "- (remove area)"], value="+ (add mask)", info="Select foreground (+) or background (-) point", label="Point label", visible=False, interactive=True, ) seg_mask_opt = gr.Checkbox( True, label="show foreground mask in manual segmentation", visible=False, ) # run img_run_btn = gr.Button( "Generate", variant="primary", interactive=False ) with gr.Column(scale=6): processed_image = gr.Image( type="pil", label="Processed Image", interactive=False, height=320, image_mode="RGBA", elem_id="disp_image", ) # with gr.Row(): # mesh_output = gr.Model3D(label="Generated Mesh", elem_id="model-3d-out") mesh_output_normal = Model3DColor( label="Generated Mesh (normal)", elem_id="mesh-normal-out", height=400, ) mesh_output = Model3DColor( label="Generated Mesh (color)", elem_id="mesh-out", height=400, ) predictor = gr.State(value=get_sam_predictor()) selected_points = gr.State(value=[]) selected_points_t2i = gr.State(value=[]) disable_checkbox = lambda: gr.Checkbox(value=False) disable_button = lambda: gr.Button(interactive=False) enable_button = lambda: gr.Button(interactive=True) update_guide = lambda GUIDE_TEXT, icon_type="info": gr.HTML( value=message(GUIDE_TEXT, icon_type) ) update_md = lambda GUIDE_TEXT: gr.Markdown(value=GUIDE_TEXT) def is_img_clear(input_image): if not input_image: raise ValueError("Input image cleared.") checkbox_manual_seg.change( vis_seg, inputs=[checkbox_manual_seg], outputs=[seg_in, seg_out, fg_bg_radio, selected_points, seg_mask_opt], queue=False, ).success( calc_feat, inputs=[checkbox_manual_seg, predictor, input_image], outputs=[seg_in], ).success( fn=create_tmp_dir, outputs=[tmp_dir_img], queue=False ) seg_in.select( manual_seg, [predictor, seg_in, selected_points, fg_bg_radio, tmp_dir_img, seg_mask_opt], [seg_out, processed_image], ) input_image.change(disable_button, outputs=img_run_btn, queue=False).success( disable_checkbox, outputs=checkbox_manual_seg, queue=False ).success(fn=is_img_clear, inputs=input_image, queue=False).success( fn=create_tmp_dir, outputs=tmp_dir_img, queue=False ).success( fn=partial(update_guide, "Preprocessing the image!", "wait"), outputs=[guide_text_i2m], queue=False, ).success( fn=preprocess, inputs=[tmp_dir_img, input_image], outputs=[processed_image], queue=True, ).success( fn=partial( update_guide, "Click Generate to generate mesh! If the input image was not segmented accurately, please adjust it using manual segmentation.", "cursor", ), outputs=[guide_text_i2m], queue=False, ).success( enable_button, outputs=img_run_btn, queue=False ) img_run_btn.click( fn=partial(update_guide, "Generating the mesh!", "wait"), outputs=[guide_text_i2m], queue=False, ).success( fn=mesh_gen, inputs=[tmp_dir_img, img_simplify, n_steps_img], outputs=[mesh_output, mesh_output_normal], queue=True, ).success( fn=partial( update_guide, "Successfully generated the mesh. (It might take a few seconds to load the mesh)", "done", ), outputs=[guide_text_i2m], queue=False, ) demo.queue().launch( debug=True, share=False, inline=False, show_api=False, server_name="0.0.0.0" )