diff --git a/src/spatialdata/dataloader/__init__.py b/src/spatialdata/dataloader/__init__.py index 819ab58e..e6232f7d 100644 --- a/src/spatialdata/dataloader/__init__.py +++ b/src/spatialdata/dataloader/__init__.py @@ -1,4 +1,26 @@ -try: - from spatialdata.dataloader.datasets import ImageTilesDataset -except ImportError: - ImageTilesDataset = None # type: ignore[assignment, misc] +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from spatialdata.dataloader.datasets import ImageTilesDataset as _ImageTilesDataset + + +class ImageTilesDataset: # noqa: D101 + _target_class: type[_ImageTilesDataset] | None = None + + def __new__(cls, *args: Any, **kwargs: Any) -> _ImageTilesDataset: # noqa: D102 + if cls._target_class is None: + try: + from spatialdata.dataloader.datasets import ( + ImageTilesDataset as ActualImageTilesDataset, + ) + + cls._target_class = ActualImageTilesDataset + + except ImportError as error: + raise ImportError( + "ImageTilesDataset could not be imported. This usually means the 'torch' dependency is missing." + ) from error + + return cls._target_class(*args, **kwargs)