diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 19da8ebffe..4978aafb84 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1486,9 +1486,11 @@ def backward(ctx, d_out, *_args): rest = [None] if ctx.use_FAv2_bwd: softmax_lse, rng_state = aux_ctx_tensors - dq = torch.empty_like(q) - dk = torch.empty_like(k) - dv = torch.empty_like(v) + # THD + CP fix: zeros_like ensures padded positions start from safe values, + # preventing garbage from propagating through backward gradient accumulation. + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) d_out, q, k, v, out = [dpa_utils.maybe_contiguous(x) for x in (d_out, q, k, v, out)] # from transformer_engine.pytorch.attention.dot_product_attention import flash_attn_cuda_bwd flash_attn_cuda_bwd( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 64cccaac6e..f664677f6f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -2044,6 +2044,17 @@ def forward( nvtx_range_pop(f"{nvtx_label}") + # THD CUDA Graph: zero-fill output at padded positions after CP assembly. + # cu_seqlens_q_padded is GLOBAL; divide by cp_size to get local actual_T. + if qkv_format == "thd" and out_ret is not None and hasattr(out_ret, "shape"): + import torch as _torch + + _local_aT = cu_seqlens_q_padded[-1] // cp_size + if out_ret.shape[0] > 0: + _m = _torch.arange(out_ret.shape[0], device=out_ret.device) >= _local_aT + out_ret.data[_m] = 0 + out.data[_m.view(-1, *([1] * (out.dim() - 1))).expand_as(out)] = 0 + if return_max_logit: return out_ret, max_logit return out_ret @@ -2680,10 +2691,17 @@ def backward(ctx, dout, *_args): dim = ctx.qkv_format.index("s") dq, dk, dv = [x.view(*x.shape[:dim], -1, *x.shape[dim + 2 :]) for x in [dq, dk, dv]] + # THD CUDA Graph fix: reading cu_seqlens[-1] as a Python index triggers + # GPU->CPU sync during graph capture. Use .shape[0] instead when capturing. if ctx.qkv_format == "thd" and not ctx.use_fused_attention: - dq[cu_seqlens_q_padded[-1] :].fill_(0) - dk[cu_seqlens_kv_padded[-1] :].fill_(0) - dv[cu_seqlens_kv_padded[-1] :].fill_(0) + if torch.cuda.is_current_stream_capturing(): + _q_end, _kv_end = dq.shape[0], dk.shape[0] + else: + _q_end = cu_seqlens_q_padded[-1] + _kv_end = cu_seqlens_kv_padded[-1] + dq[_q_end:].fill_(0) + dk[_kv_end:].fill_(0) + dv[_kv_end:].fill_(0) if ctx.fp8 and ctx.is_input_fp8: dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) @@ -2731,6 +2749,16 @@ def backward(ctx, dout, *_args): nvtx_range_pop(f"{nvtx_label}") + # THD CUDA Graph: zero-fill dQ/dK/dV at padded positions after CP backward. + if ctx.qkv_format == "thd": + import torch as _torch + + _local_aT_bwd = cu_seqlens_q_padded[-1] // get_distributed_world_size(ctx.cp_group) + for _dg in [dq, dk, dv]: + if _dg is not None and hasattr(_dg, "shape") and _dg.shape[0] > 0: + _mb = _torch.arange(_dg.shape[0], device=_dg.device) >= _local_aT_bwd + _dg[_mb] = 0 + return ( None, dq, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 588c708e10..ef48921849 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1330,13 +1330,10 @@ def forward( # check if there is padding between sequences when qkv_format='thd' if pad_between_seqs is None: if qkv_format == "thd": - pad_between_seqs = ( - cu_seqlens_q_padded is not None - and not torch.equal(cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1]) - ) or ( - cu_seqlens_kv_padded is not None - and not torch.equal(cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1]) - ) + # THD + CUDA Graph fix: torch.equal() triggers GPU->CPU sync, + # which is forbidden during CUDA graph capture. + # pad_between_seqs=True is always safe for THD with padded cu_seqlens. + pad_between_seqs = True else: pad_between_seqs = False diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 06bfb6ef3c..31781cfd3b 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -351,6 +351,15 @@ def fused_attn_fwd( cuda_graph, ) + # THD CUDA Graph: zero-fill output at positions beyond cu_seqlens[-1]. + # Uses pure CUDA ops (no CPU sync) for CUDA graph capture compatibility. + if qkv_layout in ("t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"): + _out = output_tensors[0] + _aT_fwd = cu_seqlens_q[-1] + if _out.shape[0] > 0: + _m_fwd = torch.arange(_out.shape[0], device=_out.device) >= _aT_fwd + _out[_m_fwd] = 0 + if return_max_logit: qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] # thd (newer cuDNN runtimes, non-sm120): output_tensors: out [tq, h, d], Stats [tq, h, 1], Max [tq, h, 1] @@ -607,4 +616,12 @@ def fused_attn_bwd( cuda_graph, ) + # THD CUDA Graph: zero-fill dQ/dK/dV at positions beyond cu_seqlens[-1]. + if qkv_layout in ("t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"): + _aT_bwd = cu_seqlens_q[-1] + for _dt in output_tensors[:3]: + if hasattr(_dt, "shape") and _dt.shape[0] > 0: + _m_bwd = torch.arange(_dt.shape[0], device=_dt.device) >= _aT_bwd + _dt[_m_bwd] = 0 + return output_tensors