-
Notifications
You must be signed in to change notification settings - Fork 18
feat: add PPXA and ConsensusADMM algorithms #237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1793,3 +1793,295 @@ def DouglasRachfordSplitting( | |
| print(f"\nTotal time (s) = {time.time() - tstart:.2f}") | ||
| print("---------------------------------------------------------\n") | ||
| return x, y | ||
|
|
||
|
|
||
| def PPXA( # pylint: disable=invalid-name | ||
| prox_ops: List[ProxOperator], | ||
| x0: NDArray | List[NDArray], | ||
| tau: float, | ||
| eta: float = 1.0, | ||
| weights: NDArray | List[float] | None = None, | ||
| niter: int = 1000, | ||
| tol: Optional[float] = 1e-7, | ||
| callback: Optional[Callable[..., None]] = None, | ||
| show: bool = False, | ||
| ) -> NDArray: | ||
| r"""Parallel Proximal Algorithm (PPXA) | ||
| Solves the following minimization problem using | ||
| Parallel Proximal Algorithm (PPXA): | ||
| .. math:: | ||
| \mathbf{x} = \argmin_\mathbf{x} \sum_{i=1}^m f_i(\mathbf{x}) | ||
| where :math:`f_i(\mathbf{x})` are any convex | ||
| functions that has known proximal operators. | ||
| Parameters | ||
| ---------- | ||
| prox_ops : :obj:`list` | ||
| A list of proximable functions :math:`f_1, \ldots, f_m`. | ||
| x0 : :obj:`np.ndarray` or :obj:`list` | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| Initial vector :math:`\mathbf{x}` of all :math:`f_i` if 1D :obj:`np.ndarray`, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Initial vector :math: ** write the size as otherwise it is not obvious where is the dimension of the model vs where is m |
||
| or :math:`\mathbf{x}_{i}` of each :math:`f_i` for :math:`i=1,\ldots,m` | ||
| if :obj:`list` or 2D :obj:`np.ndarray`. | ||
| tau : :obj:`float` | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any condition on tau that must be followed (usually there is an upper bound related to some Lipschitz constant?) |
||
| Positive scalar weight | ||
| eta : :obj:`float`, optional | ||
| Relaxation parameter (must be between 0 and 2, 0 excluded). | ||
| weights : :obj:`np.ndarray` or :obj:`list` or :obj:`None`, optional | ||
| Weights :math:`\sum_{i=1}^m w_i = 1, \ 0 < w_i < 1`, | ||
| Defaults to None, which means :math:`w_1 = \cdots = w_m = \frac{1}{m}.` | ||
| niter : :obj:`int`, optional | ||
| The maximum number of iterations. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Try to use as much as possible the same description used for the same parameters in other solvers. Here: |
||
| tol : :obj:`float`, optional | ||
| Tolerance on change of the solution (used as stopping criterion). | ||
| If ``tol=0``, run until ``niter`` is reached. | ||
| callback : :obj:`callable`, optional | ||
| Function with signature (``callback(x)``) to call after each iteration | ||
| where ``x`` is the current model vector | ||
| show : :obj:`bool`, optional | ||
| Display iterations log | ||
| Returns | ||
| ------- | ||
| x : :obj:`numpy.ndarray` | ||
| Inverted model | ||
| See Also | ||
| -------- | ||
| ConsensusADMM: Consensus ADMM | ||
| Notes | ||
| ----- | ||
| The Parallel Proximal Algorithm (PPXA) can be expressed by the following | ||
| recursion [1]_, [2]_, [3]_, [4]_: | ||
| * :math:`\mathbf{y}_{i}^{(0)} = \mathbf{x}` or :math:`\mathbf{y}_{i}^{(0)} = \mathbf{x}_{i}` for :math:`i=1,\ldots,m` | ||
| * :math:`\mathbf{x}^{(0)} = \sum_{i=1}^m w_i \mathbf{y}_{i}^{(0)}` | ||
| * for :math:`k = 1, \ldots` | ||
| * for :math:`i = 1, \ldots, m` | ||
| * :math:`\mathbf{p}_{i}^{(k)} = \prox_{\frac{\tau}{w_i} f_i} (\mathbf{y}_{i}^{(k)})` | ||
| * :math:`\mathbf{p}^{(k)} = \sum_{i=1}^{m} w_i \mathbf{p}_{i}^{(k)}` | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. take away |
||
| * for :math:`i = 1, \ldots, m` | ||
| * :math:`\mathbf{y}_{i}^{(k+1)} = \mathbf{y}_{i}^{(k)} + \eta (2 \mathbf{p}^{(k)} - \mathbf{x}^{(k)} - \mathbf{p}_i^{(k)})` | ||
| * :math:`\mathbf{x}^{(k+1)} = \mathbf{x}^{(k)} + \eta (\mathbf{p}^{(k)} - \mathbf{x}^{(k)})` | ||
| where :math:`0 < \eta < 2` and | ||
| :math:`\sum_{i=1}^m w_i = 1, \ 0 < w_i < 1`. | ||
| In the current implementation, :math:`w_i = 1 / m` when not provided. | ||
| References | ||
| ---------- | ||
| .. [1] Combettes, P.L., Pesquet, J.-C., 2008. A proximal decomposition | ||
| method for solving convex variational inverse problems. Inverse Problems | ||
| 24, 065014. Algorithm 3.1. https://doi.org/10.1088/0266-5611/24/6/065014 | ||
| https://arxiv.org/abs/0807.2617 | ||
| .. [2] Combettes, P.L., Pesquet, J.-C., 2011. Proximal Splitting Methods in | ||
| Signal Processing, in Fixed-Point Algorithms for Inverse Problems in | ||
| Science and Engineering, Springer, pp. 185-212. Algorithm 10.27. | ||
| https://doi.org/10.1007/978-1-4419-9569-8_10 | ||
| .. [3] Bauschke, H.H., Combettes, P.L., 2011. Convex Analysis and Monotone | ||
| Operator Theory in Hilbert Spaces, 1st ed, CMS Books in Mathematics. | ||
| Springer, New York, NY. Proposition 27.8. | ||
| https://doi.org/10.1007/978-1-4419-9467-7 | ||
| .. [4] Ryu, E.K., Yin, W., 2022. Large-Scale Convex Optimization: Algorithms | ||
| & Analyses via Monotone Operators. Cambridge University Press, | ||
| Cambridge. Exercise 2.38 https://doi.org/10.1017/9781009160865 | ||
| https://large-scale-book.mathopt.com/ | ||
| """ | ||
| if show: | ||
| tstart = time.time() | ||
| print( | ||
| "Parallel Proximal Algorithm\n" | ||
| "---------------------------------------------------------" | ||
| ) | ||
| for i, prox_op in enumerate(prox_ops): | ||
| print(f"Proximal operator (f{i}): {type(prox_op)}") | ||
| print(f"tau = {tau:10e}\tniter = {niter:d}\n") | ||
| head = " Itn x[0] J=sum_i f_i" | ||
| print(head) | ||
|
|
||
| ncp = get_array_module(x0) | ||
|
|
||
| m = len(prox_ops) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add some comments... I guess this all the way to the for loop is |
||
| if weights is None: | ||
| w = ncp.full(m, 1. / m) | ||
| else: | ||
| w = ncp.asarray(weights) | ||
|
|
||
| if isinstance(x0, list) or x0.ndim == 2: | ||
| y = ncp.asarray(x0) # yi_0 = xi_0, for i = 1, ..., m | ||
| else: | ||
| y = ncp.full((m, x0.size), x0) # y1_0 = y2_0 = ... = ym_0 = x0 | ||
|
|
||
| x = ncp.mean(y, axis=0) | ||
| x_old = x.copy() | ||
|
|
||
| for iiter in range(niter): | ||
|
|
||
| p = ncp.stack([prox_ops[i].prox(y[i], tau / w[i]) for i in range(m)]) | ||
| pn = ncp.sum(w[:, None] * p, axis=0) | ||
| y = y + eta * (2 * pn - x - p) | ||
| x = x + eta * (pn - x) | ||
|
|
||
| if callback is not None: | ||
| callback(x) | ||
|
|
||
| if show: | ||
| if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0: | ||
| pf = ncp.sum([prox_ops[i](x) for i in range(m)]) | ||
| print( | ||
| f"{iiter + 1:6d} {ncp.real(to_numpy(x[0])):12.5e} " | ||
| f"{pf:10.3e}" | ||
| ) | ||
|
|
||
| if ncp.abs(x - x_old).max() < tol: | ||
| break | ||
|
|
||
| x_old = x | ||
|
|
||
| if show: | ||
| print(f"\nTotal time (s) = {time.time() - tstart:.2f}") | ||
| print("---------------------------------------------------------\n") | ||
|
|
||
| return x | ||
|
|
||
|
|
||
| def ConsensusADMM( # pylint: disable=invalid-name | ||
| prox_ops: List[ProxOperator], | ||
| x0: NDArray, | ||
| tau: float, | ||
| niter: int = 1000, | ||
| tol: Optional[float] = 1e-7, | ||
| callback: Optional[Callable[..., None]] = None, | ||
| show: bool = False, | ||
| ) -> NDArray: | ||
| r"""Consensus ADMM | ||
| Solves the following global consensus problem using ADMM: | ||
| .. math:: | ||
| \argmin_{\mathbf{x_1}, \mathbf{x_2}, \ldots, \mathbf{x_m}} | ||
| \sum_{i=1}^m f_i(\mathbf{x}_i) \quad \text{s.t.} | ||
| \quad \mathbf{x_1} = \mathbf{x_2} = \cdots = \mathbf{x_m} | ||
| where :math:`f_i(\mathbf{x})` are any convex | ||
| functions that has known proximal operators. | ||
| Parameters | ||
| ---------- | ||
| prox_ops : :obj:`list` | ||
| A list of proximable functions :math:`f_1, \ldots, f_m`. | ||
| x0 : :obj:`np.ndarray` | ||
| Initial vector | ||
| tau : :obj:`float` | ||
| Positive scalar weight | ||
| niter : :obj:`int`, optional | ||
| The maximum number of iterations. | ||
| tol : :obj:`float`, optional | ||
| Tolerance on change of the solution (used as stopping criterion). | ||
| If ``tol=0``, run until ``niter`` is reached. | ||
| callback : :obj:`callable`, optional | ||
| Function with signature (``callback(x)``) to call after each iteration | ||
| where ``x`` is the current model vector | ||
| show : :obj:`bool`, optional | ||
| Display iterations log | ||
| Returns | ||
| ------- | ||
| x : :obj:`numpy.ndarray` | ||
| Inverted model | ||
| See Also | ||
| -------- | ||
| ADMM: Alternating Direction Method of Multipliers | ||
| PPXA: Parallel Proximal Algorithm | ||
| Notes | ||
| ----- | ||
| The ADMM for the consensus problem can be expressed by the following | ||
| recursion [1]_, [2]_: | ||
| * :math:`\bar{\mathbf{x}}^{(0)} = \mathbf{x}` | ||
| * for :math:`k = 1, \ldots` | ||
| * for :math:`i = 1, \ldots, m` | ||
| * :math:`\mathbf{x}_i^{(k+1)} = \mathrm{prox}_{\tau f_i} \left(\bar{\mathbf{x}}^{(k)} - \mathbf{y}_i^{(k)}\right)` | ||
| * :math:`\bar{\mathbf{x}}^{(k+1)} = \frac{1}{m} \sum_{i=1}^m \mathbf{x}_i^{(k)}` | ||
| * for :math:`i = 1, \ldots, m` | ||
| * :math:`\mathbf{y}_i^{(k+1)} = \mathbf{y}_i^{(k)} + \mathbf{x}_i^{(k+1)} - \bar{\mathbf{x}}^{(k+1)}` | ||
| The current implementation returns :math:`\bar{\mathbf{x}}`. | ||
| References | ||
| ---------- | ||
| .. [1] Boyd, S., Parikh, N., Chu, E., Peleato, B., Eckstein, J., 2011. | ||
| Distributed Optimization and Statistical Learning via the Alternating | ||
| Direction Method of Multipliers. Foundations and Trends in Machine Learning, | ||
| Vol. 3, No. 1, pp 1-122. Section 7.1. https://doi.org/10.1561/2200000016 | ||
| https://stanford.edu/~boyd/papers/pdf/admm_distr_stats.pdf | ||
| .. [2] Parikh, N., Boyd, S., 2014. Proximal Algorithms. Foundations and | ||
| Trends in Optimization, Vol. 1, No. 3, pp 127-239. | ||
| Section 5.2.1. https://doi.org/10.1561/2400000003 | ||
| https://web.stanford.edu/~boyd/papers/pdf/prox_algs.pdf | ||
| """ | ||
| if show: | ||
| tstart = time.time() | ||
| print( | ||
| "Consensus ADMM\n" | ||
| "---------------------------------------------------------" | ||
| ) | ||
| for i, prox_op in enumerate(prox_ops): | ||
| print(f"Proximal operator (f{i}): {type(prox_op)}") | ||
| print(f"tau = {tau:10e}\tniter = {niter:d}\n") | ||
| head = " Itn x[0] J=sum_i f_i" | ||
| print(head) | ||
|
|
||
| ncp = get_array_module(x0) | ||
|
|
||
| m = len(prox_ops) | ||
| x_bar = x0.copy() | ||
| x_bar_old = x0.copy() | ||
| y = ncp.zeros_like(x0) | ||
|
|
||
| for iiter in range(niter): | ||
|
|
||
| x = ncp.stack([prox_ops[i].prox(x_bar - y[i], tau) for i in range(m)]) | ||
| x_bar = ncp.mean(x, axis=0) | ||
| y = y + x - x_bar | ||
|
|
||
| if callback is not None: | ||
| callback(x_bar) | ||
|
|
||
| if show: | ||
| if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0: | ||
| pf = ncp.sum([prox_ops[i](x_bar) for i in range(m)]) | ||
| print( | ||
| f"{iiter + 1:6d} {ncp.real(to_numpy(x_bar[0])):12.5e} " | ||
| f"{pf:10.3e}" | ||
| ) | ||
|
|
||
| if ncp.abs(x_bar - x_bar_old).max() < tol: | ||
| break | ||
|
|
||
| x_bar_old = x_bar | ||
|
|
||
| if show: | ||
| print(f"\nTotal time (s) = {time.time() - tstart:.2f}") | ||
| print("---------------------------------------------------------\n") | ||
|
|
||
| return x_bar | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be consistent with GPG I suggest to call this
proxfs