Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def get_attn_backend_cls(
dtype,
kv_cache_dtype,
block_size,
use_v1,
use_mla,
has_sink,
use_sparse,
Expand Down
41 changes: 30 additions & 11 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import inspect
import os
from collections.abc import Generator
from contextlib import contextmanager
Expand Down Expand Up @@ -141,17 +142,35 @@ def _cached_get_attn_backend(
# get device-specific attn_backend
from vllm.platforms import current_platform

attention_cls = current_platform.get_attn_backend_cls(
selected_backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
True,
use_mla,
has_sink,
use_sparse,
)
sig = inspect.signature(current_platform.get_attn_backend_cls)
if "use_v1" in sig.parameters:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++! This really a nice practice to help plugin smoothly complete interface changes. Let's remove if "use_v1" in sig.parameters in next release

logger.warning_once(
"use_v1 parameter for get_attn_backend_cls is deprecated and will "
"be removed in v0.13.0 or v1.0.0, whichever is soonest. Please "
"remove it from your plugin code."
)
attention_cls = current_platform.get_attn_backend_cls(
selected_backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
True, # use_v1
use_mla,
has_sink,
use_sparse,
)
else:
attention_cls = current_platform.get_attn_backend_cls(
selected_backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
)
if not attention_cls:
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}"
Expand Down
3 changes: 0 additions & 3 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def get_attn_backend_cls(
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
use_v1: bool,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
Expand All @@ -144,8 +143,6 @@ def get_attn_backend_cls(
raise NotImplementedError("MLA is not supported on CPU.")
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on CPU.")
if not use_v1:
raise ValueError("CPU backend only supports V1.")
return AttentionBackendEnum.CPU_ATTN.get_path()

@classmethod
Expand Down
7 changes: 0 additions & 7 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,17 +336,10 @@ def get_attn_backend_cls(
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_v1: bool,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
) -> str:
if not use_v1:
raise RuntimeError(
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
"to select a supported backend."
)

device_capability = cls.get_device_capability()
assert device_capability is not None

Expand Down
1 change: 0 additions & 1 deletion vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def get_attn_backend_cls(
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int,
use_v1: bool,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
Expand Down
7 changes: 0 additions & 7 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ def get_attn_backend_cls(
dtype,
kv_cache_dtype,
block_size,
use_v1,
use_mla,
has_sink,
use_sparse,
Expand All @@ -224,12 +223,6 @@ def get_attn_backend_cls(
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on ROCm.")

if not use_v1:
raise RuntimeError(
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
"to select a supported backend."
)

if use_mla:
if selected_backend is None:
selected_backend = (
Expand Down
3 changes: 0 additions & 3 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def get_attn_backend_cls(
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
use_v1: bool,
use_mla: bool,
has_sink,
use_sparse,
Expand All @@ -70,8 +69,6 @@ def get_attn_backend_cls(
if selected_backend != AttentionBackendEnum.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)

if not use_v1:
raise ValueError("TPU backend only supports V1.")
logger.info("Using Pallas V1 backend.")
return AttentionBackendEnum.PALLAS.get_path()

Expand Down
3 changes: 1 addition & 2 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def get_attn_backend_cls(
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
use_v1: bool,
use_mla: bool,
has_sink: bool,
use_sparse,
Expand Down Expand Up @@ -77,7 +76,7 @@ def get_attn_backend_cls(
elif selected_backend:
raise ValueError(
f"Invalid attention backend for {cls.device_name}, "
f"with use_v1: {use_v1} use_mla: {use_mla}"
f"with use_mla: {use_mla}"
)

logger.info("Using Flash Attention backend.")
Expand Down