jsu27 commited on
Commit
90bc5e7
·
1 Parent(s): 28d6d09

celeb combination demo

Browse files
app.py CHANGED
@@ -98,40 +98,136 @@ gd = SpacedDiffusion(spaced_ts, rescale_timesteps=True, original_num_steps=num_t
98
  GD['ddim'] = gd
99
 
100
 
101
- # !wget https://www.dropbox.com/s/bqpc3ymstz9m05z/clevr_model.pt
102
- # load model
103
 
104
- ckpt_path = download_model('clevr') # 'clevr_model.pt'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  model_kwargs = unet_model_defaults()
107
  # model parameters
108
  model_kwargs.update(dict(
109
- emb_dim=64,
110
  enc_channels=128
111
  ))
112
- clevr_model = create_diffusion_model(**model_kwargs)
113
- clevr_model.eval()
114
 
115
  device = 'cuda' if th.cuda.is_available() else 'cpu'
116
- clevr_model.to(device)
117
 
118
  print(f'loading from {ckpt_path}')
119
  checkpoint = th.load(ckpt_path, map_location='cpu')
120
 
121
- clevr_model.load_state_dict(checkpoint)
122
 
 
123
 
124
 
125
  img_input = gr.inputs.Image(type="numpy", label="Input")
 
 
126
  img_output = gr.outputs.Image(type="numpy", label="Output")
127
 
128
  gr.Interface(
129
- decompose_image,
130
- inputs=img_input,
131
  outputs=img_output,
132
  examples=[
133
- os.path.join(os.path.dirname(__file__), "sample_images/clevr_im_10.png"),
134
- os.path.join(os.path.dirname(__file__), "sample_images/clevr_im_25.png"),
135
- ],
136
-
137
  ).launch()
 
 
98
  GD['ddim'] = gd
99
 
100
 
 
 
101
 
102
+ # ckpt_path = download_model('clevr') # 'clevr_model.pt'
103
+
104
+ # model_kwargs = unet_model_defaults()
105
+ # # model parameters
106
+ # model_kwargs.update(dict(
107
+ # emb_dim=64,
108
+ # enc_channels=128
109
+ # ))
110
+ # clevr_model = create_diffusion_model(**model_kwargs)
111
+ # clevr_model.eval()
112
+
113
+ # device = 'cuda' if th.cuda.is_available() else 'cpu'
114
+ # clevr_model.to(device)
115
+
116
+ # print(f'loading from {ckpt_path}')
117
+ # checkpoint = th.load(ckpt_path, map_location='cpu')
118
+
119
+ # clevr_model.load_state_dict(checkpoint)
120
+
121
+
122
+
123
+ # img_input = gr.inputs.Image(type="numpy", label="Input")
124
+ # img_output = gr.outputs.Image(type="numpy", label="Output")
125
+
126
+ # gr.Interface(
127
+ # decompose_image,
128
+ # inputs=img_input,
129
+ # outputs=img_output,
130
+ # examples=[
131
+ # "sample_images/clevr_im_10.png",
132
+ # "sample_images/clevr_im_25.png",
133
+ # ],
134
+
135
+ # ).launch()
136
+
137
+
138
+
139
+
140
+
141
+ def combine_components_slice(model, gd, im1, im2, indices=None, sample_method='ddim', device='cuda', num_images=4, model_kwargs={}, desc='', save_dir='', dataset='clevr', image_size=64):
142
+ """Combine by adding components together
143
+ """
144
+ assert sample_method in ('ddpm', 'ddim')
145
+
146
+ im1 = get_pil_im(im1, resolution=image_size).to(device)
147
+ im2 = get_pil_im(im2, resolution=image_size).to(device)
148
+
149
+ latent1 = model.encode_latent(im1)
150
+ latent2 = model.encode_latent(im2)
151
+
152
+ num_comps = model.num_components
153
+
154
+ # get latent slices
155
+ if indices == None:
156
+ half = num_comps // 2
157
+ indices = [1] * half + [0] * half # first half 1, second half 0
158
+ indices = th.Tensor(indices) == 1
159
+ indices = indices.reshape(num_comps, 1)
160
+ elif type(indices) == str:
161
+ indices = indices.split(',')
162
+ indices = [int(ind) for ind in indices]
163
+ indices = th.Tensor(indices).reshape(-1, 1) == 1
164
+ assert len(indices) == num_comps
165
+ indices = indices.to(device)
166
+
167
+ latent1 = latent1.reshape(num_comps, -1).to(device)
168
+ latent2 = latent2.reshape(num_comps, -1).to(device)
169
+
170
+ combined_latent = th.where(indices, latent1, latent2)
171
+ combined_latent = combined_latent.reshape(1, -1)
172
+ model_kwargs['latent'] = combined_latent
173
+
174
+ sample_loop_func = gd.p_sample_loop if sample_method == 'ddpm' else gd.ddim_sample_loop
175
+ if sample_method == 'ddim':
176
+ model = gd._wrap_model(model)
177
+
178
+ # sampling loop
179
+ sample = sample_loop_func(
180
+ model,
181
+ (1, 3, image_size, image_size),
182
+ device=device,
183
+ clip_denoised=True,
184
+ progress=True,
185
+ model_kwargs=model_kwargs,
186
+ cond_fn=None,
187
+ )[:1]
188
+
189
+ return sample[0].cpu()
190
+
191
+ def combine_images(im1, im2):
192
+ sample_method = 'ddim'
193
+ result = combine_components_slice(clevr_model, GD[sample_method], im1, im2, indices='1,0,1,0', sample_method=sample_method, num_images=1)
194
+ return result.permute(1, 2, 0).numpy()
195
+
196
+
197
+
198
+ ckpt_path = download_model('celebahq') # 'celeb_model.pt'
199
 
200
  model_kwargs = unet_model_defaults()
201
  # model parameters
202
  model_kwargs.update(dict(
 
203
  enc_channels=128
204
  ))
205
+ celeb_model = create_diffusion_model(**model_kwargs)
206
+ celeb_model.eval()
207
 
208
  device = 'cuda' if th.cuda.is_available() else 'cpu'
209
+ celeb_model.to(device)
210
 
211
  print(f'loading from {ckpt_path}')
212
  checkpoint = th.load(ckpt_path, map_location='cpu')
213
 
214
+ celeb_model.load_state_dict(checkpoint)
215
 
216
+ # Recombination
217
 
218
 
219
  img_input = gr.inputs.Image(type="numpy", label="Input")
220
+ img_input2 = gr.inputs.Image(type="numpy", label="Input")
221
+
222
  img_output = gr.outputs.Image(type="numpy", label="Output")
223
 
224
  gr.Interface(
225
+ combine_images,
226
+ inputs=[img_input, img_input2],
227
  outputs=img_output,
228
  examples=[
229
+ ["sample_images/celebahq_im_15.jpg",
230
+ "sample_images/celebahq_im_21.jpg"]
231
+ ]
 
232
  ).launch()
233
+
download.py CHANGED
@@ -7,7 +7,7 @@ from tqdm.auto import tqdm
7
 
8
  MODEL_PATHS = {
9
  "clevr": "https://www.dropbox.com/s/bqpc3ymstz9m05z/clevr_model.pt",
10
- "celebahq": ""
11
  }
12
 
13
  DATA_PATHS = {
 
7
 
8
  MODEL_PATHS = {
9
  "clevr": "https://www.dropbox.com/s/bqpc3ymstz9m05z/clevr_model.pt",
10
+ "celebahq": "https://www.dropbox.com/s/687wuamoud4cs9x/celeb_model.pt"
11
  }
12
 
13
  DATA_PATHS = {
sample_images/celebahq_im_15.jpg ADDED
sample_images/celebahq_im_21.jpg ADDED