diff --git a/brainstate/util/__init__.py b/brainstate/util/__init__.py index 0f20224..5efc208 100644 --- a/brainstate/util/__init__.py +++ b/brainstate/util/__init__.py @@ -80,6 +80,11 @@ pretty_repr, ) + +def breakpoint_if(*args, **kwargs): + from brainstate.transform._debug import breakpoint_if as _breakpoint_if + return _breakpoint_if(*args, **kwargs) + __all__ = [ # Tracer utilities 'StateJaxTracer', diff --git a/brainstate/util/_cache.py b/brainstate/util/_cache.py index 5b0658d..8ab88c3 100644 --- a/brainstate/util/_cache.py +++ b/brainstate/util/_cache.py @@ -35,6 +35,10 @@ class BoundedCache: """ def __init__(self, maxsize: int = 128): + if not isinstance(maxsize, int): + raise TypeError(f"maxsize must be an integer, got {type(maxsize).__name__}.") + if maxsize < 0: + raise ValueError(f"maxsize must be non-negative, got {maxsize}.") self._cache = OrderedDict() self._maxsize = maxsize self._lock = threading.RLock() @@ -92,7 +96,7 @@ def get( f"Requested key:", f" {key}", f"", - f"Available {{len(available_keys)}} keys:", + f"Available {len(available_keys)} keys:", ] if available_keys: for i, k in enumerate(available_keys, 1): @@ -128,6 +132,8 @@ def set(self, key: Any, value: Any) -> None: f"Cannot overwrite existing cached value. " f"Clear the cache first if you need to recompile." ) + if self._maxsize == 0: + return if len(self._cache) >= self._maxsize: self._cache.popitem(last=False) self._cache[key] = value diff --git a/brainstate/util/_cache_test.py b/brainstate/util/_cache_test.py index c5208b5..9d2b45a 100644 --- a/brainstate/util/_cache_test.py +++ b/brainstate/util/_cache_test.py @@ -25,6 +25,21 @@ class TestBoundedCache(unittest.TestCase): """Test the BoundedCache class.""" + def test_cache_rejects_negative_maxsize(self): + """Reject negative cache sizes before eviction logic can corrupt state.""" + with pytest.raises(ValueError, match="maxsize"): + BoundedCache(maxsize=-1) + + def test_cache_zero_maxsize_does_not_store_items(self): + """A zero-sized cache behaves as permanently empty instead of raising.""" + cache = BoundedCache(maxsize=0) + + cache.set('key1', 'value1') + + self.assertEqual(len(cache), 0) + self.assertNotIn('key1', cache) + self.assertIsNone(cache.get('key1')) + def test_cache_basic_operations(self): """Test basic get and set operations.""" cache = BoundedCache(maxsize=3) @@ -235,6 +250,9 @@ def test_cache_detailed_error_message(self): error_msg = str(exc_info.value) # Should show requested key self.assertIn('nonexistent', error_msg) + # Should show the formatted number of available keys + self.assertIn('Available 2 keys:', error_msg) + self.assertNotIn('{len(available_keys)}', error_msg) # Should show available keys self.assertIn('key1', error_msg) self.assertIn('key2', error_msg) diff --git a/brainstate/util/_init_test.py b/brainstate/util/_init_test.py new file mode 100644 index 0000000..f0990c8 --- /dev/null +++ b/brainstate/util/_init_test.py @@ -0,0 +1,28 @@ +"""Tests for the public ``brainstate.util`` package surface.""" + +import unittest + + +class TestUtilPackageExports(unittest.TestCase): + """Validate the package-level export table.""" + + def test_all_names_are_bound(self): + """Every name listed in ``__all__`` is importable as a package attribute.""" + import brainstate.util as util + + for name in util.__all__: + with self.subTest(name=name): + self.assertTrue(hasattr(util, name), name) + + def test_star_import_includes_all_names(self): + """Wildcard import should not fail due to a stale ``__all__`` entry.""" + namespace = {} + + exec("from brainstate.util import *", namespace) + + self.assertIn('BoundedCache', namespace) + self.assertIn('breakpoint_if', namespace) + + +if __name__ == "__main__": + unittest.main() diff --git a/brainstate/util/_others.py b/brainstate/util/_others.py index d1d0265..0b3f7ed 100644 --- a/brainstate/util/_others.py +++ b/brainstate/util/_others.py @@ -289,8 +289,13 @@ def add_unique_value(self, key: K, val: V) -> bool: True if the value was added (was unique), False otherwise. """ self._check_elem(val) - if not hasattr(self, '_val_id_to_key'): - self._val_id_to_key = {id(v): k for k, v in self.items()} + self._val_id_to_key = {id(v): k for k, v in self.items()} + + if key in self and id(self[key]) != id(val): + raise ValueError( + f"Key '{key}' already exists with a different value. " + f"Existing: {self[key]}, New: {val}" + ) val_id = id(val) if val_id not in self._val_id_to_key: @@ -681,6 +686,7 @@ def __setattr__(self, name: str, value: Any) -> None: def __setitem__(self, name: str, value: Any) -> None: """Set item and update parent if nested.""" + value = self._hook(value) super().__setitem__(name, value) try: parent = object.__getattribute__(self, '__parent') @@ -908,6 +914,8 @@ def flatten_dict( >>> flatten_dict(d) {'a': 1, 'b.c': 2, 'b.d.e': 3} """ + if not isinstance(d, dict): + raise TypeError(f"d must be a dict, got {type(d).__name__}.") items = [] for k, v in d.items(): new_key = f"{parent_key}{sep}{k}" if parent_key else k @@ -944,17 +952,25 @@ def unflatten_dict( >>> unflatten_dict(d) {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}} """ + if not isinstance(d, dict): + raise TypeError(f"d must be a dict, got {type(d).__name__}.") result = {} for key, value in d.items(): parts = key.split(sep) + if not parts or any(part == '' for part in parts): + raise ValueError(f"Invalid flattened key {key!r}.") current = result for part in parts[:-1]: if part not in current: current[part] = {} + elif not isinstance(current[part], dict): + raise ValueError(f"Cannot expand scalar key prefix {part!r} in {key!r}.") current = current[part] + if parts[-1] in current and isinstance(current[parts[-1]], dict): + raise ValueError(f"Cannot overwrite mapping at key {key!r} with a scalar value.") current[parts[-1]] = value return result diff --git a/brainstate/util/_others_test.py b/brainstate/util/_others_test.py index 116c58e..cc8e0dd 100644 --- a/brainstate/util/_others_test.py +++ b/brainstate/util/_others_test.py @@ -283,6 +283,40 @@ def test_add_unique_value(self): self.assertTrue(result3) self.assertIs(dm['key2'], obj2) + def test_add_unique_value_allows_readd_after_delete(self): + """Keep the identity index in sync after direct item deletion.""" + dm = DictManager() + obj = object() + self.assertTrue(dm.add_unique_value('key1', obj)) + + del dm['key1'] + + self.assertTrue(dm.add_unique_value('key2', obj)) + self.assertIs(dm['key2'], obj) + + def test_add_unique_value_indexes_existing_direct_mutations(self): + """Rebuild stale identity indexes when values were added without helper APIs.""" + dm = DictManager() + obj = object() + self.assertTrue(dm.add_unique_value('key1', object())) + + dm['direct'] = obj + + self.assertFalse(dm.add_unique_value('duplicate', obj)) + self.assertNotIn('duplicate', dm) + + def test_add_unique_value_rejects_existing_key_with_different_value(self): + """Do not overwrite an existing key when the new value is unique.""" + dm = DictManager() + obj1 = object() + obj2 = object() + self.assertTrue(dm.add_unique_value('key1', obj1)) + + with self.assertRaises(ValueError): + dm.add_unique_value('key1', obj2) + + self.assertIs(dm['key1'], obj1) + def test_unique(self): """Test getting unique values.""" obj1 = object() @@ -562,6 +596,16 @@ def test_dot_access(self): dd.e = 5 self.assertEqual(dd['e'], 5) + # Direct item assignment should still apply nested DotDict conversion. + dd['nested'] = {'x': 1} + self.assertIsInstance(dd.nested, DotDict) + self.assertEqual(dd.nested.x, 1) + + # Attribute assignment should use the same conversion path. + dd.another = {'y': 2} + self.assertIsInstance(dd.another, DotDict) + self.assertEqual(dd.another.y, 2) + def test_nested_dict_conversion(self): """Test automatic nested dict conversion.""" dd = DotDict({ @@ -711,6 +755,10 @@ def test_setdefault(self): self.assertIsNone(result) self.assertIsNone(dd.c) + result = dd.setdefault('nested', {'value': 3}) + self.assertIsInstance(result, DotDict) + self.assertEqual(result.value, 3) + def test_pickling(self): """Test pickling/unpickling.""" dd1 = DotDict({'a': 1, 'b': {'c': 2}}) @@ -847,6 +895,19 @@ def test_unflatten_dict(self): 'b': {'c': 2, 'd': 3} }) + def test_unflatten_dict_rejects_prefix_conflicts(self): + """Reject flattened inputs that would overwrite a scalar with a mapping.""" + with self.assertRaises(ValueError): + unflatten_dict({'a': 1, 'a.b': 2}) + + with self.assertRaises(ValueError): + unflatten_dict({'a.b': 2, 'a': 1}) + + def test_flatten_dict_rejects_non_mapping(self): + """Raise a clear TypeError for non-dictionary inputs.""" + with self.assertRaises(TypeError): + flatten_dict([('a', 1)]) + def test_flatten_unflatten_roundtrip(self): """Test that flatten/unflatten is reversible.""" original = { diff --git a/brainstate/util/_pretty_pytree.py b/brainstate/util/_pretty_pytree.py index d3b0d5b..af1c089 100644 --- a/brainstate/util/_pretty_pytree.py +++ b/brainstate/util/_pretty_pytree.py @@ -280,13 +280,20 @@ def nest_mapping( for path, value in xs.items(): if sep is not None: path = path.split(sep) + path = tuple(path) + if not path: + raise ValueError('Cannot nest an empty path.') if value is empty_node: value = {} cursor = result for key in path[:-1]: if key not in cursor: cursor[key] = {} + elif not isinstance(cursor[key], abc.Mapping): + raise ValueError(f'Cannot expand scalar key prefix {key!r} in path {path!r}.') cursor = cursor[key] + if path[-1] in cursor and isinstance(cursor[path[-1]], abc.Mapping): + raise ValueError(f'Cannot overwrite mapping at path {path!r} with a scalar value.') cursor[path[-1]] = value return NestedDict(result) diff --git a/brainstate/util/_pretty_pytree_test.py b/brainstate/util/_pretty_pytree_test.py index c36aa14..8694f77 100644 --- a/brainstate/util/_pretty_pytree_test.py +++ b/brainstate/util/_pretty_pytree_test.py @@ -721,6 +721,19 @@ def test_flat_nest_roundtrip(self): nested = {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}} self.assertEqual(nest_mapping(flat_mapping(nested)).to_dict(), nested) + def test_nest_mapping_rejects_prefix_conflicts(self): + """Reject flat mappings where one path is both a leaf and branch.""" + with self.assertRaises(ValueError): + nest_mapping({('a',): 1, ('a', 'b'): 2}) + + with self.assertRaises(ValueError): + nest_mapping({('a', 'b'): 2, ('a',): 1}) + + def test_nest_mapping_rejects_empty_paths(self): + """Reject empty paths because they cannot be nested as mapping keys.""" + with self.assertRaises(ValueError): + nest_mapping({(): 1}) + def test_nesteddict_flatten_unflatten_consistent(self): """Flatten and unflatten a NestedDict to an equal JAX structure.""" nd = NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}}) diff --git a/brainstate/util/filter.py b/brainstate/util/filter.py index 651e273..9e9a330 100644 --- a/brainstate/util/filter.py +++ b/brainstate/util/filter.py @@ -399,6 +399,13 @@ class PathContains: key: Key + @staticmethod + def _path_key(path_part: typing.Any) -> typing.Any: + for attr in ('key', 'name', 'idx'): + if hasattr(path_part, attr): + return getattr(path_part, attr) + return path_part + def __call__(self, path: PathParts, x: typing.Any) -> bool: """ Check if the key is present in the path. @@ -415,7 +422,7 @@ def __call__(self, path: PathParts, x: typing.Any) -> bool: bool True if the key is present in the path, False otherwise. """ - return self.key in path + return any(self._path_key(part) == self.key for part in path) def __repr__(self) -> str: return f'PathContains({self.key!r})' @@ -504,8 +511,9 @@ def __call__(self, path: PathParts, x: typing.Any): True if the object is an instance of the specified type or has a 'type' attribute that is a subclass of the specified type. """ + x_type = getattr(x, 'type', None) return isinstance(x, self.type) or ( - hasattr(x, 'type') and issubclass(x.type, self.type) + isinstance(x_type, type) and issubclass(x_type, self.type) ) def __repr__(self): diff --git a/brainstate/util/filter_test.py b/brainstate/util/filter_test.py index f3a47ac..b4a5fc7 100644 --- a/brainstate/util/filter_test.py +++ b/brainstate/util/filter_test.py @@ -19,6 +19,7 @@ import unittest from typing import Any import numpy as np +import jax.tree_util as jtu from brainstate.util.filter import ( to_predicate, @@ -206,6 +207,13 @@ def test_basic_functionality(self): self.assertFalse(filter_weight([], None)) self.assertFalse(filter_weight(['other', 'path'], None)) + def test_jax_dict_key_paths_match_by_underlying_key(self): + """Match JAX key-path entries such as DictKey by their underlying key.""" + filter_weight = PathContains('weight') + path = jtu.tree_leaves_with_path({'layer': {'weight': 1}})[0][0] + + self.assertTrue(filter_weight(path, None)) + def test_numeric_keys(self): """Test with numeric keys in path.""" filter_num = PathContains(0) @@ -272,6 +280,13 @@ def test_type_attribute_check(self): typed_obj2 = MockTypedObject(dict) self.assertFalse(filter_list([], typed_obj2)) + def test_non_type_type_attribute_does_not_raise(self): + """A proxy object's non-class ``type`` attribute should simply not match.""" + filter_list = OfType(list) + typed_obj = MockTypedObject("not-a-type") + + self.assertFalse(filter_list([], typed_obj)) + def test_repr(self): """Test string representation.""" filter_type = OfType(str) @@ -909,4 +924,4 @@ def test_recursive_filter_structures(self): if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/brainstate/util/struct.py b/brainstate/util/struct.py index dd4837c..9c67aeb 100644 --- a/brainstate/util/struct.py +++ b/brainstate/util/struct.py @@ -64,9 +64,7 @@ def is_dataclass(cls: Any) -> bool: ``True`` if ``cls`` carries the ``_brainstate_dataclass`` marker set by :func:`dataclass`, ``False`` otherwise. """ - if hasattr(cls, '_brainstate_dataclass'): - return True - return False + return dataclasses.is_dataclass(cls) and getattr(cls, '_brainstate_dataclass', False) is True def field(pytree_node: bool = True, **kwargs) -> dataclasses.Field: @@ -101,13 +99,18 @@ def field(pytree_node: bool = True, **kwargs) -> dataclasses.Field: ... # This field won't be affected by JAX transformations ... name: str = field(pytree_node=False, default="model") """ - metadata = kwargs.pop('metadata', {}) + metadata = dict(kwargs.pop('metadata', {})) + if 'pytree_node' in metadata and metadata['pytree_node'] != pytree_node: + raise ValueError( + f"Conflicting pytree_node metadata: metadata has {metadata['pytree_node']!r}, " + f"argument has {pytree_node!r}." + ) metadata['pytree_node'] = pytree_node return dataclasses.field(metadata=metadata, **kwargs) @dataclass_transform(field_specifiers=(field,)) -def dataclass(cls: type[T], **kwargs) -> type[T]: +def dataclass(cls: type[T] | None = None, **kwargs) -> type[T] | Any: """ Create a dataclass that works with JAX transformations. @@ -161,6 +164,9 @@ def dataclass(cls: type[T], **kwargs) -> type[T]: >>> # Use replace to create modified copies >>> model2 = model.replace(weights=jnp.ones((3, 3)) * 2) """ + if cls is None: + return lambda cls_: dataclass(cls_, **kwargs) + # Check if already converted if is_dataclass(cls): return cls @@ -416,12 +422,7 @@ def __repr__(self) -> str: def __hash__(self) -> int: """Return a hash of the dictionary.""" if self._hash is None: - items = [] - for key, value in self.items(): - if isinstance(value, dict): - value = FrozenDict(value) - items.append((key, value)) - self._hash = hash(tuple(sorted(items))) + self._hash = hash(frozenset(self.items())) return self._hash def __eq__(self, other: object) -> bool: @@ -613,7 +614,7 @@ def format_value(v, level): def tree_flatten_with_keys(self) -> tuple[list[tuple[Any, Any]], tuple[Any, ...]]: """Flatten for JAX pytree with keys.""" - sorted_keys = sorted(self._data.keys()) + sorted_keys = sorted(self._data.keys(), key=lambda k: (type(k).__module__, type(k).__qualname__, repr(k))) values_with_keys = [ (jax.tree_util.DictKey(k), self._data[k]) for k in sorted_keys diff --git a/brainstate/util/struct_test.py b/brainstate/util/struct_test.py index 7f8b5eb..30b04d4 100644 --- a/brainstate/util/struct_test.py +++ b/brainstate/util/struct_test.py @@ -14,6 +14,7 @@ # Import the modules to test from brainstate.util import ( field, + is_dataclass, dataclass, PyTreeNode, FrozenDict, @@ -176,6 +177,36 @@ class Point: assert p.x == 1.0 assert hasattr(Point, '_brainstate_dataclass') + def test_dataclass_supports_bare_parentheses(self): + """Allow @dataclass() usage like the stdlib decorator.""" + @dataclass() + class Point: + x: float + + p = Point(1.0) + + assert p.x == 1.0 + + def test_field_does_not_mutate_caller_metadata(self): + """Copy provided metadata before adding the pytree marker.""" + metadata = {'custom': 'data'} + + field(pytree_node=False, metadata=metadata) + + assert metadata == {'custom': 'data'} + + def test_field_rejects_conflicting_pytree_metadata(self): + """Reject metadata that disagrees with the explicit pytree_node flag.""" + with pytest.raises(ValueError, match='pytree_node'): + field(pytree_node=False, metadata={'pytree_node': True}) + + def test_is_dataclass_rejects_spoofed_marker(self): + """Do not treat arbitrary marker attributes as BrainState dataclasses.""" + class Spoofed: + _brainstate_dataclass = True + + assert not is_dataclass(Spoofed) + class TestPyTreeNode: """Test the PyTreeNode base class.""" @@ -727,6 +758,22 @@ def test_hash_with_nested_dict_value(self): fd2 = FrozenDict({'a': {'b': 1}}) self.assertEqual(hash(fd), hash(fd2)) + def test_hash_with_mixed_key_types(self): + """Hash mappings with valid but mutually incomparable key types.""" + fd1 = FrozenDict({1: 'int', '1': 'str'}) + fd2 = FrozenDict({'1': 'str', 1: 'int'}) + + self.assertEqual(hash(fd1), hash(fd2)) + self.assertEqual(fd1, fd2) + + def test_jax_flatten_with_mixed_key_types(self): + """Flatten mappings with valid but mutually incomparable key types.""" + fd = FrozenDict({1: jnp.array(1), '1': jnp.array(2)}) + + leaves = jax.tree_util.tree_leaves(fd) + + self.assertEqual(len(leaves), 2) + def test_pretty_repr_empty_frozendict(self): """Render an empty FrozenDict as FrozenDict({}).""" fd = FrozenDict({})