JMalott commited on
Commit
03dd743
·
1 Parent(s): aee01c4

Update min_dalle/min_dalle.py

Browse files
Files changed (1) hide show
  1. min_dalle/min_dalle.py +3 -3
min_dalle/min_dalle.py CHANGED
@@ -177,7 +177,7 @@ class MinDalle:
177
  progressive_outputs: bool = False,
178
  is_seamless: bool = False,
179
  temperature: float = 1,
180
- top_k: int = 256,
181
  supercondition_factor: int = 16,
182
  is_verbose: bool = False
183
  ) -> Iterator[FloatTensor]:
@@ -252,7 +252,7 @@ class MinDalle:
252
  )
253
 
254
  with torch.cuda.amp.autocast(dtype=torch.float32):
255
- if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 == 256:
256
  yield self.image_grid_from_tokens(
257
  image_tokens=image_tokens[1:].T,
258
  is_seamless=is_seamless,
@@ -270,7 +270,7 @@ class MinDalle:
270
  image_stream = self.generate_raw_image_stream(*args, **kwargs)
271
  for image in image_stream:
272
  grid_size = kwargs['grid_size']
273
- image = image.view([grid_size * 256, grid_size, 256, 3])
274
  image = image.transpose(1, 0)
275
  image = image.reshape([grid_size ** 2, 2 ** 8, 2 ** 8, 3])
276
  yield image
 
177
  progressive_outputs: bool = False,
178
  is_seamless: bool = False,
179
  temperature: float = 1,
180
+ top_k: int = 128,
181
  supercondition_factor: int = 16,
182
  is_verbose: bool = False
183
  ) -> Iterator[FloatTensor]:
 
252
  )
253
 
254
  with torch.cuda.amp.autocast(dtype=torch.float32):
255
+ if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 == 128:
256
  yield self.image_grid_from_tokens(
257
  image_tokens=image_tokens[1:].T,
258
  is_seamless=is_seamless,
 
270
  image_stream = self.generate_raw_image_stream(*args, **kwargs)
271
  for image in image_stream:
272
  grid_size = kwargs['grid_size']
273
+ image = image.view([grid_size * 128, grid_size, 128, 3])
274
  image = image.transpose(1, 0)
275
  image = image.reshape([grid_size ** 2, 2 ** 8, 2 ** 8, 3])
276
  yield image