From 368609e5cfec832b4f6228e659460f64232d522e Mon Sep 17 00:00:00 2001 From: Haihan Jiang Date: Sun, 14 Jun 2026 02:22:53 -0700 Subject: [PATCH] docs: clarify Bisection bound differentiation --- docs/root_finding.rst | 8 ++++++++ jaxopt/_src/bisection.py | 8 ++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/docs/root_finding.rst b/docs/root_finding.rst index b3c76d49..5e2afb11 100644 --- a/docs/root_finding.rst +++ b/docs/root_finding.rst @@ -61,6 +61,14 @@ with respect to ``factor``:: Under the hood, we use the implicit function theorem in order to differentiate the root. See the :ref:`implicit differentiation ` section for more details. +The bracketing values ``lower`` and ``upper`` are solver hyperparameters, not +arguments of ``optimality_fun``. Implicit differentiation therefore computes +derivatives with respect to the arguments passed to ``run``, such as +``factor`` above, but not with respect to ``lower`` or ``upper`` themselves. +If the bracketing interval is computed from differentiable JAX values and is +intended only to define the algorithm's search interval, stop its gradients +before constructing ``Bisection``, for example with ``jax.lax.stop_gradient``. + Scipy wrapper ------------- diff --git a/jaxopt/_src/bisection.py b/jaxopt/_src/bisection.py index 97f259ca..3f4b87b5 100644 --- a/jaxopt/_src/bisection.py +++ b/jaxopt/_src/bisection.py @@ -48,8 +48,12 @@ class Bisection(base.IterativeSolver): optimality_fun: a function ``optimality_fun(x, *args, **kwargs)`` where ``x`` is a 1d variable. The function should have opposite signs when evaluated at ``lower`` and at ``upper``. - lower: the lower end of the bracketing interval. - upper: the upper end of the bracketing interval. + lower: the lower end of the bracketing interval. This is a solver + hyperparameter, not an argument of ``optimality_fun``; implicit + differentiation is not with respect to this value. + upper: the upper end of the bracketing interval. This is a solver + hyperparameter, not an argument of ``optimality_fun``; implicit + differentiation is not with respect to this value. maxiter: maximum number of iterations. tol: tolerance. check_bracket: whether to check correctness of the bracketing interval.