Spaces:
Runtime error
Runtime error
feat: allow abstract_init
Browse files- dalle_mini/model/modeling.py +42 -1
- tools/train/train.py +4 -1
dalle_mini/model/modeling.py
CHANGED
@@ -16,7 +16,7 @@
|
|
16 |
|
17 |
import math
|
18 |
from functools import partial
|
19 |
-
from typing import Optional
|
20 |
|
21 |
import flax.linen as nn
|
22 |
import jax
|
@@ -298,10 +298,51 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
298 |
Edits:
|
299 |
- added num_params property
|
300 |
- config_class replaced to DalleBartConfig
|
|
|
301 |
"""
|
302 |
|
303 |
config_class = DalleBartConfig
|
304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
@property
|
306 |
def num_params(self):
|
307 |
num_params = jax.tree_map(
|
|
|
16 |
|
17 |
import math
|
18 |
from functools import partial
|
19 |
+
from typing import Optional, Tuple
|
20 |
|
21 |
import flax.linen as nn
|
22 |
import jax
|
|
|
298 |
Edits:
|
299 |
- added num_params property
|
300 |
- config_class replaced to DalleBartConfig
|
301 |
+
- __init__ accepts abstract_init which does uses parameter shape to initialize the model
|
302 |
"""
|
303 |
|
304 |
config_class = DalleBartConfig
|
305 |
|
306 |
+
def __init__(
|
307 |
+
self,
|
308 |
+
config: DalleBartConfig,
|
309 |
+
input_shape: Tuple[int] = (1, 1),
|
310 |
+
seed: int = 0,
|
311 |
+
dtype: jnp.dtype = jnp.float32,
|
312 |
+
abstract_init: bool = False,
|
313 |
+
**kwargs,
|
314 |
+
):
|
315 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
316 |
+
|
317 |
+
# adapted from HuggingFace FlaxPreTrainedModel
|
318 |
+
if config is None:
|
319 |
+
raise ValueError("config cannot be None")
|
320 |
+
|
321 |
+
if module is None:
|
322 |
+
raise ValueError("module cannot be None")
|
323 |
+
|
324 |
+
# Those are private to be exposed as typed property on derived classes.
|
325 |
+
self._config = config
|
326 |
+
self._module = module
|
327 |
+
|
328 |
+
# Those are public as their type is generic to every derived classes.
|
329 |
+
self.key = PRNGKey(seed)
|
330 |
+
self.dtype = dtype
|
331 |
+
|
332 |
+
# randomly initialized parameters
|
333 |
+
if abstract_init:
|
334 |
+
# init the model weights only abstractly, eval_shape will return a pytree
|
335 |
+
# with the structure as weights but without any actual values, this will just contain
|
336 |
+
# the shape information. Weights need to be loaded later.
|
337 |
+
init_fn = partial(self.init_weights, input_shape=input_shape)
|
338 |
+
random_params = jax.eval_shape(init_fn, self.key)
|
339 |
+
else:
|
340 |
+
random_params = self.init_weights(self.key, input_shape)
|
341 |
+
|
342 |
+
# save required_params as set
|
343 |
+
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|
344 |
+
self.params = random_params
|
345 |
+
|
346 |
@property
|
347 |
def num_params(self):
|
348 |
num_params = jax.tree_map(
|
tools/train/train.py
CHANGED
@@ -434,7 +434,9 @@ def main():
|
|
434 |
artifact_dir = artifact.download()
|
435 |
|
436 |
# load model
|
437 |
-
model = DalleBart.from_pretrained(
|
|
|
|
|
438 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
439 |
print(model.params)
|
440 |
|
@@ -458,6 +460,7 @@ def main():
|
|
458 |
config=config,
|
459 |
seed=training_args.seed_model,
|
460 |
dtype=getattr(jnp, model_args.dtype),
|
|
|
461 |
)
|
462 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
463 |
print(model.params)
|
|
|
434 |
artifact_dir = artifact.download()
|
435 |
|
436 |
# load model
|
437 |
+
model = DalleBart.from_pretrained(
|
438 |
+
artifact_dir, dtype=getattr(jnp, model_args.dtype), abstract_init=True
|
439 |
+
)
|
440 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
441 |
print(model.params)
|
442 |
|
|
|
460 |
config=config,
|
461 |
seed=training_args.seed_model,
|
462 |
dtype=getattr(jnp, model_args.dtype),
|
463 |
+
abstract_init=True,
|
464 |
)
|
465 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
466 |
print(model.params)
|