Skip to content

Commit 1d2a672

Browse files
committed
Add __eq__ for variables
1 parent 55a2366 commit 1d2a672

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

flax/nnx/variablelib.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,11 @@ def __ror__(self, other) -> A:
777777
other = other.value
778778
return self.value.__ror__(other) # type: ignore
779779

780+
def __eq__(self, other) -> bool:
781+
if isinstance(other, Variable):
782+
other = other.value
783+
return self.value.__eq__(other) # type: ignore
784+
780785
def __iadd__(self: V, other) -> V:
781786
raise NotImplementedError(
782787
'In-place operations are no longer supported for Variable.\n'

tests/nnx/variable_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def test_binary_ops(self):
8888
result = v1 + v2
8989

9090
self.assertEqual(result, 5)
91+
self.assertFalse(v1 == v2)
9192

9293
v1[...] += v2
9394

0 commit comments

Comments
 (0)