Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions micro_sam/prompt_based_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

import torch

from nifty.tools import blocking

from segment_anything.predictor import SamPredictor
from segment_anything.utils.transforms import ResizeLongestSide

Expand Down Expand Up @@ -155,6 +153,7 @@ def _process_box(box, shape, original_size=None, box_extension=0):
# and bring the points to the coordinate system of the tile.
# Discard points that are not in the tile and warn if this happens.
def _points_to_tile(prompts, shape, tile_shape, halo):
from nifty.tools import blocking
points, labels = prompts

tiling = blocking([0, 0], shape, tile_shape)
Expand Down Expand Up @@ -186,6 +185,7 @@ def _points_to_tile(prompts, shape, tile_shape, halo):


def _box_to_tile(box, shape, tile_shape, halo):
from nifty.tools import blocking
tiling = blocking([0, 0], shape, tile_shape)
center = np.array([(box[0] + box[2]) / 2, (box[1] + box[3]) / 2]).round().astype("int").tolist()
tile_id = tiling.coordinatesToBlockId(center)
Expand All @@ -205,6 +205,7 @@ def _box_to_tile(box, shape, tile_shape, halo):


def _mask_to_tile(mask, shape, tile_shape, halo):
from nifty.tools import blocking
tiling = blocking([0, 0], shape, tile_shape)

coords = np.where(mask)
Expand Down
9 changes: 5 additions & 4 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,14 @@
from collections import OrderedDict
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Callable

import elf.parallel as parallel_impl
import imageio.v3 as imageio
import numpy as np
import pooch
import segment_anything.utils.amg as amg_utils
import torch
import vigra
import xxhash
import zarr

from elf.io import open_file
from nifty.tools import blocking
from skimage.measure import regionprops
from skimage.segmentation import relabel_sequential
from torchvision.ops.boxes import batched_nms
Expand Down Expand Up @@ -762,6 +758,7 @@ def _check_mask(tile_id):


def _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask):
from nifty.tools import blocking
tiling = blocking([0, 0], input_.shape[:2], tile_shape)
n_tiles = tiling.numberOfBlocks

Expand Down Expand Up @@ -850,6 +847,7 @@ def __next__(self):


def _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask):
from nifty.tools import blocking
assert input_.ndim == 3

shape = input_.shape[1:]
Expand Down Expand Up @@ -1301,6 +1299,7 @@ def get_centers_and_bounding_boxes(
if mode == "p":
center_coordinates = {prop.label: prop.centroid for prop in properties}
elif mode == "v":
import vigra
center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32'))
center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0}

Expand All @@ -1324,6 +1323,7 @@ def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = F
if key is None:
image_data = imageio.imread(path)
else:
from elf.io import open_file
with open_file(path, mode="r") as f:
image_data = f[key]
if not lazy_loading:
Expand Down Expand Up @@ -1807,6 +1807,7 @@ def require_numpy(mask):
segmentation[this_mask] = this_seg_id
seg_id = this_seg_id + 1

import elf.parallel as parallel_impl
block_shape = (512, 512)
if label_masks:
segmentation_cc = np.zeros_like(segmentation, dtype=segmentation.dtype)
Expand Down
12 changes: 12 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ project_urls =
[options]
packages = find:
python_requires = >=3.10
install_requires =
imageio
numpy
pooch
scikit-image
scipy
segment-anything
torch>=2.5
torchvision
tqdm
xxhash
zarr
include_package_data = True
package_dir =
= .
Expand Down