FrozenBurning commited on
Commit
692682d
·
1 Parent(s): 02e04ed

fix app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -60,6 +60,7 @@ if "latent_mean" in config.model:
60
  latent_std = torch.Tensor(config.model.latent_std)[None, None, :].to(device)
61
  assert latent_mean.shape[-1] == config.model.generator.in_channels
62
  perchannel_norm = True
 
63
 
64
  config.diffusion.pop("timestep_respacing")
65
  config.model.pop("vae")
@@ -114,14 +115,14 @@ def process(input_image, input_num_steps=25, input_seed=42, input_cfg=6.0):
114
  final_samples = samples
115
  recon_param = final_samples["sample"].reshape(inf_bs, config.model.num_prims, -1)
116
  if perchannel_norm:
117
- recon_param = recon_param / config.model.latent_nf * latent_std + latent_mean
118
  recon_srt_param = recon_param[:, :, 0:4]
119
  recon_feat_param = recon_param[:, :, 4:] # [8, 2048, 64]
120
  recon_feat_param_list = []
121
  # one-by-one to avoid oom
122
  for inf_bidx in range(inf_bs):
123
  if not perchannel_norm:
124
- decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]) / config.model.latent_nf)
125
  else:
126
  decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]))
127
  recon_feat_param_list.append(decoded.detach())
 
60
  latent_std = torch.Tensor(config.model.latent_std)[None, None, :].to(device)
61
  assert latent_mean.shape[-1] == config.model.generator.in_channels
62
  perchannel_norm = True
63
+ latent_nf = config.model.latent_nf
64
 
65
  config.diffusion.pop("timestep_respacing")
66
  config.model.pop("vae")
 
115
  final_samples = samples
116
  recon_param = final_samples["sample"].reshape(inf_bs, config.model.num_prims, -1)
117
  if perchannel_norm:
118
+ recon_param = recon_param / latent_nf * latent_std + latent_mean
119
  recon_srt_param = recon_param[:, :, 0:4]
120
  recon_feat_param = recon_param[:, :, 4:] # [8, 2048, 64]
121
  recon_feat_param_list = []
122
  # one-by-one to avoid oom
123
  for inf_bidx in range(inf_bs):
124
  if not perchannel_norm:
125
+ decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]) / latent_nf)
126
  else:
127
  decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]))
128
  recon_feat_param_list.append(decoded.detach())