Spaces:
Sleeping
Sleeping
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Integration tests for the full RASP -> transformer compilation.""" | |
from absl.testing import absltest | |
from absl.testing import parameterized | |
import jax | |
import numpy as np | |
from tracr.compiler import compiling | |
from tracr.compiler import lib | |
from tracr.compiler import test_cases | |
from tracr.craft import tests_common | |
from tracr.rasp import rasp | |
_COMPILER_BOS = "rasp_to_transformer_integration_test_BOS" | |
_COMPILER_PAD = "rasp_to_transformer_integration_test_PAD" | |
# Force float32 precision on TPU, which otherwise defaults to float16. | |
jax.config.update("jax_default_matmul_precision", "float32") | |
class CompilerIntegrationTest(tests_common.VectorFnTestCase): | |
def assertSequenceEqualWhenExpectedIsNotNone(self, actual_seq, expected_seq): | |
for actual, expected in zip(actual_seq, expected_seq): | |
if expected is not None and actual != expected: | |
self.fail(f"{actual_seq} does not match (ignoring Nones) " | |
f"{expected_seq=}") | |
def test_rasp_program_and_transformer_produce_same_output(self, program): | |
vocab = {0, 1, 2} | |
max_seq_len = 3 | |
assembled_model = compiling.compile_rasp_to_model( | |
program, vocab, max_seq_len, compiler_bos=_COMPILER_BOS) | |
test_outputs = {} | |
rasp_outputs = {} | |
for val in vocab: | |
test_outputs[val] = assembled_model.apply([_COMPILER_BOS, val]).decoded[1] | |
rasp_outputs[val] = program([val])[0] | |
with self.subTest(val=0): | |
self.assertEqual(test_outputs[0], rasp_outputs[0]) | |
with self.subTest(val=1): | |
self.assertEqual(test_outputs[1], rasp_outputs[1]) | |
with self.subTest(val=2): | |
self.assertEqual(test_outputs[2], rasp_outputs[2]) | |
def test_compiled_models_produce_expected_output(self, program, vocab, | |
test_input, expected_output, | |
max_seq_len, **kwargs): | |
del kwargs | |
assembled_model = compiling.compile_rasp_to_model( | |
program, vocab, max_seq_len, compiler_bos=_COMPILER_BOS) | |
test_output = assembled_model.apply([_COMPILER_BOS] + test_input) | |
if isinstance(expected_output[0], (int, float)): | |
np.testing.assert_allclose( | |
test_output.decoded[1:], expected_output, atol=1e-7, rtol=0.005) | |
else: | |
self.assertSequenceEqualWhenExpectedIsNotNone(test_output.decoded[1:], | |
expected_output) | |
def test_compiled_causal_models_produce_expected_output( | |
self, program, vocab, test_input, expected_output, max_seq_len, **kwargs): | |
del kwargs | |
assembled_model = compiling.compile_rasp_to_model( | |
program, | |
vocab, | |
max_seq_len, | |
causal=True, | |
compiler_bos=_COMPILER_BOS, | |
compiler_pad=_COMPILER_PAD) | |
test_output = assembled_model.apply([_COMPILER_BOS] + test_input) | |
if isinstance(expected_output[0], (int, float)): | |
np.testing.assert_allclose( | |
test_output.decoded[1:], expected_output, atol=1e-7, rtol=0.005) | |
else: | |
self.assertSequenceEqualWhenExpectedIsNotNone(test_output.decoded[1:], | |
expected_output) | |
def test_compiled_models_produce_expected_output_with_padding( | |
self, program, vocab, test_input, expected_output, max_seq_len, **kwargs): | |
del kwargs | |
assembled_model = compiling.compile_rasp_to_model( | |
program, | |
vocab, | |
max_seq_len, | |
compiler_bos=_COMPILER_BOS, | |
compiler_pad=_COMPILER_PAD) | |
pad_len = (max_seq_len - len(test_input)) | |
test_input = test_input + [_COMPILER_PAD] * pad_len | |
test_input = [_COMPILER_BOS] + test_input | |
test_output = assembled_model.apply(test_input) | |
output = test_output.decoded | |
output_len = len(output) | |
output_stripped = test_output.decoded[1:output_len - pad_len] | |
self.assertEqual(output[0], _COMPILER_BOS) | |
if isinstance(expected_output[0], (int, float)): | |
np.testing.assert_allclose( | |
output_stripped, expected_output, atol=1e-7, rtol=0.005) | |
else: | |
self.assertEqual(output_stripped, expected_output) | |
if __name__ == "__main__": | |
absltest.main() | |