boris commited on
Commit
772415c
·
1 Parent(s): 5c84978

feat: allow abstract_init

Browse files
dalle_mini/model/modeling.py CHANGED
@@ -16,7 +16,7 @@
16
 
17
  import math
18
  from functools import partial
19
- from typing import Optional
20
 
21
  import flax.linen as nn
22
  import jax
@@ -298,10 +298,51 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
298
  Edits:
299
  - added num_params property
300
  - config_class replaced to DalleBartConfig
 
301
  """
302
 
303
  config_class = DalleBartConfig
304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  @property
306
  def num_params(self):
307
  num_params = jax.tree_map(
 
16
 
17
  import math
18
  from functools import partial
19
+ from typing import Optional, Tuple
20
 
21
  import flax.linen as nn
22
  import jax
 
298
  Edits:
299
  - added num_params property
300
  - config_class replaced to DalleBartConfig
301
+ - __init__ accepts abstract_init which does uses parameter shape to initialize the model
302
  """
303
 
304
  config_class = DalleBartConfig
305
 
306
+ def __init__(
307
+ self,
308
+ config: DalleBartConfig,
309
+ input_shape: Tuple[int] = (1, 1),
310
+ seed: int = 0,
311
+ dtype: jnp.dtype = jnp.float32,
312
+ abstract_init: bool = False,
313
+ **kwargs,
314
+ ):
315
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
316
+
317
+ # adapted from HuggingFace FlaxPreTrainedModel
318
+ if config is None:
319
+ raise ValueError("config cannot be None")
320
+
321
+ if module is None:
322
+ raise ValueError("module cannot be None")
323
+
324
+ # Those are private to be exposed as typed property on derived classes.
325
+ self._config = config
326
+ self._module = module
327
+
328
+ # Those are public as their type is generic to every derived classes.
329
+ self.key = PRNGKey(seed)
330
+ self.dtype = dtype
331
+
332
+ # randomly initialized parameters
333
+ if abstract_init:
334
+ # init the model weights only abstractly, eval_shape will return a pytree
335
+ # with the structure as weights but without any actual values, this will just contain
336
+ # the shape information. Weights need to be loaded later.
337
+ init_fn = partial(self.init_weights, input_shape=input_shape)
338
+ random_params = jax.eval_shape(init_fn, self.key)
339
+ else:
340
+ random_params = self.init_weights(self.key, input_shape)
341
+
342
+ # save required_params as set
343
+ self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
344
+ self.params = random_params
345
+
346
  @property
347
  def num_params(self):
348
  num_params = jax.tree_map(
tools/train/train.py CHANGED
@@ -434,7 +434,9 @@ def main():
434
  artifact_dir = artifact.download()
435
 
436
  # load model
437
- model = DalleBart.from_pretrained(artifact_dir)
 
 
438
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
439
  print(model.params)
440
 
@@ -458,6 +460,7 @@ def main():
458
  config=config,
459
  seed=training_args.seed_model,
460
  dtype=getattr(jnp, model_args.dtype),
 
461
  )
462
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
463
  print(model.params)
 
434
  artifact_dir = artifact.download()
435
 
436
  # load model
437
+ model = DalleBart.from_pretrained(
438
+ artifact_dir, dtype=getattr(jnp, model_args.dtype), abstract_init=True
439
+ )
440
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
441
  print(model.params)
442
 
 
460
  config=config,
461
  seed=training_args.seed_model,
462
  dtype=getattr(jnp, model_args.dtype),
463
+ abstract_init=True,
464
  )
465
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
466
  print(model.params)