Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion brainstate/interop/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ def bst_get_norm(layer, attr, has_offset):

def bst_set_norm(layer, attr, scale, offset):
"""Write the affine parameters of a normalization layer."""
if scale is None and offset is None:
return
d = {}
if scale is not None:
d['scale'] = scale
Expand Down Expand Up @@ -274,7 +276,12 @@ def bst_get_batchnorm(layer):
def bst_set_batchnorm(layer, scale, offset, running_mean, running_var):
"""Write affine + running statistics of a brainstate ``BatchNorm``."""
if layer.weight is not None:
layer.weight.value = {'scale': scale, 'bias': offset}
d = {}
if scale is not None:
d['scale'] = scale
if offset is not None:
d['bias'] = offset
layer.weight.value = d
if running_mean is not None:
layer.running_mean.value = running_mean
if running_var is not None:
Expand Down
102 changes: 102 additions & 0 deletions brainstate/interop/_conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,5 +669,107 @@ def test_missing_dependency_error(self):
lazy_import('a_framework_that_does_not_exist_xyz')


# ===========================================================================
# Edge-case tests for audit-discovered bugs
# ===========================================================================

class NnxNormNoAffineTest(absltest.TestCase):
"""Norm layers with all affine params disabled must not crash."""

@classmethod
def setUpClass(cls):
cls.nnx = pytest.importorskip('flax.nnx')

def _rngs(self):
return self.nnx.Rngs(brainstate.random.split_key())

def test_layernorm_no_affine_import_export(self):
nnx = self.nnx
x = brainstate.random.randn(4, 6)
src = nnx.LayerNorm(6, use_scale=False, use_bias=False, rngs=self._rngs())
dst = interop.from_nnx(src)
assert_close(src(x), bst_forward(dst, x), 'layernorm no-affine import')
back = interop.to_nnx(dst)
assert_close(src(x), back(x), 'layernorm no-affine export')

def test_groupnorm_no_affine_import_export(self):
nnx = self.nnx
x = brainstate.random.randn(4, 8)
src = nnx.GroupNorm(8, num_groups=2, use_scale=False, use_bias=False, rngs=self._rngs())
dst = interop.from_nnx(src)
assert_close(src(x), bst_forward(dst, x), 'groupnorm no-affine import')
back = interop.to_nnx(dst)
assert_close(src(x), back(x), 'groupnorm no-affine export')

def test_layernorm_scale_only(self):
nnx = self.nnx
x = brainstate.random.randn(4, 6)
src = nnx.LayerNorm(6, use_scale=True, use_bias=False, rngs=self._rngs())
src.scale[...] = brainstate.random.randn(6)
dst = interop.from_nnx(src)
assert_close(src(x), bst_forward(dst, x), 'layernorm scale-only import')
back = interop.to_nnx(dst)
assert_close(src(x), back(x), 'layernorm scale-only export')

def test_layernorm_bias_only(self):
nnx = self.nnx
x = brainstate.random.randn(4, 6)
src = nnx.LayerNorm(6, use_scale=False, use_bias=True, rngs=self._rngs())
src.bias[...] = brainstate.random.randn(6)
dst = interop.from_nnx(src)
assert_close(src(x), bst_forward(dst, x), 'layernorm bias-only import')
back = interop.to_nnx(dst)
assert_close(src(x), back(x), 'layernorm bias-only export')


class EquinoxNormNoAffineTest(absltest.TestCase):
"""Norm layers with no affine params in equinox."""

@classmethod
def setUpClass(cls):
cls.eqx = pytest.importorskip('equinox')

def test_layernorm_no_affine_import_export(self):
eqx = self.eqx
x = brainstate.random.randn(4, 6)
src = eqx.nn.LayerNorm(6, use_weight=False, use_bias=False)
dst = interop.from_equinox(src)
assert_close(jax.vmap(src)(x), bst_forward(dst, x), 'eqx layernorm no-affine import')
back = interop.to_equinox(dst)
assert_close(jax.vmap(src)(x), jax.vmap(back)(x), 'eqx layernorm no-affine export')

def test_groupnorm_no_affine_import_export(self):
eqx = self.eqx
x = brainstate.random.randn(4, 8)
src = eqx.nn.GroupNorm(groups=2, channels=8, channelwise_affine=False)
dst = interop.from_equinox(src)
assert_close(jax.vmap(src)(x), bst_forward(dst, x), 'eqx groupnorm no-affine import')
back = interop.to_equinox(dst)
assert_close(jax.vmap(src)(x), jax.vmap(back)(x), 'eqx groupnorm no-affine export')

def test_rmsnorm_no_scale_import_export(self):
eqx = self.eqx
x = brainstate.random.randn(4, 6)
src = eqx.nn.RMSNorm(6, use_weight=False, use_bias=False)
dst = interop.from_equinox(src)
assert_close(jax.vmap(src)(x), bst_forward(dst, x), 'eqx rmsnorm no-scale import')
back = interop.to_equinox(dst)
assert_close(jax.vmap(src)(x), jax.vmap(back)(x), 'eqx rmsnorm no-scale export')


class NnxConvInputDilationTest(absltest.TestCase):
"""Conv with input_dilation != 1 should raise ConversionError, not silently succeed."""

@classmethod
def setUpClass(cls):
cls.nnx = pytest.importorskip('flax.nnx')

def test_input_dilation_raises(self):
nnx = self.nnx
src = nnx.Conv(3, 4, (3, 3), input_dilation=(2, 2), rngs=nnx.Rngs(0))
with self.assertRaises(ConversionError):
interop.from_nnx(src, sample_input=(8, 8, 3))


if __name__ == '__main__':
absltest.main()
22 changes: 12 additions & 10 deletions brainstate/interop/_frameworks/_equinox.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,15 @@ def _embed_to_foreign(layer, ctx):
def _layernorm_to_bst(m, ctx):
scale = m.weight
bias = m.bias
num = (scale if scale is not None else bias).shape[0]
num = int(m.shape[0])
layer = C.build_layernorm((num,), scale is not None, bias is not None, float(m.eps))
C.bst_set_norm(layer, 'weight', scale, bias)
return layer


def _layernorm_to_foreign(layer, ctx):
scale, offset = C.bst_get_norm(layer, 'weight', has_offset=True)
num = (scale if scale is not None else offset).shape[0]
num = int(layer.in_size[-1])
m = eqx.nn.LayerNorm(num, eps=float(layer.epsilon),
use_weight=scale is not None, use_bias=offset is not None)
if scale is not None:
Expand All @@ -123,24 +123,26 @@ def _layernorm_to_foreign(layer, ctx):

def _rmsnorm_to_bst(m, ctx):
scale = m.weight
num = scale.shape[0]
num = int(m.shape[0])
layer = C.build_rmsnorm((num,), scale is not None, float(m.eps))
C.bst_set_norm(layer, 'scale', scale, None)
return layer


def _rmsnorm_to_foreign(layer, ctx):
scale, _ = C.bst_get_norm(layer, 'scale', has_offset=False)
num = scale.shape[0]
# brainstate RMSNorm has no offset -> use_bias=False
m = eqx.nn.RMSNorm(num, eps=float(layer.epsilon), use_weight=True, use_bias=False)
return _set(m, weight=scale)
num = int(layer.in_size[-1])
m = eqx.nn.RMSNorm(num, eps=float(layer.epsilon),
use_weight=scale is not None, use_bias=False)
if scale is not None:
m = _set(m, weight=scale)
return m


def _groupnorm_to_bst(m, ctx):
scale = m.weight
bias = m.bias
num = m.channels
num = int(m.channels)
layer = C.build_groupnorm((num,), int(m.groups), scale is not None, bias is not None,
float(m.eps))
C.bst_set_norm(layer, 'weight', scale, bias)
Expand All @@ -149,9 +151,9 @@ def _groupnorm_to_bst(m, ctx):

def _groupnorm_to_foreign(layer, ctx):
scale, offset = C.bst_get_norm(layer, 'weight', has_offset=True)
num = (scale if scale is not None else offset).shape[0]
num = int(layer.in_size[-1])
m = eqx.nn.GroupNorm(int(layer.num_groups), num, eps=float(layer.epsilon),
channelwise_affine=scale is not None)
channelwise_affine=(scale is not None or offset is not None))
if scale is not None:
m = _set(m, weight=scale)
if offset is not None:
Expand Down
36 changes: 31 additions & 5 deletions brainstate/interop/_frameworks/_linen.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,24 @@ def _embed_to_foreign(layer, ctx):
# LayerNorm / RMSNorm / GroupNorm (1-D affine vectors -> identity)
# ---------------------------------------------------------------------------

def _norm_num_from_params(*arrays):
"""Extract the feature count from the first non-None affine parameter."""
for a in arrays:
if a is not None:
return int(a.shape[0])
return None


def _layernorm_to_bst(node, ctx):
p = _params(node)
scale = p.get('scale')
bias = p.get('bias')
num = (scale if scale is not None else bias).shape[0]
num = _norm_num_from_params(scale, bias)
if num is None:
raise ConversionError(
"Cannot determine feature count for linen LayerNorm with no affine parameters "
"and no explicit feature size."
)
layer = C.build_layernorm((num,), scale is not None, bias is not None,
float(node.module.epsilon))
C.bst_set_norm(layer, 'weight', scale, bias)
Expand All @@ -138,23 +151,35 @@ def _layernorm_to_foreign(layer, ctx):


def _rmsnorm_to_bst(node, ctx):
scale = _params(node)['scale']
scale = _params(node).get('scale')
if scale is None:
raise ConversionError(
"Cannot determine feature count for linen RMSNorm with no scale parameter."
)
layer = C.build_rmsnorm((scale.shape[0],), True, float(node.module.epsilon))
C.bst_set_norm(layer, 'scale', scale, None)
return layer


def _rmsnorm_to_foreign(layer, ctx):
scale, _ = C.bst_get_norm(layer, 'scale', has_offset=False)
module = nn.RMSNorm(epsilon=float(layer.epsilon), use_scale=True)
return _LinenNode(module, {'params': {'scale': scale}})
has_scale = scale is not None
module = nn.RMSNorm(epsilon=float(layer.epsilon), use_scale=has_scale)
params = {}
if scale is not None:
params['scale'] = scale
return _LinenNode(module, {'params': params})


def _groupnorm_to_bst(node, ctx):
p = _params(node)
scale = p.get('scale')
bias = p.get('bias')
num = (scale if scale is not None else bias).shape[0]
num = _norm_num_from_params(scale, bias)
if num is None:
raise ConversionError(
"Cannot determine feature count for linen GroupNorm with no affine parameters."
)
layer = C.build_groupnorm((num,), int(node.module.num_groups),
scale is not None, bias is not None, float(node.module.epsilon))
C.bst_set_norm(layer, 'weight', scale, bias)
Expand All @@ -163,6 +188,7 @@ def _groupnorm_to_bst(node, ctx):

def _groupnorm_to_foreign(layer, ctx):
scale, offset = C.bst_get_norm(layer, 'weight', has_offset=True)
num = int(layer.in_size[-1])
module = nn.GroupNorm(num_groups=int(layer.num_groups), epsilon=float(layer.epsilon),
use_scale=scale is not None, use_bias=offset is not None)
params = {}
Expand Down
30 changes: 22 additions & 8 deletions brainstate/interop/_frameworks/_nnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from .. import _common as C
from .._common import FrameworkAdapter, lazy_import, new_key
from .._errors import ConversionError
from .._registry import (LayerMapping, register_layer_mapping,
register_unsupported_bst, register_unsupported_foreign)

Expand Down Expand Up @@ -105,15 +106,15 @@ def _embed_to_foreign(layer, ctx):
def _layernorm_to_bst(m, ctx):
scale = None if m.scale is None else _get(m.scale)
bias = None if m.bias is None else _get(m.bias)
num = (scale if scale is not None else bias).shape[0]
num = int(m.num_features)
layer = C.build_layernorm((num,), scale is not None, bias is not None, float(m.epsilon))
C.bst_set_norm(layer, 'weight', scale, bias)
return layer


def _layernorm_to_foreign(layer, ctx):
scale, offset = C.bst_get_norm(layer, 'weight', has_offset=True)
num = (scale if scale is not None else offset).shape[0]
num = int(layer.in_size[-1])
m = nnx.LayerNorm(num, epsilon=float(layer.epsilon),
use_scale=scale is not None, use_bias=offset is not None, rngs=_rngs(ctx))
if scale is not None:
Expand All @@ -125,24 +126,26 @@ def _layernorm_to_foreign(layer, ctx):

def _rmsnorm_to_bst(m, ctx):
scale = None if m.scale is None else _get(m.scale)
num = scale.shape[0]
num = int(m.num_features)
layer = C.build_rmsnorm((num,), scale is not None, float(m.epsilon))
C.bst_set_norm(layer, 'scale', scale, None)
return layer


def _rmsnorm_to_foreign(layer, ctx):
scale, _ = C.bst_get_norm(layer, 'scale', has_offset=False)
num = scale.shape[0]
m = nnx.RMSNorm(num, epsilon=float(layer.epsilon), use_scale=True, rngs=_rngs(ctx))
_set(m.scale, scale)
num = int(layer.in_size[-1])
m = nnx.RMSNorm(num, epsilon=float(layer.epsilon),
use_scale=scale is not None, rngs=_rngs(ctx))
if scale is not None:
_set(m.scale, scale)
return m


def _groupnorm_to_bst(m, ctx):
scale = None if m.scale is None else _get(m.scale)
bias = None if m.bias is None else _get(m.bias)
num = (scale if scale is not None else bias).shape[0]
num = int(m.num_groups) * int(m.group_size)
layer = C.build_groupnorm((num,), int(m.num_groups), scale is not None, bias is not None,
float(m.epsilon))
C.bst_set_norm(layer, 'weight', scale, bias)
Expand All @@ -151,7 +154,7 @@ def _groupnorm_to_bst(m, ctx):

def _groupnorm_to_foreign(layer, ctx):
scale, offset = C.bst_get_norm(layer, 'weight', has_offset=True)
num = (scale if scale is not None else offset).shape[0]
num = int(layer.in_size[-1])
m = nnx.GroupNorm(num, num_groups=int(layer.num_groups), epsilon=float(layer.epsilon),
use_scale=scale is not None, use_bias=offset is not None, rngs=_rngs(ctx))
if scale is not None:
Expand Down Expand Up @@ -187,10 +190,21 @@ def _conv_bias_reshape_to_foreign(b):
return u.math.reshape(b, (b.shape[-1],))


def _as_tuple(v, nd):
if isinstance(v, (tuple, list)):
return tuple(v)
return (v,) * nd


def _conv_to_bst(m, ctx):
w = _get(m.kernel) # (*k, in//g, out)
kernel_size = tuple(w.shape[:-2])
nd = len(kernel_size)
if any(d != 1 for d in _as_tuple(getattr(m, 'input_dilation', 1) or 1, nd)):
raise ConversionError(
"nnx Conv with `input_dilation` != 1 (transposed convolution) is not supported "
"by this converter."
)
out_channels = w.shape[-1]
in_size = ctx.require_size('Conv')
has_bias = m.bias is not None
Expand Down
8 changes: 6 additions & 2 deletions brainstate/interop/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,12 @@ def lookup_import(framework: str, foreign_type: type) -> Optional[LayerMapping]:

def lookup_export(bst_type: type, framework: str) -> Optional[LayerMapping]:
"""Return the export mapping for a brainstate type + framework, or ``None``."""
table = {bt: m for (bt, fw), m in _EXPORT.items() if fw == framework}
return _lookup_by_mro(table, bst_type)
if (bst_type, framework) in _EXPORT:
return _EXPORT[(bst_type, framework)]
for base in bst_type.__mro__:
if (base, framework) in _EXPORT:
return _EXPORT[(base, framework)]
return None


def unsupported_bst_reason(bst_type: type, framework: Optional[str] = None) -> Optional[str]:
Expand Down