oguzakif commited on
Commit
946e58c
·
1 Parent(s): 6c89e77

cpu support added

Browse files
Files changed (1) hide show
  1. FGT_codes/tool/video_inpainting.py +20 -7
FGT_codes/tool/video_inpainting.py CHANGED
@@ -181,7 +181,10 @@ def create_dir(dir):
181
  def initialize_RAFT(args, device):
182
  """Initializes the RAFT model."""
183
  model = torch.nn.DataParallel(RAFT(args))
184
- model.load_state_dict(torch.load(args.raft_model))
 
 
 
185
 
186
  model = model.module
187
  model.to(device)
@@ -202,9 +205,14 @@ def initialize_LAFC(args, device):
202
  model = configs["model"]
203
  pkg = import_module("LAFC.models.{}".format(model))
204
  model = pkg.Model(configs)
205
- state = torch.load(
206
- checkpoint, map_location=lambda storage, loc: storage.cuda(device)
207
- )
 
 
 
 
 
208
  model.load_state_dict(state["model_state_dict"])
209
  model = model.to(device)
210
  return model, configs
@@ -221,9 +229,14 @@ def initialize_FGT(args, device):
221
  model = configs["model"]
222
  net = import_module("FGT.models.{}".format(model))
223
  model = net.Model(configs).to(device)
224
- state = torch.load(
225
- checkpoint, map_location=lambda storage, loc: storage.cuda(device)
226
- )
 
 
 
 
 
227
  model.load_state_dict(state["model_state_dict"])
228
  return model, configs
229
 
 
181
  def initialize_RAFT(args, device):
182
  """Initializes the RAFT model."""
183
  model = torch.nn.DataParallel(RAFT(args))
184
+ if not torch.cuda.is_available():
185
+ model.load_state_dict(torch.load(args.raft_model, map_location=lambda storage, loc: storage))
186
+ else:
187
+ model.load_state_dict(torch.load(args.raft_model))
188
 
189
  model = model.module
190
  model.to(device)
 
205
  model = configs["model"]
206
  pkg = import_module("LAFC.models.{}".format(model))
207
  model = pkg.Model(configs)
208
+ if not torch.cuda.is_available():
209
+ state = torch.load(
210
+ checkpoint, map_location=lambda storage, loc: storage
211
+ )
212
+ else:
213
+ state = torch.load(
214
+ checkpoint, map_location=lambda storage, loc: storage.cuda(device)
215
+ )
216
  model.load_state_dict(state["model_state_dict"])
217
  model = model.to(device)
218
  return model, configs
 
229
  model = configs["model"]
230
  net = import_module("FGT.models.{}".format(model))
231
  model = net.Model(configs).to(device)
232
+ if not torch.cuda.is_available():
233
+ state = torch.load(
234
+ checkpoint, map_location=lambda storage, loc: storage
235
+ )
236
+ else:
237
+ state = torch.load(
238
+ checkpoint, map_location=lambda storage, loc: storage.cuda(device)
239
+ )
240
  model.load_state_dict(state["model_state_dict"])
241
  return model, configs
242