Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ Primal
ProximalPoint
TwIST
DouglasRachfordSplitting
PPXA
ConsensusADMM

.. currentmodule:: pyproximal.optimization.palm

Expand Down
2 changes: 2 additions & 0 deletions pyproximal/optimization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
TwIST Two-step Iterative Shrinkage/Threshold
PlugAndPlay Plug-and-Play Prior with ADMM
DouglasRachfordSplitting Douglas-Rachford algorithm
PPXA Parallel Proximal Algorithm
ConsensusADMM Consensus problem with ADMM
A list of solvers in ``pyproximal.optimization.proximaldual`` using both proximal
and dual proximal operators:
Expand Down
292 changes: 292 additions & 0 deletions pyproximal/optimization/primal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1793,3 +1793,295 @@ def DouglasRachfordSplitting(
print(f"\nTotal time (s) = {time.time() - tstart:.2f}")
print("---------------------------------------------------------\n")
return x, y


def PPXA( # pylint: disable=invalid-name
prox_ops: List[ProxOperator],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be consistent with GPG I suggest to call this proxfs

x0: NDArray | List[NDArray],
tau: float,
eta: float = 1.0,
weights: NDArray | List[float] | None = None,
niter: int = 1000,
tol: Optional[float] = 1e-7,
callback: Optional[Callable[..., None]] = None,
show: bool = False,
) -> NDArray:
r"""Parallel Proximal Algorithm (PPXA)
Solves the following minimization problem using
Parallel Proximal Algorithm (PPXA):
.. math::
\mathbf{x} = \argmin_\mathbf{x} \sum_{i=1}^m f_i(\mathbf{x})
where :math:`f_i(\mathbf{x})` are any convex
functions that has known proximal operators.
Parameters
----------
prox_ops : :obj:`list`
A list of proximable functions :math:`f_1, \ldots, f_m`.
x0 : :obj:`np.ndarray` or :obj:`list`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x0 : :obj:numpy.ndarray or :obj:list and also for others below

Initial vector :math:`\mathbf{x}` of all :math:`f_i` if 1D :obj:`np.ndarray`,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial vector :math:\mathbf{x} for all :math:f_i if 1-dimensional array is provided, or initial vectors :math:\mathbf{x}_{i} for each :math:f_i for :math:i=1,\ldots,m if a :obj:list of 1-dimensional arrays or a 2-dimensional array of size ** is provided`

** write the size as otherwise it is not obvious where is the dimension of the model vs where is m

or :math:`\mathbf{x}_{i}` of each :math:`f_i` for :math:`i=1,\ldots,m`
if :obj:`list` or 2D :obj:`np.ndarray`.
tau : :obj:`float`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any condition on tau that must be followed (usually there is an upper bound related to some Lipschitz constant?)

Positive scalar weight
eta : :obj:`float`, optional
Relaxation parameter (must be between 0 and 2, 0 excluded).
weights : :obj:`np.ndarray` or :obj:`list` or :obj:`None`, optional
Weights :math:`\sum_{i=1}^m w_i = 1, \ 0 < w_i < 1`,
Defaults to None, which means :math:`w_1 = \cdots = w_m = \frac{1}{m}.`
niter : :obj:`int`, optional
The maximum number of iterations.
Copy link
Contributor

@mrava87 mrava87 Jan 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to use as much as possible the same description used for the same parameters in other solvers. Here: Number of iterations of iterative scheme

tol : :obj:`float`, optional
Tolerance on change of the solution (used as stopping criterion).
If ``tol=0``, run until ``niter`` is reached.
callback : :obj:`callable`, optional
Function with signature (``callback(x)``) to call after each iteration
where ``x`` is the current model vector
show : :obj:`bool`, optional
Display iterations log
Returns
-------
x : :obj:`numpy.ndarray`
Inverted model
See Also
--------
ConsensusADMM: Consensus ADMM
Notes
-----
The Parallel Proximal Algorithm (PPXA) can be expressed by the following
recursion [1]_, [2]_, [3]_, [4]_:
* :math:`\mathbf{y}_{i}^{(0)} = \mathbf{x}` or :math:`\mathbf{y}_{i}^{(0)} = \mathbf{x}_{i}` for :math:`i=1,\ldots,m`
* :math:`\mathbf{x}^{(0)} = \sum_{i=1}^m w_i \mathbf{y}_{i}^{(0)}`
* for :math:`k = 1, \ldots`
* for :math:`i = 1, \ldots, m`
* :math:`\mathbf{p}_{i}^{(k)} = \prox_{\frac{\tau}{w_i} f_i} (\mathbf{y}_{i}^{(k)})`
* :math:`\mathbf{p}^{(k)} = \sum_{i=1}^{m} w_i \mathbf{p}_{i}^{(k)}`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

take away () from the superscripts to be consistent with the Notes of other solvers

* for :math:`i = 1, \ldots, m`
* :math:`\mathbf{y}_{i}^{(k+1)} = \mathbf{y}_{i}^{(k)} + \eta (2 \mathbf{p}^{(k)} - \mathbf{x}^{(k)} - \mathbf{p}_i^{(k)})`
* :math:`\mathbf{x}^{(k+1)} = \mathbf{x}^{(k)} + \eta (\mathbf{p}^{(k)} - \mathbf{x}^{(k)})`
where :math:`0 < \eta < 2` and
:math:`\sum_{i=1}^m w_i = 1, \ 0 < w_i < 1`.
In the current implementation, :math:`w_i = 1 / m` when not provided.
References
----------
.. [1] Combettes, P.L., Pesquet, J.-C., 2008. A proximal decomposition
method for solving convex variational inverse problems. Inverse Problems
24, 065014. Algorithm 3.1. https://doi.org/10.1088/0266-5611/24/6/065014
https://arxiv.org/abs/0807.2617
.. [2] Combettes, P.L., Pesquet, J.-C., 2011. Proximal Splitting Methods in
Signal Processing, in Fixed-Point Algorithms for Inverse Problems in
Science and Engineering, Springer, pp. 185-212. Algorithm 10.27.
https://doi.org/10.1007/978-1-4419-9569-8_10
.. [3] Bauschke, H.H., Combettes, P.L., 2011. Convex Analysis and Monotone
Operator Theory in Hilbert Spaces, 1st ed, CMS Books in Mathematics.
Springer, New York, NY. Proposition 27.8.
https://doi.org/10.1007/978-1-4419-9467-7
.. [4] Ryu, E.K., Yin, W., 2022. Large-Scale Convex Optimization: Algorithms
& Analyses via Monotone Operators. Cambridge University Press,
Cambridge. Exercise 2.38 https://doi.org/10.1017/9781009160865
https://large-scale-book.mathopt.com/
"""
if show:
tstart = time.time()
print(
"Parallel Proximal Algorithm\n"
"---------------------------------------------------------"
)
for i, prox_op in enumerate(prox_ops):
print(f"Proximal operator (f{i}): {type(prox_op)}")
print(f"tau = {tau:10e}\tniter = {niter:d}\n")
head = " Itn x[0] J=sum_i f_i"
print(head)

ncp = get_array_module(x0)

m = len(prox_ops)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some comments... I guess this all the way to the for loop is # initialize model?

if weights is None:
w = ncp.full(m, 1. / m)
else:
w = ncp.asarray(weights)

if isinstance(x0, list) or x0.ndim == 2:
y = ncp.asarray(x0) # yi_0 = xi_0, for i = 1, ..., m
else:
y = ncp.full((m, x0.size), x0) # y1_0 = y2_0 = ... = ym_0 = x0

x = ncp.mean(y, axis=0)
x_old = x.copy()

for iiter in range(niter):

p = ncp.stack([prox_ops[i].prox(y[i], tau / w[i]) for i in range(m)])
pn = ncp.sum(w[:, None] * p, axis=0)
y = y + eta * (2 * pn - x - p)
x = x + eta * (pn - x)

if callback is not None:
callback(x)

if show:
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
pf = ncp.sum([prox_ops[i](x) for i in range(m)])
print(
f"{iiter + 1:6d} {ncp.real(to_numpy(x[0])):12.5e} "
f"{pf:10.3e}"
)

if ncp.abs(x - x_old).max() < tol:
break

x_old = x

if show:
print(f"\nTotal time (s) = {time.time() - tstart:.2f}")
print("---------------------------------------------------------\n")

return x


def ConsensusADMM( # pylint: disable=invalid-name
prox_ops: List[ProxOperator],
x0: NDArray,
tau: float,
niter: int = 1000,
tol: Optional[float] = 1e-7,
callback: Optional[Callable[..., None]] = None,
show: bool = False,
) -> NDArray:
r"""Consensus ADMM
Solves the following global consensus problem using ADMM:
.. math::
\argmin_{\mathbf{x_1}, \mathbf{x_2}, \ldots, \mathbf{x_m}}
\sum_{i=1}^m f_i(\mathbf{x}_i) \quad \text{s.t.}
\quad \mathbf{x_1} = \mathbf{x_2} = \cdots = \mathbf{x_m}
where :math:`f_i(\mathbf{x})` are any convex
functions that has known proximal operators.
Parameters
----------
prox_ops : :obj:`list`
A list of proximable functions :math:`f_1, \ldots, f_m`.
x0 : :obj:`np.ndarray`
Initial vector
tau : :obj:`float`
Positive scalar weight
niter : :obj:`int`, optional
The maximum number of iterations.
tol : :obj:`float`, optional
Tolerance on change of the solution (used as stopping criterion).
If ``tol=0``, run until ``niter`` is reached.
callback : :obj:`callable`, optional
Function with signature (``callback(x)``) to call after each iteration
where ``x`` is the current model vector
show : :obj:`bool`, optional
Display iterations log
Returns
-------
x : :obj:`numpy.ndarray`
Inverted model
See Also
--------
ADMM: Alternating Direction Method of Multipliers
PPXA: Parallel Proximal Algorithm
Notes
-----
The ADMM for the consensus problem can be expressed by the following
recursion [1]_, [2]_:
* :math:`\bar{\mathbf{x}}^{(0)} = \mathbf{x}`
* for :math:`k = 1, \ldots`
* for :math:`i = 1, \ldots, m`
* :math:`\mathbf{x}_i^{(k+1)} = \mathrm{prox}_{\tau f_i} \left(\bar{\mathbf{x}}^{(k)} - \mathbf{y}_i^{(k)}\right)`
* :math:`\bar{\mathbf{x}}^{(k+1)} = \frac{1}{m} \sum_{i=1}^m \mathbf{x}_i^{(k)}`
* for :math:`i = 1, \ldots, m`
* :math:`\mathbf{y}_i^{(k+1)} = \mathbf{y}_i^{(k)} + \mathbf{x}_i^{(k+1)} - \bar{\mathbf{x}}^{(k+1)}`
The current implementation returns :math:`\bar{\mathbf{x}}`.
References
----------
.. [1] Boyd, S., Parikh, N., Chu, E., Peleato, B., Eckstein, J., 2011.
Distributed Optimization and Statistical Learning via the Alternating
Direction Method of Multipliers. Foundations and Trends in Machine Learning,
Vol. 3, No. 1, pp 1-122. Section 7.1. https://doi.org/10.1561/2200000016
https://stanford.edu/~boyd/papers/pdf/admm_distr_stats.pdf
.. [2] Parikh, N., Boyd, S., 2014. Proximal Algorithms. Foundations and
Trends in Optimization, Vol. 1, No. 3, pp 127-239.
Section 5.2.1. https://doi.org/10.1561/2400000003
https://web.stanford.edu/~boyd/papers/pdf/prox_algs.pdf
"""
if show:
tstart = time.time()
print(
"Consensus ADMM\n"
"---------------------------------------------------------"
)
for i, prox_op in enumerate(prox_ops):
print(f"Proximal operator (f{i}): {type(prox_op)}")
print(f"tau = {tau:10e}\tniter = {niter:d}\n")
head = " Itn x[0] J=sum_i f_i"
print(head)

ncp = get_array_module(x0)

m = len(prox_ops)
x_bar = x0.copy()
x_bar_old = x0.copy()
y = ncp.zeros_like(x0)

for iiter in range(niter):

x = ncp.stack([prox_ops[i].prox(x_bar - y[i], tau) for i in range(m)])
x_bar = ncp.mean(x, axis=0)
y = y + x - x_bar

if callback is not None:
callback(x_bar)

if show:
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
pf = ncp.sum([prox_ops[i](x_bar) for i in range(m)])
print(
f"{iiter + 1:6d} {ncp.real(to_numpy(x_bar[0])):12.5e} "
f"{pf:10.3e}"
)

if ncp.abs(x_bar - x_bar_old).max() < tol:
break

x_bar_old = x_bar

if show:
print(f"\nTotal time (s) = {time.time() - tstart:.2f}")
print("---------------------------------------------------------\n")

return x_bar
Loading
Loading