Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import subprocess | |
# Install flash attention, skipping CUDA build if necessary | |
subprocess.run( | |
"pip install flash-attn --no-build-isolation", | |
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
shell=True, | |
) | |
import os | |
import torch | |
import trimesh | |
from accelerate.utils import set_seed | |
from accelerate import Accelerator | |
import numpy as np | |
import gradio as gr | |
from main import load_v2 | |
from mesh_to_pc import process_mesh_to_pc | |
import time | |
import matplotlib.pyplot as plt | |
from mpl_toolkits.mplot3d.art3d import Poly3DCollection | |
from PIL import Image | |
import io | |
model = load_v2() | |
device = torch.device('cuda') | |
accelerator = Accelerator( | |
mixed_precision="fp16", | |
) | |
model = accelerator.prepare(model) | |
model.eval() | |
print("Model loaded to device") | |
def wireframe_render(mesh): | |
views = [ | |
(90, 20), (270, 20) | |
] | |
mesh.vertices = mesh.vertices[:, [0, 2, 1]] | |
bounding_box = mesh.bounds | |
center = mesh.centroid | |
scale = np.ptp(bounding_box, axis=0).max() | |
fig = plt.figure(figsize=(10, 10)) | |
# Function to render and return each view as an image | |
def render_view(mesh, azimuth, elevation): | |
ax = fig.add_subplot(111, projection='3d') | |
ax.set_axis_off() | |
# Extract vertices and faces for plotting | |
vertices = mesh.vertices | |
faces = mesh.faces | |
# Plot faces | |
ax.add_collection3d(Poly3DCollection( | |
vertices[faces], | |
facecolors=(0.8, 0.5, 0.2, 1.0), # Brownish yellow | |
edgecolors='k', | |
linewidths=0.5, | |
)) | |
# Set limits and center the view on the object | |
ax.set_xlim(center[0] - scale / 2, center[0] + scale / 2) | |
ax.set_ylim(center[1] - scale / 2, center[1] + scale / 2) | |
ax.set_zlim(center[2] - scale / 2, center[2] + scale / 2) | |
# Set view angle | |
ax.view_init(elev=elevation, azim=azimuth) | |
# Save the figure to a buffer | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=300) | |
plt.clf() | |
buf.seek(0) | |
return Image.open(buf) | |
# Render each view and store in a list | |
images = [render_view(mesh, az, el) for az, el in views] | |
# Combine images horizontally | |
widths, heights = zip(*(i.size for i in images)) | |
total_width = sum(widths) | |
max_height = max(heights) | |
combined_image = Image.new('RGBA', (total_width, max_height)) | |
x_offset = 0 | |
for img in images: | |
combined_image.paste(img, (x_offset, 0)) | |
x_offset += img.width | |
# Save the combined image | |
save_path = f"combined_mesh_view_{int(time.time())}.png" | |
combined_image.save(save_path) | |
plt.close(fig) | |
return save_path | |
def do_inference(input_3d, sample_seed=0, do_sampling=False, do_marching_cubes=False, do_smooth_shading=False): | |
set_seed(sample_seed) | |
print("Seed value:", sample_seed) | |
input_mesh = trimesh.load(input_3d) | |
pc_list, mesh_list = process_mesh_to_pc([input_mesh], marching_cubes = do_marching_cubes) | |
pc_normal = pc_list[0] # 4096, 6 | |
mesh = mesh_list[0] | |
vertices = mesh.vertices | |
pc_coor = pc_normal[:, :3] | |
normals = pc_normal[:, 3:] | |
bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)]) | |
# scale mesh and pc | |
vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2 | |
vertices = vertices / (bounds[1] - bounds[0]).max() | |
mesh.vertices = vertices | |
pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2 | |
pc_coor = pc_coor / (bounds[1] - bounds[0]).max() | |
mesh.merge_vertices() | |
mesh.update_faces(mesh.nondegenerate_faces()) | |
mesh.update_faces(mesh.unique_faces()) | |
mesh.remove_unreferenced_vertices() | |
mesh.fix_normals() | |
try: | |
if mesh.visual.vertex_colors is not None: | |
orange_color = np.array([255, 165, 0, 255], dtype=np.uint8) | |
mesh.visual.vertex_colors = np.tile(orange_color, (mesh.vertices.shape[0], 1)) | |
else: | |
orange_color = np.array([255, 165, 0, 255], dtype=np.uint8) | |
mesh.visual.vertex_colors = np.tile(orange_color, (mesh.vertices.shape[0], 1)) | |
except Exception as e: | |
print(e) | |
input_save_name = f"processed_input_{int(time.time())}.obj" | |
mesh.export(input_save_name) | |
input_render_res = wireframe_render(mesh) | |
pc_coor = pc_coor / np.abs(pc_coor).max() * 0.99 # input should be from -1 to 1 | |
assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), "normals should be unit vectors, something wrong" | |
normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16) | |
input = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None] | |
print("Data loaded") | |
# with accelerator.autocast(): | |
with accelerator.autocast(): | |
outputs = model(input, do_sampling) | |
print("Model inference done") | |
recon_mesh = outputs[0] | |
valid_mask = torch.all(~torch.isnan(recon_mesh.reshape((-1, 9))), dim=1) | |
recon_mesh = recon_mesh[valid_mask] # nvalid_face x 3 x 3 | |
vertices = recon_mesh.reshape(-1, 3).cpu() | |
vertices_index = np.arange(len(vertices)) # 0, 1, ..., 3 x face | |
triangles = vertices_index.reshape(-1, 3) | |
artist_mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, force="mesh", | |
merge_primitives=True) | |
artist_mesh.merge_vertices() | |
artist_mesh.update_faces(artist_mesh.nondegenerate_faces()) | |
artist_mesh.update_faces(artist_mesh.unique_faces()) | |
artist_mesh.remove_unreferenced_vertices() | |
artist_mesh.fix_normals() | |
if do_smooth_shading: | |
smooth_shaded(artist_mesh) | |
if artist_mesh.visual.vertex_colors is not None: | |
orange_color = np.array([255, 165, 0, 255], dtype=np.uint8) | |
artist_mesh.visual.vertex_colors = np.tile(orange_color, (artist_mesh.vertices.shape[0], 1)) | |
else: | |
orange_color = np.array([255, 165, 0, 255], dtype=np.uint8) | |
artist_mesh.visual.vertex_colors = np.tile(orange_color, (artist_mesh.vertices.shape[0], 1)) | |
num_faces = len(artist_mesh.faces) | |
brown_color = np.array([165, 42, 42, 255], dtype=np.uint8) | |
face_colors = np.tile(brown_color, (num_faces, 1)) | |
artist_mesh.visual.face_colors = face_colors | |
# add time stamp to avoid cache | |
save_name = f"output_{int(time.time())}.obj" | |
artist_mesh.export(save_name) | |
output_render = wireframe_render(artist_mesh) | |
return input_save_name, input_render_res, save_name, output_render | |
_HEADER_ = """ | |
## (Optional) Transform your high poly mesh into a low poly mesh | |
➡️ You can optimize your high poly mesh, here, to the drawback is that you'll need to create a new material on Roblox. | |
- To optimize your high poly mesh, we use a tool called [MeshAnythingV2](https://huggingface.co/Yiwen-ntu/MeshAnythingV2). | |
### The Process: | |
1. Import the OBJ model generated with the high poly mesh generator tool above. | |
2. Check on Preprocess with marching Cubes. | |
3. If you want the look of your object smooth, check "Apply Smooth Shading". | |
<img src="https://huggingface.co/spaces/ThomasSimonini/MeshAnythingV2ForRoblox/resolve/main/assets/img/smooth-shading.png" alt="With or without smooth shading applied"/> | |
4. Click on generate | |
5. The 3D mesh is generated, and you can download the file (it's OBJ format) using the ⬇️ | |
6. Open Roblox Studio | |
7. In your Roblox Project, click on Import 3D and select the downloaded file. | |
8. You can now drag and drop your generated 3D file in your scene 🎉. | |
9. You can change the material and color by clicking on Color and Material in Roblox studio. | |
""" | |
output_model_obj = gr.Model3D( | |
label="Generated Mesh (OBJ Format)", | |
display_mode="wireframe", | |
clear_color=[1, 1, 1, 1], | |
) | |
preprocess_model_obj = gr.Model3D( | |
label="Processed Input Mesh (OBJ Format)", | |
display_mode="wireframe", | |
clear_color=[1, 1, 1, 1], | |
) | |
input_image_render = gr.Image( | |
label="Wireframe Render of Processed Input Mesh", | |
) | |
output_image_render = gr.Image( | |
label="Wireframe Render of Generated Mesh", | |
) | |
with (gr.Blocks() as demo): | |
gr.Markdown(_HEADER_) | |
with gr.Row(variant="panel"): | |
with gr.Column(): | |
with gr.Row(): | |
input_3d = gr.Model3D( | |
label="Input Mesh", | |
display_mode="wireframe", | |
clear_color=[1,1,1,1], | |
) | |
with gr.Row(): | |
with gr.Group(): | |
do_marching_cubes = gr.Checkbox(label="Preprocess with Marching Cubes", value=False) | |
do_smooth_shading = gr.Checkbox(label="Apply Smooth Shading", value=False) | |
do_sampling = gr.Checkbox(label="Random Sampling", value=False) | |
sample_seed = gr.Number(value=0, label="Seed Value", precision=0) | |
with gr.Row(): | |
submit = gr.Button("Generate", elem_id="generate", variant="primary") | |
with gr.Column(): | |
with gr.Row(): | |
input_image_render.render() | |
with gr.Row(): | |
with gr.Tab("OBJ"): | |
preprocess_model_obj.render() | |
with gr.Row(): | |
output_image_render.render() | |
with gr.Row(): | |
with gr.Tab("OBJ"): | |
output_model_obj.render() | |
with gr.Row(): | |
gr.Markdown('''Try click random sampling and different <b>Seed Value</b> if the result is unsatisfying''') | |
mv_images = gr.State() | |
submit.click( | |
fn=do_inference, | |
inputs=[input_3d, sample_seed, do_sampling, do_marching_cubes, do_smooth_shading], | |
outputs=[preprocess_model_obj, input_image_render, output_model_obj, output_image_render], | |
) | |
demo.launch(share=True) |