@@ -27,6 +27,7 @@ def detect_platform():
2727 for path in ["/opt/rocm" , "/opt/rocm-*" ]:
2828 if "*" in path :
2929 import glob
30+
3031 matches = sorted (glob .glob (path ), reverse = True )
3132 if matches :
3233 rocm_path = matches [0 ]
@@ -45,14 +46,14 @@ def detect_platform():
4546 rocm_version = f .read ().strip ()
4647 else :
4748 # Try to extract version from path
48- match = re .search (r' rocm[-/](\d+\.\d+(?:\.\d+)?)' , rocm_path )
49+ match = re .search (r" rocm[-/](\d+\.\d+(?:\.\d+)?)" , rocm_path )
4950 if match :
5051 rocm_version = match .group (1 )
5152
5253 print (f"Detected ROCm platform at { rocm_path } " )
5354 if rocm_version :
5455 print (f"ROCm version: { rocm_version } " )
55- return (' rocm' , rocm_version , rocm_path )
56+ return (" rocm" , rocm_version , rocm_path )
5657
5758 # Check for CUDA
5859 cuda_home = os .environ .get ("CUDA_HOME" ) or os .environ .get ("CUDA_PATH" )
@@ -64,11 +65,11 @@ def detect_platform():
6465
6566 if cuda_home and os .path .exists (cuda_home ):
6667 print (f"Detected CUDA platform at { cuda_home } " )
67- return (' cuda' , None , None )
68+ return (" cuda" , None , None )
6869
6970 # Default to CUDA if nothing detected
7071 print ("No GPU platform detected, defaulting to CUDA" )
71- return (' cuda' , None , None )
72+ return (" cuda" , None , None )
7273
7374
7475def hipify_source_files (rocm_path ):
@@ -110,7 +111,7 @@ def hipify_source_files(rocm_path):
110111
111112 hipified_files = []
112113 for source_path , result in hipify_result .items ():
113- if hasattr (result , ' hipified_path' ) and result .hipified_path :
114+ if hasattr (result , " hipified_path" ) and result .hipified_path :
114115 print (f"Successfully hipified: { source_path } -> { result .hipified_path } " )
115116 hipified_files .append (result .hipified_path )
116117
@@ -126,8 +127,9 @@ def hipify_source_files(rocm_path):
126127 return hipified_files
127128
128129
129-
130- def MyExtension (name , sources , mod_name , platform_type , rocm_path = None , * args , ** kwargs ):
130+ def MyExtension (
131+ name , sources , mod_name , platform_type , rocm_path = None , * args , ** kwargs
132+ ):
131133 import pybind11
132134
133135 pybind11_path = os .path .dirname (pybind11 .__file__ )
@@ -143,7 +145,7 @@ def MyExtension(name, sources, mod_name, platform_type, rocm_path=None, *args, *
143145 kwargs ["extra_compile_args" ] = ["-fvisibility=hidden" , "-std=c++17" ]
144146
145147 # Platform-specific configuration
146- if platform_type == ' rocm' and rocm_path :
148+ if platform_type == " rocm" and rocm_path :
147149 # ROCm/HIP configuration
148150 kwargs ["define_macros" ].append (("__HIP_PLATFORM_AMD__" , "1" ))
149151 kwargs ["libraries" ].append ("amdhip64" )
@@ -168,7 +170,7 @@ def run(self):
168170 self .rocm_path = rocm_path
169171
170172 # Configure build based on platform
171- if platform_type == ' rocm' and rocm_path :
173+ if platform_type == " rocm" and rocm_path :
172174 print ("=" * 60 )
173175 print ("Building for AMD ROCm platform" )
174176 if rocm_version :
@@ -182,9 +184,14 @@ def run(self):
182184 for ext in self .extensions :
183185 new_sources = []
184186 for src in ext .sources :
185- if ' fastsafetensors/cpp/ext.cpp' in src :
187+ if " fastsafetensors/cpp/ext.cpp" in src :
186188 # torch.utils.hipify creates files in hip/ subdirectory
187- new_sources .append (src .replace ('fastsafetensors/cpp/ext.cpp' , 'fastsafetensors/cpp/hip/ext.cpp' ))
189+ new_sources .append (
190+ src .replace (
191+ "fastsafetensors/cpp/ext.cpp" ,
192+ "fastsafetensors/cpp/hip/ext.cpp" ,
193+ )
194+ )
188195 else :
189196 new_sources .append (src )
190197 ext .sources = new_sources
@@ -234,6 +241,6 @@ def run(self):
234241 )
235242 ],
236243 cmdclass = {
237- ' build_ext' : CustomBuildExt ,
244+ " build_ext" : CustomBuildExt ,
238245 },
239246)
0 commit comments