Spaces:
Build error
Build error
# Copyright 2022 The T5X Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
r"""Precompile and generates HLO from TPU metadata backend. | |
TPU Metadata backend is a TPU backend without real TPU devices while supporting | |
any TPU topologies, to allow work that doesn't require real TPUs to run as if | |
it is, e.g., compiling/lowering a HLO graph with the backend. | |
Ideally, the precompile defaults to cpu backend for default device array | |
placement since metadata backend does not have memory allocation. | |
The pjit function is pinned to use available TPU Metadata backend, for getting | |
a proper lowering under TPU mesh. | |
""" | |
import os | |
from typing import Iterator, Optional | |
import jax | |
from jax import random | |
import numpy as np | |
import t5.data.mixtures # pylint:disable=unused-import | |
from t5x import models | |
from t5x import partitioning | |
from t5x import trainer as trainer_lib | |
from t5x import utils | |
import tensorflow as tf | |
def precompile(*, | |
model: models.BaseTransformerModel, | |
train_dataset_cfg: utils.DatasetConfig, | |
partitioner: partitioning.BasePartitioner, | |
model_dir: str, | |
random_seed: Optional[int], | |
get_dataset_fn: utils.GetDatasetCallable = utils.get_dataset): | |
"""Compiles and dump the HLO to model dir, with HLO text dumps.""" | |
rng = random.PRNGKey(random_seed or 42) | |
_, trainer_rng = random.split(rng, 2) | |
# TODO(hthu): Find a better way of getting dataset shapes instead of actually | |
# reading database and iterate on it. | |
data_layout = partitioner.get_data_layout(train_dataset_cfg.batch_size) | |
ds_shard_id = data_layout.shard_id | |
num_ds_shards = data_layout.num_shards | |
def _verify_matching_vocabs(cfg: utils.DatasetConfig): | |
ds_vocabs = utils.get_vocabulary(cfg) | |
if (ds_vocabs[0] != model.input_vocabulary or | |
ds_vocabs[1] != model.output_vocabulary): | |
raise ValueError(f'Model and Task vocabularies do not match:\n' | |
f' task={cfg.mixture_or_task_name}\n' | |
f' ds_vocabs=({ds_vocabs[0]}, {ds_vocabs[1]})\n' | |
f' model.input_vocabulary={model.input_vocabulary}\n' | |
f' model.output_vocabulary={model.output_vocabulary}\n') | |
_verify_matching_vocabs(train_dataset_cfg) | |
train_ds = get_dataset_fn(train_dataset_cfg, ds_shard_id, num_ds_shards, | |
model.FEATURE_CONVERTER_CLS) | |
# Need to use full batch size. | |
input_shapes = { | |
k: (data_layout.batch_size, *v.shape[1:]) | |
for k, v in train_ds.element_spec.items() | |
} | |
input_types = { | |
k: v.dtype.as_numpy_dtype() for k, v in train_ds.element_spec.items() | |
} | |
checkpointable_train_iter = iter(train_ds) | |
train_iter: Iterator[trainer_lib.BatchType] = map( | |
lambda x: jax.tree_map(np.array, x), checkpointable_train_iter) | |
batch = next(train_iter) | |
# Compiling does not care about loading real weights. | |
train_state_initializer = utils.TrainStateInitializer( | |
optimizer_def=model.optimizer_def, | |
init_fn=model.get_initial_variables, | |
input_shapes=input_shapes, | |
input_types=input_types, | |
partitioner=partitioner) | |
train_state_shape = train_state_initializer.global_train_state_shape | |
train_state_axes = train_state_initializer.train_state_axes | |
def train_step(train_state, batch): | |
return trainer_lib.train_with_lr( | |
train_state, | |
batch, | |
learning_rate=1e-3, | |
dropout_rng=trainer_rng, | |
model=model, | |
num_microbatches=None, | |
weight_metrics_computer=None) | |
partitioned_step = partitioner.partition( | |
train_step, | |
in_axis_resources=(train_state_axes, partitioning.PartitionSpec('data',)), | |
out_axis_resources=(train_state_axes, None), | |
donate_argnums=(0,)) | |
# PartitionedTrainCallable has lower() defined but isn't exposed in pytype. | |
# TODO(hthu): Explicitly expose the lower() interface. | |
# pytype: disable=attribute-error | |
lowered = partitioned_step.lower(train_state_shape, batch) | |
# pytype: enable=attribute-error | |
# TODO(hthu): Make this a proper library without writing files by default. | |
tf.io.gfile.makedirs(model_dir) | |
with tf.io.gfile.GFile( | |
os.path.join(model_dir, 'lowered_hlo_pre_optimization'), 'w') as f: | |
f.write(lowered.compiler_ir(dialect='hlo').as_serialized_hlo_module_proto()) | |
compiled = lowered.compile() | |
output_path = os.path.join(model_dir, 'lowered_hlo_post_optimization') | |
with tf.io.gfile.GFile(output_path, 'w') as f: | |
f.write(compiled.compiler_ir()[0].as_serialized_hlo_module_proto()) | |
with tf.io.gfile.GFile(os.path.join(model_dir, 'assignment'), 'wb') as f: | |
np.save(f, partitioner.mesh.device_ids) | |