Spaces:
Sleeping
Sleeping
# Copyright 2023 DeepMind Technologies Limited | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""The decoder stack in inference mode.""" | |
from typing import Any, Tuple | |
import gin | |
from transformer import decoder_stack | |
import transformer_layer as tl | |
struct = decoder_stack.struct | |
nn_components = decoder_stack.nn_components | |
position = decoder_stack.position | |
jnp = decoder_stack.jnp | |
attention = decoder_stack.attention | |
DStackWindowState = decoder_stack.DStackWindowState | |
Array = Any | |
TransformerTaskConfig = decoder_stack.TransformerTaskConfig | |
DStackDecoderState = Tuple[tl.DecoderState, ...] | |
class DecoderStackGenerate(decoder_stack.DecoderStack): | |
"""Stack of transformer decoder layers.""" | |
layer_factory = tl.TransformerLayerGenerate | |
def init_decoder_state_vanilla( | |
self, sequence_length: int, start_of_sequence: Array | |
) -> DStackDecoderState: | |
"""Return initial state for autoregressive generation.""" | |
return tuple( | |
[ | |
layer.init_decoder_state_vanilla(sequence_length, start_of_sequence) | |
for layer in self.transformer_layers | |
] | |
) | |