Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/root_finding.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ with respect to ``factor``::
Under the hood, we use the implicit function theorem in order to differentiate the root.
See the :ref:`implicit differentiation <implicit_diff>` section for more details.

The bracketing values ``lower`` and ``upper`` are solver hyperparameters, not
arguments of ``optimality_fun``. Implicit differentiation therefore computes
derivatives with respect to the arguments passed to ``run``, such as
``factor`` above, but not with respect to ``lower`` or ``upper`` themselves.
If the bracketing interval is computed from differentiable JAX values and is
intended only to define the algorithm's search interval, stop its gradients
before constructing ``Bisection``, for example with ``jax.lax.stop_gradient``.

Scipy wrapper
-------------

Expand Down
8 changes: 6 additions & 2 deletions jaxopt/_src/bisection.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,12 @@ class Bisection(base.IterativeSolver):
optimality_fun: a function ``optimality_fun(x, *args, **kwargs)``
where ``x`` is a 1d variable. The function should have opposite signs
when evaluated at ``lower`` and at ``upper``.
lower: the lower end of the bracketing interval.
upper: the upper end of the bracketing interval.
lower: the lower end of the bracketing interval. This is a solver
hyperparameter, not an argument of ``optimality_fun``; implicit
differentiation is not with respect to this value.
upper: the upper end of the bracketing interval. This is a solver
hyperparameter, not an argument of ``optimality_fun``; implicit
differentiation is not with respect to this value.
maxiter: maximum number of iterations.
tol: tolerance.
check_bracket: whether to check correctness of the bracketing interval.
Expand Down