Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/spatialdata/_core/query/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ def get_bounding_box_corners(
return output.squeeze().drop_vars("box")


@nb.jit(parallel=False, nopython=True)
@nb.njit(parallel=False, nopython=True)
def _create_slices_and_translation(
min_values: nb.types.Array,
max_values: nb.types.Array,
) -> tuple[nb.types.Array, nb.types.Array]:
min_values: nb.types.Array[nb.float64, nb.float64],
max_values: nb.types.Array[nb.float64, nb.float64],
) -> tuple[nb.types.Array[nb.float64, nb.float64], nb.types.Array[nb.float64, nb.float64]]:
n_boxes, n_dims = min_values.shape
slices = np.empty((n_boxes, n_dims, 2), dtype=np.float64) # (n_boxes, n_dims, [min, max])
translation_vectors = np.empty((n_boxes, n_dims), dtype=np.float64) # (n_boxes, n_dims)
Expand Down
12 changes: 6 additions & 6 deletions src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,12 +575,12 @@ def _(
max_coordinate = _parse_list_into_array(max_coordinate)

# for triggering validation
_ = BoundingBoxRequest(
target_coordinate_system=target_coordinate_system,
axes=axes,
min_coordinate=min_coordinate,
max_coordinate=max_coordinate,
)
# _ = BoundingBoxRequest(
# target_coordinate_system=target_coordinate_system,
# axes=axes,
# min_coordinate=min_coordinate,
# max_coordinate=max_coordinate,
# )

intrinsic_bounding_box_corners, axes = _get_bounding_box_corners_in_intrinsic_coordinates(
image, axes, min_coordinate, max_coordinate, target_coordinate_system
Expand Down
55 changes: 38 additions & 17 deletions src/spatialdata/dataloader/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def __init__(
from spatialdata import bounding_box_query
from spatialdata._core.operations.rasterize import rasterize as rasterize_fn

self._validate(sdata, regions_to_images, regions_to_coordinate_systems, return_annotations, table_name)
self.sdata = sdata
self._validate(regions_to_images, regions_to_coordinate_systems, return_annotations, table_name)
self._preprocess(tile_scale, tile_dim_in_units, rasterize, table_name)

if rasterize_kwargs is not None and len(rasterize_kwargs) > 0 and rasterize is False:
Expand All @@ -144,21 +145,19 @@ def __init__(
**dict(rasterize_kwargs),
)
if rasterize
else bounding_box_query
else partial(bounding_box_query, return_request_only=True) # type: ignore[assignment]
)
self._return = self._get_return(return_annotations, table_name)
self.transform = transform

def _validate(
self,
sdata: SpatialData,
regions_to_images: dict[str, str],
regions_to_coordinate_systems: dict[str, str],
return_annotations: str | list[str] | None,
table_name: str | None,
) -> None:
"""Validate input parameters."""
self.sdata = sdata
if return_annotations is not None and table_name is None:
raise ValueError("`table_name` must be provided if `return_annotations` is not `None`.")

Expand All @@ -173,8 +172,8 @@ def _validate(
image_name = regions_to_images[region_name]

# get elements
region_elem = sdata[region_name]
image_elem = sdata[image_name]
region_elem = self.sdata[region_name]
image_elem = self.sdata[image_name]

# check that the elements are supported
if get_model(region_elem) == PointsModel:
Expand All @@ -199,13 +198,13 @@ def _validate(
)

if table_name is not None:
_, region_key, instance_key = get_table_keys(sdata.tables[table_name])
_, region_key, instance_key = get_table_keys(self.sdata.tables[table_name])
if get_model(region_elem) in [Labels2DModel, Labels3DModel]:
indices = get_element_instances(region_elem).tolist()
else:
indices = region_elem.index.tolist()
table = sdata.tables[table_name]
if not isinstance(sdata.tables[table_name].obs[region_key].dtype, CategoricalDtype):
table = self.sdata.tables[table_name]
if not isinstance(self.sdata.tables[table_name].obs[region_key].dtype, CategoricalDtype):
raise TypeError(
f"The `regions_element` column `{region_key}` in the table must be a categorical dtype. "
f"Please convert it."
Expand All @@ -228,8 +227,10 @@ def _preprocess(
table_name: str | None,
) -> None:
"""Preprocess the dataset."""
from spatialdata import bounding_box_query

if table_name is not None:
_, region_key, instance_key = get_table_keys(self.sdata.tables[table_name])
_, region_key, _ = get_table_keys(self.sdata.tables[table_name])
filtered_table = self.sdata.tables[table_name][
self.sdata.tables[table_name].obs[region_key].isin(self.regions)
] # filtered table for the data loader
Expand All @@ -249,6 +250,25 @@ def _preprocess(
tile_scale=tile_scale,
tile_dim_in_units=tile_dim_in_units,
)
tile_coords["selection"] = bounding_box_query(
self.sdata[image_name],
("x", "y"),
min_coordinate=tile_coords[["minx", "miny"]].values,
max_coordinate=tile_coords[["maxx", "maxy"]].values,
target_coordinate_system=cs,
return_request_only=True,
)
# tile_coords["selection"] = tile_coords.apply(
# lambda row, cs=cs, image_name=image_name: bounding_box_query(
# self.sdata[image_name],
# ("x", "y"),
# min_coordinate=row[["minx", "miny"]].values,
# max_coordinate=row[["maxx", "maxy"]].values,
# target_coordinate_system=cs,
# return_request_only=True,
# ),
# axis=1,
# )
tile_coords_df.append(tile_coords)

inst = circles.index.values
Expand Down Expand Up @@ -358,13 +378,14 @@ def __getitem__(self, idx: int) -> Any | SpatialData:
t_coords = self.tiles_coords.iloc[idx]

image = self.sdata[row["image"]]
tile = self._crop_image(
image,
axes=tuple(self.dims),
min_coordinate=t_coords[[f"min{i}" for i in self.dims]].values,
max_coordinate=t_coords[[f"max{i}" for i in self.dims]].values,
target_coordinate_system=row["cs"],
)
# tile = self._crop_image(
# image,
# axes=tuple(self.dims),
# min_coordinate=t_coords[[f"min{i}" for i in self.dims]].values,
# max_coordinate=t_coords[[f"max{i}" for i in self.dims]].values,
# target_coordinate_system=row["cs"],
# )
tile = image.sel(t_coords["selection"])
if self.transform is not None:
out = self._return(idx, tile)
return self.transform(out)
Expand Down
Loading