Skip to content

Commit f9c2034

Browse files
committed
address reviewer feedback
Signed-off-by: tjtanaa <[email protected]>
1 parent afeba57 commit f9c2034

File tree

12 files changed

+214
-95
lines changed

12 files changed

+214
-95
lines changed

fastsafetensors/common.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@
1313
from .frameworks import FrameworkOpBase, TensorBase
1414
from .st_types import Device, DType
1515

16-
# Add compatibility alias for is_cuda_found -> is_hip_found
17-
# This allows code written for CUDA to work transparently on both CUDA and ROCm
18-
if not hasattr(fstcpp, 'is_cuda_found'):
19-
fstcpp.is_cuda_found = fstcpp.is_hip_found
16+
17+
def is_gpu_found():
18+
"""Check if any GPU (CUDA or HIP) is available.
19+
20+
Returns True if either CUDA or ROCm/HIP GPUs are detected.
21+
This allows code to work transparently across both platforms.
22+
"""
23+
return fstcpp.is_cuda_found() or fstcpp.is_hip_found()
2024

2125

2226
def get_device_numa_node(device: Optional[int]) -> Optional[int]:

fastsafetensors/cpp/cuda_compat.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
1+
// SPDX-License-Identifier: Apache-2.0
12
/*
23
* Copyright 2024 IBM Inc. All rights reserved
3-
* SPDX-License-Identifier: Apache-2.0
44
*
55
* CUDA/HIP compatibility layer for fastsafetensors
66
* Minimal compatibility header - only defines what hipify-perl doesn't handle
77
*/
88

9-
#pragma once
9+
#ifndef __CUDA_COMPAT_H__
10+
#define __CUDA_COMPAT_H__
1011

1112
// Platform detection - this gets hipified to check __HIP_PLATFORM_AMD__
1213
#ifdef __HIP_PLATFORM_AMD__
1314
#ifndef USE_ROCM
1415
#define USE_ROCM
1516
#endif
16-
#include <hip/hip_runtime.h>
17+
// Note: We do NOT include <hip/hip_runtime.h> here to avoid compile-time dependencies.
18+
// Instead, we dynamically load the ROCm runtime library (libamdhip64.so) at runtime
19+
// using dlopen(), just like we do for CUDA (libcudart.so).
20+
// Minimal types are defined in ext.hpp.
1721
#else
18-
// For CUDA platform, or when CUDA headers aren't available, we define minimal types in ext.hpp
22+
// For CUDA platform, we also avoid including headers and define minimal types in ext.hpp
1923
#endif
2024

2125
// Runtime library name - hipify-perl doesn't change string literals
@@ -31,3 +35,5 @@
3135
#define cudaDeviceMalloc hipDeviceMalloc
3236
#define cudaDeviceFree hipDeviceFree
3337
#endif
38+
39+
#endif // __CUDA_COMPAT_H__

fastsafetensors/cpp/ext.cpp

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ ext_funcs_t cpu_fns = ext_funcs_t {
7979
ext_funcs_t cuda_fns;
8080

8181
static bool cuda_found = false;
82+
static bool is_hip_runtime = false; // Track if we loaded HIP (not auto-hipified)
8283
static bool cufile_found = false;
8384

8485
static int cufile_ver = 0;
@@ -123,8 +124,12 @@ static void load_nvidia_functions() {
123124
count = 0; // why cudaGetDeviceCount returns non-zero for errors?
124125
}
125126
cuda_found = count > 0;
127+
// Detect if we loaded HIP runtime (ROCm) vs CUDA runtime
128+
if (cuda_found && std::string(cudartLib).find("hip") != std::string::npos) {
129+
is_hip_runtime = true;
130+
}
126131
if (init_log) {
127-
fprintf(stderr, "[DEBUG] device count=%d, cuda_found=%d\n", count, cuda_found);
132+
fprintf(stderr, "[DEBUG] device count=%d, cuda_found=%d, is_hip_runtime=%d\n", count, cuda_found, is_hip_runtime);
128133
}
129134
} else {
130135
cuda_found = false;
@@ -218,11 +223,28 @@ static void load_nvidia_functions() {
218223
}
219224
}
220225

226+
// Note: is_cuda_found gets auto-hipified to is_hip_found on ROCm builds
227+
// So this function will be is_hip_found() after hipification on ROCm
221228
bool is_cuda_found()
222229
{
223230
return cuda_found;
224231
}
225232

233+
// Separate function that always returns false on ROCm (CUDA not available on ROCm)
234+
// This will be used for the "is_cuda_found" Python export on ROCm builds
235+
bool cuda_not_available()
236+
{
237+
return false; // On ROCm, CUDA is never available
238+
}
239+
240+
// Separate function for checking HIP runtime detection (not hipified)
241+
// On CUDA: checks if HIP runtime was detected
242+
// On ROCm: not used (is_cuda_found gets hipified to is_hip_found)
243+
bool check_hip_runtime()
244+
{
245+
return is_hip_runtime;
246+
}
247+
226248
bool is_cufile_found()
227249
{
228250
return cufile_found;
@@ -719,7 +741,21 @@ cpp_metrics_t get_cpp_metrics() {
719741

720742
PYBIND11_MODULE(__MOD_NAME__, m)
721743
{
722-
m.def("is_cuda_found", &is_cuda_found);
744+
// Export both is_cuda_found and is_hip_found on all platforms
745+
// Use string concatenation to prevent hipify from converting the export names
746+
#ifdef USE_ROCM
747+
// On ROCm after hipify:
748+
// - is_cuda_found() becomes is_hip_found(), so export it as "is_hip_found"
749+
// - Export cuda_not_available() as "is_cuda_found" (CUDA not available on ROCm)
750+
m.def(("is_" "cuda" "_found"), &cuda_not_available); // Returns false on ROCm
751+
m.def(("is_" "hip" "_found"), &is_cuda_found); // hipified to is_hip_found, returns hip status
752+
#else
753+
// On CUDA:
754+
// - is_cuda_found() checks for CUDA
755+
// - check_hip_runtime() checks if HIP runtime was loaded
756+
m.def(("is_" "cuda" "_found"), &is_cuda_found);
757+
m.def(("is_" "hip" "_found"), &check_hip_runtime);
758+
#endif
723759
m.def("is_cufile_found", &is_cufile_found);
724760
m.def("cufile_version", &cufile_version);
725761
m.def("set_debug_log", &set_debug_log);

fastsafetensors/cpp/ext.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,10 @@ typedef struct CUfileDescr_t {
3636
} CUfileDescr_t;
3737
typedef struct CUfileError { CUfileOpError err; } CUfileError_t;
3838

39-
// Only define minimal CUDA types if not using ROCm (where real headers are included)
40-
#ifndef USE_ROCM
39+
// Define minimal CUDA/HIP types for both platforms to avoid compile-time dependencies
40+
// We load all GPU functions dynamically at runtime via dlopen()
4141
typedef enum cudaError { cudaSuccess = 0, cudaErrorMemoryAllocation = 2 } cudaError_t;
4242
enum cudaMemcpyKind { cudaMemcpyHostToDevice=2, cudaMemcpyDefault = 4 };
43-
#endif
4443

4544

4645
typedef enum CUfileFeatureFlags {

fastsafetensors/dlpack.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,43 @@
1212
_c_str_dltensor = b"dltensor"
1313

1414

15-
# Detect GPU type at module load time
15+
# Lazy GPU type detection - avoid calling framework-specific code at module load time
16+
_GPU_DEVICE_TYPE = None # Will be detected lazily
17+
18+
1619
def _detect_gpu_type():
17-
"""Detect if we're running on ROCm or CUDA"""
18-
try:
19-
import torch
20-
if torch.cuda.is_available():
21-
# Check if this is ROCm build
22-
if hasattr(torch.version, 'hip') and torch.version.hip is not None:
23-
return 10 # kDLROCM
24-
except:
25-
pass
20+
"""Detect if we're running on ROCm or CUDA.
21+
22+
This detection is now done lazily to avoid framework-specific calls at module load time.
23+
Uses the C++ extension's is_hip_found() to determine the platform.
24+
"""
25+
# Import here to avoid circular dependency
26+
from . import cpp as fstcpp
27+
28+
# Check if we loaded HIP runtime (ROCm)
29+
if fstcpp.is_hip_found():
30+
return 10 # kDLROCM
2631
return 2 # kDLCUDA
2732

2833

29-
_GPU_DEVICE_TYPE = _detect_gpu_type()
34+
def _get_gpu_device_type():
35+
"""Get the GPU device type, detecting it lazily if needed."""
36+
global _GPU_DEVICE_TYPE
37+
if _GPU_DEVICE_TYPE is None:
38+
_GPU_DEVICE_TYPE = _detect_gpu_type()
39+
return _GPU_DEVICE_TYPE
3040

3141

3242
class DLDevice(ctypes.Structure):
3343
def __init__(self, dev: Device):
34-
self.device_type = self.DeviceToDL[dev.type]
44+
# Use lazy detection to get the GPU device type
45+
gpu_type = _get_gpu_device_type()
46+
device_to_dl = {
47+
DeviceType.CPU: self.kDLCPU,
48+
DeviceType.CUDA: gpu_type,
49+
DeviceType.GPU: gpu_type,
50+
}
51+
self.device_type = device_to_dl[dev.type]
3552
self.device_id = dev.index if dev.index is not None else 0
3653

3754
kDLCPU = 1
@@ -42,12 +59,6 @@ def __init__(self, dev: Device):
4259
("device_id", ctypes.c_int),
4360
]
4461

45-
DeviceToDL = {
46-
DeviceType.CPU: kDLCPU,
47-
DeviceType.CUDA: _GPU_DEVICE_TYPE,
48-
DeviceType.GPU: _GPU_DEVICE_TYPE,
49-
}
50-
5162

5263
class c_DLDataType(ctypes.Structure):
5364
def __init__(self, dtype: DType):

fastsafetensors/frameworks/_torch.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,18 @@ def copy_tensor(self, dst: TorchTensor, src: TorchTensor):
186186
dst.real_tensor.copy_(src.real_tensor)
187187

188188
def get_cuda_ver(self) -> str:
189+
"""Get GPU runtime version with platform indicator.
190+
191+
Returns a string like 'hip-5.7.0' for ROCm or 'cuda-12.1' for CUDA,
192+
or 'none' if no GPU is available. This allows code to distinguish
193+
between different GPU platforms without using torch directly.
194+
"""
189195
if torch.cuda.is_available():
190-
return str(torch.version.cuda)
191-
return "0.0"
196+
# Check if this is ROCm/HIP build
197+
if hasattr(torch.version, "hip") and torch.version.hip is not None:
198+
return f"hip-{torch.version.hip}"
199+
return f"cuda-{torch.version.cuda}"
200+
return "none"
192201

193202
def get_device_ptr_align(self) -> int:
194203
CUDA_PTR_ALIGN: int = 16

fastsafetensors/loader.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union
77

88
from . import cpp as fstcpp
9-
from .common import SafeTensorsMetadata, TensorFrame, get_device_numa_node
9+
from .common import SafeTensorsMetadata, TensorFrame, get_device_numa_node, is_gpu_found
1010
from .file_buffer import FilesBufferOnDevice
1111
from .frameworks import TensorBase, get_framework_op
1212
from .st_types import DeviceType, DType
@@ -69,8 +69,10 @@ def __init__(
6969
gl_set_numa = True
7070
fstcpp.set_debug_log(debug_log)
7171
device_is_not_cpu = self.device.type != DeviceType.CPU
72-
if device_is_not_cpu and not fstcpp.is_cuda_found():
73-
raise Exception("[FAIL] libcudart.so does not exist")
72+
if device_is_not_cpu and not is_gpu_found():
73+
raise Exception(
74+
"[FAIL] GPU runtime library (libcudart.so or libamdhip64.so) does not exist"
75+
)
7476
if not fstcpp.is_cufile_found() and not nogds:
7577
warnings.warn(
7678
"libcufile.so does not exist but nogds is False. use nogds=True",

setup.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def detect_platform():
2727
for path in ["/opt/rocm", "/opt/rocm-*"]:
2828
if "*" in path:
2929
import glob
30+
3031
matches = sorted(glob.glob(path), reverse=True)
3132
if matches:
3233
rocm_path = matches[0]
@@ -45,14 +46,14 @@ def detect_platform():
4546
rocm_version = f.read().strip()
4647
else:
4748
# Try to extract version from path
48-
match = re.search(r'rocm[-/](\d+\.\d+(?:\.\d+)?)', rocm_path)
49+
match = re.search(r"rocm[-/](\d+\.\d+(?:\.\d+)?)", rocm_path)
4950
if match:
5051
rocm_version = match.group(1)
5152

5253
print(f"Detected ROCm platform at {rocm_path}")
5354
if rocm_version:
5455
print(f"ROCm version: {rocm_version}")
55-
return ('rocm', rocm_version, rocm_path)
56+
return ("rocm", rocm_version, rocm_path)
5657

5758
# Check for CUDA
5859
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
@@ -64,11 +65,11 @@ def detect_platform():
6465

6566
if cuda_home and os.path.exists(cuda_home):
6667
print(f"Detected CUDA platform at {cuda_home}")
67-
return ('cuda', None, None)
68+
return ("cuda", None, None)
6869

6970
# Default to CUDA if nothing detected
7071
print("No GPU platform detected, defaulting to CUDA")
71-
return ('cuda', None, None)
72+
return ("cuda", None, None)
7273

7374

7475
def hipify_source_files(rocm_path):
@@ -110,7 +111,7 @@ def hipify_source_files(rocm_path):
110111

111112
hipified_files = []
112113
for source_path, result in hipify_result.items():
113-
if hasattr(result, 'hipified_path') and result.hipified_path:
114+
if hasattr(result, "hipified_path") and result.hipified_path:
114115
print(f"Successfully hipified: {source_path} -> {result.hipified_path}")
115116
hipified_files.append(result.hipified_path)
116117

@@ -126,8 +127,9 @@ def hipify_source_files(rocm_path):
126127
return hipified_files
127128

128129

129-
130-
def MyExtension(name, sources, mod_name, platform_type, rocm_path=None, *args, **kwargs):
130+
def MyExtension(
131+
name, sources, mod_name, platform_type, rocm_path=None, *args, **kwargs
132+
):
131133
import pybind11
132134

133135
pybind11_path = os.path.dirname(pybind11.__file__)
@@ -143,7 +145,7 @@ def MyExtension(name, sources, mod_name, platform_type, rocm_path=None, *args, *
143145
kwargs["extra_compile_args"] = ["-fvisibility=hidden", "-std=c++17"]
144146

145147
# Platform-specific configuration
146-
if platform_type == 'rocm' and rocm_path:
148+
if platform_type == "rocm" and rocm_path:
147149
# ROCm/HIP configuration
148150
kwargs["define_macros"].append(("__HIP_PLATFORM_AMD__", "1"))
149151
kwargs["libraries"].append("amdhip64")
@@ -168,7 +170,7 @@ def run(self):
168170
self.rocm_path = rocm_path
169171

170172
# Configure build based on platform
171-
if platform_type == 'rocm' and rocm_path:
173+
if platform_type == "rocm" and rocm_path:
172174
print("=" * 60)
173175
print("Building for AMD ROCm platform")
174176
if rocm_version:
@@ -182,9 +184,14 @@ def run(self):
182184
for ext in self.extensions:
183185
new_sources = []
184186
for src in ext.sources:
185-
if 'fastsafetensors/cpp/ext.cpp' in src:
187+
if "fastsafetensors/cpp/ext.cpp" in src:
186188
# torch.utils.hipify creates files in hip/ subdirectory
187-
new_sources.append(src.replace('fastsafetensors/cpp/ext.cpp', 'fastsafetensors/cpp/hip/ext.cpp'))
189+
new_sources.append(
190+
src.replace(
191+
"fastsafetensors/cpp/ext.cpp",
192+
"fastsafetensors/cpp/hip/ext.cpp",
193+
)
194+
)
188195
else:
189196
new_sources.append(src)
190197
ext.sources = new_sources
@@ -234,6 +241,6 @@ def run(self):
234241
)
235242
],
236243
cmdclass={
237-
'build_ext': CustomBuildExt,
244+
"build_ext": CustomBuildExt,
238245
},
239246
)

tests/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from fastsafetensors import SingleGroup
88
from fastsafetensors import cpp as fstcpp
9+
from fastsafetensors.common import is_gpu_found
910
from fastsafetensors.cpp import load_nvidia_functions
1011
from fastsafetensors.frameworks import FrameworkOpBase, get_framework_op
1112
from fastsafetensors.st_types import Device
@@ -14,6 +15,7 @@
1415
TESTS_DIR = os.path.dirname(__file__)
1516
sys.path.insert(0, TESTS_DIR)
1617
from platform_utils import get_platform_info, is_rocm_platform
18+
1719
REPO_ROOT = os.path.dirname(os.path.dirname(TESTS_DIR))
1820
DATA_DIR = os.path.join(REPO_ROOT, ".testdata")
1921
TF_DIR = os.path.join(DATA_DIR, "transformers_cache")
@@ -81,7 +83,7 @@ def pg():
8183

8284
@pytest.fixture(scope="session", autouse=True)
8385
def dev_init() -> None:
84-
if fstcpp.is_cuda_found():
86+
if is_gpu_found():
8587
dev_str = "cuda:0" if FRAMEWORK.get_name() == "pytorch" else "gpu:0"
8688
else:
8789
dev_str = "cpu"

0 commit comments

Comments
 (0)