MatthiasC commited on
Commit
ac75587
·
1 Parent(s): 7274f8b

Add usage for both ckpts

Browse files
Files changed (2) hide show
  1. dalle/models/__init__.py +6 -2
  2. server.py +11 -11
dalle/models/__init__.py CHANGED
@@ -6,6 +6,7 @@
6
 
7
  import os
8
  import torch
 
9
  import torch.nn as nn
10
  import pytorch_lightning as pl
11
  from typing import Optional, Tuple
@@ -55,8 +56,11 @@ class Dalle(nn.Module):
55
  context_length=model.config_dataset.context_length,
56
  lowercase=True,
57
  dropout=None)
58
- model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
59
- model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))
 
 
 
60
  return model
61
 
62
  @torch.no_grad()
 
6
 
7
  import os
8
  import torch
9
+ import logging
10
  import torch.nn as nn
11
  import pytorch_lightning as pl
12
  from typing import Optional, Tuple
 
56
  context_length=model.config_dataset.context_length,
57
  lowercase=True,
58
  dropout=None)
59
+ logging.info("Loading first stage")
60
+ model.stage1.from_ckpt('last.ckpt')
61
+ logging.info("Loading second stage")
62
+ #model.stage2.from_ckpt(os.path.join(path, 'dalle_last.ckpt'))
63
+ model.stage2.from_ckpt('dalle_last.ckpt')
64
  return model
65
 
66
  @torch.no_grad()
server.py CHANGED
@@ -34,17 +34,17 @@ model = Dalle.from_pretrained("minDALL-E/1.3B") # This will automatically downl
34
  model.to(device=device)
35
 
36
  # -----------------------------------------------------------
37
- state_dict_ = torch.load('last.ckpt', map_location='cpu')
38
- vqgan_stage_dict = model.stage1.state_dict()
39
-
40
- for name, param in state_dict_['state_dict'].items():
41
- if name not in model.stage1.state_dict().keys():
42
- continue
43
- if isinstance(param, nn.parameter.Parameter):
44
- param = param.data
45
- vqgan_stage_dict[name].copy_(param)
46
-
47
- model.stage1.load_state_dict(vqgan_stage_dict)
48
  # ---------------------------------------------------------
49
  # state_dict_dalle = torch.load('dalle_last.ckpt', map_location='cpu')
50
  # dalle_stage_dict = model.stage2.state_dict()
 
34
  model.to(device=device)
35
 
36
  # -----------------------------------------------------------
37
+ # state_dict_ = torch.load('last.ckpt', map_location='cpu')
38
+ # vqgan_stage_dict = model.stage1.state_dict()
39
+ #
40
+ # for name, param in state_dict_['state_dict'].items():
41
+ # if name not in model.stage1.state_dict().keys():
42
+ # continue
43
+ # if isinstance(param, nn.parameter.Parameter):
44
+ # param = param.data
45
+ # vqgan_stage_dict[name].copy_(param)
46
+ #
47
+ # model.stage1.load_state_dict(vqgan_stage_dict)
48
  # ---------------------------------------------------------
49
  # state_dict_dalle = torch.load('dalle_last.ckpt', map_location='cpu')
50
  # dalle_stage_dict = model.stage2.state_dict()