GeoGenSolve / aglib /meliad /transformer /memory_factory.py
HugoVoxx's picture
Upload 20 files
15bcbe6 verified
raw
history blame
4.33 kB
# Copyright 2022 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flax modules and functions for using external memory."""
from typing import Any, Optional, Tuple
from absl import logging
from flax import linen
import gin
import jax
from transformer import memory_layer
PRNGKey = Any
Shape = Tuple[int]
Dtype = Any
Array = Any
MemoryResource = Any
class MemoryManager:
"""Manages any external resources that may be required by external memory.
MemoryManager also functions as a factory, to create Flax modules that will
read and write to whatever external memory has been configured.
"""
def __init__(self,
batch_size: int,
mode: str,
num_heads: int,
key_size: int,
value_size: int,
database_size: Optional[int] = None,
dtype: Dtype = "float32",
off_device_memory: Optional[MemoryResource] = None):
"""Create a MemoryManager object.
A MemoryManager configures external memory, and is used as a factory to
construct flax modules that read or write to the memory.
Args:
batch_size: The number of separate documents in a batch.
mode: e.g. ("train", or "test")
num_heads: The number of transformer heads.
key_size: The length of the key vectors.
value_size: The length of the value vectors.
database_size: The total number of tokens in the database.
dtype: The datatype used for keys and values.
off_device_memory: An object which manages underlying SCAM memory.
If None, then the model will use on-device memory.
"""
self.batch_size = batch_size
self.mode = mode
self.num_heads = num_heads
self.key_size = key_size
self.value_size = value_size
self.database_size = database_size
self.dtype = dtype
self.off_device_memory = off_device_memory
def create_memory_layer(self) -> linen.Module:
"""Create a flax Module that implements external memory."""
num_datasets = (
self.batch_size * self.num_heads #
if self.off_device_memory is None #
else self.num_heads)
if self.off_device_memory is not None:
mem_layer = None
if mem_layer is None:
raise ValueError("Off-device memory is not supported at this time.")
return memory_layer.BatchedMemory(
mem_layer,
split_dimensions=(-2,),
)
else:
assert self.database_size is not None
mem_layer = memory_layer.MemoryOnTpu(num_datasets=num_datasets,
key_features=self.key_size,
value_features=self.value_size,
database_size=self.database_size,
dtype=self.dtype)
# Handle queries of shape [batch_size, seq_len, num_heads, kv_features]
return memory_layer.BatchedMemory(mem_layer,
split_dimensions=(0, -2))
@gin.configurable
def memory_on_tpu_factory(batch_size: int,
mode: str,
num_heads: int = gin.REQUIRED,
key_size: int = gin.REQUIRED,
value_size: int = gin.REQUIRED,
database_size: int = gin.REQUIRED,
dtype: Dtype = gin.REQUIRED) -> MemoryManager:
"""Implement SCAM memory on device."""
return MemoryManager(batch_size=batch_size,
mode=mode,
num_heads=num_heads,
key_size=key_size,
value_size=value_size,
database_size=database_size,
dtype=dtype,
off_device_memory=None)