MatthiasC commited on
Commit
5fc9f83
·
1 Parent(s): ac75587

Re-enable full loading

Browse files
Files changed (2) hide show
  1. dalle/models/__init__.py +4 -2
  2. server.py +24 -24
dalle/models/__init__.py CHANGED
@@ -57,10 +57,12 @@ class Dalle(nn.Module):
57
  lowercase=True,
58
  dropout=None)
59
  logging.info("Loading first stage")
60
- model.stage1.from_ckpt('last.ckpt')
 
61
  logging.info("Loading second stage")
 
62
  #model.stage2.from_ckpt(os.path.join(path, 'dalle_last.ckpt'))
63
- model.stage2.from_ckpt('dalle_last.ckpt')
64
  return model
65
 
66
  @torch.no_grad()
 
57
  lowercase=True,
58
  dropout=None)
59
  logging.info("Loading first stage")
60
+ #model.stage1.from_ckpt('last.ckpt')
61
+ model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
62
  logging.info("Loading second stage")
63
+ model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))
64
  #model.stage2.from_ckpt(os.path.join(path, 'dalle_last.ckpt'))
65
+ #model.stage2.from_ckpt('dalle_last.ckpt')
66
  return model
67
 
68
  @torch.no_grad()
server.py CHANGED
@@ -34,30 +34,30 @@ model = Dalle.from_pretrained("minDALL-E/1.3B") # This will automatically downl
34
  model.to(device=device)
35
 
36
  # -----------------------------------------------------------
37
- # state_dict_ = torch.load('last.ckpt', map_location='cpu')
38
- # vqgan_stage_dict = model.stage1.state_dict()
39
- #
40
- # for name, param in state_dict_['state_dict'].items():
41
- # if name not in model.stage1.state_dict().keys():
42
- # continue
43
- # if isinstance(param, nn.parameter.Parameter):
44
- # param = param.data
45
- # vqgan_stage_dict[name].copy_(param)
46
- #
47
- # model.stage1.load_state_dict(vqgan_stage_dict)
48
- # ---------------------------------------------------------
49
- # state_dict_dalle = torch.load('dalle_last.ckpt', map_location='cpu')
50
- # dalle_stage_dict = model.stage2.state_dict()
51
- #
52
- # for name, param in state_dict_dalle['state_dict'].items():
53
- # if name[6:] not in model.stage2.state_dict().keys():
54
- # print(name)
55
- # continue
56
- # if isinstance(param, nn.parameter.Parameter):
57
- # param = param.data
58
- # dalle_stage_dict[name[6:]].copy_(param)
59
- #
60
- # model.stage2.load_state_dict(dalle_stage_dict)
61
 
62
  # model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
63
  # model_clip.to(device=device)
 
34
  model.to(device=device)
35
 
36
  # -----------------------------------------------------------
37
+ state_dict_ = torch.load('last.ckpt', map_location='cpu')
38
+ vqgan_stage_dict = model.stage1.state_dict()
39
+
40
+ for name, param in state_dict_['state_dict'].items():
41
+ if name not in model.stage1.state_dict().keys():
42
+ continue
43
+ if isinstance(param, nn.parameter.Parameter):
44
+ param = param.data
45
+ vqgan_stage_dict[name].copy_(param)
46
+
47
+ model.stage1.load_state_dict(vqgan_stage_dict)
48
+ #---------------------------------------------------------
49
+ state_dict_dalle = torch.load('dalle_last.ckpt', map_location='cpu')
50
+ dalle_stage_dict = model.stage2.state_dict()
51
+
52
+ for name, param in state_dict_dalle['state_dict'].items():
53
+ if name[6:] not in model.stage2.state_dict().keys():
54
+ print(name)
55
+ continue
56
+ if isinstance(param, nn.parameter.Parameter):
57
+ param = param.data
58
+ dalle_stage_dict[name[6:]].copy_(param)
59
+
60
+ model.stage2.load_state_dict(dalle_stage_dict)
61
 
62
  # model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
63
  # model_clip.to(device=device)