diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 6d4c358a1..78d204469 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -777,6 +777,11 @@ def __ror__(self, other) -> A: other = other.value return self.value.__ror__(other) # type: ignore + def __eq__(self, other) -> bool: + if isinstance(other, Variable): + other = other.value + return self.value.__eq__(other) # type: ignore + def __iadd__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' diff --git a/tests/nnx/variable_test.py b/tests/nnx/variable_test.py index 42156f291..683232488 100644 --- a/tests/nnx/variable_test.py +++ b/tests/nnx/variable_test.py @@ -88,6 +88,7 @@ def test_binary_ops(self): result = v1 + v2 self.assertEqual(result, 5) + self.assertFalse(v1 == v2) v1[...] += v2