Skip to content

Commit 7496bdb

Browse files
committed
Add Dykstra's algorithms.
- DykstrasProjection - (Parallel) Dykstra's projection algorithm computing convex projection to the intersection of convex sets - DykstrasProjectionProx - The corresponding indicator function (prox) - DykstraLikeProximal - (Parallel) Dykstra-like proximal algorithm computing prox of the sum of convex functions Add tests in a separate test file. - test_dykstra.py
1 parent 54ac074 commit 7496bdb

File tree

7 files changed

+915
-2
lines changed

7 files changed

+915
-2
lines changed

docs/source/api/index.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Orthogonal projections
2929
L1BallProj
3030
NuclearBallProj
3131
SimplexProj
32+
DykstrasProjection
3233

3334
Proximal operators
3435
------------------
@@ -83,6 +84,8 @@ Convex
8384
Quadratic
8485
Simplex
8586
TV
87+
DykstrasProjectionProx
88+
DykstraLikeProximal
8689

8790

8891
Non-Convex
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
from typing import List, Callable
2+
import numpy as np
3+
from pylops.utils.typing import NDArray
4+
5+
6+
class DykstrasProjection():
7+
r"""The convex projection to the intersection of convex sets
8+
using Dykstra's algorithm.
9+
10+
11+
Parameters
12+
----------
13+
projections : :obj:`List[Callable[[np.ndarray], np.ndarray]]`
14+
A list of projection functions :math:`P_1, \ldots, P_m`.
15+
max_iter : :obj:`int`, optional, default=100
16+
The maximum number of iterations.
17+
tol : :obj:`float`, optional, default=1e-6
18+
Torrelance to stop the iteration.
19+
use_parallel : :obj:`bool`, optional, default=False
20+
If True, use the parallel version when $m=2$.
21+
22+
23+
Notes
24+
-----
25+
Given a set of convex projections :math:`P_i` for :math:`i=1, \ldots, m`,
26+
each mapping :math:`x` to its projection :math:`P_i(x)` onto a convex set
27+
:math:`C_i`, this class computes the convex projection :math:`P_C(x)`
28+
of :math:`x` using Dykstra's algorithm, where
29+
30+
.. math:: C = \cap_{i=1}^m C_i
31+
32+
is the intersection of :math:`C_i` provided :math:`C \neq \emptyset`.
33+
34+
35+
For :math:`m=2`, the projection :math:`P_C(x)` of :math:`x` is computed
36+
by the Dykstra's algorithm [1]_, [2]_, [3]_:
37+
38+
* :math:`x_0 = x, p_0 = q_0 = 0`,
39+
* for :math:`k = 1, 2, \ldots`
40+
41+
* :math:`y_k = P_1(x_k + p_k)`
42+
* :math:`p_{k+1} = x_k + p_k - y_k`
43+
* :math:`x_{k+1} = P_2(y_k + q_k)`
44+
* :math:`q_{k+1} = y_k + q_k - x_{k+1}`
45+
46+
47+
For :math:`m \ge 2`, the projection :math:`P_C(x)` is computed
48+
by the parallel Dykstra's algorithm [5]_, [6]_. The following
49+
is taken from [4]_:
50+
51+
* :math:`u_m^{(0)} = x, z_1^{(0)} = \cdots = z_m^{(0)} = 0`,
52+
* for :math:`k = 1, 2, \ldots`
53+
54+
* for :math:`i = 1, \ldots, m`
55+
56+
* :math:`u_0^{(k)} = u_m^{(k-1)}`
57+
* :math:`u_i^{(k)} = P_i(u_{i-1}^{(k)} + z_i^{(k-1)})`
58+
* :math:`z_i^{(k)} = z_i^{(k-1)} + u_{i-1}^{(k)} - u_i^{(k)}`
59+
60+
Note the this is the proximal operator of the corresponding
61+
indicator function
62+
(see :class:`pyproximal.DykstrasProjectionProx` for details).
63+
64+
65+
Examples
66+
--------
67+
>>> import numpy as np
68+
>>> from pyproximal.projection import (
69+
... BoxProj,
70+
... EuclideanBallProj,
71+
... DykstrasProjection
72+
... )
73+
74+
>>> circle_1 = EuclideanBallProj(np.array([-2.5, 0.0]), 5)
75+
>>> circle_2 = EuclideanBallProj(np.array([2.5, 0.0]), 5)
76+
>>> circle_3 = EuclideanBallProj(np.array([0.0, 3.5]), 5)
77+
>>> box = BoxProj(np.array([-5.0, -2.5]), np.array([5.0, 2.5]))
78+
79+
>>> projections = [circle_1, circle_2, circle_3, box]
80+
>>> dykstra_proj = DykstrasProjection(projections)
81+
82+
>>> rng = np.random.default_rng(10)
83+
>>> x = rng.normal(0., 3.5, size=2)
84+
85+
>>> print("x =", x)
86+
x = [-3.86168457 -2.53758624]
87+
88+
>>> xp = dykstra_proj(x)
89+
>>> print("x projection =", xp)
90+
x projection = [-2.42308423 -0.87363268]
91+
92+
93+
References
94+
----------
95+
.. [1] Bauschke, H.H., Borwein, J.M., 1994. Dykstra's Alternating
96+
Projection Algorithm for Two Sets. Journal of Approximation Theory 79,
97+
418-443. https://doi.org/10.1006/jath.1994.1136
98+
https://cmps-people.ok.ubc.ca/bauschke/Research/02.pdf
99+
.. [2] Bauschke, H.H., Burachik, R.S., Herman, D.B., Kaya, C.Y., 2020. On
100+
Dykstra's algorithm: finite convergence, stalling, and the method of
101+
alternating projections. Optim Lett 14, 1975-1987.
102+
https://doi.org/10.1007/s11590-020-01600-4
103+
https://arxiv.org/abs/2001.06747
104+
.. [3] Wikipedia, Dykstra's projection algorithm.
105+
https://en.wikipedia.org/wiki/Dykstra%27s_projection_algorithm
106+
107+
.. [4] Tibshirani, R.J., 2017. Dykstra's Algorithm, ADMM, and Coordinate
108+
Descent: Connections, Insights, and Extensions, NeurIPS2017.
109+
https://proceedings.neurips.cc/paper_files/paper/2017/hash/5ef698cd9fe650923ea331c15af3b160-Abstract.html
110+
.. [5] Bauschke, H.H., Combettes, P.L., 2011. Convex Analysis and Monotone
111+
Operator Theory in Hilbert Spaces, Theorem 29.2, 1st ed, Springer.
112+
https://doi.org/10.1007/978-1-4419-9467-7
113+
.. [6] Bauschke, H.H., Lewis, A.S., 2000. Dykstras algorithm with bregman
114+
projections: A convergence proof. Optimization 48, 409-427.
115+
https://doi.org/10.1080/02331930008844513
116+
https://people.orie.cornell.edu/aslewis/publications/00-dykstras.pdf
117+
118+
119+
See also
120+
--------
121+
pyproximal.DykstrasProjectionProx :
122+
The corresponding indicator function.
123+
pyproximal.DykstraLikeProximal :
124+
Proximal operator of a sum of two or more convex functions
125+
using Dykstra-like algorithm.
126+
"""
127+
128+
def __init__(
129+
self,
130+
projections: List[Callable[[NDArray], NDArray]],
131+
max_iter: int = 100,
132+
tol: float = 1e-6,
133+
use_parallel: bool = False,
134+
) -> None:
135+
self.projections = projections
136+
self.max_iter = max_iter
137+
self.tol = tol
138+
self.use_parallel = use_parallel
139+
140+
if len(projections) == 1:
141+
self._projection = self._single_projection
142+
elif len(projections) == 2 and not use_parallel:
143+
self._projection = self._dykstra_projection
144+
else:
145+
self._projection = self._parallel_dykstra_projection
146+
147+
def __call__(self, x: NDArray) -> NDArray:
148+
r"""compute projection :math:`P_C(x)` of :math:`x`.
149+
150+
Parameters
151+
----------
152+
x : :obj:`numpy.ndarray`
153+
A point
154+
155+
Returns
156+
-------
157+
:obj:`numpy.ndarray`
158+
projection of x
159+
160+
"""
161+
return self._projection(x)
162+
163+
def _single_projection(self, x0: NDArray) -> NDArray:
164+
r"""Compute projection :math:`P_C(x)` for :math:`m=1`.
165+
166+
Parameters
167+
----------
168+
x : :obj:`numpy.ndarray`
169+
A point
170+
171+
Returns
172+
-------
173+
:obj:`numpy.ndarray`
174+
projection of x
175+
176+
"""
177+
return self.projections[0](x0)
178+
179+
def _dykstra_projection(self, x0: NDArray) -> NDArray:
180+
r"""Compute projection :math:`P_C(x)` for :math:`m=2`.
181+
182+
Parameters
183+
----------
184+
x : :obj:`numpy.ndarray`
185+
A point
186+
187+
Returns
188+
-------
189+
:obj:`numpy.ndarray`
190+
projection of x
191+
192+
"""
193+
x = x0.copy()
194+
p = np.zeros_like(x)
195+
q = np.zeros_like(x)
196+
197+
for _ in range(self.max_iter):
198+
x_old = x.copy()
199+
200+
y = self.projections[0](x + p)
201+
p = x + p - y
202+
x = self.projections[1](y + q)
203+
q = y + q - x
204+
205+
if max(np.abs(x - x_old).max(),
206+
np.abs(y - x_old).max()) < self.tol:
207+
break
208+
return x
209+
210+
def _parallel_dykstra_projection(self, x0: NDArray) -> NDArray:
211+
r"""Compute projection :math:`P_C(x)` for :math:`m \ge 2`.
212+
213+
Parameters
214+
----------
215+
x : :obj:`numpy.ndarray`
216+
A point
217+
218+
Returns
219+
-------
220+
:obj:`numpy.ndarray`
221+
projection of x
222+
223+
"""
224+
u = x0.copy()
225+
m = len(self.projections)
226+
z = [np.zeros_like(u) for _ in range(m)]
227+
228+
for _ in range(self.max_iter):
229+
u_old = u.copy()
230+
u_prev = np.array([u.copy() for _ in range(m)])
231+
232+
for i in range(m):
233+
u = self.projections[i](u_prev[i - 1] + z[i])
234+
z[i] = z[i] + u_prev[i - 1] - u
235+
u_prev[i] = u
236+
237+
if max(np.abs(u_old - u).max(),
238+
np.abs(u_prev - u).max()) < self.tol:
239+
break
240+
241+
return u

pyproximal/projection/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
AffineSetProj Projection onto an Affine set
1818
HankelProj Projection onto the set of Hankel matrices
1919
HalfSpaceProj Projection onto a Half Space
20-
20+
DykstrasProjection Projection onto a union of given sets
2121
"""
2222

2323
from .Box import *
@@ -30,6 +30,7 @@
3030
from .AffineSet import *
3131
from .Hankel import *
3232
from .HalfSpace import *
33+
from .DykstrasProjection import *
3334

3435
__all__ = [
3536
"BoxProj",
@@ -45,4 +46,5 @@
4546
"AffineSetProj",
4647
"HankelProj",
4748
"HalfSpaceProj",
49+
"DykstrasProjection",
4850
]

0 commit comments

Comments
 (0)