Skip to content

Commit 961a53b

Browse files
authored
updating api for grid.get_view to be more flexible (#29)
1 parent 171baff commit 961a53b

File tree

2 files changed

+36
-10
lines changed

2 files changed

+36
-10
lines changed

src/bloqade/geometry/dialects/grid/types.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -205,19 +205,47 @@ def get(self, idx: tuple[int, int]) -> tuple[float, float]:
205205
Nx = TypeVar("Nx")
206206
Ny = TypeVar("Ny")
207207

208+
@overload
208209
def get_view(
209210
self, x_indices: ilist.IList[int, Nx], y_indices: ilist.IList[int, Ny]
210-
) -> "Grid[Nx, Ny]":
211+
) -> "Grid[Nx, Ny]": ...
212+
213+
@overload
214+
def get_view(
215+
self, x_indices: Sequence[int], y_indices: ilist.IList[int, Ny]
216+
) -> "Grid[Any, Ny]": ...
217+
218+
@overload
219+
def get_view(
220+
self, x_indices: ilist.IList[int, Nx], y_indices: Sequence[int]
221+
) -> "Grid[Nx, Any]": ...
222+
223+
@overload
224+
def get_view(
225+
self, x_indices: Sequence[int], y_indices: Sequence[int]
226+
) -> "Grid[Any, Any]": ...
227+
228+
def get_view(self, x_indices, y_indices) -> "Grid":
211229
"""Get a sub-grid view based on the specified x and y indices.
212230
213231
Args:
214-
x_indices (ilist.IList[int, Nx]): The x indices to include in the sub-grid.
215-
y_indices (ilist.IList[int, Ny]): The y indices to include in the sub-grid.
232+
x_indices (Sequence[int]): The x indices to include in the sub-grid.
233+
y_indices (Sequence[int]): The y indices to include in the sub-grid.
216234
217235
Returns:
218-
Grid[Nx, Ny]: The sub-grid view.
236+
Grid: The sub-grid view.
219237
"""
220-
return SubGrid(parent=self, x_indices=x_indices, y_indices=y_indices)
238+
if isinstance(x_indices, ilist.IList):
239+
x_indices = x_indices.data
240+
241+
if isinstance(y_indices, ilist.IList):
242+
y_indices = y_indices.data
243+
244+
return SubGrid(
245+
parent=self,
246+
x_indices=ilist.IList(x_indices),
247+
y_indices=ilist.IList(y_indices),
248+
)
221249

222250
@overload
223251
def __getitem__(
@@ -428,9 +456,7 @@ def __post_init__(self):
428456
types.Literal(len(self.y_indices)),
429457
)
430458

431-
def get_view(
432-
self, x_indices: ilist.IList[int, Any], y_indices: ilist.IList[int, Any]
433-
):
459+
def get_view(self, x_indices, y_indices):
434460
return self.parent.get_view(
435461
x_indices=ilist.IList([self.x_indices[x_index] for x_index in x_indices]),
436462
y_indices=ilist.IList([self.y_indices[y_index] for y_index in y_indices]),

test/grid/test_typeinfer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def test_typeinfer_1():
1313
def test_1(spacing: ilist.IList[float, Literal[2]]):
1414
return grid.new(spacing, [1.0, 2.0], 0.0, 0.0)
1515

16-
assert test_1.return_type.is_equal(
16+
assert test_1.return_type.is_subseteq(
1717
grid.GridType[types.Literal(3), types.Literal(3)]
1818
)
1919

@@ -23,7 +23,7 @@ def test_typeinfer_2():
2323
def test_2(spacing: ilist.IList[float, Any]):
2424
return grid.new(spacing, [1.0, 2.0], 0.0, 0.0)
2525

26-
assert test_2.return_type.is_equal(grid.GridType[types.Any, types.Literal(3)])
26+
assert test_2.return_type.is_subseteq(grid.GridType[types.Any, types.Literal(3)])
2727

2828

2929
def test_typeinfer_get_index_1():

0 commit comments

Comments
 (0)