import gradio as gr import numpy as np import os import json import subprocess from PIL import Image from functools import partial from datetime import datetime from sam_inference import get_sam_predictor, sam_seg from utils import blend_seg, blend_seg_pure import cv2 import uuid import torch import trimesh from huggingface_hub import snapshot_download from gradio_model3dcolor import Model3DColor # from gradio_model3dnormal import Model3DNormal code_dir = snapshot_download("sudo-ai/MeshFormer-API", token=os.environ['HF_TOKEN']) with open(f'{code_dir}/api.json', 'r') as file: api_dict = json.load(file) SEG_CMD = api_dict["SEG_CMD"] MESH_CMD = api_dict["MESH_CMD"] STYLE = """ <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet" integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous"> <style> .alert, .alert div, .alert b { color: black !important; } </style> """ # info (info-circle-fill), cursor (hand-index-thumb), wait (hourglass-split), done (check-circle) ICONS = { "info": """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="#0d6efd" class="bi bi-info-circle-fill flex-shrink-0 me-2" viewBox="0 0 16 16"> <path d="M8 16A8 8 0 1 0 8 0a8 8 0 0 0 0 16zm.93-9.412-1 4.705c-.07.34.029.533.304.533.194 0 .487-.07.686-.246l-.088.416c-.287.346-.92.598-1.465.598-.703 0-1.002-.422-.808-1.319l.738-3.468c.064-.293.006-.399-.287-.47l-.451-.081.082-.381 2.29-.287zM8 5.5a1 1 0 1 1 0-2 1 1 0 0 1 0 2z"/> </svg>""", "cursor": """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="#0dcaf0" class="bi bi-hand-index-thumb-fill flex-shrink-0 me-2" viewBox="0 0 16 16"> <path d="M8.5 1.75v2.716l.047-.002c.312-.012.742-.016 1.051.046.28.056.543.18.738.288.273.152.456.385.56.642l.132-.012c.312-.024.794-.038 1.158.108.37.148.689.487.88.716.075.09.141.175.195.248h.582a2 2 0 0 1 1.99 2.199l-.272 2.715a3.5 3.5 0 0 1-.444 1.389l-1.395 2.441A1.5 1.5 0 0 1 12.42 16H6.118a1.5 1.5 0 0 1-1.342-.83l-1.215-2.43L1.07 8.589a1.517 1.517 0 0 1 2.373-1.852L5 8.293V1.75a1.75 1.75 0 0 1 3.5 0z"/> </svg>""", "wait": """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="#6c757d" class="bi bi-hourglass-split flex-shrink-0 me-2" viewBox="0 0 16 16"> <path d="M2.5 15a.5.5 0 1 1 0-1h1v-1a4.5 4.5 0 0 1 2.557-4.06c.29-.139.443-.377.443-.59v-.7c0-.213-.154-.451-.443-.59A4.5 4.5 0 0 1 3.5 3V2h-1a.5.5 0 0 1 0-1h11a.5.5 0 0 1 0 1h-1v1a4.5 4.5 0 0 1-2.557 4.06c-.29.139-.443.377-.443.59v.7c0 .213.154.451.443.59A4.5 4.5 0 0 1 12.5 13v1h1a.5.5 0 0 1 0 1h-11zm2-13v1c0 .537.12 1.045.337 1.5h6.326c.216-.455.337-.963.337-1.5V2h-7zm3 6.35c0 .701-.478 1.236-1.011 1.492A3.5 3.5 0 0 0 4.5 13s.866-1.299 3-1.48V8.35zm1 0v3.17c2.134.181 3 1.48 3 1.48a3.5 3.5 0 0 0-1.989-3.158C8.978 9.586 8.5 9.052 8.5 8.351z"/> </svg>""", "done": """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="#198754" class="bi bi-check-circle-fill flex-shrink-0 me-2" viewBox="0 0 16 16"> <path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0zm-3.97-3.03a.75.75 0 0 0-1.08.022L7.477 9.417 5.384 7.323a.75.75 0 0 0-1.06 1.06L6.97 11.03a.75.75 0 0 0 1.079-.02l3.992-4.99a.75.75 0 0 0-.01-1.05z"/> </svg>""", } icons2alert = { "info": "primary", # blue "cursor": "info", # light blue "wait": "secondary", # gray "done": "success", # green } def message(text, icon_type="info"): return f"""{STYLE} <div class="alert alert-{icons2alert[icon_type]} d-flex align-items-center" role="alert"> {ICONS[icon_type]} <div> {text} </div> </div>""" def preprocess(tmp_dir, input_img, idx=None): if idx is not None: print("image idx:", int(idx)) input_img = Image.open(input_img[int(idx)]["name"]) input_img.save(f"{tmp_dir}/input.png") # print(SEG_CMD.format(tmp_dir=tmp_dir)) os.system(SEG_CMD.format(tmp_dir=tmp_dir)) processed_img = Image.open(f"{tmp_dir}/seg.png") return processed_img.resize((320, 320), Image.Resampling.LANCZOS) def ply_to_glb(ply_path): result = subprocess.run( ["python", "ply2glb.py", "--", ply_path], capture_output=True, text=True, ) print("Output of blender script:") print(result.stdout) glb_path = ply_path.replace(".ply", ".glb") return glb_path def mesh_gen(tmp_dir, simplify, num_inference_steps): # print(MESH_CMD.format(tmp_dir=tmp_dir, num_inference_steps=num_inference_steps)) os.system(MESH_CMD.format(tmp_dir=tmp_dir, num_inference_steps=num_inference_steps)) mesh = trimesh.load_mesh(f"{tmp_dir}/mesh.ply") vertex_normals = mesh.vertex_normals theta = np.radians(-90) # Rotation angle in radians # Create rotation matrix cos_theta = np.cos(theta) sin_theta = np.sin(theta) rotation_matrix = np.array([ [cos_theta, -sin_theta, 0], [sin_theta, cos_theta, 0], [0, 0, 1] ]) rotated_normal = np.dot(vertex_normals, rotation_matrix.T) # rotated_normal = rotated_normal / np.linalg.norm(rotated_normal) colors = (-rotated_normal + 1) / 2.0 # colors = (-vertex_normals + 1) / 2.0 colors = (colors * 255).clip(0, 255).astype(np.uint8) # Convert to 8-bit color # print(colors.shape) mesh.visual.vertex_colors = colors[..., [2, 1, 0]] # RGB -> BGR mesh.export(f"{tmp_dir}/mesh_normal.ply", file_type="ply") color_path = ply_to_glb(f"{tmp_dir}/mesh.ply") normal_path = ply_to_glb(f"{tmp_dir}/mesh_normal.ply") return color_path, normal_path def create_tmp_dir(): tmp_dir = ( "demo_exp/" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + "_" + str(uuid.uuid4())[:4] ) os.makedirs(tmp_dir, exist_ok=True) print("create tmp_exp_dir", tmp_dir) return tmp_dir def vis_seg(checkbox): if checkbox: print("Show manual seg windows") return ( [gr.Image(value=None, visible=True)] * 2 + [gr.Radio(visible=True)] + [[], gr.Checkbox(visible=True)] ) else: print("Clear manual seg") return ( [gr.Image(visible=False)] * 2 + [gr.Radio(visible=False)] + [[], gr.Checkbox(visible=False)] ) def calc_feat(checkbox, predictor, input_image, idx=None): if checkbox: if idx is not None: print("image idx:", int(idx)) input_image = Image.open(input_image[int(idx)]["name"]) input_image.thumbnail([512, 512], Image.Resampling.LANCZOS) w, h = input_image.size print("image size:", w, h) side_len = np.max((w, h)) seg_in = Image.new(input_image.mode, (side_len, side_len), (255, 255, 255)) seg_in.paste( input_image, (np.max((0, (h - w) // 2)), np.max((0, (w - h) // 2))) ) print("Calculating image SAM feature...") predictor.set_image(np.array(seg_in.convert("RGB"))) torch.cuda.empty_cache() return gr.Image(value=seg_in, visible=True) else: print("Quit manual seg") raise ValueError("Quit manual seg") def manual_seg( predictor, seg_in, selected_points, fg_bg_radio, tmp_dir, seg_mask_opt, evt: gr.SelectData, ): print("Start segmentation") selected_points.append( {"coord": evt.index, "add_del": fg_bg_radio == "+ (add mask)"} ) input_points = np.array([point["coord"] for point in selected_points]) input_labels = np.array([point["add_del"] for point in selected_points]) out_image = sam_seg( predictor, np.array(seg_in.convert("RGB")), input_points, input_labels ) # seg_in.save(f"{tmp_dir}/in.png") # out_image.save(f"{tmp_dir}/out.png") if seg_mask_opt: segmentation = blend_seg_pure( seg_in.convert("RGB"), out_image, input_points, input_labels ) else: segmentation = blend_seg( seg_in.convert("RGB"), out_image, input_points, input_labels ) # recenter and rescale image_arr = np.array(out_image) ret, mask = cv2.threshold( np.array(out_image.split()[-1]), 0, 255, cv2.THRESH_BINARY ) x, y, w, h = cv2.boundingRect(mask) max_size = max(w, h) ratio = 0.75 side_len = int(max_size / ratio) padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) center = side_len // 2 padded_image[ center - h // 2 : center - h // 2 + h, center - w // 2 : center - w // 2 + w ] = image_arr[y : y + h, x : x + w] rgba = Image.fromarray(padded_image) rgba.save(f"{tmp_dir}/seg.png") torch.cuda.empty_cache() return segmentation.resize((380, 380), Image.Resampling.LANCZOS), rgba.resize( (320, 320), Image.Resampling.LANCZOS ) custom_theme = gr.themes.Soft(primary_hue="blue").set( button_secondary_background_fill="*neutral_100", button_secondary_background_fill_hover="*neutral_200", ) with gr.Blocks(title="MeshFormer Demo", css="style.css", theme=custom_theme) as demo: with gr.Row(): gr.Markdown( "# MeshFormer: High-Quality Mesh Generation with 3D-Guided Reconstruction Model" ) with gr.Row(): gr.Markdown( "[Project Page](https://meshformer3d.github.io/) | [arXiv](https://arxiv.org/pdf/2408.10198)" ) with gr.Row(): gr.Markdown( """ <div> <b><em>Check out <a href="https://www.sudo.ai/3dgen">Hillbot (sudoAI)</a> for more details and advanced features.</em></b> </div> """ ) with gr.Row(): guide_text_i2m = gr.HTML(message("Please input an image!"), visible=True) tmp_dir_img = gr.State("./demo_exp/placeholder") tmp_dir_txt = gr.State("./demo_exp/placeholder") tmp_dir_3t3 = gr.State("./demo_exp/placeholder") example_folder = os.path.join(os.path.dirname(__file__), "demo_examples") example_fns = os.listdir(example_folder) example_fns.sort() img_examples = [ os.path.join(example_folder, x) for x in example_fns ] # if x.endswith('.png') or x.endswith('.') with gr.Row(variant="panel"): with gr.Row(): with gr.Column(scale=8): input_image = gr.Image( type="pil", image_mode="RGBA", height=320, label="Input Image", interactive=True, ) gr.Examples( examples=img_examples, inputs=[input_image], outputs=[input_image], cache_examples=False, label="Image Examples (Click one of the images below to start)", examples_per_page=27, ) with gr.Accordion("Options", open=False): img_simplify = gr.Checkbox( False, label="simplify the generated mesh", visible=False ) n_steps_img = gr.Slider( value=28, minimum=15, maximum=100, step=1, label="number of inference steps", ) # manual segmentation checkbox_manual_seg = gr.Checkbox(False, label="manual segmentation") with gr.Row(): with gr.Column(scale=1): seg_in = gr.Image( type="pil", image_mode="RGBA", label="Click to segment", visible=False, show_download_button=False, height=380, ) with gr.Column(scale=1): seg_out = gr.Image( type="pil", image_mode="RGBA", label="Segmentation", interactive=False, visible=False, show_download_button=False, height=380, elem_id="disp_image", ) fg_bg_radio = gr.Radio( ["+ (add mask)", "- (remove area)"], value="+ (add mask)", info="Select foreground (+) or background (-) point", label="Point label", visible=False, interactive=True, ) seg_mask_opt = gr.Checkbox( True, label="show foreground mask in manual segmentation", visible=False, ) # run img_run_btn = gr.Button( "Generate", variant="primary", interactive=False ) with gr.Column(scale=6): processed_image = gr.Image( type="pil", label="Processed Image", interactive=False, height=320, image_mode="RGBA", elem_id="disp_image", ) # with gr.Row(): # mesh_output = gr.Model3D(label="Generated Mesh", elem_id="model-3d-out") mesh_output_normal = Model3DColor( label="Generated Mesh (normal)", elem_id="mesh-normal-out", height=400, ) mesh_output = Model3DColor( label="Generated Mesh (color)", elem_id="mesh-out", height=400, ) predictor = gr.State(value=get_sam_predictor()) selected_points = gr.State(value=[]) selected_points_t2i = gr.State(value=[]) disable_checkbox = lambda: gr.Checkbox(value=False) disable_button = lambda: gr.Button(interactive=False) enable_button = lambda: gr.Button(interactive=True) update_guide = lambda GUIDE_TEXT, icon_type="info": gr.HTML( value=message(GUIDE_TEXT, icon_type) ) update_md = lambda GUIDE_TEXT: gr.Markdown(value=GUIDE_TEXT) def is_img_clear(input_image): if not input_image: raise ValueError("Input image cleared.") checkbox_manual_seg.change( vis_seg, inputs=[checkbox_manual_seg], outputs=[seg_in, seg_out, fg_bg_radio, selected_points, seg_mask_opt], queue=False, ).success( calc_feat, inputs=[checkbox_manual_seg, predictor, input_image], outputs=[seg_in], queue=True, ).success( fn=create_tmp_dir, outputs=[tmp_dir_img], queue=False ) seg_in.select( manual_seg, [predictor, seg_in, selected_points, fg_bg_radio, tmp_dir_img, seg_mask_opt], [seg_out, processed_image], queue=True, ) input_image.change(disable_button, outputs=img_run_btn, queue=False).success( disable_checkbox, outputs=checkbox_manual_seg, queue=False ).success(fn=is_img_clear, inputs=input_image, queue=False).success( fn=create_tmp_dir, outputs=tmp_dir_img, queue=False ).success( fn=partial(update_guide, "Preprocessing the image!", "wait"), outputs=[guide_text_i2m], queue=False, ).success( fn=preprocess, inputs=[tmp_dir_img, input_image], outputs=[processed_image], queue=True, ).success( fn=partial( update_guide, "Click <b>Generate</b> to generate mesh! If the input image was not segmented accurately, please adjust it using <b>manual segmentation</b>.", "cursor", ), outputs=[guide_text_i2m], queue=False, ).success( enable_button, outputs=img_run_btn, queue=False ) img_run_btn.click( fn=partial(update_guide, "Generating the mesh!", "wait"), outputs=[guide_text_i2m], queue=False, ).success( fn=mesh_gen, inputs=[tmp_dir_img, img_simplify, n_steps_img], outputs=[mesh_output, mesh_output_normal], queue=True, ).success( fn=partial( update_guide, "Successfully generated the mesh. (It might take a few seconds to load the mesh)", "done", ), outputs=[guide_text_i2m], queue=False, ) demo.queue().launch( debug=False, share=False, inline=False, show_api=False, server_name="0.0.0.0" )