Adarsh Patel commited on
Commit
4baad62
·
1 Parent(s): 93adcf7

files added

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/cartoon_panda.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
- title: Duplicate Repo
3
- emoji: 🐨
4
- colorFrom: blue
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.5.0
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: InstantMesh
3
+ emoji: 📚
4
+ colorFrom: indigo
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.26.0
8
  app_file: app.py
9
  pinned: false
10
+ short_description: Create a 3D model from an image in 10 seconds!
11
+ license: apache-2.0
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+
3
+ import os
4
+ import imageio
5
+ import numpy as np
6
+ import torch
7
+ import rembg
8
+ from PIL import Image
9
+ from torchvision.transforms import v2
10
+ from pytorch_lightning import seed_everything
11
+ from omegaconf import OmegaConf
12
+ from einops import rearrange, repeat
13
+ from tqdm import tqdm
14
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
15
+
16
+ from src.utils.train_util import instantiate_from_config
17
+ from src.utils.camera_util import (
18
+ FOV_to_intrinsics,
19
+ get_zero123plus_input_cameras,
20
+ get_circular_camera_poses,
21
+ )
22
+ from src.utils.mesh_util import save_obj, save_glb
23
+ from src.utils.infer_util import remove_background, resize_foreground, images_to_video
24
+
25
+ import tempfile
26
+ from functools import partial
27
+
28
+ from huggingface_hub import hf_hub_download
29
+
30
+ import gradio as gr
31
+
32
+
33
+ def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
34
+ """
35
+ Get the rendering camera parameters.
36
+ """
37
+ c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
38
+ if is_flexicubes:
39
+ cameras = torch.linalg.inv(c2ws)
40
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
41
+ else:
42
+ extrinsics = c2ws.flatten(-2)
43
+ intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
44
+ cameras = torch.cat([extrinsics, intrinsics], dim=-1)
45
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
46
+ return cameras
47
+
48
+
49
+ def images_to_video(images, output_path, fps=30):
50
+ # images: (N, C, H, W)
51
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
52
+ frames = []
53
+ for i in range(images.shape[0]):
54
+ frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
55
+ assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
56
+ f"Frame shape mismatch: {frame.shape} vs {images.shape}"
57
+ assert frame.min() >= 0 and frame.max() <= 255, \
58
+ f"Frame value out of range: {frame.min()} ~ {frame.max()}"
59
+ frames.append(frame)
60
+ imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264')
61
+
62
+
63
+ ###############################################################################
64
+ # Configuration.
65
+ ###############################################################################
66
+
67
+ import shutil
68
+
69
+ def find_cuda():
70
+ # Check if CUDA_HOME or CUDA_PATH environment variables are set
71
+ cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
72
+
73
+ if cuda_home and os.path.exists(cuda_home):
74
+ return cuda_home
75
+
76
+ # Search for the nvcc executable in the system's PATH
77
+ nvcc_path = shutil.which('nvcc')
78
+
79
+ if nvcc_path:
80
+ # Remove the 'bin/nvcc' part to get the CUDA installation path
81
+ cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
82
+ return cuda_path
83
+
84
+ return None
85
+
86
+ cuda_path = find_cuda()
87
+
88
+ if cuda_path:
89
+ print(f"CUDA installation found at: {cuda_path}")
90
+ else:
91
+ print("CUDA installation not found")
92
+
93
+ config_path = 'configs/instant-mesh-large.yaml'
94
+ config = OmegaConf.load(config_path)
95
+ config_name = os.path.basename(config_path).replace('.yaml', '')
96
+ model_config = config.model_config
97
+ infer_config = config.infer_config
98
+
99
+ IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
100
+
101
+ device = torch.device('cuda')
102
+
103
+ # load diffusion model
104
+ print('Loading diffusion model ...')
105
+ pipeline = DiffusionPipeline.from_pretrained(
106
+ "sudo-ai/zero123plus-v1.2",
107
+ custom_pipeline="zero123plus",
108
+ torch_dtype=torch.float16,
109
+ )
110
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
111
+ pipeline.scheduler.config, timestep_spacing='trailing'
112
+ )
113
+
114
+ # load custom white-background UNet
115
+ unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
116
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')
117
+ pipeline.unet.load_state_dict(state_dict, strict=True)
118
+
119
+ pipeline = pipeline.to(device)
120
+
121
+ # load reconstruction model
122
+ print('Loading reconstruction model ...')
123
+ model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
124
+ model = instantiate_from_config(model_config)
125
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
126
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
127
+ model.load_state_dict(state_dict, strict=True)
128
+
129
+ model = model.to(device)
130
+
131
+ print('Loading Finished!')
132
+
133
+
134
+ def check_input_image(input_image):
135
+ if input_image is None:
136
+ raise gr.Error("No image uploaded!")
137
+
138
+
139
+ def preprocess(input_image, do_remove_background):
140
+
141
+ rembg_session = rembg.new_session() if do_remove_background else None
142
+
143
+ if do_remove_background:
144
+ input_image = remove_background(input_image, rembg_session)
145
+ input_image = resize_foreground(input_image, 0.85)
146
+
147
+ return input_image
148
+
149
+
150
+ @spaces.GPU
151
+ def generate_mvs(input_image, sample_steps, sample_seed):
152
+
153
+ seed_everything(sample_seed)
154
+
155
+ # sampling
156
+ z123_image = pipeline(
157
+ input_image,
158
+ num_inference_steps=sample_steps
159
+ ).images[0]
160
+
161
+ show_image = np.asarray(z123_image, dtype=np.uint8)
162
+ show_image = torch.from_numpy(show_image) # (960, 640, 3)
163
+ show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
164
+ show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
165
+ show_image = Image.fromarray(show_image.numpy())
166
+
167
+ return z123_image, show_image
168
+
169
+
170
+ @spaces.GPU
171
+ def make3d(images):
172
+
173
+ global model
174
+ if IS_FLEXICUBES:
175
+ model.init_flexicubes_geometry(device, use_renderer=False)
176
+ model = model.eval()
177
+
178
+ images = np.asarray(images, dtype=np.float32) / 255.0
179
+ images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
180
+ images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
181
+
182
+ input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
183
+ render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
184
+
185
+ images = images.unsqueeze(0).to(device)
186
+ images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
187
+
188
+ mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
189
+ print(mesh_fpath)
190
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
191
+ mesh_dirname = os.path.dirname(mesh_fpath)
192
+ video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
193
+ mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
194
+
195
+ with torch.no_grad():
196
+ # get triplane
197
+ planes = model.forward_planes(images, input_cameras)
198
+
199
+ # # get video
200
+ # chunk_size = 20 if IS_FLEXICUBES else 1
201
+ # render_size = 384
202
+
203
+ # frames = []
204
+ # for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
205
+ # if IS_FLEXICUBES:
206
+ # frame = model.forward_geometry(
207
+ # planes,
208
+ # render_cameras[:, i:i+chunk_size],
209
+ # render_size=render_size,
210
+ # )['img']
211
+ # else:
212
+ # frame = model.synthesizer(
213
+ # planes,
214
+ # cameras=render_cameras[:, i:i+chunk_size],
215
+ # render_size=render_size,
216
+ # )['images_rgb']
217
+ # frames.append(frame)
218
+ # frames = torch.cat(frames, dim=1)
219
+
220
+ # images_to_video(
221
+ # frames[0],
222
+ # video_fpath,
223
+ # fps=30,
224
+ # )
225
+
226
+ # print(f"Video saved to {video_fpath}")
227
+
228
+ # get mesh
229
+ mesh_out = model.extract_mesh(
230
+ planes,
231
+ use_texture_map=False,
232
+ **infer_config,
233
+ )
234
+
235
+ vertices, faces, vertex_colors = mesh_out
236
+ vertices = vertices[:, [1, 2, 0]]
237
+
238
+ save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
239
+ save_obj(vertices, faces, vertex_colors, mesh_fpath)
240
+
241
+ print(f"Mesh saved to {mesh_fpath}")
242
+
243
+ return mesh_fpath, mesh_glb_fpath
244
+
245
+
246
+ _HEADER_ = '''
247
+ <h2><b>Welcome to 3DFusion!</b></h2>
248
+ <h2><a href='https://github.com/TencentARC/InstantMesh' target='_blank'><b>3D Mesh Generation from Single Images with 3DFusion</b></a></h2>
249
+
250
+ 3DFusion is a cutting-edge, efficient 3D mesh generation tool based on the powerful LRM/Instant3D architecture.
251
+
252
+ Code and Original Framework: <a href='https://github.com/TencentARC/InstantMesh' target='_blank'>InstantMesh GitHub</a>. Technical report: <a href='https://arxiv.org/abs/2404.07191' target='_blank'>ArXiv</a>.
253
+
254
+ ❗️**Important Notes:**
255
+ - This demo exports both `.obj` and `.glb` meshes, including vertex colors.
256
+ - The 3D mesh generation depends on the quality of generated multi-view images, so try different seed values (default: 42) for optimal results.
257
+ '''
258
+
259
+ _CITE_ = r"""
260
+ If you find **3DFusion** helpful, please give a ⭐ to the original <a href='https://github.com/TencentARC/InstantMesh' target='_blank'>InstantMesh repository</a>. We appreciate the work of the TencentARC team! [![GitHub Stars](https://img.shields.io/github/stars/TencentARC/InstantMesh?style=social)](https://github.com/TencentARC/InstantMesh)
261
+ ---
262
+ 📝 **Citation**
263
+
264
+ If you use this work for research or applications, cite it as follows:
265
+ ```bibtex
266
+ @article{xu2024instantmesh,
267
+ title={InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models},
268
+ author={Xu, Jiale and Cheng, Weihao and Gao, Yiming and Wang, Xintao and Gao, Shenghua and Shan, Ying},
269
+ journal={arXiv preprint arXiv:2404.07191},
270
+ year={2024}
271
+ }
272
+ ```
273
+
274
+ 📋 **License**
275
+
276
+ Apache-2.0 LICENSE. Please refer to the [LICENSE file](https://huggingface.co/spaces/TencentARC/InstantMesh/blob/main/LICENSE) for details.
277
+
278
+ 📧 **Contact**
279
+
280
+ If you have any questions, feel free to open a discussion or contact us at <b>[email protected]</b>.
281
+ """
282
+
283
+
284
+ with gr.Blocks() as demo:
285
+ gr.Markdown(_HEADER_)
286
+ with gr.Row(variant="panel"):
287
+ with gr.Column():
288
+ with gr.Row():
289
+ input_image = gr.Image(
290
+ label="Input Image",
291
+ image_mode="RGBA",
292
+ sources="upload",
293
+ #width=256,
294
+ #height=256,
295
+ type="pil",
296
+ elem_id="content_image",
297
+ )
298
+ processed_image = gr.Image(
299
+ label="Processed Image",
300
+ image_mode="RGBA",
301
+ #width=256,
302
+ #height=256,
303
+ type="pil",
304
+ interactive=False
305
+ )
306
+ with gr.Row():
307
+ with gr.Group():
308
+ do_remove_background = gr.Checkbox(
309
+ label="Remove Background", value=True
310
+ )
311
+ sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
312
+
313
+ sample_steps = gr.Slider(
314
+ label="Sample Steps",
315
+ minimum=30,
316
+ maximum=75,
317
+ value=75,
318
+ step=5
319
+ )
320
+
321
+ with gr.Row():
322
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
323
+
324
+ with gr.Row(variant="panel"):
325
+ gr.Examples(
326
+ examples=[
327
+ os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
328
+ ],
329
+ inputs=[input_image],
330
+ label="Examples",
331
+ cache_examples=False,
332
+ examples_per_page=16
333
+ )
334
+
335
+ with gr.Column():
336
+
337
+ with gr.Row():
338
+
339
+ with gr.Column():
340
+ mv_show_images = gr.Image(
341
+ label="Generated Multi-views",
342
+ type="pil",
343
+ width=379,
344
+ interactive=False
345
+ )
346
+
347
+ # with gr.Column():
348
+ # output_video = gr.Video(
349
+ # label="video", format="mp4",
350
+ # width=379,
351
+ # autoplay=True,
352
+ # interactive=False
353
+ # )
354
+
355
+ with gr.Row():
356
+ with gr.Tab("OBJ"):
357
+ output_model_obj = gr.Model3D(
358
+ label="Output Model (OBJ Format)",
359
+ interactive=False,
360
+ )
361
+ gr.Markdown("Note: Downloaded .obj model will be flipped. Export .glb instead or manually flip it before usage.")
362
+ with gr.Tab("GLB"):
363
+ output_model_glb = gr.Model3D(
364
+ label="Output Model (GLB Format)",
365
+ interactive=False,
366
+ )
367
+ gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
368
+
369
+ with gr.Row():
370
+ gr.Markdown('''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
371
+
372
+ gr.Markdown(_CITE_)
373
+
374
+ mv_images = gr.State()
375
+
376
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
377
+ fn=preprocess,
378
+ inputs=[input_image, do_remove_background],
379
+ outputs=[processed_image],
380
+ ).success(
381
+ fn=generate_mvs,
382
+ inputs=[processed_image, sample_steps, sample_seed],
383
+ outputs=[mv_images, mv_show_images]
384
+
385
+ ).success(
386
+ fn=make3d,
387
+ inputs=[mv_images],
388
+ outputs=[output_model_obj, output_model_glb]
389
+ )
390
+
391
+ demo.launch()
configs/instant-mesh-base.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm_mesh.InstantMesh
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 12
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 40
13
+ rendering_samples_per_ray: 96
14
+ grid_res: 128
15
+ grid_scale: 2.1
16
+
17
+
18
+ infer_config:
19
+ unet_path: ckpts/diffusion_pytorch_model.bin
20
+ model_path: ckpts/instant_mesh_base.ckpt
21
+ texture_resolution: 1024
22
+ render_resolution: 512
configs/instant-mesh-large.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm_mesh.InstantMesh
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 16
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 80
13
+ rendering_samples_per_ray: 128
14
+ grid_res: 128
15
+ grid_scale: 2.1
16
+
17
+
18
+ infer_config:
19
+ unet_path: ckpts/diffusion_pytorch_model.bin
20
+ model_path: ckpts/instant_mesh_large.ckpt
21
+ texture_resolution: 1024
22
+ render_resolution: 512
configs/instant-nerf-base.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm.InstantNeRF
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 12
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 40
13
+ rendering_samples_per_ray: 96
14
+
15
+
16
+ infer_config:
17
+ unet_path: ckpts/diffusion_pytorch_model.bin
18
+ model_path: ckpts/instant_nerf_base.ckpt
19
+ mesh_threshold: 10.0
20
+ mesh_resolution: 256
21
+ render_resolution: 384
configs/instant-nerf-large.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm.InstantNeRF
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 16
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 80
13
+ rendering_samples_per_ray: 128
14
+
15
+
16
+ infer_config:
17
+ unet_path: ckpts/diffusion_pytorch_model.bin
18
+ model_path: ckpts/instant_nerf_large.ckpt
19
+ mesh_threshold: 10.0
20
+ mesh_resolution: 256
21
+ render_resolution: 384
examples/bird.jpg ADDED
examples/bubble_mart_blue.png ADDED
examples/cake.jpg ADDED
examples/cartoon_dinosaur.png ADDED
examples/cartoon_panda.png ADDED

Git LFS Details

  • SHA256: c82fea6ac66b782b2aa1c6bd133447b5f54f688c7eb44998c4b00f190d47b2b7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.52 MB
examples/chair_armed.png ADDED
examples/chair_comfort.jpg ADDED
examples/chair_wood.jpg ADDED
examples/chest.jpg ADDED
examples/cute_horse.jpg ADDED
examples/cute_tiger.jpg ADDED
examples/earphone.jpg ADDED
examples/fox.jpg ADDED
examples/fruit.jpg ADDED
examples/fruit_elephant.jpg ADDED
examples/genshin_building.png ADDED
examples/genshin_teapot.png ADDED
examples/hatsune_miku.png ADDED
examples/house2.jpg ADDED
examples/mushroom_teapot.jpg ADDED
examples/pikachu.png ADDED
examples/plant.jpg ADDED
examples/robot.jpg ADDED
examples/sea_turtle.png ADDED
examples/skating_shoe.jpg ADDED
examples/sorting_board.png ADDED
examples/sword.png ADDED
examples/toy_car.jpg ADDED
examples/watermelon.png ADDED
examples/whitedog.png ADDED
examples/x_teapot.jpg ADDED
examples/x_toyduck.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.1.0
2
+ torchvision==0.16.0
3
+ torchaudio==2.1.0
4
+ pytorch-lightning==2.1.2
5
+ einops
6
+ omegaconf
7
+ deepspeed
8
+ torchmetrics
9
+ webdataset
10
+ accelerate
11
+ tensorboard
12
+ PyMCubes
13
+ trimesh
14
+ rembg
15
+ transformers==4.34.1
16
+ diffusers==0.19.3
17
+ bitsandbytes
18
+ imageio[ffmpeg]
19
+ xatlas
20
+ plyfile
21
+ xformers==0.0.22.post7
22
+ git+https://github.com/NVlabs/nvdiffrast/
23
+ huggingface-hub
src/__init__.py ADDED
File without changes
src/data/__init__.py ADDED
File without changes
src/data/objaverse.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import math
3
+ import json
4
+ import importlib
5
+ from pathlib import Path
6
+
7
+ import cv2
8
+ import random
9
+ import numpy as np
10
+ from PIL import Image
11
+ import webdataset as wds
12
+ import pytorch_lightning as pl
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from torch.utils.data import Dataset
17
+ from torch.utils.data import DataLoader
18
+ from torch.utils.data.distributed import DistributedSampler
19
+ from torchvision import transforms
20
+
21
+ from src.utils.train_util import instantiate_from_config
22
+ from src.utils.camera_util import (
23
+ FOV_to_intrinsics,
24
+ center_looking_at_camera_pose,
25
+ get_surrounding_views,
26
+ )
27
+
28
+
29
+ class DataModuleFromConfig(pl.LightningDataModule):
30
+ def __init__(
31
+ self,
32
+ batch_size=8,
33
+ num_workers=4,
34
+ train=None,
35
+ validation=None,
36
+ test=None,
37
+ **kwargs,
38
+ ):
39
+ super().__init__()
40
+
41
+ self.batch_size = batch_size
42
+ self.num_workers = num_workers
43
+
44
+ self.dataset_configs = dict()
45
+ if train is not None:
46
+ self.dataset_configs['train'] = train
47
+ if validation is not None:
48
+ self.dataset_configs['validation'] = validation
49
+ if test is not None:
50
+ self.dataset_configs['test'] = test
51
+
52
+ def setup(self, stage):
53
+
54
+ if stage in ['fit']:
55
+ self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
56
+ else:
57
+ raise NotImplementedError
58
+
59
+ def train_dataloader(self):
60
+
61
+ sampler = DistributedSampler(self.datasets['train'])
62
+ return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
63
+
64
+ def val_dataloader(self):
65
+
66
+ sampler = DistributedSampler(self.datasets['validation'])
67
+ return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler)
68
+
69
+ def test_dataloader(self):
70
+
71
+ return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
72
+
73
+
74
+ class ObjaverseData(Dataset):
75
+ def __init__(self,
76
+ root_dir='objaverse/',
77
+ meta_fname='valid_paths.json',
78
+ input_image_dir='rendering_random_32views',
79
+ target_image_dir='rendering_random_32views',
80
+ input_view_num=6,
81
+ target_view_num=2,
82
+ total_view_n=32,
83
+ fov=50,
84
+ camera_rotation=True,
85
+ validation=False,
86
+ ):
87
+ self.root_dir = Path(root_dir)
88
+ self.input_image_dir = input_image_dir
89
+ self.target_image_dir = target_image_dir
90
+
91
+ self.input_view_num = input_view_num
92
+ self.target_view_num = target_view_num
93
+ self.total_view_n = total_view_n
94
+ self.fov = fov
95
+ self.camera_rotation = camera_rotation
96
+
97
+ with open(os.path.join(root_dir, meta_fname)) as f:
98
+ filtered_dict = json.load(f)
99
+ paths = filtered_dict['good_objs']
100
+ self.paths = paths
101
+
102
+ self.depth_scale = 4.0
103
+
104
+ total_objects = len(self.paths)
105
+ print('============= length of dataset %d =============' % len(self.paths))
106
+
107
+ def __len__(self):
108
+ return len(self.paths)
109
+
110
+ def load_im(self, path, color):
111
+ '''
112
+ replace background pixel with random color in rendering
113
+ '''
114
+ pil_img = Image.open(path)
115
+
116
+ image = np.asarray(pil_img, dtype=np.float32) / 255.
117
+ alpha = image[:, :, 3:]
118
+ image = image[:, :, :3] * alpha + color * (1 - alpha)
119
+
120
+ image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
121
+ alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
122
+ return image, alpha
123
+
124
+ def __getitem__(self, index):
125
+ # load data
126
+ while True:
127
+ input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index])
128
+ target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index])
129
+
130
+ indices = np.random.choice(range(self.total_view_n), self.input_view_num + self.target_view_num, replace=False)
131
+ input_indices = indices[:self.input_view_num]
132
+ target_indices = indices[self.input_view_num:]
133
+
134
+ '''background color, default: white'''
135
+ bg_white = [1., 1., 1.]
136
+ bg_black = [0., 0., 0.]
137
+
138
+ image_list = []
139
+ alpha_list = []
140
+ depth_list = []
141
+ normal_list = []
142
+ pose_list = []
143
+
144
+ try:
145
+ input_cameras = np.load(os.path.join(input_image_path, 'cameras.npz'))['cam_poses']
146
+ for idx in input_indices:
147
+ image, alpha = self.load_im(os.path.join(input_image_path, '%03d.png' % idx), bg_white)
148
+ normal, _ = self.load_im(os.path.join(input_image_path, '%03d_normal.png' % idx), bg_black)
149
+ depth = cv2.imread(os.path.join(input_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
150
+ depth = torch.from_numpy(depth).unsqueeze(0)
151
+ pose = input_cameras[idx]
152
+ pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
153
+
154
+ image_list.append(image)
155
+ alpha_list.append(alpha)
156
+ depth_list.append(depth)
157
+ normal_list.append(normal)
158
+ pose_list.append(pose)
159
+
160
+ target_cameras = np.load(os.path.join(target_image_path, 'cameras.npz'))['cam_poses']
161
+ for idx in target_indices:
162
+ image, alpha = self.load_im(os.path.join(target_image_path, '%03d.png' % idx), bg_white)
163
+ normal, _ = self.load_im(os.path.join(target_image_path, '%03d_normal.png' % idx), bg_black)
164
+ depth = cv2.imread(os.path.join(target_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
165
+ depth = torch.from_numpy(depth).unsqueeze(0)
166
+ pose = target_cameras[idx]
167
+ pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
168
+
169
+ image_list.append(image)
170
+ alpha_list.append(alpha)
171
+ depth_list.append(depth)
172
+ normal_list.append(normal)
173
+ pose_list.append(pose)
174
+
175
+ except Exception as e:
176
+ print(e)
177
+ index = np.random.randint(0, len(self.paths))
178
+ continue
179
+
180
+ break
181
+
182
+ images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
183
+ alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
184
+ depths = torch.stack(depth_list, dim=0).float() # (6+V, 1, H, W)
185
+ normals = torch.stack(normal_list, dim=0).float() # (6+V, 3, H, W)
186
+ w2cs = torch.from_numpy(np.stack(pose_list, axis=0)).float() # (6+V, 4, 4)
187
+ c2ws = torch.linalg.inv(w2cs).float()
188
+
189
+ normals = normals * 2.0 - 1.0
190
+ normals = F.normalize(normals, dim=1)
191
+ normals = (normals + 1.0) / 2.0
192
+ normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
193
+
194
+ # random rotation along z axis
195
+ if self.camera_rotation:
196
+ degree = np.random.uniform(0, math.pi * 2)
197
+ rot = torch.tensor([
198
+ [np.cos(degree), -np.sin(degree), 0, 0],
199
+ [np.sin(degree), np.cos(degree), 0, 0],
200
+ [0, 0, 1, 0],
201
+ [0, 0, 0, 1],
202
+ ]).unsqueeze(0).float()
203
+ c2ws = torch.matmul(rot, c2ws)
204
+
205
+ # rotate normals
206
+ N, _, H, W = normals.shape
207
+ normals = normals * 2.0 - 1.0
208
+ normals = torch.matmul(rot[:, :3, :3], normals.view(N, 3, -1)).view(N, 3, H, W)
209
+ normals = F.normalize(normals, dim=1)
210
+ normals = (normals + 1.0) / 2.0
211
+ normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
212
+
213
+ # random scaling
214
+ if np.random.rand() < 0.5:
215
+ scale = np.random.uniform(0.8, 1.0)
216
+ c2ws[:, :3, 3] *= scale
217
+ depths *= scale
218
+
219
+ # instrinsics of perspective cameras
220
+ K = FOV_to_intrinsics(self.fov)
221
+ Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float()
222
+
223
+ data = {
224
+ 'input_images': images[:self.input_view_num], # (6, 3, H, W)
225
+ 'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W)
226
+ 'input_depths': depths[:self.input_view_num], # (6, 1, H, W)
227
+ 'input_normals': normals[:self.input_view_num], # (6, 3, H, W)
228
+ 'input_c2ws': c2ws_input[:self.input_view_num], # (6, 4, 4)
229
+ 'input_Ks': Ks[:self.input_view_num], # (6, 3, 3)
230
+
231
+ # lrm generator input and supervision
232
+ 'target_images': images[self.input_view_num:], # (V, 3, H, W)
233
+ 'target_alphas': alphas[self.input_view_num:], # (V, 1, H, W)
234
+ 'target_depths': depths[self.input_view_num:], # (V, 1, H, W)
235
+ 'target_normals': normals[self.input_view_num:], # (V, 3, H, W)
236
+ 'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4)
237
+ 'target_Ks': Ks[self.input_view_num:], # (V, 3, 3)
238
+
239
+ 'depth_available': 1,
240
+ }
241
+ return data
242
+
243
+
244
+ class ValidationData(Dataset):
245
+ def __init__(self,
246
+ root_dir='objaverse/',
247
+ input_view_num=6,
248
+ input_image_size=256,
249
+ fov=50,
250
+ ):
251
+ self.root_dir = Path(root_dir)
252
+ self.input_view_num = input_view_num
253
+ self.input_image_size = input_image_size
254
+ self.fov = fov
255
+
256
+ self.paths = sorted(os.listdir(self.root_dir))
257
+ print('============= length of dataset %d =============' % len(self.paths))
258
+
259
+ cam_distance = 2.5
260
+ azimuths = np.array([30, 90, 150, 210, 270, 330])
261
+ elevations = np.array([30, -20, 30, -20, 30, -20])
262
+ azimuths = np.deg2rad(azimuths)
263
+ elevations = np.deg2rad(elevations)
264
+
265
+ x = cam_distance * np.cos(elevations) * np.cos(azimuths)
266
+ y = cam_distance * np.cos(elevations) * np.sin(azimuths)
267
+ z = cam_distance * np.sin(elevations)
268
+
269
+ cam_locations = np.stack([x, y, z], axis=-1)
270
+ cam_locations = torch.from_numpy(cam_locations).float()
271
+ c2ws = center_looking_at_camera_pose(cam_locations)
272
+ self.c2ws = c2ws.float()
273
+ self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
274
+
275
+ render_c2ws = get_surrounding_views(M=8, radius=cam_distance)
276
+ render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
277
+ self.render_c2ws = render_c2ws.float()
278
+ self.render_Ks = render_Ks.float()
279
+
280
+ def __len__(self):
281
+ return len(self.paths)
282
+
283
+ def load_im(self, path, color):
284
+ '''
285
+ replace background pixel with random color in rendering
286
+ '''
287
+ pil_img = Image.open(path)
288
+ pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC)
289
+
290
+ image = np.asarray(pil_img, dtype=np.float32) / 255.
291
+ if image.shape[-1] == 4:
292
+ alpha = image[:, :, 3:]
293
+ image = image[:, :, :3] * alpha + color * (1 - alpha)
294
+ else:
295
+ alpha = np.ones_like(image[:, :, :1])
296
+
297
+ image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
298
+ alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
299
+ return image, alpha
300
+
301
+ def __getitem__(self, index):
302
+ # load data
303
+ input_image_path = os.path.join(self.root_dir, self.paths[index])
304
+
305
+ '''background color, default: white'''
306
+ # color = np.random.uniform(0.48, 0.52)
307
+ bkg_color = [1.0, 1.0, 1.0]
308
+
309
+ image_list = []
310
+ alpha_list = []
311
+
312
+ for idx in range(self.input_view_num):
313
+ image, alpha = self.load_im(os.path.join(input_image_path, f'{idx:03d}.png'), bkg_color)
314
+ image_list.append(image)
315
+ alpha_list.append(alpha)
316
+
317
+ images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
318
+ alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
319
+
320
+ data = {
321
+ 'input_images': images, # (6, 3, H, W)
322
+ 'input_alphas': alphas, # (6, 1, H, W)
323
+ 'input_c2ws': self.c2ws, # (6, 4, 4)
324
+ 'input_Ks': self.Ks, # (6, 3, 3)
325
+
326
+ 'render_c2ws': self.render_c2ws,
327
+ 'render_Ks': self.render_Ks,
328
+ }
329
+ return data
src/model.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torchvision.transforms import v2
6
+ from torchvision.utils import make_grid, save_image
7
+ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
8
+ import pytorch_lightning as pl
9
+ from einops import rearrange, repeat
10
+
11
+ from src.utils.train_util import instantiate_from_config
12
+
13
+
14
+ class MVRecon(pl.LightningModule):
15
+ def __init__(
16
+ self,
17
+ lrm_generator_config,
18
+ lrm_path=None,
19
+ input_size=256,
20
+ render_size=192,
21
+ ):
22
+ super(MVRecon, self).__init__()
23
+
24
+ self.input_size = input_size
25
+ self.render_size = render_size
26
+
27
+ # init modules
28
+ self.lrm_generator = instantiate_from_config(lrm_generator_config)
29
+ if lrm_path is not None:
30
+ lrm_ckpt = torch.load(lrm_path)
31
+ self.lrm_generator.load_state_dict(lrm_ckpt['weights'], strict=False)
32
+
33
+ self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
34
+
35
+ self.validation_step_outputs = []
36
+
37
+ def on_fit_start(self):
38
+ if self.global_rank == 0:
39
+ os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
40
+ os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
41
+
42
+ def prepare_batch_data(self, batch):
43
+ lrm_generator_input = {}
44
+ render_gt = {} # for supervision
45
+
46
+ # input images
47
+ images = batch['input_images']
48
+ images = v2.functional.resize(
49
+ images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
50
+
51
+ lrm_generator_input['images'] = images.to(self.device)
52
+
53
+ # input cameras and render cameras
54
+ input_c2ws = batch['input_c2ws'].flatten(-2)
55
+ input_Ks = batch['input_Ks'].flatten(-2)
56
+ target_c2ws = batch['target_c2ws'].flatten(-2)
57
+ target_Ks = batch['target_Ks'].flatten(-2)
58
+ render_cameras_input = torch.cat([input_c2ws, input_Ks], dim=-1)
59
+ render_cameras_target = torch.cat([target_c2ws, target_Ks], dim=-1)
60
+ render_cameras = torch.cat([render_cameras_input, render_cameras_target], dim=1)
61
+
62
+ input_extrinsics = input_c2ws[:, :, :12]
63
+ input_intrinsics = torch.stack([
64
+ input_Ks[:, :, 0], input_Ks[:, :, 4],
65
+ input_Ks[:, :, 2], input_Ks[:, :, 5],
66
+ ], dim=-1)
67
+ cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
68
+
69
+ # add noise to input cameras
70
+ cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
71
+
72
+ lrm_generator_input['cameras'] = cameras.to(self.device)
73
+ lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
74
+
75
+ # target images
76
+ target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
77
+ target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
78
+ target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
79
+
80
+ # random crop
81
+ render_size = np.random.randint(self.render_size, 513)
82
+ target_images = v2.functional.resize(
83
+ target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
84
+ target_depths = v2.functional.resize(
85
+ target_depths, render_size, interpolation=0, antialias=True)
86
+ target_alphas = v2.functional.resize(
87
+ target_alphas, render_size, interpolation=0, antialias=True)
88
+
89
+ crop_params = v2.RandomCrop.get_params(
90
+ target_images, output_size=(self.render_size, self.render_size))
91
+ target_images = v2.functional.crop(target_images, *crop_params)
92
+ target_depths = v2.functional.crop(target_depths, *crop_params)[:, :, 0:1]
93
+ target_alphas = v2.functional.crop(target_alphas, *crop_params)[:, :, 0:1]
94
+
95
+ lrm_generator_input['render_size'] = render_size
96
+ lrm_generator_input['crop_params'] = crop_params
97
+
98
+ render_gt['target_images'] = target_images.to(self.device)
99
+ render_gt['target_depths'] = target_depths.to(self.device)
100
+ render_gt['target_alphas'] = target_alphas.to(self.device)
101
+
102
+ return lrm_generator_input, render_gt
103
+
104
+ def prepare_validation_batch_data(self, batch):
105
+ lrm_generator_input = {}
106
+
107
+ # input images
108
+ images = batch['input_images']
109
+ images = v2.functional.resize(
110
+ images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
111
+
112
+ lrm_generator_input['images'] = images.to(self.device)
113
+
114
+ input_c2ws = batch['input_c2ws'].flatten(-2)
115
+ input_Ks = batch['input_Ks'].flatten(-2)
116
+
117
+ input_extrinsics = input_c2ws[:, :, :12]
118
+ input_intrinsics = torch.stack([
119
+ input_Ks[:, :, 0], input_Ks[:, :, 4],
120
+ input_Ks[:, :, 2], input_Ks[:, :, 5],
121
+ ], dim=-1)
122
+ cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
123
+
124
+ lrm_generator_input['cameras'] = cameras.to(self.device)
125
+
126
+ render_c2ws = batch['render_c2ws'].flatten(-2)
127
+ render_Ks = batch['render_Ks'].flatten(-2)
128
+ render_cameras = torch.cat([render_c2ws, render_Ks], dim=-1)
129
+
130
+ lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
131
+ lrm_generator_input['render_size'] = 384
132
+ lrm_generator_input['crop_params'] = None
133
+
134
+ return lrm_generator_input
135
+
136
+ def forward_lrm_generator(
137
+ self,
138
+ images,
139
+ cameras,
140
+ render_cameras,
141
+ render_size=192,
142
+ crop_params=None,
143
+ chunk_size=1,
144
+ ):
145
+ planes = torch.utils.checkpoint.checkpoint(
146
+ self.lrm_generator.forward_planes,
147
+ images,
148
+ cameras,
149
+ use_reentrant=False,
150
+ )
151
+ frames = []
152
+ for i in range(0, render_cameras.shape[1], chunk_size):
153
+ frames.append(
154
+ torch.utils.checkpoint.checkpoint(
155
+ self.lrm_generator.synthesizer,
156
+ planes,
157
+ cameras=render_cameras[:, i:i+chunk_size],
158
+ render_size=render_size,
159
+ crop_params=crop_params,
160
+ use_reentrant=False
161
+ )
162
+ )
163
+ frames = {
164
+ k: torch.cat([r[k] for r in frames], dim=1)
165
+ for k in frames[0].keys()
166
+ }
167
+ return frames
168
+
169
+ def forward(self, lrm_generator_input):
170
+ images = lrm_generator_input['images']
171
+ cameras = lrm_generator_input['cameras']
172
+ render_cameras = lrm_generator_input['render_cameras']
173
+ render_size = lrm_generator_input['render_size']
174
+ crop_params = lrm_generator_input['crop_params']
175
+
176
+ out = self.forward_lrm_generator(
177
+ images,
178
+ cameras,
179
+ render_cameras,
180
+ render_size=render_size,
181
+ crop_params=crop_params,
182
+ chunk_size=1,
183
+ )
184
+ render_images = torch.clamp(out['images_rgb'], 0.0, 1.0)
185
+ render_depths = out['images_depth']
186
+ render_alphas = torch.clamp(out['images_weight'], 0.0, 1.0)
187
+
188
+ out = {
189
+ 'render_images': render_images,
190
+ 'render_depths': render_depths,
191
+ 'render_alphas': render_alphas,
192
+ }
193
+ return out
194
+
195
+ def training_step(self, batch, batch_idx):
196
+ lrm_generator_input, render_gt = self.prepare_batch_data(batch)
197
+
198
+ render_out = self.forward(lrm_generator_input)
199
+
200
+ loss, loss_dict = self.compute_loss(render_out, render_gt)
201
+
202
+ self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
203
+
204
+ if self.global_step % 1000 == 0 and self.global_rank == 0:
205
+ B, N, C, H, W = render_gt['target_images'].shape
206
+ N_in = lrm_generator_input['images'].shape[1]
207
+
208
+ input_images = v2.functional.resize(
209
+ lrm_generator_input['images'], (H, W), interpolation=3, antialias=True).clamp(0, 1)
210
+ input_images = torch.cat(
211
+ [input_images, torch.ones(B, N-N_in, C, H, W).to(input_images)], dim=1)
212
+
213
+ input_images = rearrange(
214
+ input_images, 'b n c h w -> b c h (n w)')
215
+ target_images = rearrange(
216
+ render_gt['target_images'], 'b n c h w -> b c h (n w)')
217
+ render_images = rearrange(
218
+ render_out['render_images'], 'b n c h w -> b c h (n w)')
219
+ target_alphas = rearrange(
220
+ repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
221
+ render_alphas = rearrange(
222
+ repeat(render_out['render_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
223
+ target_depths = rearrange(
224
+ repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
225
+ render_depths = rearrange(
226
+ repeat(render_out['render_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
227
+ MAX_DEPTH = torch.max(target_depths)
228
+ target_depths = target_depths / MAX_DEPTH * target_alphas
229
+ render_depths = render_depths / MAX_DEPTH
230
+
231
+ grid = torch.cat([
232
+ input_images,
233
+ target_images, render_images,
234
+ target_alphas, render_alphas,
235
+ target_depths, render_depths,
236
+ ], dim=-2)
237
+ grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
238
+
239
+ save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png'))
240
+
241
+ return loss
242
+
243
+ def compute_loss(self, render_out, render_gt):
244
+ # NOTE: the rgb value range of OpenLRM is [0, 1]
245
+ render_images = render_out['render_images']
246
+ target_images = render_gt['target_images'].to(render_images)
247
+ render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
248
+ target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
249
+
250
+ loss_mse = F.mse_loss(render_images, target_images)
251
+ loss_lpips = 2.0 * self.lpips(render_images, target_images)
252
+
253
+ render_alphas = render_out['render_alphas']
254
+ target_alphas = render_gt['target_alphas']
255
+ loss_mask = F.mse_loss(render_alphas, target_alphas)
256
+
257
+ loss = loss_mse + loss_lpips + loss_mask
258
+
259
+ prefix = 'train'
260
+ loss_dict = {}
261
+ loss_dict.update({f'{prefix}/loss_mse': loss_mse})
262
+ loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
263
+ loss_dict.update({f'{prefix}/loss_mask': loss_mask})
264
+ loss_dict.update({f'{prefix}/loss': loss})
265
+
266
+ return loss, loss_dict
267
+
268
+ @torch.no_grad()
269
+ def validation_step(self, batch, batch_idx):
270
+ lrm_generator_input = self.prepare_validation_batch_data(batch)
271
+
272
+ render_out = self.forward(lrm_generator_input)
273
+ render_images = render_out['render_images']
274
+ render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
275
+
276
+ self.validation_step_outputs.append(render_images)
277
+
278
+ def on_validation_epoch_end(self):
279
+ images = torch.cat(self.validation_step_outputs, dim=-1)
280
+
281
+ all_images = self.all_gather(images)
282
+ all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
283
+
284
+ if self.global_rank == 0:
285
+ image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
286
+
287
+ grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
288
+ save_image(grid, image_path)
289
+ print(f"Saved image to {image_path}")
290
+
291
+ self.validation_step_outputs.clear()
292
+
293
+ def configure_optimizers(self):
294
+ lr = self.learning_rate
295
+
296
+ params = []
297
+
298
+ lrm_params_fast, lrm_params_slow = [], []
299
+ for n, p in self.lrm_generator.named_parameters():
300
+ if 'adaLN_modulation' in n or 'camera_embedder' in n:
301
+ lrm_params_fast.append(p)
302
+ else:
303
+ lrm_params_slow.append(p)
304
+ params.append({"params": lrm_params_fast, "lr": lr, "weight_decay": 0.01 })
305
+ params.append({"params": lrm_params_slow, "lr": lr / 10.0, "weight_decay": 0.01 })
306
+
307
+ optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95))
308
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4)
309
+
310
+ return {'optimizer': optimizer, 'lr_scheduler': scheduler}
src/model_mesh.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torchvision.transforms import v2
6
+ from torchvision.utils import make_grid, save_image
7
+ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
8
+ import pytorch_lightning as pl
9
+ from einops import rearrange, repeat
10
+
11
+ from src.utils.train_util import instantiate_from_config
12
+
13
+
14
+ # Regulrarization loss for FlexiCubes
15
+ def sdf_reg_loss_batch(sdf, all_edges):
16
+ sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2)
17
+ mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
18
+ sdf_f1x6x2 = sdf_f1x6x2[mask]
19
+ sdf_diff = F.binary_cross_entropy_with_logits(
20
+ sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
21
+ F.binary_cross_entropy_with_logits(
22
+ sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
23
+ return sdf_diff
24
+
25
+
26
+ class MVRecon(pl.LightningModule):
27
+ def __init__(
28
+ self,
29
+ lrm_generator_config,
30
+ input_size=256,
31
+ render_size=512,
32
+ init_ckpt=None,
33
+ ):
34
+ super(MVRecon, self).__init__()
35
+
36
+ self.input_size = input_size
37
+ self.render_size = render_size
38
+
39
+ # init modules
40
+ self.lrm_generator = instantiate_from_config(lrm_generator_config)
41
+
42
+ self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
43
+
44
+ # Load weights from pretrained MVRecon model, and use the mlp
45
+ # weights to initialize the weights of sdf and rgb mlps.
46
+ if init_ckpt is not None:
47
+ sd = torch.load(init_ckpt, map_location='cpu')['state_dict']
48
+ sd = {k: v for k, v in sd.items() if k.startswith('lrm_generator')}
49
+ sd_fc = {}
50
+ for k, v in sd.items():
51
+ if k.startswith('lrm_generator.synthesizer.decoder.net.'):
52
+ if k.startswith('lrm_generator.synthesizer.decoder.net.6.'): # last layer
53
+ # Here we assume the density filed's isosurface threshold is t,
54
+ # we reverse the sign of density filed to initialize SDF field.
55
+ # -(w*x + b - t) = (-w)*x + (t - b)
56
+ if 'weight' in k:
57
+ sd_fc[k.replace('net.', 'net_sdf.')] = -v[0:1]
58
+ else:
59
+ sd_fc[k.replace('net.', 'net_sdf.')] = 3.0 - v[0:1]
60
+ sd_fc[k.replace('net.', 'net_rgb.')] = v[1:4]
61
+ else:
62
+ sd_fc[k.replace('net.', 'net_sdf.')] = v
63
+ sd_fc[k.replace('net.', 'net_rgb.')] = v
64
+ else:
65
+ sd_fc[k] = v
66
+ sd_fc = {k.replace('lrm_generator.', ''): v for k, v in sd_fc.items()}
67
+ # missing `net_deformation` and `net_weight` parameters
68
+ self.lrm_generator.load_state_dict(sd_fc, strict=False)
69
+ print(f'Loaded weights from {init_ckpt}')
70
+
71
+ self.validation_step_outputs = []
72
+
73
+ def on_fit_start(self):
74
+ device = torch.device(f'cuda:{self.global_rank}')
75
+ self.lrm_generator.init_flexicubes_geometry(device)
76
+ if self.global_rank == 0:
77
+ os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
78
+ os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
79
+
80
+ def prepare_batch_data(self, batch):
81
+ lrm_generator_input = {}
82
+ render_gt = {}
83
+
84
+ # input images
85
+ images = batch['input_images']
86
+ images = v2.functional.resize(
87
+ images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
88
+
89
+ lrm_generator_input['images'] = images.to(self.device)
90
+
91
+ # input cameras and render cameras
92
+ input_c2ws = batch['input_c2ws']
93
+ input_Ks = batch['input_Ks']
94
+ target_c2ws = batch['target_c2ws']
95
+
96
+ render_c2ws = torch.cat([input_c2ws, target_c2ws], dim=1)
97
+ render_w2cs = torch.linalg.inv(render_c2ws)
98
+
99
+ input_extrinsics = input_c2ws.flatten(-2)
100
+ input_extrinsics = input_extrinsics[:, :, :12]
101
+ input_intrinsics = input_Ks.flatten(-2)
102
+ input_intrinsics = torch.stack([
103
+ input_intrinsics[:, :, 0], input_intrinsics[:, :, 4],
104
+ input_intrinsics[:, :, 2], input_intrinsics[:, :, 5],
105
+ ], dim=-1)
106
+ cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
107
+
108
+ # add noise to input_cameras
109
+ cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
110
+
111
+ lrm_generator_input['cameras'] = cameras.to(self.device)
112
+ lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
113
+
114
+ # target images
115
+ target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
116
+ target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
117
+ target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
118
+ target_normals = torch.cat([batch['input_normals'], batch['target_normals']], dim=1)
119
+
120
+ render_size = self.render_size
121
+ target_images = v2.functional.resize(
122
+ target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
123
+ target_depths = v2.functional.resize(
124
+ target_depths, render_size, interpolation=0, antialias=True)
125
+ target_alphas = v2.functional.resize(
126
+ target_alphas, render_size, interpolation=0, antialias=True)
127
+ target_normals = v2.functional.resize(
128
+ target_normals, render_size, interpolation=3, antialias=True)
129
+
130
+ lrm_generator_input['render_size'] = render_size
131
+
132
+ render_gt['target_images'] = target_images.to(self.device)
133
+ render_gt['target_depths'] = target_depths.to(self.device)
134
+ render_gt['target_alphas'] = target_alphas.to(self.device)
135
+ render_gt['target_normals'] = target_normals.to(self.device)
136
+
137
+ return lrm_generator_input, render_gt
138
+
139
+ def prepare_validation_batch_data(self, batch):
140
+ lrm_generator_input = {}
141
+
142
+ # input images
143
+ images = batch['input_images']
144
+ images = v2.functional.resize(
145
+ images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
146
+
147
+ lrm_generator_input['images'] = images.to(self.device)
148
+
149
+ # input cameras
150
+ input_c2ws = batch['input_c2ws'].flatten(-2)
151
+ input_Ks = batch['input_Ks'].flatten(-2)
152
+
153
+ input_extrinsics = input_c2ws[:, :, :12]
154
+ input_intrinsics = torch.stack([
155
+ input_Ks[:, :, 0], input_Ks[:, :, 4],
156
+ input_Ks[:, :, 2], input_Ks[:, :, 5],
157
+ ], dim=-1)
158
+ cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
159
+
160
+ lrm_generator_input['cameras'] = cameras.to(self.device)
161
+
162
+ # render cameras
163
+ render_c2ws = batch['render_c2ws']
164
+ render_w2cs = torch.linalg.inv(render_c2ws)
165
+
166
+ lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
167
+ lrm_generator_input['render_size'] = 384
168
+
169
+ return lrm_generator_input
170
+
171
+ def forward_lrm_generator(self, images, cameras, render_cameras, render_size=512):
172
+ planes = torch.utils.checkpoint.checkpoint(
173
+ self.lrm_generator.forward_planes,
174
+ images,
175
+ cameras,
176
+ use_reentrant=False,
177
+ )
178
+ out = self.lrm_generator.forward_geometry(
179
+ planes,
180
+ render_cameras,
181
+ render_size,
182
+ )
183
+ return out
184
+
185
+ def forward(self, lrm_generator_input):
186
+ images = lrm_generator_input['images']
187
+ cameras = lrm_generator_input['cameras']
188
+ render_cameras = lrm_generator_input['render_cameras']
189
+ render_size = lrm_generator_input['render_size']
190
+
191
+ out = self.forward_lrm_generator(
192
+ images, cameras, render_cameras, render_size=render_size)
193
+
194
+ return out
195
+
196
+ def training_step(self, batch, batch_idx):
197
+ lrm_generator_input, render_gt = self.prepare_batch_data(batch)
198
+
199
+ render_out = self.forward(lrm_generator_input)
200
+
201
+ loss, loss_dict = self.compute_loss(render_out, render_gt)
202
+
203
+ self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
204
+
205
+ if self.global_step % 1000 == 0 and self.global_rank == 0:
206
+ B, N, C, H, W = render_gt['target_images'].shape
207
+ N_in = lrm_generator_input['images'].shape[1]
208
+
209
+ target_images = rearrange(
210
+ render_gt['target_images'], 'b n c h w -> b c h (n w)')
211
+ render_images = rearrange(
212
+ render_out['img'], 'b n c h w -> b c h (n w)')
213
+ target_alphas = rearrange(
214
+ repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
215
+ render_alphas = rearrange(
216
+ repeat(render_out['mask'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
217
+ target_depths = rearrange(
218
+ repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
219
+ render_depths = rearrange(
220
+ repeat(render_out['depth'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
221
+ target_normals = rearrange(
222
+ render_gt['target_normals'], 'b n c h w -> b c h (n w)')
223
+ render_normals = rearrange(
224
+ render_out['normal'], 'b n c h w -> b c h (n w)')
225
+ MAX_DEPTH = torch.max(target_depths)
226
+ target_depths = target_depths / MAX_DEPTH * target_alphas
227
+ render_depths = render_depths / MAX_DEPTH
228
+
229
+ grid = torch.cat([
230
+ target_images, render_images,
231
+ target_alphas, render_alphas,
232
+ target_depths, render_depths,
233
+ target_normals, render_normals,
234
+ ], dim=-2)
235
+ grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
236
+
237
+ image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')
238
+ save_image(grid, image_path)
239
+ print(f"Saved image to {image_path}")
240
+
241
+ return loss
242
+
243
+ def compute_loss(self, render_out, render_gt):
244
+ # NOTE: the rgb value range of OpenLRM is [0, 1]
245
+ render_images = render_out['img']
246
+ target_images = render_gt['target_images'].to(render_images)
247
+ render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
248
+ target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
249
+ loss_mse = F.mse_loss(render_images, target_images)
250
+ loss_lpips = 2.0 * self.lpips(render_images, target_images)
251
+
252
+ render_alphas = render_out['mask']
253
+ target_alphas = render_gt['target_alphas']
254
+ loss_mask = F.mse_loss(render_alphas, target_alphas)
255
+
256
+ render_depths = render_out['depth']
257
+ target_depths = render_gt['target_depths']
258
+ loss_depth = 0.5 * F.l1_loss(render_depths[target_alphas>0], target_depths[target_alphas>0])
259
+
260
+ render_normals = render_out['normal'] * 2.0 - 1.0
261
+ target_normals = render_gt['target_normals'] * 2.0 - 1.0
262
+ similarity = (render_normals * target_normals).sum(dim=-3).abs()
263
+ normal_mask = target_alphas.squeeze(-3)
264
+ loss_normal = 1 - similarity[normal_mask>0].mean()
265
+ loss_normal = 0.2 * loss_normal
266
+
267
+ # flexicubes regularization loss
268
+ sdf = render_out['sdf']
269
+ sdf_reg_loss = render_out['sdf_reg_loss']
270
+ sdf_reg_loss_entropy = sdf_reg_loss_batch(sdf, self.lrm_generator.geometry.all_edges).mean() * 0.01
271
+ _, flexicubes_surface_reg, flexicubes_weights_reg = sdf_reg_loss
272
+ flexicubes_surface_reg = flexicubes_surface_reg.mean() * 0.5
273
+ flexicubes_weights_reg = flexicubes_weights_reg.mean() * 0.1
274
+
275
+ loss_reg = sdf_reg_loss_entropy + flexicubes_surface_reg + flexicubes_weights_reg
276
+
277
+ loss = loss_mse + loss_lpips + loss_mask + loss_normal + loss_reg
278
+
279
+ prefix = 'train'
280
+ loss_dict = {}
281
+ loss_dict.update({f'{prefix}/loss_mse': loss_mse})
282
+ loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
283
+ loss_dict.update({f'{prefix}/loss_mask': loss_mask})
284
+ loss_dict.update({f'{prefix}/loss_normal': loss_normal})
285
+ loss_dict.update({f'{prefix}/loss_depth': loss_depth})
286
+ loss_dict.update({f'{prefix}/loss_reg_sdf': sdf_reg_loss_entropy})
287
+ loss_dict.update({f'{prefix}/loss_reg_surface': flexicubes_surface_reg})
288
+ loss_dict.update({f'{prefix}/loss_reg_weights': flexicubes_weights_reg})
289
+ loss_dict.update({f'{prefix}/loss': loss})
290
+
291
+ return loss, loss_dict
292
+
293
+ @torch.no_grad()
294
+ def validation_step(self, batch, batch_idx):
295
+ lrm_generator_input = self.prepare_validation_batch_data(batch)
296
+
297
+ render_out = self.forward(lrm_generator_input)
298
+ render_images = render_out['img']
299
+ render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
300
+
301
+ self.validation_step_outputs.append(render_images)
302
+
303
+ def on_validation_epoch_end(self):
304
+ images = torch.cat(self.validation_step_outputs, dim=-1)
305
+
306
+ all_images = self.all_gather(images)
307
+ all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
308
+
309
+ if self.global_rank == 0:
310
+ image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
311
+
312
+ grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
313
+ save_image(grid, image_path)
314
+ print(f"Saved image to {image_path}")
315
+
316
+ self.validation_step_outputs.clear()
317
+
318
+ def configure_optimizers(self):
319
+ lr = self.learning_rate
320
+
321
+ optimizer = torch.optim.AdamW(
322
+ self.lrm_generator.parameters(), lr=lr, betas=(0.90, 0.95), weight_decay=0.01)
323
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 100000, eta_min=0)
324
+
325
+ return {'optimizer': optimizer, 'lr_scheduler': scheduler}
src/models/__init__.py ADDED
File without changes
src/models/decoder/__init__.py ADDED
File without changes
src/models/decoder/transformer.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+
20
+ class BasicTransformerBlock(nn.Module):
21
+ """
22
+ Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
23
+ """
24
+ # use attention from torch.nn.MultiHeadAttention
25
+ # Block contains a cross-attention layer, a self-attention layer, and a MLP
26
+ def __init__(
27
+ self,
28
+ inner_dim: int,
29
+ cond_dim: int,
30
+ num_heads: int,
31
+ eps: float,
32
+ attn_drop: float = 0.,
33
+ attn_bias: bool = False,
34
+ mlp_ratio: float = 4.,
35
+ mlp_drop: float = 0.,
36
+ ):
37
+ super().__init__()
38
+
39
+ self.norm1 = nn.LayerNorm(inner_dim)
40
+ self.cross_attn = nn.MultiheadAttention(
41
+ embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
42
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
43
+ self.norm2 = nn.LayerNorm(inner_dim)
44
+ self.self_attn = nn.MultiheadAttention(
45
+ embed_dim=inner_dim, num_heads=num_heads,
46
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
47
+ self.norm3 = nn.LayerNorm(inner_dim)
48
+ self.mlp = nn.Sequential(
49
+ nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
50
+ nn.GELU(),
51
+ nn.Dropout(mlp_drop),
52
+ nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
53
+ nn.Dropout(mlp_drop),
54
+ )
55
+
56
+ def forward(self, x, cond):
57
+ # x: [N, L, D]
58
+ # cond: [N, L_cond, D_cond]
59
+ x = x + self.cross_attn(self.norm1(x), cond, cond)[0]
60
+ before_sa = self.norm2(x)
61
+ x = x + self.self_attn(before_sa, before_sa, before_sa)[0]
62
+ x = x + self.mlp(self.norm3(x))
63
+ return x
64
+
65
+
66
+ class TriplaneTransformer(nn.Module):
67
+ """
68
+ Transformer with condition that generates a triplane representation.
69
+
70
+ Reference:
71
+ Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486
72
+ """
73
+ def __init__(
74
+ self,
75
+ inner_dim: int,
76
+ image_feat_dim: int,
77
+ triplane_low_res: int,
78
+ triplane_high_res: int,
79
+ triplane_dim: int,
80
+ num_layers: int,
81
+ num_heads: int,
82
+ eps: float = 1e-6,
83
+ ):
84
+ super().__init__()
85
+
86
+ # attributes
87
+ self.triplane_low_res = triplane_low_res
88
+ self.triplane_high_res = triplane_high_res
89
+ self.triplane_dim = triplane_dim
90
+
91
+ # modules
92
+ # initialize pos_embed with 1/sqrt(dim) * N(0, 1)
93
+ self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5)
94
+ self.layers = nn.ModuleList([
95
+ BasicTransformerBlock(
96
+ inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps)
97
+ for _ in range(num_layers)
98
+ ])
99
+ self.norm = nn.LayerNorm(inner_dim, eps=eps)
100
+ self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
101
+
102
+ def forward(self, image_feats):
103
+ # image_feats: [N, L_cond, D_cond]
104
+
105
+ N = image_feats.shape[0]
106
+ H = W = self.triplane_low_res
107
+ L = 3 * H * W
108
+
109
+ x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
110
+ for layer in self.layers:
111
+ x = layer(x, image_feats)
112
+ x = self.norm(x)
113
+
114
+ # separate each plane and apply deconv
115
+ x = x.view(N, 3, H, W, -1)
116
+ x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
117
+ x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
118
+ x = self.deconv(x) # [3*N, D', H', W']
119
+ x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
120
+ x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
121
+ x = x.contiguous()
122
+
123
+ return x