diff --git a/src/bloqade/geometry/dialects/grid/__init__.py b/src/bloqade/geometry/dialects/grid/__init__.py index 9915d37..69b0d46 100644 --- a/src/bloqade/geometry/dialects/grid/__init__.py +++ b/src/bloqade/geometry/dialects/grid/__init__.py @@ -31,5 +31,7 @@ Scale as Scale, Shape as Shape, Shift as Shift, + ShiftSubgridX as ShiftSubgridX, + ShiftSubgridY as ShiftSubgridY, ) from .types import Grid as Grid, GridType as GridType diff --git a/src/bloqade/geometry/dialects/grid/_interface.py b/src/bloqade/geometry/dialects/grid/_interface.py index e00f847..9e48680 100644 --- a/src/bloqade/geometry/dialects/grid/_interface.py +++ b/src/bloqade/geometry/dialects/grid/_interface.py @@ -17,6 +17,8 @@ Scale, Shape, Shift, + ShiftSubgridX, + ShiftSubgridY, ) from .types import Grid @@ -231,6 +233,38 @@ def shift(grid: Grid[Nx, Ny], x_shift: float, y_shift: float) -> Grid[Nx, Ny]: ... +@_wraps(ShiftSubgridX) +def shift_subgrid_x( + grid: Grid[Nx, Ny], x_indices: ilist.IList[int, typing.Any], x_shift: float +) -> Grid[Nx, Ny]: + """Shift a sub grid of grid in the x directions. + + Args: + grid (Grid): a grid object + x_indices (ilist.IList[int, typing.Any]): a list/ilist of x indices to shift + x_shift (float): shift in the x direction + Returns: + Grid: a new grid object that has been shifted + """ + ... + + +@_wraps(ShiftSubgridY) +def shift_subgrid_y( + grid: Grid[Nx, Ny], y_indices: ilist.IList[int, typing.Any], y_shift: float +) -> Grid[Nx, Ny]: + """Shift a sub grid of grid in the y directions. + + Args: + grid (Grid): a grid object + y_indices (ilist.IList[int, typing.Any]): a list/ilist of y indices to shift + y_shift (float): shift in the y direction + Returns: + Grid: a new grid object that has been shifted + """ + ... + + @_wraps(Shape) def shape(grid: Grid) -> tuple[int, int]: """Get the shape of a grid. diff --git a/src/bloqade/geometry/dialects/grid/concrete.py b/src/bloqade/geometry/dialects/grid/concrete.py index dddcc28..8b0bfb3 100644 --- a/src/bloqade/geometry/dialects/grid/concrete.py +++ b/src/bloqade/geometry/dialects/grid/concrete.py @@ -154,6 +154,32 @@ def shift( return (grid.shift(x_shift, y_shift),) + @impl(stmts.ShiftSubgridX) + def shift_subgrid_x( + self, + interp: Interpreter, + frame: Frame, + stmt: stmts.ShiftSubgridX, + ): + grid = frame.get_casted(stmt.zone, Grid) + x_indices = frame.get_casted(stmt.x_indices, ilist.IList) + x_shift = frame.get_casted(stmt.x_shift, float) + + return (grid.shift_subgrid_x(x_indices, x_shift),) + + @impl(stmts.ShiftSubgridY) + def shift_subgrid_y( + self, + interp: Interpreter, + frame: Frame, + stmt: stmts.ShiftSubgridY, + ): + grid = frame.get_casted(stmt.zone, Grid) + y_indices = frame.get_casted(stmt.y_indices, ilist.IList) + y_shift = frame.get_casted(stmt.y_shift, float) + + return (grid.shift_subgrid_y(y_indices, y_shift),) + @impl(stmts.Scale) def scale( self, diff --git a/src/bloqade/geometry/dialects/grid/stmts.py b/src/bloqade/geometry/dialects/grid/stmts.py index f5f1a00..4fdf2ae 100644 --- a/src/bloqade/geometry/dialects/grid/stmts.py +++ b/src/bloqade/geometry/dialects/grid/stmts.py @@ -146,6 +146,34 @@ class Shift(ir.Statement): result: ir.ResultValue = info.result(GridType[NumX, NumY]) +@statement(dialect=dialect) +class ShiftSubgridX(ir.Statement): + name = "shift_subgrid_x" + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) + zone: ir.SSAValue = info.argument( + type=GridType[NumX := types.TypeVar("NumX"), NumY := types.TypeVar("NumY")] + ) + x_indices: ir.SSAValue = info.argument( + ilist.IListType[types.Int, types.TypeVar("SubNumX")] + ) + x_shift: ir.SSAValue = info.argument(types.Float) + result: ir.ResultValue = info.result(GridType[NumX, NumY]) + + +@statement(dialect=dialect) +class ShiftSubgridY(ir.Statement): + name = "shift_subgrid_y" + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) + zone: ir.SSAValue = info.argument( + type=GridType[NumX := types.TypeVar("NumX"), NumY := types.TypeVar("NumY")] + ) + y_indices: ir.SSAValue = info.argument( + ilist.IListType[types.Int, types.TypeVar("SubNumY")] + ) + y_shift: ir.SSAValue = info.argument(types.Float) + result: ir.ResultValue = info.result(GridType[NumX, NumY]) + + @statement(dialect=dialect) class Scale(ir.Statement): name = "scale_grid" diff --git a/src/bloqade/geometry/dialects/grid/types.py b/src/bloqade/geometry/dialects/grid/types.py index b985784..8a2c0f3 100644 --- a/src/bloqade/geometry/dialects/grid/types.py +++ b/src/bloqade/geometry/dialects/grid/types.py @@ -377,6 +377,84 @@ def shift(self, x_shift: float, y_shift: float) -> "Grid[NumX, NumY]": y_init=self.y_init + y_shift if self.y_init is not None else None, ) + def shift_subgrid_x( + self, x_indices: ilist.IList[int, Nx] | slice, x_shift: float + ) -> "Grid[NumX, NumY]": + """Shift a sub grid of grid in the x directions. + + Args: + grid (Grid): a grid object + x_indices (float): a list/ilist of x indices to shift + x_shift (float): shift in the x direction + Returns: + Grid: a new grid object that has been shifted + """ + indices = get_indices(len(self.x_spacing) + 1, x_indices) + + def shift_x(index): + new_spacing = self.x_spacing[index] + if index in indices and (index + 1) not in indices: + new_spacing -= x_shift + elif index not in indices and (index + 1) in indices: + new_spacing += x_shift + return new_spacing + + new_spacing = tuple(shift_x(i) for i in range(len(self.x_spacing))) + + assert all( + x >= 0 for x in new_spacing + ), "Invalid shift: column order changes after shift." + + x_init = self.x_init + if x_init is not None and 0 in indices: + x_init += x_shift + + return Grid( + x_spacing=new_spacing, + y_spacing=self.y_spacing, + x_init=x_init, + y_init=self.y_init, + ) + + def shift_subgrid_y( + self, y_indices: ilist.IList[int, Ny] | slice, y_shift: float + ) -> "Grid[NumX, NumY]": + """Shift a sub grid of grid in the y directions. + + Args: + grid (Grid): a grid object + y_indices (float): a list/ilist of y indices to shift + y_shift (float): shift in the y direction + Returns: + Grid: a new grid object that has been shifted + """ + indices = get_indices(len(self.y_spacing) + 1, y_indices) + + def shift_y(index): + new_spacing = self.y_spacing[index] + if index in indices and (index + 1) not in indices: + new_spacing -= y_shift + elif index not in indices and (index + 1) in indices: + new_spacing += y_shift + return new_spacing + + new_spacing = tuple(shift_y(i) for i in range(len(self.y_spacing))) + + assert all( + y >= 0 for y in new_spacing + ), "Invalid shift: row order changes after shift." + + y_init = self.y_init + if y_init is not None and 0 in indices: + y_init += y_shift + + return Grid( + x_spacing=self.x_spacing, + y_spacing=new_spacing, + x_init=self.x_init, + y_init=y_init, + ) + def repeat( self, x_times: int, y_times: int, x_gap: float, y_gap: float ) -> "Grid[NumX, NumY]": diff --git a/test/grid/test_concrete.py b/test/grid/test_concrete.py index fb0ba36..0619ea5 100644 --- a/test/grid/test_concrete.py +++ b/test/grid/test_concrete.py @@ -81,6 +81,8 @@ def test_from_ranges(self): (grid.GetYPos, "y_positions", ()), (grid.Get, "get", ((1, 0),)), (grid.Shift, "shift", (1.0, 2.0)), + (grid.ShiftSubgridX, "shift_subgrid_x", (ilist.IList([0]), -1)), + (grid.ShiftSubgridY, "shift_subgrid_y", (ilist.IList([0]), -1)), (grid.Scale, "scale", (1.0, 2.0)), (grid.Repeat, "repeat", (1, 2, 0.5, 1.0)), (grid.GetSubGrid, "get_view", (ilist.IList((0,)), ilist.IList((1,)))), diff --git a/test/grid/test_types.py b/test/grid/test_types.py index d1f9a04..0161a2c 100644 --- a/test/grid/test_types.py +++ b/test/grid/test_types.py @@ -88,6 +88,137 @@ def test_shift(self): ) assert shifted_grid.is_equal(expected_grid) + @pytest.mark.parametrize( + "x_indices, x_shift, expected_grid", + [ + ( + ilist.IList([]), + 0, + Grid( + x_spacing=(1, 2, 3), + y_spacing=(4, 5), + x_init=1, + y_init=2, + ), + ), + ( + ilist.IList([0, 1]), + 1, + Grid( + x_spacing=(1, 1, 3), + y_spacing=(4, 5), + x_init=2, + y_init=2, + ), + ), + ( + ilist.IList([1]), + 1, + Grid( + x_spacing=(2, 1, 3), + y_spacing=(4, 5), + x_init=1, + y_init=2, + ), + ), + ( + ilist.IList([1, 2, 3]), + 1, + Grid( + x_spacing=(2, 2, 3), + y_spacing=(4, 5), + x_init=1, + y_init=2, + ), + ), + ( + slice(1, 4, 1), + 1, + Grid( + x_spacing=(2, 2, 3), + y_spacing=(4, 5), + x_init=1, + y_init=2, + ), + ), + (ilist.IList([1]), 3, None), + ], + ) + def test_shift_subgrid_x(self, x_indices, x_shift, expected_grid): + if expected_grid is None: + with pytest.raises(AssertionError): + shifted_grid = self.grid_obj.shift_subgrid_x(x_indices, x_shift) + return + + shifted_grid = self.grid_obj.shift_subgrid_x(x_indices, x_shift) + assert shifted_grid.is_equal(expected_grid) + + @pytest.mark.parametrize( + "y_indices, y_shift, expected_grid", + [ + ( + ilist.IList([]), + 0, + Grid( + x_spacing=(1, 2, 3), + y_spacing=(4, 5), + x_init=1, + y_init=2, + ), + ), + ( + ilist.IList([0]), + -1, + Grid( + x_spacing=(1, 2, 3), + y_spacing=(5, 5), + x_init=1, + y_init=1, + ), + ), + ( + ilist.IList([1]), + 1, + Grid( + x_spacing=(1, 2, 3), + y_spacing=(5, 4), + x_init=1, + y_init=2, + ), + ), + ( + ilist.IList([0, 2]), + 1, + Grid( + x_spacing=(1, 2, 3), + y_spacing=(3, 6), + x_init=1, + y_init=3, + ), + ), + ( + slice(0, 1, 1), + -1, + Grid( + x_spacing=(1, 2, 3), + y_spacing=(5, 5), + x_init=1, + y_init=1, + ), + ), + (ilist.IList([0]), 5, None), + ], + ) + def test_shift_subgrid_y(self, y_indices, y_shift, expected_grid): + + if expected_grid is None: + with pytest.raises(AssertionError): + shifted_grid = self.grid_obj.shift_subgrid_y(y_indices, y_shift) + return + + shifted_grid = self.grid_obj.shift_subgrid_y(y_indices, y_shift) + assert shifted_grid.is_equal(expected_grid) + def test_scale(self): scaled_grid = self.grid_obj.scale(2, 3) expected_grid = Grid(