From 1d2a672b52fd190d1f08d0777ec48a04b4562e7b Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Wed, 12 Nov 2025 09:05:55 -0600 Subject: [PATCH] Add __eq__ for variables --- flax/nnx/variablelib.py | 5 +++++ tests/nnx/variable_test.py | 1 + 2 files changed, 6 insertions(+) diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 6d4c358a1..78d204469 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -777,6 +777,11 @@ def __ror__(self, other) -> A: other = other.value return self.value.__ror__(other) # type: ignore + def __eq__(self, other) -> bool: + if isinstance(other, Variable): + other = other.value + return self.value.__eq__(other) # type: ignore + def __iadd__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' diff --git a/tests/nnx/variable_test.py b/tests/nnx/variable_test.py index 42156f291..683232488 100644 --- a/tests/nnx/variable_test.py +++ b/tests/nnx/variable_test.py @@ -88,6 +88,7 @@ def test_binary_ops(self): result = v1 + v2 self.assertEqual(result, 5) + self.assertFalse(v1 == v2) v1[...] += v2