Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 73 additions & 2 deletions lightllm/common/triton_utils/autotuner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import triton
import orjson
import os
Expand All @@ -11,8 +12,8 @@
from frozendict import frozendict
from lightllm.utils.device_utils import get_current_device_name
from lightllm.utils.log_utils import init_logger
from typing import Callable, Optional, Union, List
from lightllm.utils.envs_utils import get_triton_autotune_level
from typing import Callable, Optional, Tuple, Union, List
from lightllm.utils.envs_utils import get_env_start_args, get_triton_autotune_level
from lightllm.common.kernel_config import KernelConfigs
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_rank_in_node

Expand Down Expand Up @@ -218,6 +219,76 @@ def _try_load_cache(self, static_key):
logger.info(f"Loading cached configs for {self.kernel_name} - {static_key}")
with open(cache_file, "rb") as f:
self.cached_configs[static_key] = orjson.loads(f.read())
elif get_env_start_args().enable_kernel_config_fallback:

def parse_triton_version_tag(tag: str) -> Optional[Tuple[int, int, int]]:
"""
Parse "triton_X.Y.Z" or "triton_X.Y" to (X, Y, Z), Z defaults to 0.
Returns None if invalid.
"""
match = re.match(r"^triton_(\d+)\.(\d+)(?:\.(\d+))?$", tag)
if not match:
return None
x, y, z = match.groups()
return (int(x), int(y), int(z) if z is not None else 0)

def version_distance(v1: Tuple[int, int, int], v2: Tuple[int, int, int]) -> int:
"""
Compute weighted distance: major * 1e6 + minor * 1e3 + patch
Ensures lexicographic ordering.
"""
return abs((v1[0] - v2[0]) * 1_000_000 + (v1[1] - v2[1]) * 1_000 + (v1[2] - v2[2]))

current_triton_version = get_triton_version()
current_parsed = parse_triton_version_tag(current_triton_version)
if current_parsed is None:
logger.error("Unable to parse current Triton version. Triton may not be installed properly.")
possible_dirs = [
d
for d in os.listdir(os.path.join(Path(__file__).parent, "autotune_kernel_configs"))
if d.startswith("triton_")
]
possible_dirs.sort()
else:
config_dir = os.path.join(Path(__file__).parent, "autotune_kernel_configs")
possible_dirs = []
for d in os.listdir(config_dir):
if not d.startswith("triton_"):
continue
parsed = parse_triton_version_tag(d)
if parsed is not None:
dist = version_distance(parsed, current_parsed)
possible_dirs.append((dist, d, parsed))
else:
logger.debug(f"Skipping invalid version directory: {d}")
possible_dirs.sort(key=lambda x: x[0])
possible_dirs = [d for _, d, _ in possible_dirs]

loaded = False
for triton_version in possible_dirs:
fallback_cache_file = os.path.join(
Path(__file__).parent,
"autotune_kernel_configs",
triton_version,
get_current_device_name(),
self.kernel_name,
KernelConfigs.get_config_file_name(static_key),
)
if os.path.exists(fallback_cache_file):
try:
logger.warning(
f"Fallback loading cached configs for {self.kernel_name} - {static_key} "
f"from triton version {triton_version} (current: {current_triton_version})"
)
with open(fallback_cache_file, "rb") as f:
self.cached_configs[static_key] = orjson.loads(f.read())
loaded = True
break
except Exception as e:
logger.error(f"Failed to load fallback config from {fallback_cache_file}: {e}")

if not loaded:
logger.info(f"No fallback config found for {self.kernel_name} - {static_key}")
return True

def kernel_warmup(self, static_key, *args, **kwargs):
Expand Down
5 changes: 5 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,11 @@ def make_argument_parser() -> argparse.ArgumentParser:
action="store_true",
help="""inference backend will use the fa3 attention kernel for prefill and decode""",
)
parser.add_argument(
"--enable_kernel_config_fallback",
action="store_true",
help="""Whether to enable kernel config fallback when triton version is not compatible.""",
)
parser.add_argument(
"--cache_capacity", type=int, default=200, help="cache server capacity for multimodal resources"
)
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,4 @@ class StartArgs:

# kernel setting
enable_fa3: bool = field(default=False)
enable_kernel_config_fallback: bool = field(default=False)