MatthiasC commited on
Commit
9b4f999
·
1 Parent(s): 5cb75cc

Load only our own weights and include necessary config and tokenizer files

Browse files
cog.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build:
2
+ python_version: "3.8"
3
+
4
+ gpu: true
5
+
6
+ python_packages:
7
+ - 'huggingface_hub==0.8.1'
8
+ - 'torch==1.7.1'
9
+ - 'torchvision==0.8.2'
10
+ - 'tokenizers==0.10.2'
11
+ - 'pyflakes==2.2.0'
12
+ - 'tqdm==4.46.0'
13
+ - 'pytorch-lightning==1.5'
14
+ - 'einops==0.4.1'
15
+ - 'omegaconf==2.2.2'
16
+ - 'git+https://github.com/openai/CLIP.git'
17
+ - 'numpy==1.23.1'
18
+ - 'pillow==9.2.0'
19
+ - 'python-dotenv==0.20.0'
20
+
21
+ predict: "predict.py:Predictor"
dalle/models/__init__.py CHANGED
@@ -44,25 +44,15 @@ class Dalle(nn.Module):
44
  @classmethod
45
  def from_pretrained(cls,
46
  path: str) -> nn.Module:
47
- path = _MODELS[path] if path in _MODELS else path
48
- path = utils.realpath_url_or_path(path, root=os.path.expanduser("~/.cache/minDALL-E"))
49
-
50
  config_base = get_base_config()
51
- config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
52
  config_update = OmegaConf.merge(config_base, config_new)
53
 
54
  model = cls(config_update)
55
- model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
56
  context_length=model.config_dataset.context_length,
57
  lowercase=True,
58
  dropout=None)
59
- logging.info("Loading first stage")
60
- #model.stage1.from_ckpt('last.ckpt')
61
- model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
62
- logging.info("Loading second stage")
63
- model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))
64
- #model.stage2.from_ckpt(os.path.join(path, 'dalle_last.ckpt'))
65
- #model.stage2.from_ckpt('dalle_last.ckpt')
66
  return model
67
 
68
  @torch.no_grad()
 
44
  @classmethod
45
  def from_pretrained(cls,
46
  path: str) -> nn.Module:
 
 
 
47
  config_base = get_base_config()
48
+ config_new = OmegaConf.load('config.yaml')
49
  config_update = OmegaConf.merge(config_base, config_new)
50
 
51
  model = cls(config_update)
52
+ model.tokenizer = build_tokenizer('tokenizer',
53
  context_length=model.config_dataset.context_length,
54
  lowercase=True,
55
  dropout=None)
 
 
 
 
 
 
 
56
  return model
57
 
58
  @torch.no_grad()
server.py CHANGED
@@ -25,50 +25,9 @@ logging.info("Start downloading")
25
  full_dict_path = hf_hub_download(repo_id="MatthiasC/dall-e-logo", filename="full_dict_new.ckpt",
26
  use_auth_token=st.secrets["model_hub"])
27
  logging.info("End downloading")
28
- logging.info(full_dict_path)
29
-
30
-
31
- # url = "https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz"
32
- # root = os.path.expanduser("~/.cache/minDALLE")
33
- # filename = os.path.basename(url)
34
- # pathname = filename[: -len(".tar.gz")]
35
- # download_target = os.path.join(root, filename)
36
- # result_path = os.path.join(root, pathname)
37
- # if not os.path.exists(result_path):
38
- # result_path = download(url, root)
39
-
40
 
41
  device = "cuda" if torch.cuda.is_available() else "cpu"
42
- model = Dalle.from_pretrained("minDALL-E/1.3B") # This will automatically download the pretrained model.
43
- #model.to(device=device)
44
-
45
-
46
- # OLD CODE
47
- # -----------------------------------------------------------
48
- # state_dict_ = torch.load('last.ckpt', map_location='cpu')
49
- # vqgan_stage_dict = model.stage1.state_dict()
50
- #
51
- # for name, param in state_dict_['state_dict'].items():
52
- # if name not in model.stage1.state_dict().keys():
53
- # continue
54
- # if isinstance(param, nn.parameter.Parameter):
55
- # param = param.data
56
- # vqgan_stage_dict[name].copy_(param)
57
- #
58
- # model.stage1.load_state_dict(vqgan_stage_dict)
59
- # #---------------------------------------------------------
60
- # state_dict_dalle = torch.load('dalle_last.ckpt', map_location='cpu')
61
- # dalle_stage_dict = model.stage2.state_dict()
62
- #
63
- # for name, param in state_dict_dalle['state_dict'].items():
64
- # if name[6:] not in model.stage2.state_dict().keys():
65
- # print(name)
66
- # continue
67
- # if isinstance(param, nn.parameter.Parameter):
68
- # param = param.data
69
- # dalle_stage_dict[name[6:]].copy_(param)
70
- #
71
- # model.stage2.load_state_dict(dalle_stage_dict)
72
 
73
  # NEW METHOD
74
  model.load_state_dict(torch.load(full_dict_path, map_location=torch.device('cpu')))
 
25
  full_dict_path = hf_hub_download(repo_id="MatthiasC/dall-e-logo", filename="full_dict_new.ckpt",
26
  use_auth_token=st.secrets["model_hub"])
27
  logging.info("End downloading")
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ model = Dalle.from_pretrained("minDALL-E/1.3B")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # NEW METHOD
33
  model.load_state_dict(torch.load(full_dict_path, map_location=torch.device('cpu')))
tokenizer/bpe-16k-merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/bpe-16k-vocab.json ADDED
The diff for this file is too large to render. See raw diff