Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions brainstate/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@
pretty_repr,
)


def breakpoint_if(*args, **kwargs):
from brainstate.transform._debug import breakpoint_if as _breakpoint_if
return _breakpoint_if(*args, **kwargs)

__all__ = [
# Tracer utilities
'StateJaxTracer',
Expand Down
8 changes: 7 additions & 1 deletion brainstate/util/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class BoundedCache:
"""

def __init__(self, maxsize: int = 128):
if not isinstance(maxsize, int):
raise TypeError(f"maxsize must be an integer, got {type(maxsize).__name__}.")
if maxsize < 0:
raise ValueError(f"maxsize must be non-negative, got {maxsize}.")
self._cache = OrderedDict()
self._maxsize = maxsize
self._lock = threading.RLock()
Expand Down Expand Up @@ -92,7 +96,7 @@ def get(
f"Requested key:",
f" {key}",
f"",
f"Available {{len(available_keys)}} keys:",
f"Available {len(available_keys)} keys:",
]
if available_keys:
for i, k in enumerate(available_keys, 1):
Expand Down Expand Up @@ -128,6 +132,8 @@ def set(self, key: Any, value: Any) -> None:
f"Cannot overwrite existing cached value. "
f"Clear the cache first if you need to recompile."
)
if self._maxsize == 0:
return
if len(self._cache) >= self._maxsize:
self._cache.popitem(last=False)
self._cache[key] = value
Expand Down
18 changes: 18 additions & 0 deletions brainstate/util/_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,21 @@
class TestBoundedCache(unittest.TestCase):
"""Test the BoundedCache class."""

def test_cache_rejects_negative_maxsize(self):
"""Reject negative cache sizes before eviction logic can corrupt state."""
with pytest.raises(ValueError, match="maxsize"):
BoundedCache(maxsize=-1)

def test_cache_zero_maxsize_does_not_store_items(self):
"""A zero-sized cache behaves as permanently empty instead of raising."""
cache = BoundedCache(maxsize=0)

cache.set('key1', 'value1')

self.assertEqual(len(cache), 0)
self.assertNotIn('key1', cache)
self.assertIsNone(cache.get('key1'))

def test_cache_basic_operations(self):
"""Test basic get and set operations."""
cache = BoundedCache(maxsize=3)
Expand Down Expand Up @@ -235,6 +250,9 @@ def test_cache_detailed_error_message(self):
error_msg = str(exc_info.value)
# Should show requested key
self.assertIn('nonexistent', error_msg)
# Should show the formatted number of available keys
self.assertIn('Available 2 keys:', error_msg)
self.assertNotIn('{len(available_keys)}', error_msg)
# Should show available keys
self.assertIn('key1', error_msg)
self.assertIn('key2', error_msg)
Expand Down
28 changes: 28 additions & 0 deletions brainstate/util/_init_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Tests for the public ``brainstate.util`` package surface."""

import unittest


class TestUtilPackageExports(unittest.TestCase):
"""Validate the package-level export table."""

def test_all_names_are_bound(self):
"""Every name listed in ``__all__`` is importable as a package attribute."""
import brainstate.util as util

for name in util.__all__:
with self.subTest(name=name):
self.assertTrue(hasattr(util, name), name)

def test_star_import_includes_all_names(self):
"""Wildcard import should not fail due to a stale ``__all__`` entry."""
namespace = {}

exec("from brainstate.util import *", namespace)

self.assertIn('BoundedCache', namespace)
self.assertIn('breakpoint_if', namespace)


if __name__ == "__main__":
unittest.main()
20 changes: 18 additions & 2 deletions brainstate/util/_others.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,13 @@ def add_unique_value(self, key: K, val: V) -> bool:
True if the value was added (was unique), False otherwise.
"""
self._check_elem(val)
if not hasattr(self, '_val_id_to_key'):
self._val_id_to_key = {id(v): k for k, v in self.items()}
self._val_id_to_key = {id(v): k for k, v in self.items()}

if key in self and id(self[key]) != id(val):
raise ValueError(
f"Key '{key}' already exists with a different value. "
f"Existing: {self[key]}, New: {val}"
)

val_id = id(val)
if val_id not in self._val_id_to_key:
Expand Down Expand Up @@ -681,6 +686,7 @@ def __setattr__(self, name: str, value: Any) -> None:

def __setitem__(self, name: str, value: Any) -> None:
"""Set item and update parent if nested."""
value = self._hook(value)
super().__setitem__(name, value)
try:
parent = object.__getattribute__(self, '__parent')
Expand Down Expand Up @@ -908,6 +914,8 @@ def flatten_dict(
>>> flatten_dict(d)
{'a': 1, 'b.c': 2, 'b.d.e': 3}
"""
if not isinstance(d, dict):
raise TypeError(f"d must be a dict, got {type(d).__name__}.")
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
Expand Down Expand Up @@ -944,17 +952,25 @@ def unflatten_dict(
>>> unflatten_dict(d)
{'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
"""
if not isinstance(d, dict):
raise TypeError(f"d must be a dict, got {type(d).__name__}.")
result = {}

for key, value in d.items():
parts = key.split(sep)
if not parts or any(part == '' for part in parts):
raise ValueError(f"Invalid flattened key {key!r}.")
current = result

for part in parts[:-1]:
if part not in current:
current[part] = {}
elif not isinstance(current[part], dict):
raise ValueError(f"Cannot expand scalar key prefix {part!r} in {key!r}.")
current = current[part]

if parts[-1] in current and isinstance(current[parts[-1]], dict):
raise ValueError(f"Cannot overwrite mapping at key {key!r} with a scalar value.")
current[parts[-1]] = value

return result
Expand Down
61 changes: 61 additions & 0 deletions brainstate/util/_others_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,40 @@ def test_add_unique_value(self):
self.assertTrue(result3)
self.assertIs(dm['key2'], obj2)

def test_add_unique_value_allows_readd_after_delete(self):
"""Keep the identity index in sync after direct item deletion."""
dm = DictManager()
obj = object()
self.assertTrue(dm.add_unique_value('key1', obj))

del dm['key1']

self.assertTrue(dm.add_unique_value('key2', obj))
self.assertIs(dm['key2'], obj)

def test_add_unique_value_indexes_existing_direct_mutations(self):
"""Rebuild stale identity indexes when values were added without helper APIs."""
dm = DictManager()
obj = object()
self.assertTrue(dm.add_unique_value('key1', object()))

dm['direct'] = obj

self.assertFalse(dm.add_unique_value('duplicate', obj))
self.assertNotIn('duplicate', dm)

def test_add_unique_value_rejects_existing_key_with_different_value(self):
"""Do not overwrite an existing key when the new value is unique."""
dm = DictManager()
obj1 = object()
obj2 = object()
self.assertTrue(dm.add_unique_value('key1', obj1))

with self.assertRaises(ValueError):
dm.add_unique_value('key1', obj2)

self.assertIs(dm['key1'], obj1)

def test_unique(self):
"""Test getting unique values."""
obj1 = object()
Expand Down Expand Up @@ -562,6 +596,16 @@ def test_dot_access(self):
dd.e = 5
self.assertEqual(dd['e'], 5)

# Direct item assignment should still apply nested DotDict conversion.
dd['nested'] = {'x': 1}
self.assertIsInstance(dd.nested, DotDict)
self.assertEqual(dd.nested.x, 1)

# Attribute assignment should use the same conversion path.
dd.another = {'y': 2}
self.assertIsInstance(dd.another, DotDict)
self.assertEqual(dd.another.y, 2)

def test_nested_dict_conversion(self):
"""Test automatic nested dict conversion."""
dd = DotDict({
Expand Down Expand Up @@ -711,6 +755,10 @@ def test_setdefault(self):
self.assertIsNone(result)
self.assertIsNone(dd.c)

result = dd.setdefault('nested', {'value': 3})
self.assertIsInstance(result, DotDict)
self.assertEqual(result.value, 3)

def test_pickling(self):
"""Test pickling/unpickling."""
dd1 = DotDict({'a': 1, 'b': {'c': 2}})
Expand Down Expand Up @@ -847,6 +895,19 @@ def test_unflatten_dict(self):
'b': {'c': 2, 'd': 3}
})

def test_unflatten_dict_rejects_prefix_conflicts(self):
"""Reject flattened inputs that would overwrite a scalar with a mapping."""
with self.assertRaises(ValueError):
unflatten_dict({'a': 1, 'a.b': 2})

with self.assertRaises(ValueError):
unflatten_dict({'a.b': 2, 'a': 1})

def test_flatten_dict_rejects_non_mapping(self):
"""Raise a clear TypeError for non-dictionary inputs."""
with self.assertRaises(TypeError):
flatten_dict([('a', 1)])

def test_flatten_unflatten_roundtrip(self):
"""Test that flatten/unflatten is reversible."""
original = {
Expand Down
7 changes: 7 additions & 0 deletions brainstate/util/_pretty_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,13 +280,20 @@ def nest_mapping(
for path, value in xs.items():
if sep is not None:
path = path.split(sep)
path = tuple(path)
if not path:
raise ValueError('Cannot nest an empty path.')
if value is empty_node:
value = {}
cursor = result
for key in path[:-1]:
if key not in cursor:
cursor[key] = {}
elif not isinstance(cursor[key], abc.Mapping):
raise ValueError(f'Cannot expand scalar key prefix {key!r} in path {path!r}.')
cursor = cursor[key]
if path[-1] in cursor and isinstance(cursor[path[-1]], abc.Mapping):
raise ValueError(f'Cannot overwrite mapping at path {path!r} with a scalar value.')
cursor[path[-1]] = value
return NestedDict(result)

Expand Down
13 changes: 13 additions & 0 deletions brainstate/util/_pretty_pytree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,19 @@ def test_flat_nest_roundtrip(self):
nested = {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
self.assertEqual(nest_mapping(flat_mapping(nested)).to_dict(), nested)

def test_nest_mapping_rejects_prefix_conflicts(self):
"""Reject flat mappings where one path is both a leaf and branch."""
with self.assertRaises(ValueError):
nest_mapping({('a',): 1, ('a', 'b'): 2})

with self.assertRaises(ValueError):
nest_mapping({('a', 'b'): 2, ('a',): 1})

def test_nest_mapping_rejects_empty_paths(self):
"""Reject empty paths because they cannot be nested as mapping keys."""
with self.assertRaises(ValueError):
nest_mapping({(): 1})

def test_nesteddict_flatten_unflatten_consistent(self):
"""Flatten and unflatten a NestedDict to an equal JAX structure."""
nd = NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
Expand Down
12 changes: 10 additions & 2 deletions brainstate/util/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,13 @@ class PathContains:

key: Key

@staticmethod
def _path_key(path_part: typing.Any) -> typing.Any:
for attr in ('key', 'name', 'idx'):
if hasattr(path_part, attr):
return getattr(path_part, attr)
return path_part

def __call__(self, path: PathParts, x: typing.Any) -> bool:
"""
Check if the key is present in the path.
Expand All @@ -415,7 +422,7 @@ def __call__(self, path: PathParts, x: typing.Any) -> bool:
bool
True if the key is present in the path, False otherwise.
"""
return self.key in path
return any(self._path_key(part) == self.key for part in path)

def __repr__(self) -> str:
return f'PathContains({self.key!r})'
Expand Down Expand Up @@ -504,8 +511,9 @@ def __call__(self, path: PathParts, x: typing.Any):
True if the object is an instance of the specified type or
has a 'type' attribute that is a subclass of the specified type.
"""
x_type = getattr(x, 'type', None)
return isinstance(x, self.type) or (
hasattr(x, 'type') and issubclass(x.type, self.type)
isinstance(x_type, type) and issubclass(x_type, self.type)
)

def __repr__(self):
Expand Down
17 changes: 16 additions & 1 deletion brainstate/util/filter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import unittest
from typing import Any
import numpy as np
import jax.tree_util as jtu

from brainstate.util.filter import (
to_predicate,
Expand Down Expand Up @@ -206,6 +207,13 @@ def test_basic_functionality(self):
self.assertFalse(filter_weight([], None))
self.assertFalse(filter_weight(['other', 'path'], None))

def test_jax_dict_key_paths_match_by_underlying_key(self):
"""Match JAX key-path entries such as DictKey by their underlying key."""
filter_weight = PathContains('weight')
path = jtu.tree_leaves_with_path({'layer': {'weight': 1}})[0][0]

self.assertTrue(filter_weight(path, None))

def test_numeric_keys(self):
"""Test with numeric keys in path."""
filter_num = PathContains(0)
Expand Down Expand Up @@ -272,6 +280,13 @@ def test_type_attribute_check(self):
typed_obj2 = MockTypedObject(dict)
self.assertFalse(filter_list([], typed_obj2))

def test_non_type_type_attribute_does_not_raise(self):
"""A proxy object's non-class ``type`` attribute should simply not match."""
filter_list = OfType(list)
typed_obj = MockTypedObject("not-a-type")

self.assertFalse(filter_list([], typed_obj))

def test_repr(self):
"""Test string representation."""
filter_type = OfType(str)
Expand Down Expand Up @@ -909,4 +924,4 @@ def test_recursive_filter_structures(self):


if __name__ == '__main__':
unittest.main()
unittest.main()
Loading