From 731abe59ab616fa8009b6341d9aabd7e7cf6a2e0 Mon Sep 17 00:00:00 2001 From: Hana Joo Date: Wed, 1 Jul 2026 16:18:52 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 941350306 --- decorators/affine.py | 2 +- decorators/flow.py | 8 ++++---- decorators/maps.py | 4 ++-- flow_field.py | 14 +++++++------- map_utils.py | 30 +++++++++++++++--------------- mesh.py | 6 +++--- processor/flow.py | 12 ++++++------ processor/maps.py | 6 +++--- processor/mesh.py | 24 ++++++++++++------------ processor/warp.py | 26 +++++++++++++------------- stitch_elastic.py | 16 ++++++++-------- stitch_rigid.py | 4 ++-- warp.py | 16 ++++++++-------- 13 files changed, 84 insertions(+), 84 deletions(-) diff --git a/decorators/affine.py b/decorators/affine.py index a0433c8..a0e5ec0 100644 --- a/decorators/affine.py +++ b/decorators/affine.py @@ -157,7 +157,7 @@ def read_fn(domain: ts.IndexDomain, array: np.ndarray, idx = [slice(None) for _ in range(array.ndim)] assert batch_idx is not None - idx[batch_idx] = i + idx[batch_idx] = i # pyrefly: ignore[unsupported-operation] array[tuple(idx)] = transform.reshape(array[tuple(idx)].shape) chunksize = [2, 3] diff --git a/decorators/flow.py b/decorators/flow.py index e78defd..285ad50 100644 --- a/decorators/flow.py +++ b/decorators/flow.py @@ -93,9 +93,9 @@ def _mesh_relax_flow(flow: np.ndarray, **filter_args) -> np.ndarray: num_spatial_dim = flow.shape[0] if num_spatial_dim == 2: - res = sofima.mesh.relax_mesh(x, flow.squeeze(), cfg) + res = sofima.mesh.relax_mesh(x, flow.squeeze(), cfg) # pyrefly: ignore[bad-argument-type] elif num_spatial_dim == 3: - res = sofima.mesh.relax_mesh(x, flow.squeeze(), cfg, + res = sofima.mesh.relax_mesh(x, flow.squeeze(), cfg, # pyrefly: ignore[bad-argument-type] mesh_force=sofima.mesh.elastic_mesh_3d) else: raise ValueError( @@ -297,9 +297,9 @@ def read_fn(domain: ts.IndexDomain, array: np.ndarray, pad_left = np.array(self._patch_zyx) // np.array(self._step_zyx) // 2 pad_width = [(0, 0)] if num_image_dims == 2: - pad_width.append([0, 0]) + pad_width.append([0, 0]) # pyrefly: ignore[bad-argument-type] for left, total in zip(pad_left, pad_total): - pad_width.append([left, total - left]) + pad_width.append([left, total - left]) # pyrefly: ignore[bad-argument-type] array[...] = np.pad( flow_post_to_pre, pad_width, constant_values=np.nan ).reshape(array.shape) diff --git a/decorators/maps.py b/decorators/maps.py index a8751d5..aa1bcf2 100644 --- a/decorators/maps.py +++ b/decorators/maps.py @@ -87,8 +87,8 @@ def warp_fn(domain: ts.IndexDomain, array: np.ndarray, read_domain_coord_map = _adjust_read_domain(domain, coord_map_ts) array[...] = compose_maps( - map1=np.array(input_ts[read_domain_input]).squeeze(), - map2=np.array(coord_map_ts[read_domain_coord_map]).squeeze(), + map1=np.array(input_ts[read_domain_input]).squeeze(), # pyrefly: ignore[bad-argument-type] + map2=np.array(coord_map_ts[read_domain_coord_map]).squeeze(), # pyrefly: ignore[bad-argument-type] **self._compose_args).reshape(array.shape) chunksize = [] diff --git a/flow_field.py b/flow_field.py index da565d1..7669eea 100644 --- a/flow_field.py +++ b/flow_field.py @@ -146,7 +146,7 @@ def fmax(x, v=0.0): if use_jax: out = jnp.clip(out, min=-1, max=1) else: - np.clip(out, min=-1, max=1, out=out) + np.clip(out, min=-1, max=1, out=out) # pyrefly: ignore[no-matching-overload] px_threshold = 0.3 * xnp.max(overlap_masked_px, keepdims=True) if use_jax: @@ -189,7 +189,7 @@ def _peak_stats(peak1_val, peak2_val, peak1_idx, img, offset, peak_radius=5): peak_radius = np.array(peak_radius) size = 2 * peak_radius + 1 start = jnp.asarray(inds) - size // 2 - sharpness = img[inds] / jnp.min(jax.lax.dynamic_slice(img, start, size)) + sharpness = img[inds] / jnp.min(jax.lax.dynamic_slice(img, start, size)) # pyrefly: ignore[bad-argument-type] return jnp.where( jnp.isinf(peak1_val), # @@ -237,7 +237,7 @@ def _batched_peaks( # Apply the maximum filter as a sequence of 1d filters. img_max = img strides = (1,) * dim - for i, s in enumerate(size): + for i, s in enumerate(size): # pyrefly: ignore[unbound-name] patch = [1] * dim patch[i] = s img_max = jnp.max( @@ -358,7 +358,7 @@ def _masked_mean(source, mask): np.array(pre_batch.shape[-len(patch_size) :]) + post_batch.shape[-len(patch_size) :] ) // 2 - 1 - return ( + return ( # pyrefly: ignore[bad-return] center_offset, # pytype: disable=bad-return-type # jax-ndarray masked_xcorr( pre_batch - pre_mean, @@ -433,7 +433,7 @@ def batched_xcorr_peaks( ) peaks = _batched_peaks( xcorr, - center_offset, + center_offset, # pyrefly: ignore[bad-argument-type] min_distance, threshold_rel, # pytype: disable=wrong-arg-types # jax-ndarray peak_radius, @@ -596,8 +596,8 @@ def flow_field( if post_mask is not None: post_mask = jnp.asarray(post_mask) - pre_image = jnp.asarray(pre_image) - post_image = jnp.asarray(post_image) + pre_image = jnp.asarray(pre_image) # pyrefly: ignore[bad-assignment] + post_image = jnp.asarray(post_image) # pyrefly: ignore[bad-assignment] # Offset to add to the starts of the 'prev' patches so that # the 'post' patches remain centered at the same location. diff --git a/map_utils.py b/map_utils.py index bfb081d..5e286b7 100644 --- a/map_utils.py +++ b/map_utils.py @@ -106,7 +106,7 @@ def _interpolate_points( return np.array(ret) assert method in ('linear', 'cubic') - data_points = np.array(data_points).T + data_points = np.array(data_points).T # pyrefly: ignore[bad-assignment] tri = spatial.Delaunay(np.ascontiguousarray(data_points, dtype=np.double)) values = np.array(values).T # [N, dim] @@ -254,7 +254,7 @@ def fill_missing( elif dim == 3: query_coords = np.mgrid[: s[-3], : s[-2], : s[-1]] # zyx - query_points = tuple([q.ravel() for q in query_coords[::-1]]) # xy[z] + query_points = tuple([q.ravel() for q in query_coords[::-1]]) # xy[z] # pyrefly: ignore[unbound-name] rets = [] @@ -339,7 +339,7 @@ def outer_box( start[i] = x_min size[i] = -(int(-x_max) // tl) - x_min + 1 - return bounding_box.BoundingBox(start, size) + return bounding_box.BoundingBox(start, size) # pyrefly: ignore[bad-argument-type] def inner_box( @@ -415,8 +415,8 @@ def invert_map( coord_map = coord_map.astype(np.float64) dim = coord_map.shape[0] stride = _as_vec(stride, dim) - src_box = src_box.adjusted_by(start=-dst_box.start, end=-dst_box.start) - dst_box = dst_box.adjusted_by(start=-dst_box.start, end=-dst_box.start) + src_box = src_box.adjusted_by(start=-dst_box.start, end=-dst_box.start) # pyrefly: ignore[bad-argument-type] + dst_box = dst_box.adjusted_by(start=-dst_box.start, end=-dst_box.start) # pyrefly: ignore[bad-argument-type] coord_map = to_absolute(coord_map, stride, src_box) def _sel_size(box): @@ -646,9 +646,9 @@ def compose_maps_fast( stride1 = _as_vec(stride1, dim) stride2 = _as_vec(stride2, dim) - start1 = jnp.asarray(start1) - start2 = jnp.asarray(start2) - origin = jnp.minimum(start1, start2) + start1 = jnp.asarray(start1) # pyrefly: ignore[bad-assignment] + start2 = jnp.asarray(start2) # pyrefly: ignore[bad-assignment] + origin = jnp.minimum(start1, start2) # pyrefly: ignore[bad-argument-type] def _ref_grid(coord_map, start, stride): start = (start - origin)[-dim:] # yx @@ -675,7 +675,7 @@ def _ref_grid(coord_map, start, stride): xx = ( jax.scipy.ndimage.map_coordinates( map2[0, z, ...] + ref2[-1], - query_coords, + query_coords, # pyrefly: ignore[bad-argument-type] order=1, mode=mode, cval=np.nan, @@ -685,7 +685,7 @@ def _ref_grid(coord_map, start, stride): yy = ( jax.scipy.ndimage.map_coordinates( map2[1, z, ...] + ref2[-2], - query_coords, + query_coords, # pyrefly: ignore[bad-argument-type] order=1, mode=mode, cval=np.nan, @@ -702,7 +702,7 @@ def _ref_grid(coord_map, start, stride): xx = ( jax.scipy.ndimage.map_coordinates( map2[0, ...] + ref2[-1], - query_coords, + query_coords, # pyrefly: ignore[bad-argument-type] order=1, mode=mode, cval=np.nan, @@ -712,7 +712,7 @@ def _ref_grid(coord_map, start, stride): yy = ( jax.scipy.ndimage.map_coordinates( map2[1, ...] + ref2[-2], - query_coords, + query_coords, # pyrefly: ignore[bad-argument-type] order=1, mode=mode, cval=np.nan, @@ -722,7 +722,7 @@ def _ref_grid(coord_map, start, stride): zz = ( jax.scipy.ndimage.map_coordinates( map2[2, ...] + ref2[-3], - query_coords, + query_coords, # pyrefly: ignore[bad-argument-type] order=1, mode=mode, cval=np.nan, @@ -760,7 +760,7 @@ def mask_irregular( """ assert len(coord_map.shape) == 3 assert coord_map.shape[0] == 2 - stride = np.asarray(stride) + stride = np.asarray(stride) # pyrefly: ignore[bad-assignment] if max_frac is None: max_frac = 2 - frac @@ -799,7 +799,7 @@ def make_affine_map( Returns: coordinate map representing the specified affine transform """ - coord_map = np.array(_identity_map_absolute(box.size[::-1], stride)[::-1]) + coord_map = np.array(_identity_map_absolute(box.size[::-1], stride)[::-1]) # pyrefly: ignore[bad-argument-type] coord_map[0, ...] += box.start[0] coord_map[1, ...] += box.start[1] coord_map[2, ...] += box.start[2] diff --git a/mesh.py b/mesh.py index 432f6f1..2b53fde 100644 --- a/mesh.py +++ b/mesh.py @@ -217,7 +217,7 @@ def elastic_mesh_3d( if not isinstance(stride, collections.abc.Sequence): stride = (stride,) * 3 - stride = np.array(stride) + stride = np.array(stride) # pyrefly: ignore[bad-assignment] f_tot = None num_non_spatial = x.ndim - 3 for direction in links: @@ -246,7 +246,7 @@ def elastic_mesh_3d( else: raise ValueError('Only |v| <= 1 values supported within links.') - l0 = np.array(stride * direction, dtype=np.float32).reshape( + l0 = np.array(stride * direction, dtype=np.float32).reshape( # pyrefly: ignore[unsupported-operation] [3] + [1] * (x.ndim - 1) ) dx = x[tuple(sel1)] - x[tuple(sel2)] + l0 @@ -256,7 +256,7 @@ def elastic_mesh_3d( # We want to maintain constant elasticity E and E ~ k⋅l0. # k is specified for the horizontal direction, and so l0 for it is # stride_x. - k_eff = k * stride[0] / l0 + k_eff = k * stride[0] / l0 # pyrefly: ignore[bad-index] if prefer_orig_order: ones = jnp.ones_like(dx[0]) factor = jnp.array([ diff --git a/processor/flow.py b/processor/flow.py index 2f5d369..85b73b3 100644 --- a/processor/flow.py +++ b/processor/flow.py @@ -235,7 +235,7 @@ def _estimate_flow(z_prev, z_curr): # Δz < 0: box.start.z out_box = self.crop_box(box) out_box = bounding_box.BoundingBox( - start=out_box.start // [self._config.stride, self._config.stride, 1], + start=out_box.start // [self._config.stride, self._config.stride, 1], # pyrefly: ignore[bad-argument-type] size=[ret.shape[-1], ret.shape[-2], out_box.size[2]], ) if ret.shape[0] != out_box.size[2]: @@ -272,7 +272,7 @@ def expected_output_box( - self._config.patch_size + self._config.stride ) // self._config.stride - return bounding_box.BoundingBox(start, size) + return bounding_box.BoundingBox(start, size) # pyrefly: ignore[bad-argument-type] # TODO(blakely): Remove references to volinfos in favor of metadata @@ -407,7 +407,7 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany: read_box = box.scale((scale, scale, 1)) if scale < 1: read_box = read_box.adjusted_by( - start=-self._context[0], end=self._context[1] + start=-self._context[0], end=self._context[1] # pyrefly: ignore[bad-argument-type] ) read_box = vol.clip_box_to_volume(read_box) assert read_box is not None @@ -587,7 +587,7 @@ def __init__( ) if config.selection_mask_configs: - config.selection_mask_configs = dataclasses.replace( + config.selection_mask_configs = dataclasses.replace( # pyrefly: ignore[read-only] config, selection_mask_configs=self._get_mask_configs( config.selection_mask_configs @@ -750,7 +750,7 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany: curr_mask = None if self._config.mask_configs: - curr_mask = full_mask[curr_z_idx, ...][curr_slice] + curr_mask = full_mask[curr_z_idx, ...][curr_slice] # pyrefly: ignore[unsupported-operation] if np.all(curr_mask): beam_utils.counter(namespace, 'sections-masked').inc() continue @@ -776,7 +776,7 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany: t1 = time.time() if self._config.mask_configs: - prev_mask = full_mask[prev_z_idx, ...] + prev_mask = full_mask[prev_z_idx, ...] # pyrefly: ignore[unsupported-operation] if np.all(prev_mask): continue diff --git a/processor/maps.py b/processor/maps.py index 9fadfab..0f861be 100644 --- a/processor/maps.py +++ b/processor/maps.py @@ -184,7 +184,7 @@ def _interpolate( block_end_inv = load_main_inv(z1) flat_box = bounding_box.BoundingBox( - start=box.start, size=(box.size[0], box.size[1], 1) + start=box.start, size=(box.size[0], box.size[1], 1) # pyrefly: ignore[bad-argument-type] ) # The interpolation is done so that the first section of the block ends up @@ -438,10 +438,10 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany: dst_box = dst_box.scale([ratio, ratio, 1.0]) out_map = map_utils.resample_map( - rel_map, box, dst_box, config.stride, config.out_stride, config.method + rel_map, box, dst_box, config.stride, config.out_stride, config.method # pyrefly: ignore[bad-argument-type] ) - return [Subvolume(out_map, dst_box)] + return [Subvolume(out_map, dst_box)] # pyrefly: ignore[bad-argument-type] class MaskIrregularities(subvolume_processor.SubvolumeProcessor): diff --git a/processor/mesh.py b/processor/mesh.py index 8f77182..44e8dee 100644 --- a/processor/mesh.py +++ b/processor/mesh.py @@ -199,7 +199,7 @@ def compute_ref_mesh_multiz( ) offset = np.array([0, 0, delta_z]) - ref_box = box.translate(-offset) + ref_box = box.translate(-offset) # pyrefly: ignore[bad-argument-type] logging.info('Attempting to load ref. mesh for %r', ref_box) ref_mesh = self._load_stitched_tile(config.output_dir, ref_box) if ref_mesh is None: @@ -221,11 +221,11 @@ def compute_ref_mesh_multiz( curr_flow = np.array( map_utils.compose_maps_fast( # pytype: disable=wrong-arg-types # jax-ndarray - curr_flow, - box.start[::-1], + curr_flow, # pyrefly: ignore[bad-argument-type] + box.start[::-1], # pyrefly: ignore[bad-argument-type] stride, - ref_mesh, - box.start[::-1], + ref_mesh, # pyrefly: ignore[bad-argument-type] + box.start[::-1], # pyrefly: ignore[bad-argument-type] stride, ) ) @@ -265,11 +265,11 @@ def compute_ref_mesh( flow = np.array( map_utils.compose_maps_fast( # pytype: disable=wrong-arg-types # jax-ndarray - flow, - ref_box.start[::-1], + flow, # pyrefly: ignore[bad-argument-type] + ref_box.start[::-1], # pyrefly: ignore[bad-argument-type] stride, - ref_mesh, - ref_box.start[::-1], + ref_mesh, # pyrefly: ignore[bad-argument-type] + ref_box.start[::-1], # pyrefly: ignore[bad-argument-type] stride, ) ) @@ -347,7 +347,7 @@ def get_prev_state( flow_field = flow_volume[bbox.to_slice4d()] if flow_volume.meta.num_channels == 2: offset = np.array([0, 0, flow.delta_z]) - ref_box = bbox.translate(-offset) + ref_box = bbox.translate(-offset) # pyrefly: ignore[bad-argument-type] ref_mesh = self.compute_ref_mesh(flow_field, ref_box, stride) else: ref_mesh = self.compute_ref_mesh_multiz( @@ -482,8 +482,8 @@ def relax_mesh( start_x = self.maybe_update_init_state(start_x, prev, config.options) x, _, prep_steps = mesh_lib.relax_mesh( # pytype: disable=wrong-arg-types # jax-ndarray - start_x, - x, + start_x, # pyrefly: ignore[bad-argument-type] + x, # pyrefly: ignore[bad-argument-type] dataclasses.replace( integration_config, k0=integration_config.k0 / 10.0 ), diff --git a/processor/warp.py b/processor/warp.py index b42384e..0bcb7f5 100644 --- a/processor/warp.py +++ b/processor/warp.py @@ -118,7 +118,7 @@ def _collect_tile_boxes(self, tile_shape_zyx: ZYX): ) for i in range(tile_meshes.shape[1]): - tx, ty = StitchAndRender3dTiles._tile_idx_to_xy[i] + tx, ty = StitchAndRender3dTiles._tile_idx_to_xy[i] # pyrefly: ignore[unsupported-operation] mesh = tile_meshes[:, i, ...] tg_box = map_utils.outer_box(mesh, map_box, self._stride) @@ -158,7 +158,7 @@ def _get_dts(self, shape: ZYX, tx: int, ty: int) -> np.ndarray: mask[...] = 1 # Compute a (2d) distance transform of the mask, for use in blending. - return edt.edt(mask, black_border=True, parallel=0) + return edt.edt(mask, black_border=True, parallel=0) # pyrefly: ignore[not-callable] def _load_tile_images( self, @@ -188,7 +188,7 @@ def _load_tile_images( logging.info('Processing source %r (%r)', i, out_box) coord_map = tile_meshes[:, i, ...] - tx, ty = StitchAndRender3dTiles._tile_idx_to_xy[i] + tx, ty = StitchAndRender3dTiles._tile_idx_to_xy[i] # pyrefly: ignore[unsupported-operation] if i not in StitchAndRender3dTiles._inverted_meshes: # Add context to avoid rounding issues in map inversion. @@ -266,7 +266,7 @@ def process( if StitchAndRender3dTiles._tile_meshes is None: data_path = self._tile_mesh_path with file.Path(data_path).open('rb') as f: - data = np.load(f, allow_pickle=True) + data = np.load(f, allow_pickle=True) # pyrefly: ignore[bad-argument-type] StitchAndRender3dTiles._tile_idx_to_xy = { v: k for k, v in data['key_to_idx'].item().items() } @@ -278,7 +278,7 @@ def process( volstores = {} for i in range(StitchAndRender3dTiles._tile_meshes.shape[1]): - tile_id = self._key_to_idx[StitchAndRender3dTiles._tile_idx_to_xy[i]] + tile_id = self._key_to_idx[StitchAndRender3dTiles._tile_idx_to_xy[i]] # pyrefly: ignore[unsupported-operation] volstores[i] = self._open_tile_volume(tile_id) # Bounding boxes representing a single tile placed the origin. @@ -340,7 +340,7 @@ def process( # there are some contrast differences. ret = img ret[norm > 0] /= norm[norm > 0] - ret = ret.astype(self.output_type(subvol.data.dtype)) + ret = ret.astype(self.output_type(subvol.data.dtype)) # pyrefly: ignore[bad-argument-type] return self.crop_box_and_data(box, ret[None, ...]) @@ -486,7 +486,7 @@ def _get_map_for_box( map_vol = self._map_volinfo if self._map_decorator_specs: map_vol = metadata.DecoratedVolume( - path=self._map_volinfo, + path=self._map_volinfo, # pyrefly: ignore[bad-argument-type] decorator_specs=json.dumps(self._map_decorator_specs), ) map_vol = self._open_volume(map_vol) @@ -509,7 +509,7 @@ def _generate_boxes_to_warp(self, data_vol, box: bounding_box.BoundingBox): logging.debug('No map found for %r.', box) return - data_box = map_utils.outer_box(rel_map, map_box, self._source_stride, 1) + data_box = map_utils.outer_box(rel_map, map_box, self._source_stride, 1) # pyrefly: ignore[bad-argument-type] data_box = data_vol.clip_box_to_volume(data_box) if data_box is None or np.any(data_box.size == 0): logging.debug('Data out of bounds for map: %r.', map_box) @@ -534,9 +534,9 @@ def _generate_boxes_to_warp(self, data_vol, box: bounding_box.BoundingBox): subvol_size = np.array(list(-(-box.size[:2] // 2)) + [box.size[2]]) subvol_size = -(-subvol_size // self._downsample) * self._downsample - calc = box_generator.BoxGenerator(box, subvol_size, box_overlap=(0, 0, 0)) + calc = box_generator.BoxGenerator(box, subvol_size, box_overlap=(0, 0, 0)) # pyrefly: ignore[bad-argument-type] for sub_box in calc.boxes: - yield from self._generate_boxes_to_warp(data_vol, sub_box) + yield from self._generate_boxes_to_warp(data_vol, sub_box) # pyrefly: ignore[bad-argument-type] def process(self, subvol: subvolume.Subvolume) -> subvolume.SubvolumeOrMany: box = subvol.bbox @@ -545,7 +545,7 @@ def process(self, subvol: subvolume.Subvolume) -> subvolume.SubvolumeOrMany: data_vol = self._data_volinfo if self._data_decorator_specs: data_vol = metadata.DecoratedVolume( - path=self._data_volinfo, + path=self._data_volinfo, # pyrefly: ignore[bad-argument-type] decorator_specs=json.dumps(self._data_decorator_specs), ) data_vol = self._open_volume(data_vol) @@ -577,7 +577,7 @@ def process(self, subvol: subvolume.Subvolume) -> subvolume.SubvolumeOrMany: # Warp data section-wise. for z in range(warped.shape[1]): curr_box = bounding_box.BoundingBox( - start=box.start + [0, 0, z], size=[box.size[0], box.size[1], 1] + start=box.start + [0, 0, z], size=[box.size[0], box.size[1], 1] # pyrefly: ignore[bad-argument-type] ) logging.debug('warping z=%d', z) @@ -612,7 +612,7 @@ def process(self, subvol: subvolume.Subvolume) -> subvolume.SubvolumeOrMany: svt, warp_box, self._downsample, warped.dtype ) downsampled.append(down_data) - write_box = down_box.translate(-box.start) + write_box = down_box.translate(-box.start) # pyrefly: ignore[bad-argument-type] warped[write_box.to_slice4d()] = np.concatenate( downsampled, axis=0 ).astype(warped.dtype) diff --git a/stitch_elastic.py b/stitch_elastic.py index 930b256..3140a7d 100644 --- a/stitch_elastic.py +++ b/stitch_elastic.py @@ -77,8 +77,8 @@ def _relative_intersection( ) -> tuple[bounding_box.BoundingBox, bounding_box.BoundingBox]: ibox = box1.intersection(box2) return ( - bounding_box.BoundingBox(start=ibox.start - box1.start, size=ibox.size), - bounding_box.BoundingBox(start=ibox.start - box2.start, size=ibox.size), + bounding_box.BoundingBox(start=ibox.start - box1.start, size=ibox.size), # pyrefly: ignore[missing-attribute] + bounding_box.BoundingBox(start=ibox.start - box2.start, size=ibox.size), # pyrefly: ignore[missing-attribute] ) @@ -169,10 +169,10 @@ def compute_flow_map3d( diff = s * np.round(isec_nbor.start[ax] / s) - isec_nbor.start[ax] off[ax] = -diff - nbor_box = nbor_box.translate(off) + nbor_box = nbor_box.translate(off) # pyrefly: ignore[bad-argument-type] isec_curr, isec_nbor = _relative_intersection(curr_box, nbor_box) - assert np.all(isec_curr.start % s == 0) + assert np.all(isec_curr.start % s == 0) # pyrefly: ignore[unbound-name] assert np.all(isec_nbor.start % s == 0) offset = np.array(nbor_box.start - curr_box.start) @@ -521,10 +521,10 @@ def _apply_flow( update = map_utils.compose_maps_fast( # pytype: disable=wrong-arg-types # jnp-type nbor_flow_3d, - start, + start, # pyrefly: ignore[bad-argument-type] stride, nbor_mesh_3d, - jnp.zeros_like(start), + jnp.zeros_like(start), # pyrefly: ignore[bad-argument-type] stride, mode='constant', ) @@ -558,8 +558,8 @@ def _apply_flow( if base_mesh.shape[0] == 3: tg_start_z = jnp.where( - ((mult == 1) & (offset_z < 0)) | ((mult == -1) & (offset_z > 0)), - nbor_mesh.shape[-3] - flow_z, + ((mult == 1) & (offset_z < 0)) | ((mult == -1) & (offset_z > 0)), # pyrefly: ignore[unbound-name] + nbor_mesh.shape[-3] - flow_z, # pyrefly: ignore[unbound-name] 0, ) tg_start = (0, tg_start_z) + tg_start[1:] diff --git a/stitch_rigid.py b/stitch_rigid.py index b268b75..d1c4c8f 100644 --- a/stitch_rigid.py +++ b/stitch_rigid.py @@ -57,7 +57,7 @@ def _estimate_offset( # Apply custom overlap masks if masks is not None: a_mask |= masks[0] - b_mask |= masks[1] + b_mask |= masks[1] # pyrefly: ignore[bad-index] mfc = flow_field.JAXMaskedXCorrWithStatsCalculator() xo, yo, _, pr = mfc.flow_field( @@ -211,7 +211,7 @@ def _is_valid_offset(offset, axis): offset = estimates[max_idx] done = True - if not done or abs(offset[axis]) < min_overlap: + if not done or abs(offset[axis]) < min_overlap: # pyrefly: ignore[unbound-name] offset = np.inf, np.inf return offset diff --git a/warp.py b/warp.py index 38ef98c..be17c9c 100644 --- a/warp.py +++ b/warp.py @@ -179,9 +179,9 @@ def _warp_section(z): # Map IDs back to the original space, which might be beyond the range of # int32. if orig_to_low is not None: - warped = _relabel_segmentation(warped, orig_to_low, old_uids) + warped = _relabel_segmentation(warped, orig_to_low, old_uids) # pyrefly: ignore[unbound-name] else: - warped = warped.astype(orig_dtype) + warped = warped.astype(orig_dtype) # pyrefly: ignore[unbound-name] return warped @@ -276,9 +276,9 @@ def ndimage_warp( out_box = bounding_box.BoundingBox(start=(0, 0, 0), size=image_size_xyz) calc = box_generator.BoxGenerator( - outer_box=bounding_box.BoundingBox(start=(0, 0, 0), size=out_box.size), - box_size=work_size, - box_overlap=overlap, + outer_box=bounding_box.BoundingBox(start=(0, 0, 0), size=out_box.size), # pyrefly: ignore[bad-argument-type] + box_size=work_size, # pyrefly: ignore[bad-argument-type] + box_overlap=overlap, # pyrefly: ignore[bad-argument-type] back_shift_small_boxes=True, ) @@ -315,7 +315,7 @@ def _warp_box(i): # Crop and save data for the current subvolume. out_sub_box = calc.index_to_cropped_box(i) - rel_box = out_sub_box.translate(-in_sub_box.start) + rel_box = out_sub_box.translate(-in_sub_box.start) # pyrefly: ignore[bad-argument-type] warped[out_sub_box.to_slice3d()[sub_dim:]] = sub_warped[ rel_box.to_slice3d()[sub_dim:] @@ -330,7 +330,7 @@ def _warp_box(i): f.result() if orig_to_low is not None: - warped = _relabel_segmentation(warped, orig_to_low, old_uids) + warped = _relabel_segmentation(warped, orig_to_low, old_uids) # pyrefly: ignore[unbound-name] return warped.astype(image.dtype) @@ -474,7 +474,7 @@ def _render_tile(tile_x, tile_y, coord_map): 0, )) out_box = bounding_box.BoundingBox( - start=out_box.start, + start=out_box.start, # pyrefly: ignore[bad-argument-type] size=(tg_box.size[0] * stride[1], tg_box.size[1] * stride[0], 1), )