diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index e21915a5a6..c86cc8cb15 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -6,6 +6,7 @@ from __future__ import annotations import functools +import inspect from importlib.metadata import PackageNotFoundError, version as get_pkg_version from typing import Optional diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 389dfbc838..4cbc5cb620 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -44,6 +44,7 @@ def _cudnn_compute_wgrad( accumulate: bool, wgrad_kernel_fn, single_grouped_weight: bool, + current_stream=None, ): """Compute wgrad using the cuDNN CuTe DSL grouped GEMM wgrad kernel. @@ -88,6 +89,7 @@ def _cudnn_compute_wgrad( wgrad_dtype=wgrad_tensor.dtype, sf_vec_size=MXFP8_BLOCK_SCALING_SIZE, accumulate_on_output=accumulate, + current_stream=current_stream, ) else: # Discrete mode: per-expert wgrad device pointers @@ -104,6 +106,7 @@ def _cudnn_compute_wgrad( wgrad_dtype=wgrad_output[0].dtype, sf_vec_size=MXFP8_BLOCK_SCALING_SIZE, accumulate_on_output=accumulate, + current_stream=current_stream, ) @@ -214,6 +217,7 @@ def _compute_grad_params( accumulate=accumulate_into_main_grad, wgrad_kernel_fn=cudnn_wgrad_kernel_fn, single_grouped_weight=fc_op.single_grouped_weight, + current_stream=torch.cuda.current_stream().cuda_stream, ) else: gemm_fn = functools.partial( @@ -295,11 +299,9 @@ def grouped_gemm_quant_kernel(cls) -> Callable: @functools.lru_cache(maxsize=None) def grouped_gemm_wgrad_kernel(cls) -> Optional[Callable]: """CuTe DSL kernel for grouped GEMM wgrad on SM100+. - Returns ``None`` when the cuDNN front-end package is older than - 1.23.0. + Returns ``None`` when the cuDNN front-end wgrad API is not + available or lacks the required wgrad_tensor/wgrad_ptrs params. """ - if not _nvidia_cudnn_frontend_supports_wgrad(): - return None from cudnn import grouped_gemm_wgrad_wrapper_sm100 # pylint: disable=no-name-in-module return grouped_gemm_wgrad_wrapper_sm100