Skip to content

Commit 858d990

Browse files
authored
make index typing less strict (#42)
* make index typing less strict * IList is Sequence so just use Sequence check * make more generic * Adding test to check error
1 parent 366b61d commit 858d990

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

src/bloqade/geometry/dialects/grid/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from ._dialect import dialect as dialect
22
from ._interface import (
3+
col_ypos as col_ypos,
34
from_positions as from_positions,
45
get as get,
56
get_xpos as get_xpos,
67
get_ypos as get_ypos,
78
new as new,
89
positions as positions,
910
repeat as repeat,
11+
row_xpos as row_xpos,
1012
scale as scale,
1113
shape as shape,
1214
shift as shift,

src/bloqade/geometry/dialects/grid/types.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@ def get_indices(size: int, index: Any) -> ilist.IList[int, Any]:
2222
raise IndexError("Index out of range")
2323

2424
return ilist.IList([index])
25-
elif isinstance(index, ilist.IList):
26-
return index
27-
else:
28-
raise TypeError("Index must be an int, slice, or IList")
25+
26+
index = ilist.IList(list(index))
27+
if any(not isinstance(i, int) for i in index.data):
28+
raise TypeError("Index must be an int, slice, or Sequence of ints")
29+
30+
return index
2931

3032

3133
@dataclasses.dataclass

test/grid/test_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,10 @@ def test_shift_subgrid_y(self, y_indices, y_shift, expected_grid):
219219
shifted_grid = self.grid_obj.shift_subgrid_y(y_indices, y_shift)
220220
assert shifted_grid.is_equal(expected_grid)
221221

222+
def test_invalid_slice(self):
223+
with pytest.raises(TypeError):
224+
self.grid_obj[[1.5], ilist.IList([0, 1])] # type: ignore
225+
222226
def test_scale(self):
223227
scaled_grid = self.grid_obj.scale(2, 3)
224228
expected_grid = Grid(

0 commit comments

Comments
 (0)