Spaces:
Sleeping
Sleeping
# Copyright 2022 Google. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Sequence to sequence model.""" | |
from typing import Any, Callable, Dict, Tuple | |
from absl import logging | |
from flax import linen as nn | |
from flax.training import common_utils | |
import gin | |
import jax | |
import jax.numpy as jnp | |
import metrics_summary | |
from transformer import decoder_stack | |
from transformer import metric_utils | |
from transformer import text_dataset | |
import numpy as np | |
import seqio | |
Array = jnp.ndarray | |
MetricsSummary = metrics_summary.MetricsSummary | |
# TODO(mrabe): Remove this function and find a better way to turn text metrics | |
# into text on tensorboard. | |
def process_summaries(vocab: seqio.Vocabulary, | |
met_summary: MetricsSummary, | |
mode: str) -> MetricsSummary: | |
"""Compute some additional summaries, and convert tokens to text. | |
Args: | |
vocab: The vocabulary to detokenize generated text. | |
met_summary: The summary object to process. | |
mode: The mode of the summary (e.g. "test", "train") | |
Returns: | |
The modified summary dictionary. | |
""" | |
mdict = met_summary.current_metric_dict() | |
# Calculate perplexity from the average nats_per_token over all replicas. | |
# This has to be done here, because the perplexities themselves can't be | |
# averaged in the usual way. | |
if "nats_per_token" in mdict: | |
nats_per_token = mdict["nats_per_token"].to_value() | |
met_summary.add({"perplexity": np.exp(nats_per_token)}) | |
if mode == "generate" and "gen_tokens" in mdict: | |
# Convert output tokens to example output text. | |
# Write text to both the summary, and pretty-print to the log file. | |
gen_toks = mdict["gen_tokens"].to_value() | |
if np.ndim(gen_toks) != 2: | |
raise ValueError("Unsupported shape for gen_tokens: %s" % gen_toks.shape) | |
ntoks = gen_toks.shape[-1] | |
gen_text = text_dataset.decode_tokens(gen_toks, vocab, max_length=ntoks) | |
logging.info("Generated text = %s", gen_text) | |
met_summary.add_text({"gen_text": gen_text}) | |
del mdict["gen_tokens"] # Otherwise it will turn into a histogram. | |
return met_summary | |
def process_summaries_function(vocab: seqio.Vocabulary) -> Callable[ | |
[MetricsSummary, str], MetricsSummary]: | |
"""Return a function that processes summaries with the given vocabulary.""" | |
# For use with training_loop.process_summaries_function | |
def process_fn(met_summary: MetricsSummary, mode: str): | |
return process_summaries(vocab, met_summary, mode) | |
return process_fn | |
class DecoderOnlyLanguageModel(nn.Module): | |
"""Decoder only language modeling.""" | |
mode: str | |
task_config: decoder_stack.TransformerTaskConfig = gin.REQUIRED | |
decoder_factory: Callable[[], Any] = gin.REQUIRED | |
sample_method: str = "sample" # Can be {"sample", "greedy"} | |
output_token_losses: bool = False | |
def get_fake_input(self): | |
"""Returns a fake input for initialization of the appropriate shape.""" | |
b = self.task_config.batch_size | |
fake_input_dict = { | |
"targets": jnp.ones([b, self.task_config.sequence_length], | |
dtype=jnp.int32), | |
"start_of_sequence": jnp.ones([b], dtype=jnp.bool_), | |
"epoch": jnp.ones([b], dtype=jnp.int32), | |
} | |
if text_dataset.get_loss_mask_tokens(split=self.mode) != (None, None): | |
# We are not adding the loss mask to the dummy input by default as it can | |
# cause a slowdown during evaluation and perhaps inference. | |
fake_input_dict["loss_mask"] = jnp.ones( | |
[b, self.task_config.sequence_length], dtype=jnp.bool_) | |
return fake_input_dict | |
def metrics_summary_operations(self, aggregate_over: str) -> Dict[str, str]: | |
"""Summary operation to use for recorded metrics.""" | |
metric_ops = { | |
"loss": "mean", | |
"nats_per_token": "mean", | |
"bits_per_token": "mean", | |
"bits_per_char": "mean", | |
"accuracy": "mean", | |
"num_tokens": "mean", | |
"num_chars_per_device": "mean", | |
"num_chars_per_batch": "mean", | |
"nonzero_tokens": "mean", | |
"num_tokens_per_device": "mean", | |
"num_tokens_per_batch": "mean", | |
"epoch": "mean", | |
} | |
if aggregate_over == "steps": | |
return metric_ops | |
elif aggregate_over == "devices": | |
# Ensure that statistics that refer to the total batch size stay constant | |
# as TPU topologies change. For those we have to sum over devices, but | |
# compute the mean over steps. | |
metric_ops.update({ | |
"num_tokens_per_batch": "sum", | |
"num_chars_per_batch": "sum", | |
"loss": "sum"}) | |
return metric_ops | |
else: | |
raise ValueError("Don't know how to aggregate over: %s" % aggregate_over) | |
def setup(self): | |
self.decoder = self.decoder_factory(mode=self.mode, | |
task_config=self.task_config) # pytype: disable=wrong-keyword-args # trace-all-classes | |
def __call__(self, inputs: ...): | |
task_config = self.task_config | |
input_tokens = inputs["targets"] # [b, seq_len] | |
start_of_sequence = inputs["start_of_sequence"] # [b] | |
epochs = inputs["epoch"] # [b] | |
if "loss_mask" in inputs: | |
loss_mask = inputs["loss_mask"] # [b, seq_len] | |
else: | |
loss_mask = jnp.ones((1, 1), dtype=jnp.bool_) | |
input_tokens = jnp.asarray(input_tokens) | |
assert input_tokens.ndim == 2 | |
assert input_tokens.shape[0] == task_config.batch_size | |
assert input_tokens.shape[1] == task_config.sequence_length | |
assert start_of_sequence.shape[0] == task_config.batch_size | |
# Sanity check to avoid out-of-bounds on token lookup. | |
input_tokens = input_tokens % task_config.vocab_size | |
logging.info("langmodel: Compiling model for mode %s", self.mode) | |
logging.info("langmodel: input_tokens = %r", input_tokens) | |
logging.info("langmodel: start_of_sequece = %r", start_of_sequence) | |
logging.info("langmodel: epochs = %r", epochs) | |
# The target outputs are the next character in each sequence. | |
# Shift tokens left and pad with a zero at the end. | |
# TODO(delesley): We don't predict the first token of each sequence. | |
target_tokens = jnp.pad(input_tokens[:, 1:], [(0, 0), (0, 1)]) | |
logging.info("langmodel: target_tokens = %r", target_tokens) | |
# Invoke the decoder stack. | |
# The decoder will return pre-softmax logits for the predicted targets. | |
(logits, _, d_metrics) = self.decoder(input_tokens=input_tokens, | |
target_tokens=target_tokens, | |
start_of_sequence=start_of_sequence) | |
# Softmax cross-entropy loss on target tokens. | |
logits = nn.log_softmax(logits, axis=-1) # (b, seq_len, vocab_size) | |
logging.info("langmodel: logits = %r", logits) | |
soft_targets = common_utils.onehot(target_tokens, task_config.vocab_size) | |
logging.info("langmodel: soft_targets = %r", soft_targets) | |
losses = -jnp.sum(soft_targets * logits, axis=-1) # (b, seq_len) | |
logging.info("langmodel: losses = %r", losses) | |
# Don't predict null tokens which are past the end-of-sequence. | |
# Also don't predict the 0 at the end of the sequence. | |
# TODO(delesley): Predict the final end-of-sequence marker. | |
loss_mask = jnp.logical_and( | |
loss_mask, | |
input_tokens > 0) | |
loss_mask = jnp.logical_and( | |
loss_mask, | |
target_tokens > 0) | |
logging.info("langmodel: loss_mask = %r", loss_mask) | |
losses = jnp.where(loss_mask, losses, 0.0) # (batch_size, seq_len) | |
loss = jnp.sum(losses) # total loss on device | |
token_count = jnp.sum(loss_mask) # tokens on device | |
token_count_nz = token_count + 1.0e-6 | |
loss_per_token = loss / token_count_nz | |
bits_per_token = loss_per_token * 1.442695 # log(e)/log(2) | |
accuracy = metric_utils.compute_accuracy_sum(logits, target_tokens, | |
loss_mask) | |
accuracy = accuracy / token_count_nz # Percent correct. | |
epoch = jnp.mean(epochs) | |
if self.mode == "generate" and self.decoder.supports_generate(): | |
# Generate example text. | |
logging.info("lang_model: text inference.") | |
gen_tokens = self.generate(inputs, task_config.sequence_length) | |
# Return generated text, along with vizualizations and histograms. | |
metrics = {"gen_tokens": gen_tokens, **d_metrics} | |
return (loss, metrics) | |
# Just return metrics related to the loss. | |
metrics = { | |
"loss": loss, # will be summed over devices | |
"nats_per_token": (loss_per_token, token_count), | |
"bits_per_token": (bits_per_token, token_count), | |
"accuracy": (accuracy, token_count), | |
"num_tokens_per_device": token_count, | |
"num_tokens_per_batch": token_count, # will be summed over devices | |
"epoch": epoch, | |
} | |
# Compute bits per character if we have the number of characters. | |
if "num_chars" in inputs: | |
num_chars = jnp.sum(inputs["num_chars"]) | |
bits_per_char = loss / (num_chars + 1e-6) * 1.442695 | |
metrics["num_chars_per_device"] = num_chars | |
metrics["num_chars_per_batch"] = num_chars # will be summed over devices | |
metrics["bits_per_char"] = (bits_per_char, num_chars) | |
# Provided to make sure that the data pipeline and the the model agree | |
# on the number of tokens with a loss. | |
if "nonzero_tokens" in inputs: | |
nonzero_tokens = jnp.sum(inputs["nonzero_tokens"]) | |
metrics["nonzero_tokens"] = nonzero_tokens | |
if self.output_token_losses: | |
metrics["token_losses"] = losses | |
return (loss, metrics) | |
def generate(self, inputs: ..., sequence_length: int) -> Array: | |
"""Generate an output sequence. | |
Args: | |
inputs: the same as argument to _call_. | |
sequence_length: the length of sequence to generate. | |
Returns: | |
An array of generated tokens of shape (batch_size, sequence_length). | |
""" | |
# TODO(delesley): Add support for passing the prefix as an argument. | |
# TODO(delesley): Add support for temperature, gumbel softmax, beam search. | |
batch_size = self.task_config.batch_size | |
input_tokens = inputs["targets"] # [b,seq_len] | |
start_of_sequence = inputs["start_of_sequence"] # [b] | |
# Initialize decoder. | |
dstate = self.decoder.init_decoder_state(sequence_length, | |
start_of_sequence) | |
# TODO(delesley): Handle start-of-sequence in a better way. | |
# There is no special token for start of sequence, so we grab the first | |
# one from the ground-truth input data. | |
first_token = input_tokens[:, 0:1] | |
no_start_of_seq = jnp.array([False] * batch_size, dtype=jnp.bool_) | |
sample_method = self.sample_method | |
sample_prng = self.make_rng("sample") | |
# Greedy autoregressive decoder function. | |
def loop_fn(scan_state: Any, i: Array) -> Tuple[Any, Array]: | |
prng = jax.random.fold_in(sample_prng, i) | |
(dstate, input_token) = scan_state | |
del i | |
(logits, dstate, _) = self.decoder(input_tokens=input_token, | |
target_tokens=None, | |
start_of_sequence=no_start_of_seq, | |
decoder_state=dstate) | |
if sample_method == "sample": | |
logging.info("Using categorical sampling.") | |
output_token = jax.random.categorical(prng, logits, axis=-1) | |
elif sample_method == "greedy": | |
logging.info("Using greedy sampling.") | |
output_token = jnp.argmax(logits, axis=-1) | |
else: | |
raise ValueError(f"Invalid sampling method: {sample_method}") | |
logging.info("generate_loop_fn: output_token = %r", output_token) | |
return ((dstate, output_token), output_token) | |
# Scan over the sequence length. | |
iterations = jnp.arange(sequence_length) | |
initial_scan_state = (dstate, first_token) | |
(_, output_tokens) = jax.lax.scan(loop_fn, initial_scan_state, iterations) | |
logging.info("generate: output_tokens = %r", output_tokens) | |
# Output_tokens has shape (sequence_length, batch_size, 1) | |
assert output_tokens.shape == (sequence_length, batch_size, 1) | |
output_tokens = jnp.reshape( | |
output_tokens, (sequence_length, self.task_config.batch_size)) | |
output_tokens = output_tokens.transpose([1, 0]) | |
return output_tokens | |