Spaces:
Runtime error
Runtime error
device
Browse files
app.py
CHANGED
@@ -19,14 +19,14 @@ np.random.seed(0)
|
|
19 |
|
20 |
def get_pil_im(im, resolution=64):
|
21 |
im = imresize(im, (resolution, resolution))[:, :, :3]
|
22 |
-
im = th.Tensor(im).permute(2, 0, 1)[None, :, :, :].contiguous()
|
23 |
return im
|
24 |
|
25 |
|
26 |
# generate image components and reconstruction
|
27 |
def gen_image_and_components(model, gd, im, num_components=4, sample_method='ddim', batch_size=1, image_size=64, device='cuda', num_images=1):
|
28 |
"""Generate row of orig image, individual components, and reconstructed image"""
|
29 |
-
orig_img = get_pil_im(im, resolution=image_size)
|
30 |
latent = model.encode_latent(orig_img)
|
31 |
model_kwargs = {'latent': latent}
|
32 |
|
@@ -74,7 +74,7 @@ def gen_image_and_components(model, gd, im, num_components=4, sample_method='ddi
|
|
74 |
|
75 |
def decompose_image(im):
|
76 |
sample_method = 'ddim'
|
77 |
-
result = gen_image_and_components(clevr_model, GD[sample_method], im, sample_method=sample_method, num_images=1)
|
78 |
return result.permute(1, 2, 0).numpy()
|
79 |
|
80 |
|
|
|
19 |
|
20 |
def get_pil_im(im, resolution=64):
|
21 |
im = imresize(im, (resolution, resolution))[:, :, :3]
|
22 |
+
im = th.Tensor(im).permute(2, 0, 1)[None, :, :, :].contiguous()
|
23 |
return im
|
24 |
|
25 |
|
26 |
# generate image components and reconstruction
|
27 |
def gen_image_and_components(model, gd, im, num_components=4, sample_method='ddim', batch_size=1, image_size=64, device='cuda', num_images=1):
|
28 |
"""Generate row of orig image, individual components, and reconstructed image"""
|
29 |
+
orig_img = get_pil_im(im, resolution=image_size).to(device)
|
30 |
latent = model.encode_latent(orig_img)
|
31 |
model_kwargs = {'latent': latent}
|
32 |
|
|
|
74 |
|
75 |
def decompose_image(im):
|
76 |
sample_method = 'ddim'
|
77 |
+
result = gen_image_and_components(clevr_model, GD[sample_method], im, sample_method=sample_method, num_images=1, device=device)
|
78 |
return result.permute(1, 2, 0).numpy()
|
79 |
|
80 |
|