1414#
1515
1616import warnings
17+ from collections .abc import Callable
1718from functools import partial
18- from typing import Callable , List , Optional , Tuple , Union
1919
2020import torch
2121from pytorchvideo .layers .utils import round_width # usort:skip
@@ -74,12 +74,12 @@ class MultiscaleVisionTransformers(Module):
7474 def __init__ (
7575 self ,
7676 * ,
77- patch_embed : Optional [ Module ] ,
77+ patch_embed : Module | None ,
7878 cls_positional_encoding : Module ,
79- pos_drop : Optional [ Module ] ,
79+ pos_drop : Module | None ,
8080 blocks : ModuleList ,
81- norm_embed : Optional [ Module ] ,
82- head : Optional [ Module ] ,
81+ norm_embed : Module | None ,
82+ head : Module | None ,
8383 ) -> None :
8484 """
8585 Args:
@@ -121,7 +121,7 @@ def forward(self, x: Tensor) -> Tensor:
121121
122122def create_multiscale_vision_transformers (
123123 * ,
124- spatial_size : Union [ int , Tuple [int , int ] ],
124+ spatial_size : int | tuple [int , int ],
125125 temporal_size : int ,
126126 cls_embed_on : bool = True ,
127127 sep_pos_embed : bool = True ,
@@ -131,9 +131,9 @@ def create_multiscale_vision_transformers(
131131 enable_patch_embed : bool = True ,
132132 input_channels : int = 3 ,
133133 patch_embed_dim : int = 96 ,
134- conv_patch_embed_kernel : Tuple [int ] = (3 , 7 , 7 ),
135- conv_patch_embed_stride : Tuple [int ] = (2 , 4 , 4 ),
136- conv_patch_embed_padding : Tuple [int ] = (1 , 3 , 3 ),
134+ conv_patch_embed_kernel : tuple [int ] = (3 , 7 , 7 ),
135+ conv_patch_embed_stride : tuple [int ] = (2 , 4 , 4 ),
136+ conv_patch_embed_padding : tuple [int ] = (1 , 3 , 3 ),
137137 enable_patch_embed_norm : bool = False ,
138138 use_2d_patch : bool = False ,
139139 # Attention block config.
@@ -148,15 +148,15 @@ def create_multiscale_vision_transformers(
148148 depthwise_conv : bool = True ,
149149 bias_on : bool = True ,
150150 separate_qkv : bool = True ,
151- embed_dim_mul : Optional [ List [ List [ int ]]] = None ,
152- atten_head_mul : Optional [ List [ List [ int ]]] = None ,
151+ embed_dim_mul : list [ list [ int ]] | None = None ,
152+ atten_head_mul : list [ list [ int ]] | None = None ,
153153 dim_mul_in_att : bool = False ,
154- pool_q_stride_size : Optional [ List [ List [ int ]]] = None ,
155- pool_kv_stride_size : Optional [ List [ List [ int ]]] = None ,
156- pool_kv_stride_adaptive : Optional [ Union [ int , Tuple [int , int , int ]]] = None ,
157- pool_kvq_kernel : Optional [ Union [ int , Tuple [int , int , int ]]] = None ,
154+ pool_q_stride_size : list [ list [ int ]] | None = None ,
155+ pool_kv_stride_size : list [ list [ int ]] | None = None ,
156+ pool_kv_stride_adaptive : int | tuple [int , int , int ] | None = None ,
157+ pool_kvq_kernel : int | tuple [int , int , int ] | None = None ,
158158 # Head config.
159- head : Optional [ Callable ] = create_vit_basic_head ,
159+ head : Callable | None = create_vit_basic_head ,
160160 head_dropout_rate : float = 0.5 ,
161161 head_activation : Callable = None ,
162162 head_num_classes : int = 400 ,
0 commit comments