diff --git a/brainstate/interop/_common.py b/brainstate/interop/_common.py index 4e44141..4547da5 100644 --- a/brainstate/interop/_common.py +++ b/brainstate/interop/_common.py @@ -240,6 +240,8 @@ def bst_get_norm(layer, attr, has_offset): def bst_set_norm(layer, attr, scale, offset): """Write the affine parameters of a normalization layer.""" + if scale is None and offset is None: + return d = {} if scale is not None: d['scale'] = scale @@ -274,7 +276,12 @@ def bst_get_batchnorm(layer): def bst_set_batchnorm(layer, scale, offset, running_mean, running_var): """Write affine + running statistics of a brainstate ``BatchNorm``.""" if layer.weight is not None: - layer.weight.value = {'scale': scale, 'bias': offset} + d = {} + if scale is not None: + d['scale'] = scale + if offset is not None: + d['bias'] = offset + layer.weight.value = d if running_mean is not None: layer.running_mean.value = running_mean if running_var is not None: diff --git a/brainstate/interop/_conversion_test.py b/brainstate/interop/_conversion_test.py index 857a426..730b5e6 100644 --- a/brainstate/interop/_conversion_test.py +++ b/brainstate/interop/_conversion_test.py @@ -669,5 +669,107 @@ def test_missing_dependency_error(self): lazy_import('a_framework_that_does_not_exist_xyz') +# =========================================================================== +# Edge-case tests for audit-discovered bugs +# =========================================================================== + +class NnxNormNoAffineTest(absltest.TestCase): + """Norm layers with all affine params disabled must not crash.""" + + @classmethod + def setUpClass(cls): + cls.nnx = pytest.importorskip('flax.nnx') + + def _rngs(self): + return self.nnx.Rngs(brainstate.random.split_key()) + + def test_layernorm_no_affine_import_export(self): + nnx = self.nnx + x = brainstate.random.randn(4, 6) + src = nnx.LayerNorm(6, use_scale=False, use_bias=False, rngs=self._rngs()) + dst = interop.from_nnx(src) + assert_close(src(x), bst_forward(dst, x), 'layernorm no-affine import') + back = interop.to_nnx(dst) + assert_close(src(x), back(x), 'layernorm no-affine export') + + def test_groupnorm_no_affine_import_export(self): + nnx = self.nnx + x = brainstate.random.randn(4, 8) + src = nnx.GroupNorm(8, num_groups=2, use_scale=False, use_bias=False, rngs=self._rngs()) + dst = interop.from_nnx(src) + assert_close(src(x), bst_forward(dst, x), 'groupnorm no-affine import') + back = interop.to_nnx(dst) + assert_close(src(x), back(x), 'groupnorm no-affine export') + + def test_layernorm_scale_only(self): + nnx = self.nnx + x = brainstate.random.randn(4, 6) + src = nnx.LayerNorm(6, use_scale=True, use_bias=False, rngs=self._rngs()) + src.scale[...] = brainstate.random.randn(6) + dst = interop.from_nnx(src) + assert_close(src(x), bst_forward(dst, x), 'layernorm scale-only import') + back = interop.to_nnx(dst) + assert_close(src(x), back(x), 'layernorm scale-only export') + + def test_layernorm_bias_only(self): + nnx = self.nnx + x = brainstate.random.randn(4, 6) + src = nnx.LayerNorm(6, use_scale=False, use_bias=True, rngs=self._rngs()) + src.bias[...] = brainstate.random.randn(6) + dst = interop.from_nnx(src) + assert_close(src(x), bst_forward(dst, x), 'layernorm bias-only import') + back = interop.to_nnx(dst) + assert_close(src(x), back(x), 'layernorm bias-only export') + + +class EquinoxNormNoAffineTest(absltest.TestCase): + """Norm layers with no affine params in equinox.""" + + @classmethod + def setUpClass(cls): + cls.eqx = pytest.importorskip('equinox') + + def test_layernorm_no_affine_import_export(self): + eqx = self.eqx + x = brainstate.random.randn(4, 6) + src = eqx.nn.LayerNorm(6, use_weight=False, use_bias=False) + dst = interop.from_equinox(src) + assert_close(jax.vmap(src)(x), bst_forward(dst, x), 'eqx layernorm no-affine import') + back = interop.to_equinox(dst) + assert_close(jax.vmap(src)(x), jax.vmap(back)(x), 'eqx layernorm no-affine export') + + def test_groupnorm_no_affine_import_export(self): + eqx = self.eqx + x = brainstate.random.randn(4, 8) + src = eqx.nn.GroupNorm(groups=2, channels=8, channelwise_affine=False) + dst = interop.from_equinox(src) + assert_close(jax.vmap(src)(x), bst_forward(dst, x), 'eqx groupnorm no-affine import') + back = interop.to_equinox(dst) + assert_close(jax.vmap(src)(x), jax.vmap(back)(x), 'eqx groupnorm no-affine export') + + def test_rmsnorm_no_scale_import_export(self): + eqx = self.eqx + x = brainstate.random.randn(4, 6) + src = eqx.nn.RMSNorm(6, use_weight=False, use_bias=False) + dst = interop.from_equinox(src) + assert_close(jax.vmap(src)(x), bst_forward(dst, x), 'eqx rmsnorm no-scale import') + back = interop.to_equinox(dst) + assert_close(jax.vmap(src)(x), jax.vmap(back)(x), 'eqx rmsnorm no-scale export') + + +class NnxConvInputDilationTest(absltest.TestCase): + """Conv with input_dilation != 1 should raise ConversionError, not silently succeed.""" + + @classmethod + def setUpClass(cls): + cls.nnx = pytest.importorskip('flax.nnx') + + def test_input_dilation_raises(self): + nnx = self.nnx + src = nnx.Conv(3, 4, (3, 3), input_dilation=(2, 2), rngs=nnx.Rngs(0)) + with self.assertRaises(ConversionError): + interop.from_nnx(src, sample_input=(8, 8, 3)) + + if __name__ == '__main__': absltest.main() diff --git a/brainstate/interop/_frameworks/_equinox.py b/brainstate/interop/_frameworks/_equinox.py index 11acf88..0dba520 100644 --- a/brainstate/interop/_frameworks/_equinox.py +++ b/brainstate/interop/_frameworks/_equinox.py @@ -103,7 +103,7 @@ def _embed_to_foreign(layer, ctx): def _layernorm_to_bst(m, ctx): scale = m.weight bias = m.bias - num = (scale if scale is not None else bias).shape[0] + num = int(m.shape[0]) layer = C.build_layernorm((num,), scale is not None, bias is not None, float(m.eps)) C.bst_set_norm(layer, 'weight', scale, bias) return layer @@ -111,7 +111,7 @@ def _layernorm_to_bst(m, ctx): def _layernorm_to_foreign(layer, ctx): scale, offset = C.bst_get_norm(layer, 'weight', has_offset=True) - num = (scale if scale is not None else offset).shape[0] + num = int(layer.in_size[-1]) m = eqx.nn.LayerNorm(num, eps=float(layer.epsilon), use_weight=scale is not None, use_bias=offset is not None) if scale is not None: @@ -123,7 +123,7 @@ def _layernorm_to_foreign(layer, ctx): def _rmsnorm_to_bst(m, ctx): scale = m.weight - num = scale.shape[0] + num = int(m.shape[0]) layer = C.build_rmsnorm((num,), scale is not None, float(m.eps)) C.bst_set_norm(layer, 'scale', scale, None) return layer @@ -131,16 +131,18 @@ def _rmsnorm_to_bst(m, ctx): def _rmsnorm_to_foreign(layer, ctx): scale, _ = C.bst_get_norm(layer, 'scale', has_offset=False) - num = scale.shape[0] - # brainstate RMSNorm has no offset -> use_bias=False - m = eqx.nn.RMSNorm(num, eps=float(layer.epsilon), use_weight=True, use_bias=False) - return _set(m, weight=scale) + num = int(layer.in_size[-1]) + m = eqx.nn.RMSNorm(num, eps=float(layer.epsilon), + use_weight=scale is not None, use_bias=False) + if scale is not None: + m = _set(m, weight=scale) + return m def _groupnorm_to_bst(m, ctx): scale = m.weight bias = m.bias - num = m.channels + num = int(m.channels) layer = C.build_groupnorm((num,), int(m.groups), scale is not None, bias is not None, float(m.eps)) C.bst_set_norm(layer, 'weight', scale, bias) @@ -149,9 +151,9 @@ def _groupnorm_to_bst(m, ctx): def _groupnorm_to_foreign(layer, ctx): scale, offset = C.bst_get_norm(layer, 'weight', has_offset=True) - num = (scale if scale is not None else offset).shape[0] + num = int(layer.in_size[-1]) m = eqx.nn.GroupNorm(int(layer.num_groups), num, eps=float(layer.epsilon), - channelwise_affine=scale is not None) + channelwise_affine=(scale is not None or offset is not None)) if scale is not None: m = _set(m, weight=scale) if offset is not None: diff --git a/brainstate/interop/_frameworks/_linen.py b/brainstate/interop/_frameworks/_linen.py index e08f0d9..fcc56c3 100644 --- a/brainstate/interop/_frameworks/_linen.py +++ b/brainstate/interop/_frameworks/_linen.py @@ -114,11 +114,24 @@ def _embed_to_foreign(layer, ctx): # LayerNorm / RMSNorm / GroupNorm (1-D affine vectors -> identity) # --------------------------------------------------------------------------- +def _norm_num_from_params(*arrays): + """Extract the feature count from the first non-None affine parameter.""" + for a in arrays: + if a is not None: + return int(a.shape[0]) + return None + + def _layernorm_to_bst(node, ctx): p = _params(node) scale = p.get('scale') bias = p.get('bias') - num = (scale if scale is not None else bias).shape[0] + num = _norm_num_from_params(scale, bias) + if num is None: + raise ConversionError( + "Cannot determine feature count for linen LayerNorm with no affine parameters " + "and no explicit feature size." + ) layer = C.build_layernorm((num,), scale is not None, bias is not None, float(node.module.epsilon)) C.bst_set_norm(layer, 'weight', scale, bias) @@ -138,7 +151,11 @@ def _layernorm_to_foreign(layer, ctx): def _rmsnorm_to_bst(node, ctx): - scale = _params(node)['scale'] + scale = _params(node).get('scale') + if scale is None: + raise ConversionError( + "Cannot determine feature count for linen RMSNorm with no scale parameter." + ) layer = C.build_rmsnorm((scale.shape[0],), True, float(node.module.epsilon)) C.bst_set_norm(layer, 'scale', scale, None) return layer @@ -146,15 +163,23 @@ def _rmsnorm_to_bst(node, ctx): def _rmsnorm_to_foreign(layer, ctx): scale, _ = C.bst_get_norm(layer, 'scale', has_offset=False) - module = nn.RMSNorm(epsilon=float(layer.epsilon), use_scale=True) - return _LinenNode(module, {'params': {'scale': scale}}) + has_scale = scale is not None + module = nn.RMSNorm(epsilon=float(layer.epsilon), use_scale=has_scale) + params = {} + if scale is not None: + params['scale'] = scale + return _LinenNode(module, {'params': params}) def _groupnorm_to_bst(node, ctx): p = _params(node) scale = p.get('scale') bias = p.get('bias') - num = (scale if scale is not None else bias).shape[0] + num = _norm_num_from_params(scale, bias) + if num is None: + raise ConversionError( + "Cannot determine feature count for linen GroupNorm with no affine parameters." + ) layer = C.build_groupnorm((num,), int(node.module.num_groups), scale is not None, bias is not None, float(node.module.epsilon)) C.bst_set_norm(layer, 'weight', scale, bias) @@ -163,6 +188,7 @@ def _groupnorm_to_bst(node, ctx): def _groupnorm_to_foreign(layer, ctx): scale, offset = C.bst_get_norm(layer, 'weight', has_offset=True) + num = int(layer.in_size[-1]) module = nn.GroupNorm(num_groups=int(layer.num_groups), epsilon=float(layer.epsilon), use_scale=scale is not None, use_bias=offset is not None) params = {} diff --git a/brainstate/interop/_frameworks/_nnx.py b/brainstate/interop/_frameworks/_nnx.py index b50c482..7f02303 100644 --- a/brainstate/interop/_frameworks/_nnx.py +++ b/brainstate/interop/_frameworks/_nnx.py @@ -26,6 +26,7 @@ from .. import _common as C from .._common import FrameworkAdapter, lazy_import, new_key +from .._errors import ConversionError from .._registry import (LayerMapping, register_layer_mapping, register_unsupported_bst, register_unsupported_foreign) @@ -105,7 +106,7 @@ def _embed_to_foreign(layer, ctx): def _layernorm_to_bst(m, ctx): scale = None if m.scale is None else _get(m.scale) bias = None if m.bias is None else _get(m.bias) - num = (scale if scale is not None else bias).shape[0] + num = int(m.num_features) layer = C.build_layernorm((num,), scale is not None, bias is not None, float(m.epsilon)) C.bst_set_norm(layer, 'weight', scale, bias) return layer @@ -113,7 +114,7 @@ def _layernorm_to_bst(m, ctx): def _layernorm_to_foreign(layer, ctx): scale, offset = C.bst_get_norm(layer, 'weight', has_offset=True) - num = (scale if scale is not None else offset).shape[0] + num = int(layer.in_size[-1]) m = nnx.LayerNorm(num, epsilon=float(layer.epsilon), use_scale=scale is not None, use_bias=offset is not None, rngs=_rngs(ctx)) if scale is not None: @@ -125,7 +126,7 @@ def _layernorm_to_foreign(layer, ctx): def _rmsnorm_to_bst(m, ctx): scale = None if m.scale is None else _get(m.scale) - num = scale.shape[0] + num = int(m.num_features) layer = C.build_rmsnorm((num,), scale is not None, float(m.epsilon)) C.bst_set_norm(layer, 'scale', scale, None) return layer @@ -133,16 +134,18 @@ def _rmsnorm_to_bst(m, ctx): def _rmsnorm_to_foreign(layer, ctx): scale, _ = C.bst_get_norm(layer, 'scale', has_offset=False) - num = scale.shape[0] - m = nnx.RMSNorm(num, epsilon=float(layer.epsilon), use_scale=True, rngs=_rngs(ctx)) - _set(m.scale, scale) + num = int(layer.in_size[-1]) + m = nnx.RMSNorm(num, epsilon=float(layer.epsilon), + use_scale=scale is not None, rngs=_rngs(ctx)) + if scale is not None: + _set(m.scale, scale) return m def _groupnorm_to_bst(m, ctx): scale = None if m.scale is None else _get(m.scale) bias = None if m.bias is None else _get(m.bias) - num = (scale if scale is not None else bias).shape[0] + num = int(m.num_groups) * int(m.group_size) layer = C.build_groupnorm((num,), int(m.num_groups), scale is not None, bias is not None, float(m.epsilon)) C.bst_set_norm(layer, 'weight', scale, bias) @@ -151,7 +154,7 @@ def _groupnorm_to_bst(m, ctx): def _groupnorm_to_foreign(layer, ctx): scale, offset = C.bst_get_norm(layer, 'weight', has_offset=True) - num = (scale if scale is not None else offset).shape[0] + num = int(layer.in_size[-1]) m = nnx.GroupNorm(num, num_groups=int(layer.num_groups), epsilon=float(layer.epsilon), use_scale=scale is not None, use_bias=offset is not None, rngs=_rngs(ctx)) if scale is not None: @@ -187,10 +190,21 @@ def _conv_bias_reshape_to_foreign(b): return u.math.reshape(b, (b.shape[-1],)) +def _as_tuple(v, nd): + if isinstance(v, (tuple, list)): + return tuple(v) + return (v,) * nd + + def _conv_to_bst(m, ctx): w = _get(m.kernel) # (*k, in//g, out) kernel_size = tuple(w.shape[:-2]) nd = len(kernel_size) + if any(d != 1 for d in _as_tuple(getattr(m, 'input_dilation', 1) or 1, nd)): + raise ConversionError( + "nnx Conv with `input_dilation` != 1 (transposed convolution) is not supported " + "by this converter." + ) out_channels = w.shape[-1] in_size = ctx.require_size('Conv') has_bias = m.bias is not None diff --git a/brainstate/interop/_registry.py b/brainstate/interop/_registry.py index d1f2cea..0e24d87 100644 --- a/brainstate/interop/_registry.py +++ b/brainstate/interop/_registry.py @@ -124,8 +124,12 @@ def lookup_import(framework: str, foreign_type: type) -> Optional[LayerMapping]: def lookup_export(bst_type: type, framework: str) -> Optional[LayerMapping]: """Return the export mapping for a brainstate type + framework, or ``None``.""" - table = {bt: m for (bt, fw), m in _EXPORT.items() if fw == framework} - return _lookup_by_mro(table, bst_type) + if (bst_type, framework) in _EXPORT: + return _EXPORT[(bst_type, framework)] + for base in bst_type.__mro__: + if (base, framework) in _EXPORT: + return _EXPORT[(base, framework)] + return None def unsupported_bst_reason(bst_type: type, framework: Optional[str] = None) -> Optional[str]: