Skip to content

Commit 1ef500f

Browse files
committed
aitemplate [A] [B] [A] (#1040)
Summary: Pull Request resolved: facebookincubator/AITemplate#1040 Reviewed By: muchulee8 Differential Revision: D81852038
1 parent c695a65 commit 1ef500f

File tree

8 files changed

+51
-55
lines changed

8 files changed

+51
-55
lines changed

python/aitemplate/frontend/nn/patch_embed.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
patch_embed Module.
1717
"""
1818

19-
from typing import Callable, Tuple
19+
from collections.abc import Callable
2020

2121
from aitemplate.compiler import ops
2222
from aitemplate.frontend import Tensor
@@ -60,9 +60,9 @@ def create_conv_patch_embed(
6060
*,
6161
in_channels: int,
6262
out_channels: int,
63-
conv_kernel_size: Tuple[int] = (1, 16, 16),
64-
conv_stride: Tuple[int] = (1, 4, 4),
65-
conv_padding: Tuple[int] = (1, 7, 7),
63+
conv_kernel_size: tuple[int] = (1, 16, 16),
64+
conv_stride: tuple[int] = (1, 4, 4),
65+
conv_padding: tuple[int] = (1, 7, 7),
6666
conv_bias: bool = True,
6767
conv: Callable = Conv3d,
6868
) -> Module:

python/aitemplate/frontend/nn/positional_encoding.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
"""
1818

1919
import logging
20-
from typing import Tuple
2120

2221
from aitemplate.compiler import ops
2322
from aitemplate.compiler.base import IntImm
@@ -112,7 +111,7 @@ class SpatioTemporalClsPositionalEncoding(Module):
112111
def __init__(
113112
self,
114113
embed_dim: int,
115-
patch_embed_shape: Tuple[int, int, int],
114+
patch_embed_shape: tuple[int, int, int],
116115
sep_pos_embed: bool = False,
117116
has_cls: bool = True,
118117
dtype: str = "float16",
@@ -168,7 +167,7 @@ def __init__(
168167
self.pos_embed_temporal = Parameter(shape=[], dtype=dtype)
169168
self.pos_embed_class = Parameter(shape=[], dtype=dtype)
170169

171-
def patch_embed_shape(self) -> Tuple[int, int, int]:
170+
def patch_embed_shape(self) -> tuple[int, int, int]:
172171
return self._patch_embed_shape
173172

174173
def forward(self, x: Tensor) -> Tensor:

python/aitemplate/frontend/nn/softmax.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
softmax Module.
1717
"""
1818

19-
from typing import Optional
20-
2119
from aitemplate.compiler import ops
2220
from aitemplate.frontend.nn.module import Module
2321

@@ -52,7 +50,7 @@ class Softmax(Module):
5250

5351
def __init__(
5452
self,
55-
dim: Optional[int] = None,
53+
dim: int | None = None,
5654
):
5755
super().__init__()
5856
self.dim = dim

python/aitemplate/frontend/nn/vision_transformers.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
#
1515

1616
import warnings
17+
from collections.abc import Callable
1718
from functools import partial
18-
from typing import Callable, List, Optional, Tuple, Union
1919

2020
import torch
2121
from pytorchvideo.layers.utils import round_width # usort:skip
@@ -74,12 +74,12 @@ class MultiscaleVisionTransformers(Module):
7474
def __init__(
7575
self,
7676
*,
77-
patch_embed: Optional[Module],
77+
patch_embed: Module | None,
7878
cls_positional_encoding: Module,
79-
pos_drop: Optional[Module],
79+
pos_drop: Module | None,
8080
blocks: ModuleList,
81-
norm_embed: Optional[Module],
82-
head: Optional[Module],
81+
norm_embed: Module | None,
82+
head: Module | None,
8383
) -> None:
8484
"""
8585
Args:
@@ -121,7 +121,7 @@ def forward(self, x: Tensor) -> Tensor:
121121

122122
def create_multiscale_vision_transformers(
123123
*,
124-
spatial_size: Union[int, Tuple[int, int]],
124+
spatial_size: int | tuple[int, int],
125125
temporal_size: int,
126126
cls_embed_on: bool = True,
127127
sep_pos_embed: bool = True,
@@ -131,9 +131,9 @@ def create_multiscale_vision_transformers(
131131
enable_patch_embed: bool = True,
132132
input_channels: int = 3,
133133
patch_embed_dim: int = 96,
134-
conv_patch_embed_kernel: Tuple[int] = (3, 7, 7),
135-
conv_patch_embed_stride: Tuple[int] = (2, 4, 4),
136-
conv_patch_embed_padding: Tuple[int] = (1, 3, 3),
134+
conv_patch_embed_kernel: tuple[int] = (3, 7, 7),
135+
conv_patch_embed_stride: tuple[int] = (2, 4, 4),
136+
conv_patch_embed_padding: tuple[int] = (1, 3, 3),
137137
enable_patch_embed_norm: bool = False,
138138
use_2d_patch: bool = False,
139139
# Attention block config.
@@ -148,15 +148,15 @@ def create_multiscale_vision_transformers(
148148
depthwise_conv: bool = True,
149149
bias_on: bool = True,
150150
separate_qkv: bool = True,
151-
embed_dim_mul: Optional[List[List[int]]] = None,
152-
atten_head_mul: Optional[List[List[int]]] = None,
151+
embed_dim_mul: list[list[int]] | None = None,
152+
atten_head_mul: list[list[int]] | None = None,
153153
dim_mul_in_att: bool = False,
154-
pool_q_stride_size: Optional[List[List[int]]] = None,
155-
pool_kv_stride_size: Optional[List[List[int]]] = None,
156-
pool_kv_stride_adaptive: Optional[Union[int, Tuple[int, int, int]]] = None,
157-
pool_kvq_kernel: Optional[Union[int, Tuple[int, int, int]]] = None,
154+
pool_q_stride_size: list[list[int]] | None = None,
155+
pool_kv_stride_size: list[list[int]] | None = None,
156+
pool_kv_stride_adaptive: int | tuple[int, int, int] | None = None,
157+
pool_kvq_kernel: int | tuple[int, int, int] | None = None,
158158
# Head config.
159-
head: Optional[Callable] = create_vit_basic_head,
159+
head: Callable | None = create_vit_basic_head,
160160
head_dropout_rate: float = 0.5,
161161
head_activation: Callable = None,
162162
head_num_classes: int = 400,

python/aitemplate/testing/benchmark_ait.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
#
1515

16-
from typing import Optional
1716

1817
import torch
1918

@@ -61,7 +60,7 @@ def run_module_with_pools(
6160
inputs_pool,
6261
outputs_pool,
6362
num_iters,
64-
stream_ptr: Optional[int] = None,
63+
stream_ptr: int | None = None,
6564
sync: bool = False,
6665
graph_mode: bool = False,
6766
):
@@ -101,7 +100,7 @@ def run_benchmark(
101100
outputs_pool,
102101
num_iters,
103102
num_warmup_iters,
104-
stream: Optional[torch.cuda.Stream] = None,
103+
stream: torch.cuda.Stream | None = None,
105104
sync: bool = False,
106105
graph_mode: bool = False,
107106
):

python/aitemplate/testing/jagged_utils.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#
1515
import random
1616
from itertools import product
17-
from typing import List, Tuple
1817

1918
import torch
2019

@@ -23,7 +22,7 @@
2322

2423

2524
def _check_offsets(
26-
offsets_list: List[List[int]],
25+
offsets_list: list[list[int]],
2726
) -> None:
2827
offsets_len = len(offsets_list[0])
2928
for offsets in offsets_list:
@@ -36,8 +35,8 @@ def _check_offsets(
3635

3736
def _get_preceding_offset_idx(
3837
idx: int,
39-
offsets: List[int],
40-
) -> Tuple[int, int]:
38+
offsets: list[int],
39+
) -> tuple[int, int]:
4140
result = None
4241
left, right = 0, len(offsets) - 1
4342
while left <= right:
@@ -54,8 +53,8 @@ def _get_preceding_offset_idx(
5453

5554
def _jagged_idx_to_dense_idx(
5655
jagged_idx: int,
57-
offsets_list: List[List[int]],
58-
) -> List[int]:
56+
offsets_list: list[list[int]],
57+
) -> list[int]:
5958
assert jagged_idx < offsets_list[-1][-1]
6059

6160
result = []
@@ -73,8 +72,8 @@ def _jagged_idx_to_dense_idx(
7372

7473
def jagged_to_dense(
7574
jagged: torch.Tensor,
76-
offsets_list: List[torch.Tensor],
77-
dense_shape: List[int],
75+
offsets_list: list[torch.Tensor],
76+
dense_shape: list[int],
7877
padding_value: float = 0.0,
7978
) -> torch.Tensor:
8079
"""
@@ -153,8 +152,8 @@ def jagged_to_dense(
153152

154153

155154
def _dense_idx_to_jagged_idx(
156-
dense_idx: List[int],
157-
offsets_list: List[List[int]],
155+
dense_idx: list[int],
156+
offsets_list: list[list[int]],
158157
) -> int:
159158
assert len(dense_idx) == 1 + len(offsets_list)
160159

@@ -172,7 +171,7 @@ def _dense_idx_to_jagged_idx(
172171

173172
def dense_to_jagged(
174173
dense: torch.Tensor,
175-
offsets_list: List[torch.Tensor],
174+
offsets_list: list[torch.Tensor],
176175
padding_value: float = 0.0,
177176
) -> torch.Tensor:
178177
"""
@@ -375,9 +374,9 @@ def batched_dense_vec_jagged_2d_mul_ref(
375374

376375
def add_jagged_dense_ref(
377376
jagged: torch.Tensor,
378-
offsets_list: List[torch.Tensor],
377+
offsets_list: list[torch.Tensor],
379378
dense: torch.Tensor,
380-
jagged_max_shape: List[int] = None,
379+
jagged_max_shape: list[int] = None,
381380
) -> torch.Tensor:
382381
"""The reference function for jagged / dense elementwise add."""
383382
if jagged_max_shape is None:

python/aitemplate/testing/profile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
"""
1818

1919
import logging
20+
from collections.abc import Callable
2021
from operator import itemgetter
21-
from typing import Callable, List, Tuple
2222

2323
import torch
2424

@@ -29,7 +29,7 @@ def profile_callable(
2929
func: Callable,
3030
cache_flush_slab: torch.Tensor,
3131
n_iter: int,
32-
) -> Tuple[List[int], List[int]]:
32+
) -> tuple[list[int], list[int]]:
3333
"""
3434
Profile the callable and return the device and wall time for each iteration.
3535
We assume the iterations happen sequentially, not concurrently.

python/aitemplate/testing/test_utils.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
import itertools
2121
import os
2222
import unittest
23+
from collections.abc import Callable
2324
from enum import Enum
24-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
25+
from typing import Any
2526

2627
import torch
2728

@@ -54,7 +55,7 @@ def _SM90_filter(method_name: str) -> bool:
5455
return method_name.endswith("sm90")
5556

5657

57-
_TEST_ENV_TO_FILTER_METHOD: Dict[TestEnv, Callable[[str], bool]] = {
58+
_TEST_ENV_TO_FILTER_METHOD: dict[TestEnv, Callable[[str], bool]] = {
5859
TestEnv.CUDA_LESS_THAN_SM80: (
5960
lambda method_name: not (
6061
_SM80_filter(method_name)
@@ -71,7 +72,7 @@ def _SM90_filter(method_name: str) -> bool:
7172
# maps each test env (key) to the set of all test envs compatible with
7273
# it (value). "compatible" means that a tests that can run in *any*
7374
# env in the value Set[TestEnv] can also run in the key TestEnv.
74-
_COMPATIBLE_TEST_ENVS: Dict[TestEnv, Set[TestEnv]] = {
75+
_COMPATIBLE_TEST_ENVS: dict[TestEnv, set[TestEnv]] = {
7576
TestEnv.ROCM: {
7677
TestEnv.ROCM,
7778
},
@@ -122,7 +123,7 @@ def _test_runnable_in_env(test_name: str, env: TestEnv) -> bool:
122123
return False
123124

124125

125-
def filter_test_cases_by_params(params: Dict[TestEnv, List[Tuple[Any]]]):
126+
def filter_test_cases_by_params(params: dict[TestEnv, list[tuple[Any]]]):
126127
"""Filters test cases to run by given params.
127128
128129
The params corresponding to any test env compatible with
@@ -143,7 +144,7 @@ def filter_test_cases_by_params(params: Dict[TestEnv, List[Tuple[Any]]]):
143144
}
144145

145146

146-
def filter_test_cases_by_test_env(cls: Type[unittest.TestCase]):
147+
def filter_test_cases_by_test_env(cls: type[unittest.TestCase]):
147148
"""Filters test cases to run by test case names implicitly.
148149
149150
The test cases filtered by any test env compatible with
@@ -204,19 +205,19 @@ def get_torch_full_tensor(shape, fill_value, dtype="float16"):
204205
)
205206

206207

207-
def has_op(sorted_ops: List[Operator], op_name: str) -> bool:
208+
def has_op(sorted_ops: list[Operator], op_name: str) -> bool:
208209
for op in sorted_ops:
209210
op_type = op._attrs["op"]
210211
if op_type == op_name:
211212
return True
212213
return False
213214

214215

215-
def graph_has_op(graph: List[Tensor], op_name: str) -> bool:
216+
def graph_has_op(graph: list[Tensor], op_name: str) -> bool:
216217
return has_op(get_sorted_ops(graph), op_name)
217218

218219

219-
def count_ops(sorted_ops: List[Operator], op_name: str):
220+
def count_ops(sorted_ops: list[Operator], op_name: str):
220221
count = 0
221222
for op in sorted_ops:
222223
op_type = op._attrs["op"]
@@ -226,7 +227,7 @@ def count_ops(sorted_ops: List[Operator], op_name: str):
226227

227228

228229
def gen_input_tensor(
229-
shape: List[Any], dtype: str = "float16", name: Optional[str] = None
230+
shape: list[Any], dtype: str = "float16", name: str | None = None
230231
) -> Tensor:
231232
tensor = Tensor(
232233
shape=shape,
@@ -252,7 +253,7 @@ def get_src_input(tensor: Tensor) -> Tensor:
252253
return src_op._attrs["inputs"][0]
253254

254255

255-
def get_shape(shape: List[IntVar], dim_to_value_dict: Dict[str, int]):
256+
def get_shape(shape: list[IntVar], dim_to_value_dict: dict[str, int]):
256257
res = [
257258
(
258259
dim.value()
@@ -327,7 +328,7 @@ def benchmark_module(
327328
pt_mod: torch.nn.Module,
328329
ait_mod: AITModule,
329330
iters: int = 100,
330-
permute_inputs: Optional[List[int]] = None,
331+
permute_inputs: list[int] | None = None,
331332
):
332333
input_shape = inputs.size()
333334
batch_size = input_shape[0]

0 commit comments

Comments
 (0)