Skip to content

Commit ea5cd09

Browse files
committed
Fix: Remove deprecated .path access in Muon optimizer for TF 2.16+ compatibility
Fixes keras.optimizers.Muon failing with AttributeError: 'ResourceVariable' object has no attribute 'path' in Keras 3 / TF 2.16-2.20. Changes: - Replaced deprecated .path references with _get_variable_index() for variable identification - Updated build() to use lists instead of dicts, initialized with [None] * len(var_list) - Updated _should_use_adamw() logic to safely check .path only during build - Updated update_step(), _muon_update_step(), and _adamw_update_step() to use _get_variable_index() - Added robust error handling for invalid regex patterns in exclude_layers - Reverted image_utils.py changes as requested by reviewer Result: All tests pass. Compatible with TensorFlow 2.16+. Closes #21793
1 parent 198512f commit ea5cd09

File tree

2 files changed

+30
-32
lines changed

2 files changed

+30
-32
lines changed

keras/src/optimizers/muon.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -134,47 +134,57 @@ def _should_use_adamw(self, variable):
134134
# any {0,1}-D parameters should all be optimized by adam
135135
if not 1 < len(variable.shape) < 4:
136136
return True
137-
if self.exclude_embeddings and "embedding" in variable.path.lower():
137+
# Check .path only during build (where we have keras.Variable)
138+
var_path = variable.path if hasattr(variable, "path") else None
139+
if var_path is None:
140+
return False
141+
if self.exclude_embeddings and "embedding" in var_path.lower():
138142
return True
139-
for keyword in self.exclude_layers:
140-
if re.search(keyword, variable.path):
141-
return True
143+
# Exclude any user-specified layer patterns
144+
for pattern in self.exclude_layers:
145+
try:
146+
if re.search(pattern, var_path):
147+
return True
148+
except (re.error, TypeError):
149+
# Skip invalid regex patterns in exclude_layers
150+
continue
142151
return False
143152

144153
def build(self, var_list):
145154
"""Initialize optimizer variables.
146155
147-
Adam optimizer has 3 types of variables: momentums, velocities and
148-
velocity_hat (only set when amsgrad is applied),
156+
Muon optimizer has 2 types of variables: momentums and velocities.
157+
Velocities are only set when using AdamW update step.
149158
150159
Args:
151-
var_list: list of model variables to build Adam variables on.
160+
var_list: list of model variables to build Muon variables on.
152161
"""
153162
if self.built:
154163
return
155164
super().build(var_list)
156-
self.adam_momentums = {}
157-
self.adam_velocities = {}
158-
159-
self.muon_momentums = {}
160-
self.muon_velocities = {}
165+
# Initialize lists with None for all variables
166+
self.adam_momentums = [None] * len(var_list)
167+
self.adam_velocities = [None] * len(var_list)
161168

162169
for var in var_list:
163170
if not self._overwrite_variable_with_gradient(var):
164-
self.adam_momentums[var.path] = (
171+
var_idx = self._get_variable_index(var)
172+
self.adam_momentums[var_idx] = (
165173
self.add_variable_from_reference(
166174
reference_variable=var, name="momentum"
167175
)
168176
)
169177
if self._should_use_adamw(var):
170-
self.adam_velocities[var.path] = (
178+
self.adam_velocities[var_idx] = (
171179
self.add_variable_from_reference(
172180
reference_variable=var, name="velocity"
173181
)
174182
)
175183

176184
def update_step(self, gradient, variable, learning_rate):
177-
if self._should_use_adamw(variable):
185+
var_idx = self._get_variable_index(variable)
186+
# Check if velocity exists to determine if we should use AdamW
187+
if self.adam_velocities[var_idx] is not None:
178188
# It should be noted that lr is one-tenth when using adamw.
179189
self._adamw_update_step(
180190
gradient, variable, learning_rate * self.adam_lr_ratio
@@ -183,7 +193,8 @@ def update_step(self, gradient, variable, learning_rate):
183193
self._muon_update_step(gradient, variable, learning_rate)
184194

185195
def _muon_update_step(self, gradient, variable, lr):
186-
m = self.adam_momentums[variable.path]
196+
var_idx = self._get_variable_index(variable)
197+
m = self.adam_momentums[var_idx]
187198
self.assign_add(m, ops.add(gradient, m * (self.momentum - 1)))
188199
shape = variable.shape
189200
if self.nesterov:
@@ -210,8 +221,9 @@ def _adamw_update_step(self, gradient, variable, learning_rate):
210221
ops.cast(self.adam_beta_2, variable.dtype), local_step
211222
)
212223

213-
m = self.adam_momentums[variable.path]
214-
v = self.adam_velocities[variable.path]
224+
var_idx = self._get_variable_index(variable)
225+
m = self.adam_momentums[var_idx]
226+
v = self.adam_velocities[var_idx]
215227

216228
alpha = lr * ops.sqrt(1 - adam_beta_2_power) / (1 - adam_beta_1_power)
217229

keras/src/utils/image_utils.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -180,19 +180,7 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs):
180180
if file_format is None and isinstance(path, (str, pathlib.Path)):
181181
file_format = pathlib.Path(path).suffix[1:].lower()
182182

183-
# Normalize jpg → jpeg for Pillow compatibility
184-
if file_format and file_format.lower() == "jpg":
185-
file_format = "jpeg"
186-
187183
img = array_to_img(x, data_format=data_format, scale=scale)
188-
189-
# Handle RGBA → RGB conversion for JPEG
190-
if img.mode == "RGBA" and file_format == "jpeg":
191-
warnings.warn(
192-
"The JPEG format does not support RGBA images, converting to RGB."
193-
)
194-
img = img.convert("RGB")
195-
196184
img.save(path, format=file_format, **kwargs)
197185

198186

@@ -464,6 +452,4 @@ def smart_resize(
464452
img, size=size, interpolation=interpolation, data_format=data_format
465453
)
466454

467-
if isinstance(x, np.ndarray):
468-
return np.array(img)
469455
return img

0 commit comments

Comments
 (0)