Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .scripts/download_zenodo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ def calculate_md5(filename):

def download_zenodo_files(output_dir: Path):
"""
Download all files from Zenodo record 14938787 and verify their checksums.
Download all files from Zenodo record 14979785 and verify their checksums.

Args:
output_dir: Directory where files should be downloaded
"""
try:
print("Fetching files from Zenodo record 14938787...")
print("Fetching files from Zenodo record 14979785...")
with urllib.request.urlopen(
"https://zenodo.org/api/records/14938787"
"https://zenodo.org/api/records/14979785"
) as response:
data = json.loads(response.read())

Expand Down
87 changes: 44 additions & 43 deletions httomolibgpu/prep/stripe.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,23 @@ def remove_all_stripe(
Corrected 3D tomographic data as a CuPy or NumPy array.

"""
matindex = _create_matindex(data.shape[2], data.shape[0])
for m in range(data.shape[1]):
sino = data[:, m, :]
sino = _rs_dead(sino, snr, la_size, matindex)
sino = _rs_sort(sino, sm_size, dim)
sino = cp.nan_to_num(sino)
data[:, m, :] = sino
return data
streams = [cp.cuda.Stream() for _ in range(4)]
output = data.copy()
def process_slice(m, stream):
with stream:
output[:, m, :] = _rs_dead(output[:, m, :], snr, la_size)
output[:, m, :] = _rs_sort(output[:, m, :], sm_size, dim)
output[:, m, :] = cp.nan_to_num(output[:, m, :])

# Distribute slices across streams
for i in range(data.shape[1]):
stream = streams[i % 4]
process_slice(i, stream)

for stream in streams:
stream.synchronize()

return output


def _mpolyfit(x, y):
Expand Down Expand Up @@ -252,7 +261,7 @@ def _detect_stripe(listdata, snr):
return listmask


def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True):
def _rs_large(sinogram, snr, size, drop_ratio=0.1, norm=True):
"""
Remove large stripes.
"""
Expand All @@ -264,35 +273,35 @@ def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True):
list1 = cp.mean(sinosort[ndrop : nrow - ndrop], axis=0)
list2 = cp.mean(sinosmooth[ndrop : nrow - ndrop], axis=0)
listfact = list1 / list2

# Locate stripes
listmask = _detect_stripe(listfact, snr)
listmask = binary_dilation(listmask, iterations=1).astype(listmask.dtype)
matfact = cp.tile(listfact, (nrow, 1))

# Normalize
if norm is True:
sinogram = sinogram / matfact
sinogram1 = cp.transpose(sinogram)
matcombine = cp.asarray(cp.dstack((matindex, sinogram1)))

ids = cp.argsort(matcombine[:, :, 1], axis=1)
matsort = matcombine.copy()
matsort[:, :, 0] = cp.take_along_axis(matsort[:, :, 0], ids, axis=1)
matsort[:, :, 1] = cp.take_along_axis(matsort[:, :, 1], ids, axis=1)

matsort[:, :, 1] = cp.transpose(sinosmooth)
ids = cp.argsort(matsort[:, :, 0], axis=1)
matsortback = matsort.copy()
matsortback[:, :, 0] = cp.take_along_axis(matsortback[:, :, 0], ids, axis=1)
matsortback[:, :, 1] = cp.take_along_axis(matsortback[:, :, 1], ids, axis=1)

sino_corrected = cp.transpose(matsortback[:, :, 1])
if norm:
sinogram /= cp.tile(listfact, (nrow, 1))

sino_transposed = sinogram.T
ids_sort = cp.argsort(sino_transposed, axis=1)

# Apply sorting without explicit matindex
sino_sorted = cp.take_along_axis(sino_transposed, ids_sort, axis=1)

# Smoothen sorted sinogram
sino_sorted[:, :] = cp.transpose(sinosmooth)

# Restore original order
ids_restore = cp.argsort(ids_sort, axis=1)
sino_corrected = cp.take_along_axis(sino_sorted, ids_restore, axis=1).T

# Apply corrections only to affected columns
listxmiss = cp.where(listmask > 0.0)[0]
sinogram[:, listxmiss] = sino_corrected[:, listxmiss]

return sinogram


def _rs_dead(sinogram, snr, size, matindex, norm=True):
def _rs_dead(sinogram, snr, size, norm=True):
"""remove unresponsive and fluctuating stripes"""
sinogram = cp.copy(sinogram) # Make it mutable
(nrow, _) = sinogram.shape
Expand All @@ -316,14 +325,15 @@ def _rs_dead(sinogram, snr, size, matindex, norm=True):
if len(listxmiss) > 0:
ids = cp.searchsorted(listx, listxmiss)
weights = (listxmiss - listx[ids - 1]) / (listx[ids] - listx[ids - 1])
# direct interpolation without making an extra copy
sinogram[:, listxmiss] = sinogram[:, listx[ids - 1]] + weights * (
sinogram[:, listx[ids]] - sinogram[:, listx[ids - 1]]
)
left_vals = cp.take(sinogram, listx[ids - 1], axis=1)
right_vals = cp.take(sinogram, listx[ids], axis=1)
diff = right_vals - left_vals
diff *= weights
sinogram[:, listxmiss] = left_vals + diff

# Remove residual stripes
if norm is True:
sinogram = _rs_large(sinogram, snr, size, matindex)
sinogram = _rs_large(sinogram, snr, size)
return sinogram


Expand Down Expand Up @@ -416,12 +426,3 @@ def raven_filter(
data = data[pad_y : height - pad_y, :, pad_x : width - pad_x].real

return cp.require(data, requirements="C")


def _create_matindex(nrow, ncol):
"""
Create a 2D array of indexes used for the sorting technique.
"""
listindex = cp.arange(0.0, ncol, 1.0)
matindex = cp.tile(listindex, (nrow, 1))
return matindex.astype(np.float32)
74 changes: 74 additions & 0 deletions remove_all_stripe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import cupy as cp
import numpy as np
import os
import time
from cupy.cuda import memory_hooks
from datetime import datetime
from math import isclose
from cupyx.profiler import time_range

from httomolibgpu.prep.stripe import remove_all_stripe

test_data_path = "/mnt/gpfs03/scratch/data/imaging/tomography/zenodo"
data_path = os.path.join(test_data_path, "synth_tomophantom1.npz")
data_file = np.load(data_path)
projdata = cp.asarray(cp.swapaxes(data_file["projdata"], 0, 1))
angles = cp.asarray(data_file["angles"])

with time_range("all_stripe", color_id=0):
remove_all_stripe(
cp.copy(projdata),
snr=0.1,
la_size=71,
sm_size=31,
dim=1
)


# cold run
remove_all_stripe(
cp.copy(projdata),
snr=0.1,
la_size=71,
sm_size=31,
dim=1,
)

dev = cp.cuda.Device()
dev.synchronize()
start = time.perf_counter_ns()
for _ in range(10):
remove_all_stripe(
cp.copy(projdata),
snr=0.1,
la_size=71,
sm_size=31,
dim=1,
)

dev.synchronize()
duration_ms = float(time.perf_counter_ns() - start) * 1e-6 / 10

print(duration_ms)


output = remove_all_stripe(cp.copy(projdata), snr=0.1, la_size=61, sm_size=21, dim=1)
residual_calc = projdata - output
norm_res = cp.linalg.norm(residual_calc.flatten())
assert isclose(norm_res, 67917.71, abs_tol=10**-2)

output = remove_all_stripe(cp.copy(projdata), snr=0.001, la_size=61, sm_size=21, dim=1)
residual_calc = projdata - output
norm_res = cp.linalg.norm(residual_calc.flatten())
assert isclose(norm_res, 70015.51, abs_tol=10**-2)

hook = memory_hooks.LineProfileHook()
with hook:
remove_all_stripe(
cp.copy(projdata),
snr=0.1,
la_size=71,
sm_size=31,
dim=1
)
hook.print_report()
14 changes: 14 additions & 0 deletions zenodo-tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,20 @@ def geant4_dataset1(geant4_dataset1_file):
)


@pytest.fixture(scope="session")
def synth_tomophantom1_file(test_data_path):
in_file = os.path.join(test_data_path, "synth_tomophantom1.npz")
return np.load(in_file)


@pytest.fixture
def synth_tomophantom1_dataset(synth_tomophantom1_file):
return (
cp.asarray(cp.swapaxes(synth_tomophantom1_file["projdata"], 0, 1)),
synth_tomophantom1_file["angles"],
)


@pytest.fixture
def ensure_clean_memory():
gc.collect()
Expand Down
32 changes: 32 additions & 0 deletions zenodo-tests/test_prep/test_stripe.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,38 @@ def test_remove_all_stripe_i12_dataset4(
assert output.flags.c_contiguous


@pytest.mark.parametrize(
"dataset_fixture, snr_val, la_size_val, sm_size_val, norm_res_expected",
[
("synth_tomophantom1_dataset", 1.0, 61, 21, 53435.61),
("synth_tomophantom1_dataset", 0.1, 61, 21, 67917.71),
("synth_tomophantom1_dataset", 0.001, 61, 21, 70015.51),
],
ids=["snr_1", "snr_2", "snr_3"],
)
def test_remove_all_stripe_synth_tomophantom1_dataset(
request, dataset_fixture, snr_val, la_size_val, sm_size_val, norm_res_expected
):
dataset = request.getfixturevalue(dataset_fixture)
force_clean_gpu_memory()

output = remove_all_stripe(
cp.copy(dataset[0]),
snr=snr_val,
la_size=la_size_val,
sm_size=sm_size_val,
dim=1,
)

residual_calc = dataset[0] - output
norm_res = cp.linalg.norm(residual_calc.flatten())

assert isclose(norm_res, norm_res_expected, abs_tol=10**-2)
assert not np.isnan(output).any(), "Output contains NaN values"
assert output.dtype == np.float32
assert output.flags.c_contiguous


@pytest.mark.parametrize(
"dataset_fixture, nvalue_val, vvalue_val, norm_res_expected",
[
Expand Down
Loading