diff --git a/src/spatialdata/_core/query/_utils.py b/src/spatialdata/_core/query/_utils.py index cb2c57b18..c8fba0196 100644 --- a/src/spatialdata/_core/query/_utils.py +++ b/src/spatialdata/_core/query/_utils.py @@ -95,11 +95,11 @@ def get_bounding_box_corners( return output.squeeze().drop_vars("box") -@nb.jit(parallel=False, nopython=True) +@nb.njit(parallel=False, nopython=True) def _create_slices_and_translation( - min_values: nb.types.Array, - max_values: nb.types.Array, -) -> tuple[nb.types.Array, nb.types.Array]: + min_values: nb.types.Array[nb.float64, nb.float64], + max_values: nb.types.Array[nb.float64, nb.float64], +) -> tuple[nb.types.Array[nb.float64, nb.float64], nb.types.Array[nb.float64, nb.float64]]: n_boxes, n_dims = min_values.shape slices = np.empty((n_boxes, n_dims, 2), dtype=np.float64) # (n_boxes, n_dims, [min, max]) translation_vectors = np.empty((n_boxes, n_dims), dtype=np.float64) # (n_boxes, n_dims) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index a8a8fc251..7fef1c4b7 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -575,12 +575,12 @@ def _( max_coordinate = _parse_list_into_array(max_coordinate) # for triggering validation - _ = BoundingBoxRequest( - target_coordinate_system=target_coordinate_system, - axes=axes, - min_coordinate=min_coordinate, - max_coordinate=max_coordinate, - ) + # _ = BoundingBoxRequest( + # target_coordinate_system=target_coordinate_system, + # axes=axes, + # min_coordinate=min_coordinate, + # max_coordinate=max_coordinate, + # ) intrinsic_bounding_box_corners, axes = _get_bounding_box_corners_in_intrinsic_coordinates( image, axes, min_coordinate, max_coordinate, target_coordinate_system diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index bbefdf34a..bb5dd4530 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -127,7 +127,8 @@ def __init__( from spatialdata import bounding_box_query from spatialdata._core.operations.rasterize import rasterize as rasterize_fn - self._validate(sdata, regions_to_images, regions_to_coordinate_systems, return_annotations, table_name) + self.sdata = sdata + self._validate(regions_to_images, regions_to_coordinate_systems, return_annotations, table_name) self._preprocess(tile_scale, tile_dim_in_units, rasterize, table_name) if rasterize_kwargs is not None and len(rasterize_kwargs) > 0 and rasterize is False: @@ -144,21 +145,19 @@ def __init__( **dict(rasterize_kwargs), ) if rasterize - else bounding_box_query + else partial(bounding_box_query, return_request_only=True) # type: ignore[assignment] ) self._return = self._get_return(return_annotations, table_name) self.transform = transform def _validate( self, - sdata: SpatialData, regions_to_images: dict[str, str], regions_to_coordinate_systems: dict[str, str], return_annotations: str | list[str] | None, table_name: str | None, ) -> None: """Validate input parameters.""" - self.sdata = sdata if return_annotations is not None and table_name is None: raise ValueError("`table_name` must be provided if `return_annotations` is not `None`.") @@ -173,8 +172,8 @@ def _validate( image_name = regions_to_images[region_name] # get elements - region_elem = sdata[region_name] - image_elem = sdata[image_name] + region_elem = self.sdata[region_name] + image_elem = self.sdata[image_name] # check that the elements are supported if get_model(region_elem) == PointsModel: @@ -199,13 +198,13 @@ def _validate( ) if table_name is not None: - _, region_key, instance_key = get_table_keys(sdata.tables[table_name]) + _, region_key, instance_key = get_table_keys(self.sdata.tables[table_name]) if get_model(region_elem) in [Labels2DModel, Labels3DModel]: indices = get_element_instances(region_elem).tolist() else: indices = region_elem.index.tolist() - table = sdata.tables[table_name] - if not isinstance(sdata.tables[table_name].obs[region_key].dtype, CategoricalDtype): + table = self.sdata.tables[table_name] + if not isinstance(self.sdata.tables[table_name].obs[region_key].dtype, CategoricalDtype): raise TypeError( f"The `regions_element` column `{region_key}` in the table must be a categorical dtype. " f"Please convert it." @@ -228,8 +227,10 @@ def _preprocess( table_name: str | None, ) -> None: """Preprocess the dataset.""" + from spatialdata import bounding_box_query + if table_name is not None: - _, region_key, instance_key = get_table_keys(self.sdata.tables[table_name]) + _, region_key, _ = get_table_keys(self.sdata.tables[table_name]) filtered_table = self.sdata.tables[table_name][ self.sdata.tables[table_name].obs[region_key].isin(self.regions) ] # filtered table for the data loader @@ -249,6 +250,25 @@ def _preprocess( tile_scale=tile_scale, tile_dim_in_units=tile_dim_in_units, ) + tile_coords["selection"] = bounding_box_query( + self.sdata[image_name], + ("x", "y"), + min_coordinate=tile_coords[["minx", "miny"]].values, + max_coordinate=tile_coords[["maxx", "maxy"]].values, + target_coordinate_system=cs, + return_request_only=True, + ) + # tile_coords["selection"] = tile_coords.apply( + # lambda row, cs=cs, image_name=image_name: bounding_box_query( + # self.sdata[image_name], + # ("x", "y"), + # min_coordinate=row[["minx", "miny"]].values, + # max_coordinate=row[["maxx", "maxy"]].values, + # target_coordinate_system=cs, + # return_request_only=True, + # ), + # axis=1, + # ) tile_coords_df.append(tile_coords) inst = circles.index.values @@ -358,13 +378,14 @@ def __getitem__(self, idx: int) -> Any | SpatialData: t_coords = self.tiles_coords.iloc[idx] image = self.sdata[row["image"]] - tile = self._crop_image( - image, - axes=tuple(self.dims), - min_coordinate=t_coords[[f"min{i}" for i in self.dims]].values, - max_coordinate=t_coords[[f"max{i}" for i in self.dims]].values, - target_coordinate_system=row["cs"], - ) + # tile = self._crop_image( + # image, + # axes=tuple(self.dims), + # min_coordinate=t_coords[[f"min{i}" for i in self.dims]].values, + # max_coordinate=t_coords[[f"max{i}" for i in self.dims]].values, + # target_coordinate_system=row["cs"], + # ) + tile = image.sel(t_coords["selection"]) if self.transform is not None: out = self._return(idx, tile) return self.transform(out)