JMalott commited on
Commit
96a62e8
·
1 Parent(s): 520aeff

Update min_dalle/min_dalle.py

Browse files
Files changed (1) hide show
  1. min_dalle/min_dalle.py +12 -26
min_dalle/min_dalle.py CHANGED
@@ -10,9 +10,6 @@ from typing import Iterator
10
  from .text_tokenizer import TextTokenizer
11
  from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
12
  import streamlit as st
13
- import time
14
-
15
- import tracemalloc
16
 
17
  torch.set_grad_enabled(False)
18
  torch.set_num_threads(os.cpu_count())
@@ -24,7 +21,6 @@ IMAGE_TOKEN_COUNT = 256
24
 
25
 
26
  class MinDalle:
27
- @st.cache
28
  def __init__(
29
  self,
30
  models_root: str = 'pretrained',
@@ -67,6 +63,7 @@ class MinDalle:
67
  self.init_decoder()
68
  self.init_detokenizer()
69
 
 
70
  def download_tokenizer(self):
71
  if self.is_verbose: print("downloading tokenizer params")
72
  suffix = '' if self.is_mega else '_mini'
@@ -76,23 +73,27 @@ class MinDalle:
76
  with open(self.vocab_path, 'wb') as f: f.write(vocab.content)
77
  with open(self.merges_path, 'wb') as f: f.write(merges.content)
78
 
 
79
  def download_encoder(self):
80
  if self.is_verbose: print("downloading encoder params")
81
  suffix = '' if self.is_mega else '_mini'
82
  params = requests.get(MIN_DALLE_REPO + 'encoder{}.pt'.format(suffix))
83
  with open(self.encoder_params_path, 'wb') as f: f.write(params.content)
84
 
 
85
  def download_decoder(self):
86
  if self.is_verbose: print("downloading decoder params")
87
  suffix = '' if self.is_mega else '_mini'
88
  params = requests.get(MIN_DALLE_REPO + 'decoder{}.pt'.format(suffix))
89
  with open(self.decoder_params_path, 'wb') as f: f.write(params.content)
90
 
 
91
  def download_detokenizer(self):
92
  if self.is_verbose: print("downloading detokenizer params")
93
  params = requests.get(MIN_DALLE_REPO + 'detoker.pt')
94
  with open(self.detoker_params_path, 'wb') as f: f.write(params.content)
95
 
 
96
  def init_tokenizer(self):
97
  is_downloaded = os.path.exists(self.vocab_path)
98
  is_downloaded &= os.path.exists(self.merges_path)
@@ -104,6 +105,7 @@ class MinDalle:
104
  merges = f.read().split("\n")[1:-1]
105
  self.tokenizer = TextTokenizer(vocab, merges)
106
 
 
107
  def init_encoder(self):
108
  is_downloaded = os.path.exists(self.encoder_params_path)
109
  if not is_downloaded: self.download_encoder()
@@ -122,6 +124,7 @@ class MinDalle:
122
  del params
123
  self.encoder = self.encoder.to(device=self.device)
124
 
 
125
  def init_decoder(self):
126
  is_downloaded = os.path.exists(self.decoder_params_path)
127
  if not is_downloaded: self.download_decoder()
@@ -138,7 +141,8 @@ class MinDalle:
138
  self.decoder.load_state_dict(params, strict=False)
139
  del params
140
  self.decoder = self.decoder.to(device=self.device)
141
-
 
142
  def init_detokenizer(self):
143
  is_downloaded = os.path.exists(self.detoker_params_path)
144
  if not is_downloaded: self.download_detokenizer()
@@ -230,17 +234,12 @@ class MinDalle:
230
  dtype=torch.float32,
231
  device=self.device
232
  )
233
-
234
- tracemalloc.start()
235
-
236
- for i in range( IMAGE_TOKEN_COUNT ):
237
-
238
  if(st.session_state.page != 0):
239
  break
240
  st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
241
 
242
  torch.cuda.empty_cache()
243
-
244
  with torch.cuda.amp.autocast(dtype=self.dtype):
245
  image_tokens[i + 1], attention_state = self.decoder.forward(
246
  settings=settings,
@@ -250,27 +249,14 @@ class MinDalle:
250
  prev_tokens=image_tokens[i],
251
  token_index=token_indices[[i]]
252
  )
253
-
254
-
255
-
256
 
257
- with torch.cuda.amp.autocast(dtype=torch.float16):
258
- if ((i + 1) % 16 == 0 and progressive_outputs) or i + 1 == 256:
259
  yield self.image_grid_from_tokens(
260
  image_tokens=image_tokens[1:].T,
261
  is_seamless=is_seamless,
262
  is_verbose=is_verbose
263
  )
264
-
265
- # displaying the memory
266
- print(tracemalloc.get_traced_memory())
267
-
268
- # stopping the library
269
- tracemalloc.stop()
270
-
271
-
272
-
273
-
274
 
275
  def generate_image_stream(self, *args, **kwargs) -> Iterator[Image.Image]:
276
  image_stream = self.generate_raw_image_stream(*args, **kwargs)
 
10
  from .text_tokenizer import TextTokenizer
11
  from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
12
  import streamlit as st
 
 
 
13
 
14
  torch.set_grad_enabled(False)
15
  torch.set_num_threads(os.cpu_count())
 
21
 
22
 
23
  class MinDalle:
 
24
  def __init__(
25
  self,
26
  models_root: str = 'pretrained',
 
63
  self.init_decoder()
64
  self.init_detokenizer()
65
 
66
+
67
  def download_tokenizer(self):
68
  if self.is_verbose: print("downloading tokenizer params")
69
  suffix = '' if self.is_mega else '_mini'
 
73
  with open(self.vocab_path, 'wb') as f: f.write(vocab.content)
74
  with open(self.merges_path, 'wb') as f: f.write(merges.content)
75
 
76
+
77
  def download_encoder(self):
78
  if self.is_verbose: print("downloading encoder params")
79
  suffix = '' if self.is_mega else '_mini'
80
  params = requests.get(MIN_DALLE_REPO + 'encoder{}.pt'.format(suffix))
81
  with open(self.encoder_params_path, 'wb') as f: f.write(params.content)
82
 
83
+
84
  def download_decoder(self):
85
  if self.is_verbose: print("downloading decoder params")
86
  suffix = '' if self.is_mega else '_mini'
87
  params = requests.get(MIN_DALLE_REPO + 'decoder{}.pt'.format(suffix))
88
  with open(self.decoder_params_path, 'wb') as f: f.write(params.content)
89
 
90
+
91
  def download_detokenizer(self):
92
  if self.is_verbose: print("downloading detokenizer params")
93
  params = requests.get(MIN_DALLE_REPO + 'detoker.pt')
94
  with open(self.detoker_params_path, 'wb') as f: f.write(params.content)
95
 
96
+
97
  def init_tokenizer(self):
98
  is_downloaded = os.path.exists(self.vocab_path)
99
  is_downloaded &= os.path.exists(self.merges_path)
 
105
  merges = f.read().split("\n")[1:-1]
106
  self.tokenizer = TextTokenizer(vocab, merges)
107
 
108
+
109
  def init_encoder(self):
110
  is_downloaded = os.path.exists(self.encoder_params_path)
111
  if not is_downloaded: self.download_encoder()
 
124
  del params
125
  self.encoder = self.encoder.to(device=self.device)
126
 
127
+
128
  def init_decoder(self):
129
  is_downloaded = os.path.exists(self.decoder_params_path)
130
  if not is_downloaded: self.download_decoder()
 
141
  self.decoder.load_state_dict(params, strict=False)
142
  del params
143
  self.decoder = self.decoder.to(device=self.device)
144
+
145
+
146
  def init_detokenizer(self):
147
  is_downloaded = os.path.exists(self.detoker_params_path)
148
  if not is_downloaded: self.download_detokenizer()
 
234
  dtype=torch.float32,
235
  device=self.device
236
  )
237
+ for i in range(IMAGE_TOKEN_COUNT):
 
 
 
 
238
  if(st.session_state.page != 0):
239
  break
240
  st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
241
 
242
  torch.cuda.empty_cache()
 
243
  with torch.cuda.amp.autocast(dtype=self.dtype):
244
  image_tokens[i + 1], attention_state = self.decoder.forward(
245
  settings=settings,
 
249
  prev_tokens=image_tokens[i],
250
  token_index=token_indices[[i]]
251
  )
 
 
 
252
 
253
+ with torch.cuda.amp.autocast(dtype=torch.float32):
254
+ if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 == 256:
255
  yield self.image_grid_from_tokens(
256
  image_tokens=image_tokens[1:].T,
257
  is_seamless=is_seamless,
258
  is_verbose=is_verbose
259
  )
 
 
 
 
 
 
 
 
 
 
260
 
261
  def generate_image_stream(self, *args, **kwargs) -> Iterator[Image.Image]:
262
  image_stream = self.generate_raw_image_stream(*args, **kwargs)