diff --git a/models/hierarchy_inference_model.py b/models/hierarchy_inference_model.py
index 3116307..5de661d 100644
--- a/models/hierarchy_inference_model.py
+++ b/models/hierarchy_inference_model.py
@@ -21,7 +21,7 @@ class VQGANTextureAwareSpatialHierarchyInferenceModel():
 
     def __init__(self, opt):
         self.opt = opt
-        self.device = torch.device('cuda')
+        self.device = torch.device(opt['device'])
         self.is_train = opt['is_train']
 
         self.top_encoder = Encoder(
diff --git a/models/hierarchy_vqgan_model.py b/models/hierarchy_vqgan_model.py
index 4b0d657..0bf4712 100644
--- a/models/hierarchy_vqgan_model.py
+++ b/models/hierarchy_vqgan_model.py
@@ -20,7 +20,7 @@ class HierarchyVQSpatialTextureAwareModel():
 
     def __init__(self, opt):
         self.opt = opt
-        self.device = torch.device('cuda')
+        self.device = torch.device(opt['device'])
         self.top_encoder = Encoder(
             ch=opt['top_ch'],
             num_res_blocks=opt['top_num_res_blocks'],
diff --git a/models/parsing_gen_model.py b/models/parsing_gen_model.py
index 9440345..15a1ecb 100644
--- a/models/parsing_gen_model.py
+++ b/models/parsing_gen_model.py
@@ -22,7 +22,7 @@ class ParsingGenModel():
 
     def __init__(self, opt):
         self.opt = opt
-        self.device = torch.device('cuda')
+        self.device = torch.device(opt['device'])
         self.is_train = opt['is_train']
 
         self.attr_embedder = ShapeAttrEmbedding(
diff --git a/models/sample_model.py b/models/sample_model.py
index 4c60e3f..5265cd0 100644
--- a/models/sample_model.py
+++ b/models/sample_model.py
@@ -23,7 +23,7 @@ class BaseSampleModel():
 
     def __init__(self, opt):
         self.opt = opt
-        self.device = torch.device('cuda')
+        self.device = torch.device(opt['device'])
 
         # hierarchical VQVAE
         self.decoder = Decoder(
@@ -123,7 +123,7 @@ class BaseSampleModel():
 
     def load_top_pretrain_models(self):
         # load pretrained vqgan
-        top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
+        top_vae_checkpoint = torch.load(self.opt['top_vae_path'], map_location=self.device)
 
         self.decoder.load_state_dict(
             top_vae_checkpoint['decoder'], strict=True)
@@ -137,7 +137,7 @@ class BaseSampleModel():
         self.top_post_quant_conv.eval()
 
     def load_bot_pretrain_network(self):
-        checkpoint = torch.load(self.opt['bot_vae_path'])
+        checkpoint = torch.load(self.opt['bot_vae_path'], map_location=self.device)
         self.bot_decoder_res.load_state_dict(
             checkpoint['bot_decoder_res'], strict=True)
         self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
@@ -153,7 +153,7 @@ class BaseSampleModel():
 
     def load_pretrained_segm_token(self):
         # load pretrained vqgan for segmentation mask
-        segm_token_checkpoint = torch.load(self.opt['segm_token_path'])
+        segm_token_checkpoint = torch.load(self.opt['segm_token_path'], map_location=self.device)
         self.segm_encoder.load_state_dict(
             segm_token_checkpoint['encoder'], strict=True)
         self.segm_quantizer.load_state_dict(
@@ -166,7 +166,7 @@ class BaseSampleModel():
         self.segm_quant_conv.eval()
 
     def load_index_pred_network(self):
-        checkpoint = torch.load(self.opt['pretrained_index_network'])
+        checkpoint = torch.load(self.opt['pretrained_index_network'], map_location=self.device)
         self.index_pred_guidance_encoder.load_state_dict(
             checkpoint['guidance_encoder'], strict=True)
         self.index_pred_decoder.load_state_dict(
@@ -176,7 +176,7 @@ class BaseSampleModel():
         self.index_pred_decoder.eval()
 
     def load_sampler_pretrained_network(self):
-        checkpoint = torch.load(self.opt['pretrained_sampler'])
+        checkpoint = torch.load(self.opt['pretrained_sampler'], map_location=self.device)
         self.sampler_fn.load_state_dict(checkpoint, strict=True)
         self.sampler_fn.eval()
 
@@ -397,7 +397,7 @@ class SampleFromPoseModel(BaseSampleModel):
                         [185, 210, 205], [130, 165, 180], [225, 141, 151]]
 
     def load_shape_generation_models(self):
-        checkpoint = torch.load(self.opt['pretrained_parsing_gen'])
+        checkpoint = torch.load(self.opt['pretrained_parsing_gen'], map_location=self.device)
 
         self.shape_attr_embedder.load_state_dict(
             checkpoint['embedder'], strict=True)
diff --git a/models/transformer_model.py b/models/transformer_model.py
index 7db0f3e..4523d17 100644
--- a/models/transformer_model.py
+++ b/models/transformer_model.py
@@ -21,7 +21,7 @@ class TransformerTextureAwareModel():
 
     def __init__(self, opt):
         self.opt = opt
-        self.device = torch.device('cuda')
+        self.device = torch.device(opt['device'])
         self.is_train = opt['is_train']
 
         # VQVAE for image
@@ -317,10 +317,10 @@ class TransformerTextureAwareModel():
     def sample_fn(self, temp=1.0, sample_steps=None):
         self._denoise_fn.eval()
 
-        b, device = self.image.size(0), 'cuda'
+        b = self.image.size(0)
         x_t = torch.ones(
-            (b, np.prod(self.shape)), device=device).long() * self.mask_id
-        unmasked = torch.zeros_like(x_t, device=device).bool()
+            (b, np.prod(self.shape)), device=self.device).long() * self.mask_id
+        unmasked = torch.zeros_like(x_t, device=self.device).bool()
         sample_steps = list(range(1, sample_steps + 1))
 
         texture_mask_flatten = self.texture_tokens.view(-1)
@@ -336,11 +336,11 @@ class TransformerTextureAwareModel():
 
         for t in reversed(sample_steps):
             print(f'Sample timestep {t:4d}', end='\r')
-            t = torch.full((b, ), t, device=device, dtype=torch.long)
+            t = torch.full((b, ), t, device=self.device, dtype=torch.long)
 
             # where to unmask
             changes = torch.rand(
-                x_t.shape, device=device) < 1 / t.float().unsqueeze(-1)
+                x_t.shape, device=self.device) < 1 / t.float().unsqueeze(-1)
             # don't unmask somewhere already unmasked
             changes = torch.bitwise_xor(changes,
                                         torch.bitwise_and(changes, unmasked))
diff --git a/models/vqgan_model.py b/models/vqgan_model.py
index 13a2e70..9c840f1 100644
--- a/models/vqgan_model.py
+++ b/models/vqgan_model.py
@@ -20,7 +20,7 @@ class VQModel():
     def __init__(self, opt):
         super().__init__()
         self.opt = opt
-        self.device = torch.device('cuda')
+        self.device = torch.device(opt['device'])
         self.encoder = Encoder(
             ch=opt['ch'],
             num_res_blocks=opt['num_res_blocks'],
@@ -390,7 +390,7 @@ class VQImageSegmTextureModel(VQImageModel):
 
     def __init__(self, opt):
         self.opt = opt
-        self.device = torch.device('cuda')
+        self.device = torch.device(opt['device'])
         self.encoder = Encoder(
             ch=opt['ch'],
             num_res_blocks=opt['num_res_blocks'],