Spaces:
Running
on
L4
Running
on
L4
FrozenBurning
commited on
Commit
·
692682d
1
Parent(s):
02e04ed
fix app.py
Browse files
app.py
CHANGED
@@ -60,6 +60,7 @@ if "latent_mean" in config.model:
|
|
60 |
latent_std = torch.Tensor(config.model.latent_std)[None, None, :].to(device)
|
61 |
assert latent_mean.shape[-1] == config.model.generator.in_channels
|
62 |
perchannel_norm = True
|
|
|
63 |
|
64 |
config.diffusion.pop("timestep_respacing")
|
65 |
config.model.pop("vae")
|
@@ -114,14 +115,14 @@ def process(input_image, input_num_steps=25, input_seed=42, input_cfg=6.0):
|
|
114 |
final_samples = samples
|
115 |
recon_param = final_samples["sample"].reshape(inf_bs, config.model.num_prims, -1)
|
116 |
if perchannel_norm:
|
117 |
-
recon_param = recon_param /
|
118 |
recon_srt_param = recon_param[:, :, 0:4]
|
119 |
recon_feat_param = recon_param[:, :, 4:] # [8, 2048, 64]
|
120 |
recon_feat_param_list = []
|
121 |
# one-by-one to avoid oom
|
122 |
for inf_bidx in range(inf_bs):
|
123 |
if not perchannel_norm:
|
124 |
-
decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]) /
|
125 |
else:
|
126 |
decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]))
|
127 |
recon_feat_param_list.append(decoded.detach())
|
|
|
60 |
latent_std = torch.Tensor(config.model.latent_std)[None, None, :].to(device)
|
61 |
assert latent_mean.shape[-1] == config.model.generator.in_channels
|
62 |
perchannel_norm = True
|
63 |
+
latent_nf = config.model.latent_nf
|
64 |
|
65 |
config.diffusion.pop("timestep_respacing")
|
66 |
config.model.pop("vae")
|
|
|
115 |
final_samples = samples
|
116 |
recon_param = final_samples["sample"].reshape(inf_bs, config.model.num_prims, -1)
|
117 |
if perchannel_norm:
|
118 |
+
recon_param = recon_param / latent_nf * latent_std + latent_mean
|
119 |
recon_srt_param = recon_param[:, :, 0:4]
|
120 |
recon_feat_param = recon_param[:, :, 4:] # [8, 2048, 64]
|
121 |
recon_feat_param_list = []
|
122 |
# one-by-one to avoid oom
|
123 |
for inf_bidx in range(inf_bs):
|
124 |
if not perchannel_norm:
|
125 |
+
decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]) / latent_nf)
|
126 |
else:
|
127 |
decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]))
|
128 |
recon_feat_param_list.append(decoded.detach())
|