Skip to content

Commit 4a299d3

Browse files
committed
Fix: Remove deprecated .path access in Muon optimizer for TF 2.16+ compatibility
2 parents ee582b8 + 9d08112 commit 4a299d3

File tree

2 files changed

+38
-36
lines changed

2 files changed

+38
-36
lines changed

keras/src/optimizers/muon.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -129,32 +129,26 @@ def __init__(
129129
self.exclude_layers = exclude_layers or []
130130

131131
def _should_use_adamw(self, variable):
132-
"""
133-
To use it with 4D convolutional filters,
134-
it works well to just flatten their last 3 dimensions.
135-
any {0,1}-D parameters should all be optimized by adam
136-
"""
137-
# Use Adam for scalar or vector parameters
132+
# To use it with 4D convolutional filters,
133+
# it works well to just flatten their last 3 dimensions.
134+
# any {0,1}-D parameters should all be optimized by adam
138135
if not 1 < len(variable.shape) < 5:
139136
return True
140137

141-
# Exclude embedding layers if specified
142-
var_identifier = getattr(variable, "name", "") or getattr(
143-
variable, "path", ""
144-
)
138+
# Get variable identifier (use .name in Keras 3+)
139+
var_identifier = variable.name
140+
141+
# Check if embedding layer should be excluded
145142
if self.exclude_embeddings and "embedding" in var_identifier.lower():
146143
return True
147144

148-
# Exclude variables matching any of the excluded layer patterns
149-
for keyword in getattr(self, "exclude_layers", []):
145+
# Check if variable matches any excluded layer patterns
146+
for keyword in self.exclude_layers:
150147
try:
151148
if re.search(keyword, var_identifier):
152149
return True
153150
except re.error:
154-
# Skip invalid regex patterns
155151
continue
156-
157-
# Otherwise, use AdamW
158152
return False
159153

160154
def build(self, var_list):

keras/src/optimizers/muon_test.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import tensorflow as tf
23

34
from keras.src import backend
45
from keras.src import ops
@@ -28,34 +29,39 @@ def test_adamw_single_step(self):
2829
optimizer._adamw_update_step(grads, var, 0.5)
2930
self.assertAllClose(var, [0.5, 1.5, 2.5, 3.5], rtol=1e-4, atol=1e-4)
3031

31-
def test_should_use_adamw(self):
32-
# Excluded layer test
33-
var = backend.Variable([[1.0, 2.0], [3.0, 4.0]])
34-
optimizer = Muon(exclude_layers=["var"])
35-
self.assertTrue(optimizer._should_use_adamw(var))
32+
def test_should_use_adamw_excluded_layer(self):
33+
"""Ensure exclude_layers keyword works and no .path is accessed."""
34+
optimizer = Muon(exclude_layers=["dense"])
35+
dummy_var = backend.Variable(
36+
[[1.0, 2.0], [3.0, 4.0]], name="dense_kernel_0"
37+
)
38+
result = optimizer._should_use_adamw(dummy_var)
39+
self.assertTrue(result)
3640

37-
# Embedding test
41+
def test_should_use_adamw_embedding(self):
42+
"""Embedding layer should use AdamW when exclude_embeddings=True."""
3843
embedding = Embedding(2, 2)
3944
embedding.build()
4045
optimizer = Muon(exclude_embeddings=True)
41-
self.assertTrue(optimizer._should_use_adamw(embedding.weights[0]))
42-
43-
# 2D variable not excluded
44-
var2 = backend.Variable([[1.0, 2.0], [3.0, 4.0]])
45-
optimizer = Muon()
46-
self.assertFalse(optimizer._should_use_adamw(var2))
47-
48-
# Dense layer
49-
dense = Dense(2)
50-
dense.build([None, 2])
51-
self.assertFalse(optimizer._should_use_adamw(dense.weights[0]))
46+
result = optimizer._should_use_adamw(embedding.weights[0])
47+
self.assertTrue(result)
5248

53-
# Dimension rules
49+
def test_should_use_adamw_dimension_rule(self):
50+
"""Variables with dimensions not between 2–4 use AdamW."""
5451
v_1d = backend.Variable([1.0, 2.0], name="v1d")
5552
v_5d = backend.Variable(np.zeros((2, 2, 2, 2, 2)), name="v5d")
53+
optimizer = Muon()
5654
self.assertTrue(optimizer._should_use_adamw(v_1d))
5755
self.assertTrue(optimizer._should_use_adamw(v_5d))
5856

57+
def test_should_use_adamw_dense_layer(self):
58+
"""2D dense layer weights should use Muon (False)."""
59+
dense = Dense(2)
60+
dense.build([None, 2])
61+
optimizer = Muon()
62+
result = optimizer._should_use_adamw(dense.weights[0])
63+
self.assertFalse(result)
64+
5965
def test_muon_single_step(self):
6066
optimizer = Muon(learning_rate=0.5, weight_decay=0)
6167
grads = ops.array([[1.0, 6.0], [7.0, 2.0]])
@@ -79,10 +85,12 @@ def test_clip_value(self):
7985
self.assertAllClose(clipped_grad[0], [1.0, 1.0])
8086

8187
def test_no_path_attribute_error(self):
82-
"""Ensure compatibility with TF 2.16+
83-
ResourceVariable (no .path)."""
88+
"""Ensure compatibility with TF 2.16+ where
89+
ResourceVariable has no .path."""
8490
optimizer = Muon()
85-
var = backend.Variable([1.0, 2.0], name="test_var")
91+
var = tf.Variable([1.0, 2.0], name="test_var")
92+
# Force-run method that caused AttributeError in issue #21793
93+
8694
try:
8795
result = optimizer._should_use_adamw(var)
8896
self.assertIn(result, [True, False])

0 commit comments

Comments
 (0)