Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions python/sglang/srt/model_loader/weight_validation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import os
import re
Expand Down Expand Up @@ -36,6 +37,61 @@ def _validate_safetensors_file(file_path: str) -> bool:
return False


def _check_index_files_exist(snapshot_dir: str) -> Tuple[bool, Optional[str]]:
"""
Check if all files listed in safetensors index files actually exist on disk.

This catches cases where the snapshot directory exists but files are missing
(e.g., due to incomplete downloads or corrupted cache).

Args:
snapshot_dir: Path to the model snapshot directory

Returns:
Tuple of (all_exist, error_message)
"""
# Find all safetensors index files
index_files = [
f for f in os.listdir(snapshot_dir) if f.endswith(".safetensors.index.json")
]

if not index_files:
# No index files means it's not a sharded model, skip this check
return True, None

for index_file in index_files:
index_path = os.path.join(snapshot_dir, index_file)
try:
with open(index_path) as f:
index_data = json.load(f)

weight_map = index_data.get("weight_map", {})
if not weight_map:
continue

# Check that all files in weight_map exist
required_files = set(weight_map.values())
missing_files = []

for file_name in required_files:
file_path = os.path.join(snapshot_dir, file_name)
# Check both existence and that it's not a broken symlink
if not os.path.exists(file_path):
missing_files.append(file_name)

if missing_files:
return (
False,
f"Missing {len(missing_files)} file(s) from index {index_file}: {missing_files[:3]}{'...' if len(missing_files) > 3 else ''}",
)

except Exception as e:
logger.warning("Failed to read index file %s: %s", index_file, e)
continue

return True, None


def _validate_sharded_model(
snapshot_dir: str, weight_files: List[str]
) -> Tuple[bool, Optional[str], List[str]]:
Expand All @@ -50,6 +106,12 @@ def _validate_sharded_model(
Tuple of (is_valid, error_message, corrupted_files)
- corrupted_files: List of file paths that are corrupted (for selective cleanup)
"""
# First, check if all files from the index actually exist
# This catches missing files that wouldn't be found by glob
index_check_valid, index_error = _check_index_files_exist(snapshot_dir)
if not index_check_valid:
return False, index_error, []

# Pattern for sharded files: model-00001-of-00009.safetensors
shard_pattern = re.compile(r"(.*?)-(\d+)-of-(\d+)\.(safetensors|bin)")

Expand Down
Loading