diff --git a/brax/training/gradients.py b/brax/training/gradients.py index 4c21d4c7..a616f087 100644 --- a/brax/training/gradients.py +++ b/brax/training/gradients.py @@ -60,7 +60,7 @@ def gradient_update_fn( def f(*args, optimizer_state): value, grads = loss_and_pgrad_fn(*args) - params_update, optimizer_state = optimizer.update(grads, optimizer_state) + params_update, optimizer_state = optimizer.update(grads, optimizer_state, args[0]) params = optax.apply_updates(args[0], params_update) return value, params, optimizer_state