Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion decorators/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def read_fn(domain: ts.IndexDomain, array: np.ndarray,

idx = [slice(None) for _ in range(array.ndim)]
assert batch_idx is not None
idx[batch_idx] = i
idx[batch_idx] = i # pyrefly: ignore[unsupported-operation]
array[tuple(idx)] = transform.reshape(array[tuple(idx)].shape)

chunksize = [2, 3]
Expand Down
8 changes: 4 additions & 4 deletions decorators/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ def _mesh_relax_flow(flow: np.ndarray, **filter_args) -> np.ndarray:

num_spatial_dim = flow.shape[0]
if num_spatial_dim == 2:
res = sofima.mesh.relax_mesh(x, flow.squeeze(), cfg)
res = sofima.mesh.relax_mesh(x, flow.squeeze(), cfg) # pyrefly: ignore[bad-argument-type]
elif num_spatial_dim == 3:
res = sofima.mesh.relax_mesh(x, flow.squeeze(), cfg,
res = sofima.mesh.relax_mesh(x, flow.squeeze(), cfg, # pyrefly: ignore[bad-argument-type]
mesh_force=sofima.mesh.elastic_mesh_3d)
else:
raise ValueError(
Expand Down Expand Up @@ -297,9 +297,9 @@ def read_fn(domain: ts.IndexDomain, array: np.ndarray,
pad_left = np.array(self._patch_zyx) // np.array(self._step_zyx) // 2
pad_width = [(0, 0)]
if num_image_dims == 2:
pad_width.append([0, 0])
pad_width.append([0, 0]) # pyrefly: ignore[bad-argument-type]
for left, total in zip(pad_left, pad_total):
pad_width.append([left, total - left])
pad_width.append([left, total - left]) # pyrefly: ignore[bad-argument-type]
array[...] = np.pad(
flow_post_to_pre, pad_width, constant_values=np.nan
).reshape(array.shape)
Expand Down
4 changes: 2 additions & 2 deletions decorators/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def warp_fn(domain: ts.IndexDomain, array: np.ndarray,
read_domain_coord_map = _adjust_read_domain(domain, coord_map_ts)

array[...] = compose_maps(
map1=np.array(input_ts[read_domain_input]).squeeze(),
map2=np.array(coord_map_ts[read_domain_coord_map]).squeeze(),
map1=np.array(input_ts[read_domain_input]).squeeze(), # pyrefly: ignore[bad-argument-type]
map2=np.array(coord_map_ts[read_domain_coord_map]).squeeze(), # pyrefly: ignore[bad-argument-type]
**self._compose_args).reshape(array.shape)

chunksize = []
Expand Down
14 changes: 7 additions & 7 deletions flow_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def fmax(x, v=0.0):
if use_jax:
out = jnp.clip(out, min=-1, max=1)
else:
np.clip(out, min=-1, max=1, out=out)
np.clip(out, min=-1, max=1, out=out) # pyrefly: ignore[no-matching-overload]

px_threshold = 0.3 * xnp.max(overlap_masked_px, keepdims=True)
if use_jax:
Expand Down Expand Up @@ -189,7 +189,7 @@ def _peak_stats(peak1_val, peak2_val, peak1_idx, img, offset, peak_radius=5):
peak_radius = np.array(peak_radius)
size = 2 * peak_radius + 1
start = jnp.asarray(inds) - size // 2
sharpness = img[inds] / jnp.min(jax.lax.dynamic_slice(img, start, size))
sharpness = img[inds] / jnp.min(jax.lax.dynamic_slice(img, start, size)) # pyrefly: ignore[bad-argument-type]

return jnp.where(
jnp.isinf(peak1_val), #
Expand Down Expand Up @@ -237,7 +237,7 @@ def _batched_peaks(
# Apply the maximum filter as a sequence of 1d filters.
img_max = img
strides = (1,) * dim
for i, s in enumerate(size):
for i, s in enumerate(size): # pyrefly: ignore[unbound-name]
patch = [1] * dim
patch[i] = s
img_max = jnp.max(
Expand Down Expand Up @@ -358,7 +358,7 @@ def _masked_mean(source, mask):
np.array(pre_batch.shape[-len(patch_size) :])
+ post_batch.shape[-len(patch_size) :]
) // 2 - 1
return (
return ( # pyrefly: ignore[bad-return]
center_offset, # pytype: disable=bad-return-type # jax-ndarray
masked_xcorr(
pre_batch - pre_mean,
Expand Down Expand Up @@ -433,7 +433,7 @@ def batched_xcorr_peaks(
)
peaks = _batched_peaks(
xcorr,
center_offset,
center_offset, # pyrefly: ignore[bad-argument-type]
min_distance,
threshold_rel, # pytype: disable=wrong-arg-types # jax-ndarray
peak_radius,
Expand Down Expand Up @@ -596,8 +596,8 @@ def flow_field(
if post_mask is not None:
post_mask = jnp.asarray(post_mask)

pre_image = jnp.asarray(pre_image)
post_image = jnp.asarray(post_image)
pre_image = jnp.asarray(pre_image) # pyrefly: ignore[bad-assignment]
post_image = jnp.asarray(post_image) # pyrefly: ignore[bad-assignment]

# Offset to add to the starts of the 'prev' patches so that
# the 'post' patches remain centered at the same location.
Expand Down
30 changes: 15 additions & 15 deletions map_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _interpolate_points(
return np.array(ret)

assert method in ('linear', 'cubic')
data_points = np.array(data_points).T
data_points = np.array(data_points).T # pyrefly: ignore[bad-assignment]
tri = spatial.Delaunay(np.ascontiguousarray(data_points, dtype=np.double))

values = np.array(values).T # [N, dim]
Expand Down Expand Up @@ -254,7 +254,7 @@ def fill_missing(
elif dim == 3:
query_coords = np.mgrid[: s[-3], : s[-2], : s[-1]] # zyx

query_points = tuple([q.ravel() for q in query_coords[::-1]]) # xy[z]
query_points = tuple([q.ravel() for q in query_coords[::-1]]) # xy[z] # pyrefly: ignore[unbound-name]

rets = []

Expand Down Expand Up @@ -339,7 +339,7 @@ def outer_box(
start[i] = x_min
size[i] = -(int(-x_max) // tl) - x_min + 1

return bounding_box.BoundingBox(start, size)
return bounding_box.BoundingBox(start, size) # pyrefly: ignore[bad-argument-type]


def inner_box(
Expand Down Expand Up @@ -415,8 +415,8 @@ def invert_map(
coord_map = coord_map.astype(np.float64)
dim = coord_map.shape[0]
stride = _as_vec(stride, dim)
src_box = src_box.adjusted_by(start=-dst_box.start, end=-dst_box.start)
dst_box = dst_box.adjusted_by(start=-dst_box.start, end=-dst_box.start)
src_box = src_box.adjusted_by(start=-dst_box.start, end=-dst_box.start) # pyrefly: ignore[bad-argument-type]
dst_box = dst_box.adjusted_by(start=-dst_box.start, end=-dst_box.start) # pyrefly: ignore[bad-argument-type]
coord_map = to_absolute(coord_map, stride, src_box)

def _sel_size(box):
Expand Down Expand Up @@ -646,9 +646,9 @@ def compose_maps_fast(

stride1 = _as_vec(stride1, dim)
stride2 = _as_vec(stride2, dim)
start1 = jnp.asarray(start1)
start2 = jnp.asarray(start2)
origin = jnp.minimum(start1, start2)
start1 = jnp.asarray(start1) # pyrefly: ignore[bad-assignment]
start2 = jnp.asarray(start2) # pyrefly: ignore[bad-assignment]
origin = jnp.minimum(start1, start2) # pyrefly: ignore[bad-argument-type]

def _ref_grid(coord_map, start, stride):
start = (start - origin)[-dim:] # yx
Expand All @@ -675,7 +675,7 @@ def _ref_grid(coord_map, start, stride):
xx = (
jax.scipy.ndimage.map_coordinates(
map2[0, z, ...] + ref2[-1],
query_coords,
query_coords, # pyrefly: ignore[bad-argument-type]
order=1,
mode=mode,
cval=np.nan,
Expand All @@ -685,7 +685,7 @@ def _ref_grid(coord_map, start, stride):
yy = (
jax.scipy.ndimage.map_coordinates(
map2[1, z, ...] + ref2[-2],
query_coords,
query_coords, # pyrefly: ignore[bad-argument-type]
order=1,
mode=mode,
cval=np.nan,
Expand All @@ -702,7 +702,7 @@ def _ref_grid(coord_map, start, stride):
xx = (
jax.scipy.ndimage.map_coordinates(
map2[0, ...] + ref2[-1],
query_coords,
query_coords, # pyrefly: ignore[bad-argument-type]
order=1,
mode=mode,
cval=np.nan,
Expand All @@ -712,7 +712,7 @@ def _ref_grid(coord_map, start, stride):
yy = (
jax.scipy.ndimage.map_coordinates(
map2[1, ...] + ref2[-2],
query_coords,
query_coords, # pyrefly: ignore[bad-argument-type]
order=1,
mode=mode,
cval=np.nan,
Expand All @@ -722,7 +722,7 @@ def _ref_grid(coord_map, start, stride):
zz = (
jax.scipy.ndimage.map_coordinates(
map2[2, ...] + ref2[-3],
query_coords,
query_coords, # pyrefly: ignore[bad-argument-type]
order=1,
mode=mode,
cval=np.nan,
Expand Down Expand Up @@ -760,7 +760,7 @@ def mask_irregular(
"""
assert len(coord_map.shape) == 3
assert coord_map.shape[0] == 2
stride = np.asarray(stride)
stride = np.asarray(stride) # pyrefly: ignore[bad-assignment]

if max_frac is None:
max_frac = 2 - frac
Expand Down Expand Up @@ -799,7 +799,7 @@ def make_affine_map(
Returns:
coordinate map representing the specified affine transform
"""
coord_map = np.array(_identity_map_absolute(box.size[::-1], stride)[::-1])
coord_map = np.array(_identity_map_absolute(box.size[::-1], stride)[::-1]) # pyrefly: ignore[bad-argument-type]
coord_map[0, ...] += box.start[0]
coord_map[1, ...] += box.start[1]
coord_map[2, ...] += box.start[2]
Expand Down
6 changes: 3 additions & 3 deletions mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def elastic_mesh_3d(
if not isinstance(stride, collections.abc.Sequence):
stride = (stride,) * 3

stride = np.array(stride)
stride = np.array(stride) # pyrefly: ignore[bad-assignment]
f_tot = None
num_non_spatial = x.ndim - 3
for direction in links:
Expand Down Expand Up @@ -246,7 +246,7 @@ def elastic_mesh_3d(
else:
raise ValueError('Only |v| <= 1 values supported within links.')

l0 = np.array(stride * direction, dtype=np.float32).reshape(
l0 = np.array(stride * direction, dtype=np.float32).reshape( # pyrefly: ignore[unsupported-operation]
[3] + [1] * (x.ndim - 1)
)
dx = x[tuple(sel1)] - x[tuple(sel2)] + l0
Expand All @@ -256,7 +256,7 @@ def elastic_mesh_3d(
# We want to maintain constant elasticity E and E ~ k⋅l0.
# k is specified for the horizontal direction, and so l0 for it is
# stride_x.
k_eff = k * stride[0] / l0
k_eff = k * stride[0] / l0 # pyrefly: ignore[bad-index]
if prefer_orig_order:
ones = jnp.ones_like(dx[0])
factor = jnp.array([
Expand Down
12 changes: 6 additions & 6 deletions processor/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _estimate_flow(z_prev, z_curr):
# Δz < 0: box.start.z
out_box = self.crop_box(box)
out_box = bounding_box.BoundingBox(
start=out_box.start // [self._config.stride, self._config.stride, 1],
start=out_box.start // [self._config.stride, self._config.stride, 1], # pyrefly: ignore[bad-argument-type]
size=[ret.shape[-1], ret.shape[-2], out_box.size[2]],
)
if ret.shape[0] != out_box.size[2]:
Expand Down Expand Up @@ -272,7 +272,7 @@ def expected_output_box(
- self._config.patch_size
+ self._config.stride
) // self._config.stride
return bounding_box.BoundingBox(start, size)
return bounding_box.BoundingBox(start, size) # pyrefly: ignore[bad-argument-type]


# TODO(blakely): Remove references to volinfos in favor of metadata
Expand Down Expand Up @@ -407,7 +407,7 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany:
read_box = box.scale((scale, scale, 1))
if scale < 1:
read_box = read_box.adjusted_by(
start=-self._context[0], end=self._context[1]
start=-self._context[0], end=self._context[1] # pyrefly: ignore[bad-argument-type]
)
read_box = vol.clip_box_to_volume(read_box)
assert read_box is not None
Expand Down Expand Up @@ -587,7 +587,7 @@ def __init__(
)

if config.selection_mask_configs:
config.selection_mask_configs = dataclasses.replace(
config.selection_mask_configs = dataclasses.replace( # pyrefly: ignore[read-only]
config,
selection_mask_configs=self._get_mask_configs(
config.selection_mask_configs
Expand Down Expand Up @@ -750,7 +750,7 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany:

curr_mask = None
if self._config.mask_configs:
curr_mask = full_mask[curr_z_idx, ...][curr_slice]
curr_mask = full_mask[curr_z_idx, ...][curr_slice] # pyrefly: ignore[unsupported-operation]
if np.all(curr_mask):
beam_utils.counter(namespace, 'sections-masked').inc()
continue
Expand All @@ -776,7 +776,7 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany:
t1 = time.time()

if self._config.mask_configs:
prev_mask = full_mask[prev_z_idx, ...]
prev_mask = full_mask[prev_z_idx, ...] # pyrefly: ignore[unsupported-operation]
if np.all(prev_mask):
continue

Expand Down
6 changes: 3 additions & 3 deletions processor/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _interpolate(
block_end_inv = load_main_inv(z1)

flat_box = bounding_box.BoundingBox(
start=box.start, size=(box.size[0], box.size[1], 1)
start=box.start, size=(box.size[0], box.size[1], 1) # pyrefly: ignore[bad-argument-type]
)

# The interpolation is done so that the first section of the block ends up
Expand Down Expand Up @@ -438,10 +438,10 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany:
dst_box = dst_box.scale([ratio, ratio, 1.0])

out_map = map_utils.resample_map(
rel_map, box, dst_box, config.stride, config.out_stride, config.method
rel_map, box, dst_box, config.stride, config.out_stride, config.method # pyrefly: ignore[bad-argument-type]
)

return [Subvolume(out_map, dst_box)]
return [Subvolume(out_map, dst_box)] # pyrefly: ignore[bad-argument-type]


class MaskIrregularities(subvolume_processor.SubvolumeProcessor):
Expand Down
24 changes: 12 additions & 12 deletions processor/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def compute_ref_mesh_multiz(
)

offset = np.array([0, 0, delta_z])
ref_box = box.translate(-offset)
ref_box = box.translate(-offset) # pyrefly: ignore[bad-argument-type]
logging.info('Attempting to load ref. mesh for %r', ref_box)
ref_mesh = self._load_stitched_tile(config.output_dir, ref_box)
if ref_mesh is None:
Expand All @@ -221,11 +221,11 @@ def compute_ref_mesh_multiz(

curr_flow = np.array(
map_utils.compose_maps_fast( # pytype: disable=wrong-arg-types # jax-ndarray
curr_flow,
box.start[::-1],
curr_flow, # pyrefly: ignore[bad-argument-type]
box.start[::-1], # pyrefly: ignore[bad-argument-type]
stride,
ref_mesh,
box.start[::-1],
ref_mesh, # pyrefly: ignore[bad-argument-type]
box.start[::-1], # pyrefly: ignore[bad-argument-type]
stride,
)
)
Expand Down Expand Up @@ -265,11 +265,11 @@ def compute_ref_mesh(

flow = np.array(
map_utils.compose_maps_fast( # pytype: disable=wrong-arg-types # jax-ndarray
flow,
ref_box.start[::-1],
flow, # pyrefly: ignore[bad-argument-type]
ref_box.start[::-1], # pyrefly: ignore[bad-argument-type]
stride,
ref_mesh,
ref_box.start[::-1],
ref_mesh, # pyrefly: ignore[bad-argument-type]
ref_box.start[::-1], # pyrefly: ignore[bad-argument-type]
stride,
)
)
Expand Down Expand Up @@ -347,7 +347,7 @@ def get_prev_state(
flow_field = flow_volume[bbox.to_slice4d()]
if flow_volume.meta.num_channels == 2:
offset = np.array([0, 0, flow.delta_z])
ref_box = bbox.translate(-offset)
ref_box = bbox.translate(-offset) # pyrefly: ignore[bad-argument-type]
ref_mesh = self.compute_ref_mesh(flow_field, ref_box, stride)
else:
ref_mesh = self.compute_ref_mesh_multiz(
Expand Down Expand Up @@ -482,8 +482,8 @@ def relax_mesh(
start_x = self.maybe_update_init_state(start_x, prev, config.options)

x, _, prep_steps = mesh_lib.relax_mesh( # pytype: disable=wrong-arg-types # jax-ndarray
start_x,
x,
start_x, # pyrefly: ignore[bad-argument-type]
x, # pyrefly: ignore[bad-argument-type]
dataclasses.replace(
integration_config, k0=integration_config.k0 / 10.0
),
Expand Down
Loading
Loading