|
from typing import Sequence |
|
|
|
import flax.linen as nn |
|
|
|
|
|
class MLP(nn.Module): |
|
features: Sequence[int] |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
for feat in self.features[:-1]: |
|
x = nn.relu(nn.Dense(feat)(x)) |
|
x = nn.Dense(self.features[-1])(x) |
|
return x |
|
|
|
|
|
def assertEqual(actual, expected, msg, first="Got", second="Expected"): |
|
if actual != expected: |
|
raise ValueError(msg + f' {first}: "{actual}" {second}: "{expected}"') |
|
|
|
|
|
def assertIn(actual, expected, msg, first="Got", second="Expected one of"): |
|
if actual not in expected: |
|
raise ValueError(msg + f' {first}: "{actual}" {second}: {expected}') |
|
|