HugoVoxx's picture
Upload 20 files
15bcbe6 verified
raw
history blame
12.8 kB
# Copyright 2022 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sequence to sequence model."""
from typing import Any, Callable, Dict, Tuple
from absl import logging
from flax import linen as nn
from flax.training import common_utils
import gin
import jax
import jax.numpy as jnp
import metrics_summary
from transformer import decoder_stack
from transformer import metric_utils
from transformer import text_dataset
import numpy as np
import seqio
Array = jnp.ndarray
MetricsSummary = metrics_summary.MetricsSummary
# TODO(mrabe): Remove this function and find a better way to turn text metrics
# into text on tensorboard.
def process_summaries(vocab: seqio.Vocabulary,
met_summary: MetricsSummary,
mode: str) -> MetricsSummary:
"""Compute some additional summaries, and convert tokens to text.
Args:
vocab: The vocabulary to detokenize generated text.
met_summary: The summary object to process.
mode: The mode of the summary (e.g. "test", "train")
Returns:
The modified summary dictionary.
"""
mdict = met_summary.current_metric_dict()
# Calculate perplexity from the average nats_per_token over all replicas.
# This has to be done here, because the perplexities themselves can't be
# averaged in the usual way.
if "nats_per_token" in mdict:
nats_per_token = mdict["nats_per_token"].to_value()
met_summary.add({"perplexity": np.exp(nats_per_token)})
if mode == "generate" and "gen_tokens" in mdict:
# Convert output tokens to example output text.
# Write text to both the summary, and pretty-print to the log file.
gen_toks = mdict["gen_tokens"].to_value()
if np.ndim(gen_toks) != 2:
raise ValueError("Unsupported shape for gen_tokens: %s" % gen_toks.shape)
ntoks = gen_toks.shape[-1]
gen_text = text_dataset.decode_tokens(gen_toks, vocab, max_length=ntoks)
logging.info("Generated text = %s", gen_text)
met_summary.add_text({"gen_text": gen_text})
del mdict["gen_tokens"] # Otherwise it will turn into a histogram.
return met_summary
@gin.configurable
def process_summaries_function(vocab: seqio.Vocabulary) -> Callable[
[MetricsSummary, str], MetricsSummary]:
"""Return a function that processes summaries with the given vocabulary."""
# For use with training_loop.process_summaries_function
def process_fn(met_summary: MetricsSummary, mode: str):
return process_summaries(vocab, met_summary, mode)
return process_fn
@gin.configurable
class DecoderOnlyLanguageModel(nn.Module):
"""Decoder only language modeling."""
mode: str
task_config: decoder_stack.TransformerTaskConfig = gin.REQUIRED
decoder_factory: Callable[[], Any] = gin.REQUIRED
sample_method: str = "sample" # Can be {"sample", "greedy"}
output_token_losses: bool = False
def get_fake_input(self):
"""Returns a fake input for initialization of the appropriate shape."""
b = self.task_config.batch_size
fake_input_dict = {
"targets": jnp.ones([b, self.task_config.sequence_length],
dtype=jnp.int32),
"start_of_sequence": jnp.ones([b], dtype=jnp.bool_),
"epoch": jnp.ones([b], dtype=jnp.int32),
}
if text_dataset.get_loss_mask_tokens(split=self.mode) != (None, None):
# We are not adding the loss mask to the dummy input by default as it can
# cause a slowdown during evaluation and perhaps inference.
fake_input_dict["loss_mask"] = jnp.ones(
[b, self.task_config.sequence_length], dtype=jnp.bool_)
return fake_input_dict
def metrics_summary_operations(self, aggregate_over: str) -> Dict[str, str]:
"""Summary operation to use for recorded metrics."""
metric_ops = {
"loss": "mean",
"nats_per_token": "mean",
"bits_per_token": "mean",
"bits_per_char": "mean",
"accuracy": "mean",
"num_tokens": "mean",
"num_chars_per_device": "mean",
"num_chars_per_batch": "mean",
"nonzero_tokens": "mean",
"num_tokens_per_device": "mean",
"num_tokens_per_batch": "mean",
"epoch": "mean",
}
if aggregate_over == "steps":
return metric_ops
elif aggregate_over == "devices":
# Ensure that statistics that refer to the total batch size stay constant
# as TPU topologies change. For those we have to sum over devices, but
# compute the mean over steps.
metric_ops.update({
"num_tokens_per_batch": "sum",
"num_chars_per_batch": "sum",
"loss": "sum"})
return metric_ops
else:
raise ValueError("Don't know how to aggregate over: %s" % aggregate_over)
def setup(self):
self.decoder = self.decoder_factory(mode=self.mode,
task_config=self.task_config) # pytype: disable=wrong-keyword-args # trace-all-classes
def __call__(self, inputs: ...):
task_config = self.task_config
input_tokens = inputs["targets"] # [b, seq_len]
start_of_sequence = inputs["start_of_sequence"] # [b]
epochs = inputs["epoch"] # [b]
if "loss_mask" in inputs:
loss_mask = inputs["loss_mask"] # [b, seq_len]
else:
loss_mask = jnp.ones((1, 1), dtype=jnp.bool_)
input_tokens = jnp.asarray(input_tokens)
assert input_tokens.ndim == 2
assert input_tokens.shape[0] == task_config.batch_size
assert input_tokens.shape[1] == task_config.sequence_length
assert start_of_sequence.shape[0] == task_config.batch_size
# Sanity check to avoid out-of-bounds on token lookup.
input_tokens = input_tokens % task_config.vocab_size
logging.info("langmodel: Compiling model for mode %s", self.mode)
logging.info("langmodel: input_tokens = %r", input_tokens)
logging.info("langmodel: start_of_sequece = %r", start_of_sequence)
logging.info("langmodel: epochs = %r", epochs)
# The target outputs are the next character in each sequence.
# Shift tokens left and pad with a zero at the end.
# TODO(delesley): We don't predict the first token of each sequence.
target_tokens = jnp.pad(input_tokens[:, 1:], [(0, 0), (0, 1)])
logging.info("langmodel: target_tokens = %r", target_tokens)
# Invoke the decoder stack.
# The decoder will return pre-softmax logits for the predicted targets.
(logits, _, d_metrics) = self.decoder(input_tokens=input_tokens,
target_tokens=target_tokens,
start_of_sequence=start_of_sequence)
# Softmax cross-entropy loss on target tokens.
logits = nn.log_softmax(logits, axis=-1) # (b, seq_len, vocab_size)
logging.info("langmodel: logits = %r", logits)
soft_targets = common_utils.onehot(target_tokens, task_config.vocab_size)
logging.info("langmodel: soft_targets = %r", soft_targets)
losses = -jnp.sum(soft_targets * logits, axis=-1) # (b, seq_len)
logging.info("langmodel: losses = %r", losses)
# Don't predict null tokens which are past the end-of-sequence.
# Also don't predict the 0 at the end of the sequence.
# TODO(delesley): Predict the final end-of-sequence marker.
loss_mask = jnp.logical_and(
loss_mask,
input_tokens > 0)
loss_mask = jnp.logical_and(
loss_mask,
target_tokens > 0)
logging.info("langmodel: loss_mask = %r", loss_mask)
losses = jnp.where(loss_mask, losses, 0.0) # (batch_size, seq_len)
loss = jnp.sum(losses) # total loss on device
token_count = jnp.sum(loss_mask) # tokens on device
token_count_nz = token_count + 1.0e-6
loss_per_token = loss / token_count_nz
bits_per_token = loss_per_token * 1.442695 # log(e)/log(2)
accuracy = metric_utils.compute_accuracy_sum(logits, target_tokens,
loss_mask)
accuracy = accuracy / token_count_nz # Percent correct.
epoch = jnp.mean(epochs)
if self.mode == "generate" and self.decoder.supports_generate():
# Generate example text.
logging.info("lang_model: text inference.")
gen_tokens = self.generate(inputs, task_config.sequence_length)
# Return generated text, along with vizualizations and histograms.
metrics = {"gen_tokens": gen_tokens, **d_metrics}
return (loss, metrics)
# Just return metrics related to the loss.
metrics = {
"loss": loss, # will be summed over devices
"nats_per_token": (loss_per_token, token_count),
"bits_per_token": (bits_per_token, token_count),
"accuracy": (accuracy, token_count),
"num_tokens_per_device": token_count,
"num_tokens_per_batch": token_count, # will be summed over devices
"epoch": epoch,
}
# Compute bits per character if we have the number of characters.
if "num_chars" in inputs:
num_chars = jnp.sum(inputs["num_chars"])
bits_per_char = loss / (num_chars + 1e-6) * 1.442695
metrics["num_chars_per_device"] = num_chars
metrics["num_chars_per_batch"] = num_chars # will be summed over devices
metrics["bits_per_char"] = (bits_per_char, num_chars)
# Provided to make sure that the data pipeline and the the model agree
# on the number of tokens with a loss.
if "nonzero_tokens" in inputs:
nonzero_tokens = jnp.sum(inputs["nonzero_tokens"])
metrics["nonzero_tokens"] = nonzero_tokens
if self.output_token_losses:
metrics["token_losses"] = losses
return (loss, metrics)
def generate(self, inputs: ..., sequence_length: int) -> Array:
"""Generate an output sequence.
Args:
inputs: the same as argument to _call_.
sequence_length: the length of sequence to generate.
Returns:
An array of generated tokens of shape (batch_size, sequence_length).
"""
# TODO(delesley): Add support for passing the prefix as an argument.
# TODO(delesley): Add support for temperature, gumbel softmax, beam search.
batch_size = self.task_config.batch_size
input_tokens = inputs["targets"] # [b,seq_len]
start_of_sequence = inputs["start_of_sequence"] # [b]
# Initialize decoder.
dstate = self.decoder.init_decoder_state(sequence_length,
start_of_sequence)
# TODO(delesley): Handle start-of-sequence in a better way.
# There is no special token for start of sequence, so we grab the first
# one from the ground-truth input data.
first_token = input_tokens[:, 0:1]
no_start_of_seq = jnp.array([False] * batch_size, dtype=jnp.bool_)
sample_method = self.sample_method
sample_prng = self.make_rng("sample")
# Greedy autoregressive decoder function.
def loop_fn(scan_state: Any, i: Array) -> Tuple[Any, Array]:
prng = jax.random.fold_in(sample_prng, i)
(dstate, input_token) = scan_state
del i
(logits, dstate, _) = self.decoder(input_tokens=input_token,
target_tokens=None,
start_of_sequence=no_start_of_seq,
decoder_state=dstate)
if sample_method == "sample":
logging.info("Using categorical sampling.")
output_token = jax.random.categorical(prng, logits, axis=-1)
elif sample_method == "greedy":
logging.info("Using greedy sampling.")
output_token = jnp.argmax(logits, axis=-1)
else:
raise ValueError(f"Invalid sampling method: {sample_method}")
logging.info("generate_loop_fn: output_token = %r", output_token)
return ((dstate, output_token), output_token)
# Scan over the sequence length.
iterations = jnp.arange(sequence_length)
initial_scan_state = (dstate, first_token)
(_, output_tokens) = jax.lax.scan(loop_fn, initial_scan_state, iterations)
logging.info("generate: output_tokens = %r", output_tokens)
# Output_tokens has shape (sequence_length, batch_size, 1)
assert output_tokens.shape == (sequence_length, batch_size, 1)
output_tokens = jnp.reshape(
output_tokens, (sequence_length, self.task_config.batch_size))
output_tokens = output_tokens.transpose([1, 0])
return output_tokens