Skip to content

Commit 1437477

Browse files
cxlclpre-commit-ci[bot]KumoLiuericspod
authored andcommitted
Stein's Unbiased Risk Estimator (SURE) loss and Conjugate Gradient (Project-MONAI#7308)
Based on the discussion topic [here](Project-MONAI#7161 (comment)), we implemented the Conjugate-Gradient algorithm for linear operator inversion, and Stein's Unbiased Risk Estimator (SURE) [1] loss for ground-truth-date free diffusion process guidance that is proposed in [2] and illustrated in the algorithm below: <img width="650" alt="Screenshot 2023-12-10 at 10 19 25 PM" src="https://github.com/Project-MONAI/MONAI/assets/8581162/97069466-cbaf-44e0-b7a7-ae9deb8fd7f2"> The Conjugate-Gradient (CG) algorithm is used to solve for the inversion of the linear operator in Line-4 in the algorithm above, where the linear operator is too large to store explicitly as a matrix (such as FFT/IFFT of an image) and invert directly. Instead, we can solve for the linear inversion iteratively as in CG. The SURE loss is applied for Line-6 above. This is a differentiable loss function that can be used to train/giude an operator (e.g. neural network), where the pseudo ground truth is available but the reference ground truth is not. For example, in the MRI reconstruction, the pseudo ground truth is the zero-filled reconstruction and the reference ground truth is the fully sampled reconstruction. The reference ground truth is not available due to the lack of fully sampled. **Reference** [1] Stein, C.M.: Estimation of the mean of a multivariate normal distribution. Annals of Statistics 1981 [[paper link](https://projecteuclid.org/journals/annals-of-statistics/volume-9/issue-6/Estimation-of-the-Mean-of-a-Multivariate-Normal-Distribution/10.1214/aos/1176345632.full)] [2] B. Ozturkler et al. SMRD: SURE-based Robust MRI Reconstruction with Diffusion Models. MICCAI 2023 [[paper link](https://arxiv.org/pdf/2310.01799.pdf)] <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: chaoliu <[email protected]> Signed-off-by: cxlcl <[email protected]> Signed-off-by: chaoliu <[email protected]> Signed-off-by: YunLiu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]> Signed-off-by: Nikolas Schmitz <[email protected]>
1 parent 5168a12 commit 1437477

File tree

2 files changed

+126
-0
lines changed

2 files changed

+126
-0
lines changed

tests/test_conjugate_gradient.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import torch
17+
18+
from monai.networks.layers import ConjugateGradient
19+
20+
21+
class TestConjugateGradient(unittest.TestCase):
22+
def test_real_valued_inverse(self):
23+
"""Test ConjugateGradient with real-valued input: when the input is real
24+
value, the output should be the inverse of the matrix."""
25+
a_dim = 3
26+
a_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.float)
27+
28+
def a_op(x):
29+
return a_mat @ x
30+
31+
cg_solver = ConjugateGradient(a_op, num_iter=100)
32+
# define the measurement
33+
y = torch.tensor([1, 2, 3], dtype=torch.float)
34+
# solve for x
35+
x = cg_solver(torch.zeros(a_dim), y)
36+
x_ref = torch.linalg.solve(a_mat, y)
37+
# assert torch.allclose(x, x_ref, atol=1e-6), 'CG solver failed to converge to reference solution'
38+
self.assertTrue(torch.allclose(x, x_ref, atol=1e-6))
39+
40+
def test_complex_valued_inverse(self):
41+
a_dim = 3
42+
a_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.complex64)
43+
44+
def a_op(x):
45+
return a_mat @ x
46+
47+
cg_solver = ConjugateGradient(a_op, num_iter=100)
48+
y = torch.tensor([1, 2, 3], dtype=torch.complex64)
49+
x = cg_solver(torch.zeros(a_dim, dtype=torch.complex64), y)
50+
x_ref = torch.linalg.solve(a_mat, y)
51+
self.assertTrue(torch.allclose(x, x_ref, atol=1e-6))
52+
53+
54+
if __name__ == "__main__":
55+
unittest.main()

tests/test_sure_loss.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import torch
17+
18+
from monai.losses import SURELoss
19+
20+
21+
class TestSURELoss(unittest.TestCase):
22+
def test_real_value(self):
23+
"""Test SURELoss with real-valued input: when the input is real value, the loss should be 0.0."""
24+
sure_loss_real = SURELoss(perturb_noise=torch.zeros(2, 1, 128, 128), eps=0.1)
25+
26+
def operator(x):
27+
return x
28+
29+
y_pseudo_gt = torch.randn(2, 1, 128, 128)
30+
x = torch.randn(2, 1, 128, 128)
31+
loss = sure_loss_real(operator, x, y_pseudo_gt, complex_input=False)
32+
self.assertAlmostEqual(loss.item(), 0.0)
33+
34+
def test_complex_value(self):
35+
"""Test SURELoss with complex-valued input: when the input is complex value, the loss should be 0.0."""
36+
37+
def operator(x):
38+
return x
39+
40+
sure_loss_complex = SURELoss(perturb_noise=torch.zeros(2, 2, 128, 128), eps=0.1)
41+
y_pseudo_gt = torch.randn(2, 2, 128, 128)
42+
x = torch.randn(2, 2, 128, 128)
43+
loss = sure_loss_complex(operator, x, y_pseudo_gt, complex_input=True)
44+
self.assertAlmostEqual(loss.item(), 0.0)
45+
46+
def test_complex_general_input(self):
47+
"""Test SURELoss with complex-valued input: when the input is general complex value, the loss should be 0.0."""
48+
49+
def operator(x):
50+
return x
51+
52+
perturb_noise_real = torch.randn(2, 1, 128, 128)
53+
perturb_noise_complex = torch.zeros(2, 2, 128, 128)
54+
perturb_noise_complex[:, 0, :, :] = perturb_noise_real.squeeze()
55+
y_pseudo_gt_real = torch.randn(2, 1, 128, 128)
56+
y_pseudo_gt_complex = torch.zeros(2, 2, 128, 128)
57+
y_pseudo_gt_complex[:, 0, :, :] = y_pseudo_gt_real.squeeze()
58+
x_real = torch.randn(2, 1, 128, 128)
59+
x_complex = torch.zeros(2, 2, 128, 128)
60+
x_complex[:, 0, :, :] = x_real.squeeze()
61+
62+
sure_loss_real = SURELoss(perturb_noise=perturb_noise_real, eps=0.1)
63+
sure_loss_complex = SURELoss(perturb_noise=perturb_noise_complex, eps=0.1)
64+
65+
loss_real = sure_loss_real(operator, x_real, y_pseudo_gt_real, complex_input=False)
66+
loss_complex = sure_loss_complex(operator, x_complex, y_pseudo_gt_complex, complex_input=True)
67+
self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item(), places=6)
68+
69+
70+
if __name__ == "__main__":
71+
unittest.main()

0 commit comments

Comments
 (0)