Skip to content

Commit e9d5b08

Browse files
WanHsuanLinweinbe58Copilot
committed
Subgrid shift (#33)
* add test cases for sub grid shift * add test in test_concrete * add implementation for shift_subgrid_x and shift_subgrid_y * add test for shift_sub_grid_x/y with slice * add implementation for shift_sub_grid_x/y with slice * fix format * Apply suggestion from @Copilot Co-authored-by: Copilot <[email protected]> * Apply suggestion from @Copilot Co-authored-by: Copilot <[email protected]> * Apply suggestion from @Copilot Co-authored-by: Copilot <[email protected]> * fix format --------- Co-authored-by: Phillip Weinberg <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 4f626f6 commit e9d5b08

File tree

7 files changed

+301
-0
lines changed

7 files changed

+301
-0
lines changed

src/bloqade/geometry/dialects/grid/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,7 @@
3131
Scale as Scale,
3232
Shape as Shape,
3333
Shift as Shift,
34+
ShiftSubgridX as ShiftSubgridX,
35+
ShiftSubgridY as ShiftSubgridY,
3436
)
3537
from .types import Grid as Grid, GridType as GridType

src/bloqade/geometry/dialects/grid/_interface.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
Scale,
1818
Shape,
1919
Shift,
20+
ShiftSubgridX,
21+
ShiftSubgridY,
2022
)
2123
from .types import Grid
2224

@@ -231,6 +233,38 @@ def shift(grid: Grid[Nx, Ny], x_shift: float, y_shift: float) -> Grid[Nx, Ny]:
231233
...
232234

233235

236+
@_wraps(ShiftSubgridX)
237+
def shift_subgrid_x(
238+
grid: Grid[Nx, Ny], x_indices: ilist.IList[int, typing.Any], x_shift: float
239+
) -> Grid[Nx, Ny]:
240+
"""Shift a sub grid of grid in the x directions.
241+
242+
Args:
243+
grid (Grid): a grid object
244+
x_indices (ilist.IList[int, typing.Any]): a list/ilist of x indices to shift
245+
x_shift (float): shift in the x direction
246+
Returns:
247+
Grid: a new grid object that has been shifted
248+
"""
249+
...
250+
251+
252+
@_wraps(ShiftSubgridY)
253+
def shift_subgrid_y(
254+
grid: Grid[Nx, Ny], y_indices: ilist.IList[int, typing.Any], y_shift: float
255+
) -> Grid[Nx, Ny]:
256+
"""Shift a sub grid of grid in the y directions.
257+
258+
Args:
259+
grid (Grid): a grid object
260+
y_indices (ilist.IList[int, typing.Any]): a list/ilist of y indices to shift
261+
y_shift (float): shift in the y direction
262+
Returns:
263+
Grid: a new grid object that has been shifted
264+
"""
265+
...
266+
267+
234268
@_wraps(Shape)
235269
def shape(grid: Grid) -> tuple[int, int]:
236270
"""Get the shape of a grid.

src/bloqade/geometry/dialects/grid/concrete.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,32 @@ def shift(
154154

155155
return (grid.shift(x_shift, y_shift),)
156156

157+
@impl(stmts.ShiftSubgridX)
158+
def shift_subgrid_x(
159+
self,
160+
interp: Interpreter,
161+
frame: Frame,
162+
stmt: stmts.ShiftSubgridX,
163+
):
164+
grid = frame.get_casted(stmt.zone, Grid)
165+
x_indices = frame.get_casted(stmt.x_indices, ilist.IList)
166+
x_shift = frame.get_casted(stmt.x_shift, float)
167+
168+
return (grid.shift_subgrid_x(x_indices, x_shift),)
169+
170+
@impl(stmts.ShiftSubgridY)
171+
def shift_subgrid_y(
172+
self,
173+
interp: Interpreter,
174+
frame: Frame,
175+
stmt: stmts.ShiftSubgridY,
176+
):
177+
grid = frame.get_casted(stmt.zone, Grid)
178+
y_indices = frame.get_casted(stmt.y_indices, ilist.IList)
179+
y_shift = frame.get_casted(stmt.y_shift, float)
180+
181+
return (grid.shift_subgrid_y(y_indices, y_shift),)
182+
157183
@impl(stmts.Scale)
158184
def scale(
159185
self,

src/bloqade/geometry/dialects/grid/stmts.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,34 @@ class Shift(ir.Statement):
146146
result: ir.ResultValue = info.result(GridType[NumX, NumY])
147147

148148

149+
@statement(dialect=dialect)
150+
class ShiftSubgridX(ir.Statement):
151+
name = "shift_subgrid_x"
152+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
153+
zone: ir.SSAValue = info.argument(
154+
type=GridType[NumX := types.TypeVar("NumX"), NumY := types.TypeVar("NumY")]
155+
)
156+
x_indices: ir.SSAValue = info.argument(
157+
ilist.IListType[types.Int, types.TypeVar("SubNumX")]
158+
)
159+
x_shift: ir.SSAValue = info.argument(types.Float)
160+
result: ir.ResultValue = info.result(GridType[NumX, NumY])
161+
162+
163+
@statement(dialect=dialect)
164+
class ShiftSubgridY(ir.Statement):
165+
name = "shift_subgrid_y"
166+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
167+
zone: ir.SSAValue = info.argument(
168+
type=GridType[NumX := types.TypeVar("NumX"), NumY := types.TypeVar("NumY")]
169+
)
170+
y_indices: ir.SSAValue = info.argument(
171+
ilist.IListType[types.Int, types.TypeVar("SubNumY")]
172+
)
173+
y_shift: ir.SSAValue = info.argument(types.Float)
174+
result: ir.ResultValue = info.result(GridType[NumX, NumY])
175+
176+
149177
@statement(dialect=dialect)
150178
class Scale(ir.Statement):
151179
name = "scale_grid"

src/bloqade/geometry/dialects/grid/types.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,84 @@ def shift(self, x_shift: float, y_shift: float) -> "Grid[NumX, NumY]":
375375
y_init=self.y_init + y_shift if self.y_init is not None else None,
376376
)
377377

378+
def shift_subgrid_x(
379+
self, x_indices: ilist.IList[int, Nx] | slice, x_shift: float
380+
) -> "Grid[NumX, NumY]":
381+
"""Shift a sub grid of grid in the x directions.
382+
383+
Args:
384+
grid (Grid): a grid object
385+
x_indices (float): a list/ilist of x indices to shift
386+
x_shift (float): shift in the x direction
387+
Returns:
388+
Grid: a new grid object that has been shifted
389+
"""
390+
indices = get_indices(len(self.x_spacing) + 1, x_indices)
391+
392+
def shift_x(index):
393+
new_spacing = self.x_spacing[index]
394+
if index in indices and (index + 1) not in indices:
395+
new_spacing -= x_shift
396+
elif index not in indices and (index + 1) in indices:
397+
new_spacing += x_shift
398+
return new_spacing
399+
400+
new_spacing = tuple(shift_x(i) for i in range(len(self.x_spacing)))
401+
402+
assert all(
403+
x >= 0 for x in new_spacing
404+
), "Invalid shift: column order changes after shift."
405+
406+
x_init = self.x_init
407+
if x_init is not None and 0 in indices:
408+
x_init += x_shift
409+
410+
return Grid(
411+
x_spacing=new_spacing,
412+
y_spacing=self.y_spacing,
413+
x_init=x_init,
414+
y_init=self.y_init,
415+
)
416+
417+
def shift_subgrid_y(
418+
self, y_indices: ilist.IList[int, Ny] | slice, y_shift: float
419+
) -> "Grid[NumX, NumY]":
420+
"""Shift a sub grid of grid in the y directions.
421+
422+
Args:
423+
grid (Grid): a grid object
424+
y_indices (float): a list/ilist of y indices to shift
425+
y_shift (float): shift in the y direction
426+
Returns:
427+
Grid: a new grid object that has been shifted
428+
"""
429+
indices = get_indices(len(self.y_spacing) + 1, y_indices)
430+
431+
def shift_y(index):
432+
new_spacing = self.y_spacing[index]
433+
if index in indices and (index + 1) not in indices:
434+
new_spacing -= y_shift
435+
elif index not in indices and (index + 1) in indices:
436+
new_spacing += y_shift
437+
return new_spacing
438+
439+
new_spacing = tuple(shift_y(i) for i in range(len(self.y_spacing)))
440+
441+
assert all(
442+
y >= 0 for y in new_spacing
443+
), "Invalid shift: row order changes after shift."
444+
445+
y_init = self.y_init
446+
if y_init is not None and 0 in indices:
447+
y_init += y_shift
448+
449+
return Grid(
450+
x_spacing=self.x_spacing,
451+
y_spacing=new_spacing,
452+
x_init=self.x_init,
453+
y_init=y_init,
454+
)
455+
378456
def repeat(
379457
self, x_times: int, y_times: int, x_gap: float, y_gap: float
380458
) -> "Grid[NumX, NumY]":

test/grid/test_concrete.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def test_from_ranges(self):
7979
(grid.GetYPos, "y_positions", ()),
8080
(grid.Get, "get", ((1, 0),)),
8181
(grid.Shift, "shift", (1.0, 2.0)),
82+
(grid.ShiftSubgridX, "shift_subgrid_x", (ilist.IList([0]), -1)),
83+
(grid.ShiftSubgridY, "shift_subgrid_y", (ilist.IList([0]), -1)),
8284
(grid.Scale, "scale", (1.0, 2.0)),
8385
(grid.Repeat, "repeat", (1, 2, 0.5, 1.0)),
8486
(grid.GetSubGrid, "get_view", (ilist.IList((0,)), ilist.IList((1,)))),

test/grid/test_types.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,137 @@ def test_shift(self):
8888
)
8989
assert shifted_grid.is_equal(expected_grid)
9090

91+
@pytest.mark.parametrize(
92+
"x_indices, x_shift, expected_grid",
93+
[
94+
(
95+
ilist.IList([]),
96+
0,
97+
Grid(
98+
x_spacing=(1, 2, 3),
99+
y_spacing=(4, 5),
100+
x_init=1,
101+
y_init=2,
102+
),
103+
),
104+
(
105+
ilist.IList([0, 1]),
106+
1,
107+
Grid(
108+
x_spacing=(1, 1, 3),
109+
y_spacing=(4, 5),
110+
x_init=2,
111+
y_init=2,
112+
),
113+
),
114+
(
115+
ilist.IList([1]),
116+
1,
117+
Grid(
118+
x_spacing=(2, 1, 3),
119+
y_spacing=(4, 5),
120+
x_init=1,
121+
y_init=2,
122+
),
123+
),
124+
(
125+
ilist.IList([1, 2, 3]),
126+
1,
127+
Grid(
128+
x_spacing=(2, 2, 3),
129+
y_spacing=(4, 5),
130+
x_init=1,
131+
y_init=2,
132+
),
133+
),
134+
(
135+
slice(1, 4, 1),
136+
1,
137+
Grid(
138+
x_spacing=(2, 2, 3),
139+
y_spacing=(4, 5),
140+
x_init=1,
141+
y_init=2,
142+
),
143+
),
144+
(ilist.IList([1]), 3, None),
145+
],
146+
)
147+
def test_shift_subgrid_x(self, x_indices, x_shift, expected_grid):
148+
if expected_grid is None:
149+
with pytest.raises(AssertionError):
150+
shifted_grid = self.grid_obj.shift_subgrid_x(x_indices, x_shift)
151+
return
152+
153+
shifted_grid = self.grid_obj.shift_subgrid_x(x_indices, x_shift)
154+
assert shifted_grid.is_equal(expected_grid)
155+
156+
@pytest.mark.parametrize(
157+
"y_indices, y_shift, expected_grid",
158+
[
159+
(
160+
ilist.IList([]),
161+
0,
162+
Grid(
163+
x_spacing=(1, 2, 3),
164+
y_spacing=(4, 5),
165+
x_init=1,
166+
y_init=2,
167+
),
168+
),
169+
(
170+
ilist.IList([0]),
171+
-1,
172+
Grid(
173+
x_spacing=(1, 2, 3),
174+
y_spacing=(5, 5),
175+
x_init=1,
176+
y_init=1,
177+
),
178+
),
179+
(
180+
ilist.IList([1]),
181+
1,
182+
Grid(
183+
x_spacing=(1, 2, 3),
184+
y_spacing=(5, 4),
185+
x_init=1,
186+
y_init=2,
187+
),
188+
),
189+
(
190+
ilist.IList([0, 2]),
191+
1,
192+
Grid(
193+
x_spacing=(1, 2, 3),
194+
y_spacing=(3, 6),
195+
x_init=1,
196+
y_init=3,
197+
),
198+
),
199+
(
200+
slice(0, 1, 1),
201+
-1,
202+
Grid(
203+
x_spacing=(1, 2, 3),
204+
y_spacing=(5, 5),
205+
x_init=1,
206+
y_init=1,
207+
),
208+
),
209+
(ilist.IList([0]), 5, None),
210+
],
211+
)
212+
def test_shift_subgrid_y(self, y_indices, y_shift, expected_grid):
213+
214+
if expected_grid is None:
215+
with pytest.raises(AssertionError):
216+
shifted_grid = self.grid_obj.shift_subgrid_y(y_indices, y_shift)
217+
return
218+
219+
shifted_grid = self.grid_obj.shift_subgrid_y(y_indices, y_shift)
220+
assert shifted_grid.is_equal(expected_grid)
221+
91222
def test_scale(self):
92223
scaled_grid = self.grid_obj.scale(2, 3)
93224
expected_grid = Grid(

0 commit comments

Comments
 (0)