diff --git a/tritonbench/operators/blackwell_attentions/operator.py b/tritonbench/operators/blackwell_attentions/operator.py index 21ddfffd..02da5656 100644 --- a/tritonbench/operators/blackwell_attentions/operator.py +++ b/tritonbench/operators/blackwell_attentions/operator.py @@ -120,7 +120,9 @@ torch.cuda.is_available() and torch.version.cuda and torch.version.cuda >= "12.4" ) -IS_B200 = is_cuda() and "B200" in get_nvidia_gpu_model() +IS_BLACKWELL = is_cuda() and ( + "B200" in get_nvidia_gpu_model() or "B300" in get_nvidia_gpu_model() +) def parse_op_args(args: List[str]): @@ -385,7 +387,7 @@ def xformers_splitk( ) @register_benchmark( - enabled=IS_B200 and _is_sdpa_cudnn_attention_available(), + enabled=IS_BLACKWELL and _is_sdpa_cudnn_attention_available(), label=f"cudnn-sdpa-{torch.backends.cudnn.version()}", ) def cudnn_sdpa(self, q, k, v): @@ -398,7 +400,7 @@ def cudnn_sdpa(self, q, k, v): ) @register_benchmark( - enabled=(IS_B200 and HAS_FLASH_CUTE), label="FAv4", fwd_only=True + enabled=(IS_BLACKWELL and HAS_FLASH_CUTE), label="FAv4", fwd_only=True ) def cutedsl_blackwell( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor