-
Notifications
You must be signed in to change notification settings - Fork 14
[WIP] Improve downsampling performance #1406
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
0dfd7be
6ef0074
c94b527
a169ea4
e09484a
a03dd68
2c45246
e74fff6
9a17837
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,12 @@ | ||
| import logging | ||
| import math | ||
| import math # | ||
| import warnings | ||
| from collections.abc import Callable | ||
| from enum import Enum | ||
| from itertools import product | ||
| from typing import TYPE_CHECKING, Union | ||
|
|
||
| import numba | ||
| import numpy as np | ||
| from scipy.ndimage import zoom | ||
|
|
||
|
|
@@ -32,10 +33,10 @@ class InterpolationModes(Enum): | |
|
|
||
| def determine_downsample_buffer_shape(array_info: ArrayInfo) -> Vec3Int: | ||
| # This is the shape of the data in the downsampling target magnification, so the | ||
| # data that is read is up to 512³ vx in the source magnification. Using larger | ||
| # data that is read is up to 1024³ vx in the source magnification. Using larger | ||
| # shapes uses a lot of RAM, especially for segmentation layers which use the mode filter. | ||
| # See https://scm.slack.com/archives/CMBMU5684/p1749771929954699 for more context. | ||
| return Vec3Int.full(256).pairmin(array_info.shard_shape) | ||
| return Vec3Int.full(512).pairmin(array_info.shard_shape) | ||
|
Comment on lines
+36
to
+39
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
|
||
|
|
||
| def determine_upsample_buffer_shape(array_info: ArrayInfo) -> Vec3Int: | ||
|
|
@@ -262,6 +263,32 @@ def _mode(x: np.ndarray) -> np.ndarray: | |
| return sort[tuple(index)] | ||
|
|
||
|
|
||
| @numba.jit(nopython=True, nogil=True) | ||
| def fast_mode(input_array: np.ndarray) -> np.ndarray: | ||
| values = np.zeros(input_array.shape[0], dtype=input_array.dtype) | ||
| counter = np.zeros(input_array.shape[0], dtype=np.uint8) | ||
| output_array = np.zeros(input_array.shape[1], dtype=input_array.dtype) | ||
| for row_index in range(input_array.shape[1]): | ||
| values[0] = input_array[0, row_index] | ||
| counter[:] = 0 | ||
| value_offset = 1 | ||
| for col_index in range(1, input_array.shape[0]): | ||
| value = input_array[col_index, row_index] | ||
| found_value = False | ||
| for i in range(col_index): # iterate one less | ||
| if value == values[i]: | ||
| counter[i] = counter[i] + 1 | ||
| found_value = True | ||
| break | ||
| if not found_value: | ||
| values[value_offset] = value | ||
| value_offset += 1 | ||
| mode = values[np.argmax(counter)] | ||
| output_array[row_index] = mode | ||
|
|
||
| return output_array | ||
|
Comment on lines
+266
to
+289
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
|
||
|
|
||
| def downsample_unpadded_data( | ||
| buffer: np.ndarray, target_mag: Mag, interpolation_mode: InterpolationModes | ||
| ) -> np.ndarray: | ||
|
|
@@ -288,7 +315,7 @@ def downsample_cube( | |
| cube_buffer: np.ndarray, factors: list[int], interpolation_mode: InterpolationModes | ||
| ) -> np.ndarray: | ||
| if interpolation_mode == InterpolationModes.MODE: | ||
| return non_linear_filter_3d(cube_buffer, factors, _mode) | ||
| return non_linear_filter_3d(cube_buffer, factors, fast_mode) | ||
| elif interpolation_mode == InterpolationModes.MEDIAN: | ||
| return non_linear_filter_3d(cube_buffer, factors, _median) | ||
| elif interpolation_mode == InterpolationModes.NEAREST: | ||
|
|
@@ -318,7 +345,7 @@ def downsample_cube_job( | |
| target_bbox_in_mag = target_view.bounding_box.in_mag(target_view.mag) | ||
| shape = (num_channels,) + target_bbox_in_mag.size.to_tuple() | ||
| shape_xyz = target_bbox_in_mag.size_xyz | ||
| file_buffer = np.zeros(shape, target_view.get_dtype()) | ||
| file_buffer = np.zeros(shape, target_view.get_dtype(), order="F") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Specifying |
||
|
|
||
| tiles = product( | ||
| *( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1161,19 +1161,23 @@ def downsample_mag( | |
| # perform downsampling | ||
| with get_executor_for_args(None, executor) as executor: | ||
| if buffer_shape is None: | ||
| buffer_shape = determine_downsample_buffer_shape(prev_mag_view.info) | ||
| buffer_shape = determine_downsample_buffer_shape(target_view.info) | ||
| func = named_partial( | ||
| downsample_cube_job, | ||
| mag_factors=mag_factors, | ||
| interpolation_mode=parsed_interpolation_mode, | ||
| buffer_shape=buffer_shape, | ||
| ) | ||
|
|
||
| target_chunk_shape = Vec3Int([1024, 1024, 512]).pairmax( | ||
| target_view.info.shard_shape | ||
| ) | ||
| source_view.for_zipped_chunks( | ||
| # this view is restricted to the bounding box specified in the properties | ||
| func, | ||
| target_view=target_view, | ||
| executor=executor, | ||
| source_chunk_shape=target_chunk_shape * target_mag.to_np(), | ||
| target_chunk_shape=target_chunk_shape * target_mag.to_np(), | ||
|
Comment on lines
+1171
to
+1180
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In |
||
| progress_desc=f"Downsampling layer {self.name} from Mag {from_mag} to Mag {target_mag}", | ||
| ) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous implementation of
_try_open_zarrattempted to open withzarr3first and then fell back tozarr(v2) ifzarr3failed. The current change removes this fallback, exclusively tryingzarr3. If the system is expected to handle olderzarr(v2) datasets, this change could lead toTensorStoreErrorfor those files. Please confirm ifzarr(v2) support is intentionally being dropped or if a fallback mechanism is still required.