Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/aitemplate/frontend/nn/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
patch_embed Module.
"""

from typing import Callable, Tuple
from collections.abc import Callable

from aitemplate.compiler import ops
from aitemplate.frontend import Tensor
Expand Down Expand Up @@ -60,9 +60,9 @@ def create_conv_patch_embed(
*,
in_channels: int,
out_channels: int,
conv_kernel_size: Tuple[int] = (1, 16, 16),
conv_stride: Tuple[int] = (1, 4, 4),
conv_padding: Tuple[int] = (1, 7, 7),
conv_kernel_size: tuple[int] = (1, 16, 16),
conv_stride: tuple[int] = (1, 4, 4),
conv_padding: tuple[int] = (1, 7, 7),
conv_bias: bool = True,
conv: Callable = Conv3d,
) -> Module:
Expand Down
5 changes: 2 additions & 3 deletions python/aitemplate/frontend/nn/positional_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""

import logging
from typing import Tuple

from aitemplate.compiler import ops
from aitemplate.compiler.base import IntImm
Expand Down Expand Up @@ -112,7 +111,7 @@ class SpatioTemporalClsPositionalEncoding(Module):
def __init__(
self,
embed_dim: int,
patch_embed_shape: Tuple[int, int, int],
patch_embed_shape: tuple[int, int, int],
sep_pos_embed: bool = False,
has_cls: bool = True,
dtype: str = "float16",
Expand Down Expand Up @@ -168,7 +167,7 @@ def __init__(
self.pos_embed_temporal = Parameter(shape=[], dtype=dtype)
self.pos_embed_class = Parameter(shape=[], dtype=dtype)

def patch_embed_shape(self) -> Tuple[int, int, int]:
def patch_embed_shape(self) -> tuple[int, int, int]:
return self._patch_embed_shape

def forward(self, x: Tensor) -> Tensor:
Expand Down
4 changes: 1 addition & 3 deletions python/aitemplate/frontend/nn/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
softmax Module.
"""

from typing import Optional

from aitemplate.compiler import ops
from aitemplate.frontend.nn.module import Module

Expand Down Expand Up @@ -52,7 +50,7 @@ class Softmax(Module):

def __init__(
self,
dim: Optional[int] = None,
dim: int | None = None,
):
super().__init__()
self.dim = dim
Expand Down
32 changes: 16 additions & 16 deletions python/aitemplate/frontend/nn/vision_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
#

import warnings
from collections.abc import Callable
from functools import partial
from typing import Callable, List, Optional, Tuple, Union

import torch
from pytorchvideo.layers.utils import round_width # usort:skip
Expand Down Expand Up @@ -74,12 +74,12 @@ class MultiscaleVisionTransformers(Module):
def __init__(
self,
*,
patch_embed: Optional[Module],
patch_embed: Module | None,
cls_positional_encoding: Module,
pos_drop: Optional[Module],
pos_drop: Module | None,
blocks: ModuleList,
norm_embed: Optional[Module],
head: Optional[Module],
norm_embed: Module | None,
head: Module | None,
) -> None:
"""
Args:
Expand Down Expand Up @@ -121,7 +121,7 @@ def forward(self, x: Tensor) -> Tensor:

def create_multiscale_vision_transformers(
*,
spatial_size: Union[int, Tuple[int, int]],
spatial_size: int | tuple[int, int],
temporal_size: int,
cls_embed_on: bool = True,
sep_pos_embed: bool = True,
Expand All @@ -131,9 +131,9 @@ def create_multiscale_vision_transformers(
enable_patch_embed: bool = True,
input_channels: int = 3,
patch_embed_dim: int = 96,
conv_patch_embed_kernel: Tuple[int] = (3, 7, 7),
conv_patch_embed_stride: Tuple[int] = (2, 4, 4),
conv_patch_embed_padding: Tuple[int] = (1, 3, 3),
conv_patch_embed_kernel: tuple[int] = (3, 7, 7),
conv_patch_embed_stride: tuple[int] = (2, 4, 4),
conv_patch_embed_padding: tuple[int] = (1, 3, 3),
enable_patch_embed_norm: bool = False,
use_2d_patch: bool = False,
# Attention block config.
Expand All @@ -148,15 +148,15 @@ def create_multiscale_vision_transformers(
depthwise_conv: bool = True,
bias_on: bool = True,
separate_qkv: bool = True,
embed_dim_mul: Optional[List[List[int]]] = None,
atten_head_mul: Optional[List[List[int]]] = None,
embed_dim_mul: list[list[int]] | None = None,
atten_head_mul: list[list[int]] | None = None,
dim_mul_in_att: bool = False,
pool_q_stride_size: Optional[List[List[int]]] = None,
pool_kv_stride_size: Optional[List[List[int]]] = None,
pool_kv_stride_adaptive: Optional[Union[int, Tuple[int, int, int]]] = None,
pool_kvq_kernel: Optional[Union[int, Tuple[int, int, int]]] = None,
pool_q_stride_size: list[list[int]] | None = None,
pool_kv_stride_size: list[list[int]] | None = None,
pool_kv_stride_adaptive: int | tuple[int, int, int] | None = None,
pool_kvq_kernel: int | tuple[int, int, int] | None = None,
# Head config.
head: Optional[Callable] = create_vit_basic_head,
head: Callable | None = create_vit_basic_head,
head_dropout_rate: float = 0.5,
head_activation: Callable = None,
head_num_classes: int = 400,
Expand Down
5 changes: 2 additions & 3 deletions python/aitemplate/testing/benchmark_ait.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
#

from typing import Optional

import torch

Expand Down Expand Up @@ -61,7 +60,7 @@ def run_module_with_pools(
inputs_pool,
outputs_pool,
num_iters,
stream_ptr: Optional[int] = None,
stream_ptr: int | None = None,
sync: bool = False,
graph_mode: bool = False,
):
Expand Down Expand Up @@ -101,7 +100,7 @@ def run_benchmark(
outputs_pool,
num_iters,
num_warmup_iters,
stream: Optional[torch.cuda.Stream] = None,
stream: torch.cuda.Stream | None = None,
sync: bool = False,
graph_mode: bool = False,
):
Expand Down
25 changes: 12 additions & 13 deletions python/aitemplate/testing/jagged_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#
import random
from itertools import product
from typing import List, Tuple

import torch

Expand All @@ -23,7 +22,7 @@


def _check_offsets(
offsets_list: List[List[int]],
offsets_list: list[list[int]],
) -> None:
offsets_len = len(offsets_list[0])
for offsets in offsets_list:
Expand All @@ -36,8 +35,8 @@ def _check_offsets(

def _get_preceding_offset_idx(
idx: int,
offsets: List[int],
) -> Tuple[int, int]:
offsets: list[int],
) -> tuple[int, int]:
result = None
left, right = 0, len(offsets) - 1
while left <= right:
Expand All @@ -54,8 +53,8 @@ def _get_preceding_offset_idx(

def _jagged_idx_to_dense_idx(
jagged_idx: int,
offsets_list: List[List[int]],
) -> List[int]:
offsets_list: list[list[int]],
) -> list[int]:
assert jagged_idx < offsets_list[-1][-1]

result = []
Expand All @@ -73,8 +72,8 @@ def _jagged_idx_to_dense_idx(

def jagged_to_dense(
jagged: torch.Tensor,
offsets_list: List[torch.Tensor],
dense_shape: List[int],
offsets_list: list[torch.Tensor],
dense_shape: list[int],
padding_value: float = 0.0,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -153,8 +152,8 @@ def jagged_to_dense(


def _dense_idx_to_jagged_idx(
dense_idx: List[int],
offsets_list: List[List[int]],
dense_idx: list[int],
offsets_list: list[list[int]],
) -> int:
assert len(dense_idx) == 1 + len(offsets_list)

Expand All @@ -172,7 +171,7 @@ def _dense_idx_to_jagged_idx(

def dense_to_jagged(
dense: torch.Tensor,
offsets_list: List[torch.Tensor],
offsets_list: list[torch.Tensor],
padding_value: float = 0.0,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -375,9 +374,9 @@ def batched_dense_vec_jagged_2d_mul_ref(

def add_jagged_dense_ref(
jagged: torch.Tensor,
offsets_list: List[torch.Tensor],
offsets_list: list[torch.Tensor],
dense: torch.Tensor,
jagged_max_shape: List[int] = None,
jagged_max_shape: list[int] = None,
) -> torch.Tensor:
"""The reference function for jagged / dense elementwise add."""
if jagged_max_shape is None:
Expand Down
4 changes: 2 additions & 2 deletions python/aitemplate/testing/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
"""

import logging
from collections.abc import Callable
from operator import itemgetter
from typing import Callable, List, Tuple

import torch

Expand All @@ -29,7 +29,7 @@ def profile_callable(
func: Callable,
cache_flush_slab: torch.Tensor,
n_iter: int,
) -> Tuple[List[int], List[int]]:
) -> tuple[list[int], list[int]]:
"""
Profile the callable and return the device and wall time for each iteration.
We assume the iterations happen sequentially, not concurrently.
Expand Down
23 changes: 12 additions & 11 deletions python/aitemplate/testing/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
import itertools
import os
import unittest
from collections.abc import Callable
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
from typing import Any

import torch

Expand Down Expand Up @@ -54,7 +55,7 @@ def _SM90_filter(method_name: str) -> bool:
return method_name.endswith("sm90")


_TEST_ENV_TO_FILTER_METHOD: Dict[TestEnv, Callable[[str], bool]] = {
_TEST_ENV_TO_FILTER_METHOD: dict[TestEnv, Callable[[str], bool]] = {
TestEnv.CUDA_LESS_THAN_SM80: (
lambda method_name: not (
_SM80_filter(method_name)
Expand All @@ -71,7 +72,7 @@ def _SM90_filter(method_name: str) -> bool:
# maps each test env (key) to the set of all test envs compatible with
# it (value). "compatible" means that a tests that can run in *any*
# env in the value Set[TestEnv] can also run in the key TestEnv.
_COMPATIBLE_TEST_ENVS: Dict[TestEnv, Set[TestEnv]] = {
_COMPATIBLE_TEST_ENVS: dict[TestEnv, set[TestEnv]] = {
TestEnv.ROCM: {
TestEnv.ROCM,
},
Expand Down Expand Up @@ -122,7 +123,7 @@ def _test_runnable_in_env(test_name: str, env: TestEnv) -> bool:
return False


def filter_test_cases_by_params(params: Dict[TestEnv, List[Tuple[Any]]]):
def filter_test_cases_by_params(params: dict[TestEnv, list[tuple[Any]]]):
"""Filters test cases to run by given params.

The params corresponding to any test env compatible with
Expand All @@ -143,7 +144,7 @@ def filter_test_cases_by_params(params: Dict[TestEnv, List[Tuple[Any]]]):
}


def filter_test_cases_by_test_env(cls: Type[unittest.TestCase]):
def filter_test_cases_by_test_env(cls: type[unittest.TestCase]):
"""Filters test cases to run by test case names implicitly.

The test cases filtered by any test env compatible with
Expand Down Expand Up @@ -204,19 +205,19 @@ def get_torch_full_tensor(shape, fill_value, dtype="float16"):
)


def has_op(sorted_ops: List[Operator], op_name: str) -> bool:
def has_op(sorted_ops: list[Operator], op_name: str) -> bool:
for op in sorted_ops:
op_type = op._attrs["op"]
if op_type == op_name:
return True
return False


def graph_has_op(graph: List[Tensor], op_name: str) -> bool:
def graph_has_op(graph: list[Tensor], op_name: str) -> bool:
return has_op(get_sorted_ops(graph), op_name)


def count_ops(sorted_ops: List[Operator], op_name: str):
def count_ops(sorted_ops: list[Operator], op_name: str):
count = 0
for op in sorted_ops:
op_type = op._attrs["op"]
Expand All @@ -226,7 +227,7 @@ def count_ops(sorted_ops: List[Operator], op_name: str):


def gen_input_tensor(
shape: List[Any], dtype: str = "float16", name: Optional[str] = None
shape: list[Any], dtype: str = "float16", name: str | None = None
) -> Tensor:
tensor = Tensor(
shape=shape,
Expand All @@ -252,7 +253,7 @@ def get_src_input(tensor: Tensor) -> Tensor:
return src_op._attrs["inputs"][0]


def get_shape(shape: List[IntVar], dim_to_value_dict: Dict[str, int]):
def get_shape(shape: list[IntVar], dim_to_value_dict: dict[str, int]):
res = [
(
dim.value()
Expand Down Expand Up @@ -327,7 +328,7 @@ def benchmark_module(
pt_mod: torch.nn.Module,
ait_mod: AITModule,
iters: int = 100,
permute_inputs: Optional[List[int]] = None,
permute_inputs: list[int] | None = None,
):
input_shape = inputs.size()
batch_size = input_shape[0]
Expand Down
Loading