Skip to content

Commit e74b4b4

Browse files
authored
Fix global hostcall detection (#787)
1 parent cc9f619 commit e74b4b4

File tree

5 files changed

+209
-298
lines changed

5 files changed

+209
-298
lines changed

docs/src/tutorials/perf.md

Lines changed: 0 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -44,96 +44,3 @@ julia> GPUArrays.unsafe_free!(cache)
4444
For a more sophisticated real-world example, see how
4545
[GaussianSplatting.jl](https://github.com/JuliaNeuralGraphics/GaussianSplatting.jl/blob/e4ef1324c187371e336bef875b053023afe7fb2c/src/training.jl#L183)
4646
handles it.
47-
48-
## Avoid triggering Hostcalls
49-
50-
Some functions in the kernel may cause an exception,
51-
capturing the original value of the variable that caused it.
52-
These are usually related to float-to-integer conversion, so functions like
53-
`Int(1.0), ceil(Int, 1.0), floor(Int, 1.0)` will cause it.
54-
55-
This will perform dynamic memory allocation and launch a `Hostcall` for that,
56-
which will sit in the background thread until kernel finishes execution and the user synchronizes the `stream`.
57-
Having a hostcall unnecessarily slows execution down and you can avoid that by using
58-
"GPU-friendly" version of the function.
59-
60-
!!! info "Hostcalls"
61-
Hostcalls should be used mostly for debugging. When performance matters, they should be avoided.
62-
63-
For example, let's see how we may deal with `ceil(Int, x)` and convert it to GPU-friendly version.
64-
65-
Starting with the bad example:
66-
67-
```jldoctest hostcall
68-
julia> function bad_kernel!(y, x)
69-
@inbounds y[1] = ceil(Int, x[1])
70-
return
71-
end
72-
bad_kernel! (generic function with 1 method)
73-
74-
julia> x = ROCArray(Float32[1.1f0]);
75-
76-
julia> y = ROCArray(zeros(Int, 1));
77-
78-
julia> @roc bad_kernel!(y, x);
79-
┌ Info: Global hostcalls detected!
80-
│ - Source: MethodInstance for bad_kernel!(::AMDGPU.Device.ROCDeviceVector{Int64, 1}, ::AMDGPU.Device.ROCDeviceVector{Float32, 1})
81-
│ - Hostcalls: [:malloc_hostcall]
82-
83-
│ Use `AMDGPU.synchronize(; stop_hostcalls=true)` to synchronize and stop them.
84-
└ Otherwise, performance might degrade if they keep running in the background.
85-
86-
julia> y
87-
1-element ROCArray{Int64, 1, AMDGPU.Runtime.Mem.HIPBuffer}:
88-
2
89-
90-
julia> AMDGPU.synchronize(; stop_hostcalls=true)
91-
[ Info: Stopped global hostcall: `malloc_hostcall`.
92-
```
93-
94-
Here we can see that using "un-optimized" version of `ceil(Int, x[1])`
95-
causes a `malloc_hostcall` to be launched.
96-
Which we then have to stop by passing `stop_hostcalls=true` to the synchronization functions.
97-
98-
We can avoid this by using "unsafe" version that avoids checking for errors under-the-hood.
99-
100-
```jldoctest hostcall
101-
julia> function good_kernel!(y, x)
102-
@inbounds y[1] = unsafe_trunc(Int, ceil(x[1]))
103-
return
104-
end
105-
good_kernel! (generic function with 1 method)
106-
107-
julia> fill!(y, 0);
108-
109-
julia> @roc good_kernel!(y, x);
110-
111-
julia> AMDGPU.synchronize(; stop_hostcalls=true) # Nothing is printed, so no hostcall was launched & stopped.
112-
113-
julia> y
114-
1-element ROCArray{Int64, 1, AMDGPU.Runtime.Mem.HIPBuffer}:
115-
2
116-
```
117-
118-
By doing `ceil(x[1])` first, then "unsafely" converting `Float32` to `Int`
119-
we can avoid error-checking & hostcalls.
120-
121-
We can compare LLVM IR of the function that converts `Float32` to `Int` to see how they differ:
122-
123-
::: tabs
124-
125-
== unsafe_trunc(Int, 1.0)
126-
127-
```@example good-conversion
128-
using InteractiveUtils
129-
InteractiveUtils.@code_llvm unsafe_trunc(Int, 1.0)
130-
```
131-
132-
== Int(1.0)
133-
134-
```@example bad-conversion
135-
using InteractiveUtils
136-
InteractiveUtils.@code_llvm Int(1.0)
137-
```
138-
139-
:::

src/compiler/codegen.jl

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ const _hip_compiler_cache = Dict{HIP.HIPDevice, Dict{Any, HIP.HIPFunction}}()
1616

1717
# hash(fun, hash(f, hash(tt))) => HIPKernel
1818
const _kernel_instances = Dict{UInt, Runtime.HIPKernel}()
19+
# UInt (hash(job)) => Vector{Symbol} (global hostcall names)
20+
const _global_hostcalls = Dict{UInt, Vector{Symbol}}()
1921

2022
function compiler_cache(dev::HIP.HIPDevice)
2123
get!(() -> Dict{UInt, Any}(), _hip_compiler_cache, dev)
@@ -34,6 +36,10 @@ function GPUCompiler.link_libraries!(
3436
invoke(GPUCompiler.link_libraries!,
3537
Tuple{CompilerJob{GCNCompilerTarget}, typeof(mod), typeof(undefined_fns)},
3638
job, mod, undefined_fns)
39+
40+
# Detect global hostcalls here, before optimizations & cleanup occur.
41+
_global_hostcalls[hash(job)] = find_global_hostcalls(mod)
42+
3743
# Link only if there are undefined functions.
3844
# Everything else was loaded in `finish_module!` stage.
3945
link_device_libs!(
@@ -189,21 +195,24 @@ function create_executable(obj)
189195
return bin
190196
end
191197

192-
function hipcompile(@nospecialize(job::CompilerJob))
193-
obj, meta = JuliaContext() do ctx
194-
GPUCompiler.compile(:obj, job)
195-
end
196-
197-
entry = LLVM.name(meta.entry)
198-
globals = filter(isextinit, collect(LLVM.globals(meta.ir))) .|> LLVM.name
199-
198+
function find_global_hostcalls(mod::LLVM.Module)
200199
global_hostcall_names = (
201200
:malloc_hostcall, :free_hostcall, :print_hostcall, :printf_hostcall)
201+
202202
global_hostcalls = Symbol[]
203-
for gbl in LLVM.globals(meta.ir), gbl_name in global_hostcall_names
203+
for gbl in LLVM.globals(mod), gbl_name in global_hostcall_names
204204
occursin("__$gbl_name", LLVM.name(gbl)) || continue
205205
push!(global_hostcalls, gbl_name)
206206
end
207+
return global_hostcalls
208+
end
209+
210+
function hipcompile(@nospecialize(job::CompilerJob))
211+
obj, meta = JuliaContext() do ctx
212+
GPUCompiler.compile(:obj, job)
213+
end
214+
215+
global_hostcalls = pop!(_global_hostcalls, hash(job))
207216
if !isempty(global_hostcalls)
208217
@info """Global hostcalls detected!
209218
- Source: $(job.source)
@@ -214,11 +223,13 @@ function hipcompile(@nospecialize(job::CompilerJob))
214223
"""
215224
end
216225

217-
if !isempty(globals)
226+
entry = LLVM.name(meta.entry)
227+
extinit_globals = filter(isextinit, collect(LLVM.globals(meta.ir))) .|> LLVM.name
228+
if !isempty(extinit_globals)
218229
@warn """
219230
HIP backend does not support setting extinit globals.
220231
But kernel `$entry` has following:
221-
$globals
232+
$extinit_globals
222233
223234
Compilation will likely fail.
224235
"""

src/hip/module.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ mutable struct HIPModule
44
function HIPModule(data)
55
device_synchronize()
66

7-
# TODO use alloc_retry?
87
mod_ref = Ref{hipModule_t}()
98
hipModuleLoadData(mod_ref, data)
109
mod = new(mod_ref[])
@@ -21,9 +20,7 @@ struct HIPFunction
2120
mod::HIPModule
2221
global_hostcalls::Vector{Symbol}
2322

24-
function HIPFunction(
25-
mod::HIPModule, name::String, global_hostcalls::Vector{Symbol},
26-
)
23+
function HIPFunction(mod::HIPModule, name::String, global_hostcalls::Vector{Symbol})
2724
fun_ref = Ref{hipFunction_t}()
2825
hipModuleGetFunction(fun_ref, mod, name)
2926
new(fun_ref[], mod, global_hostcalls)
@@ -32,9 +29,7 @@ end
3229

3330
Base.unsafe_convert(::Type{hipFunction_t}, fun::HIPFunction) = fun.handle
3431

35-
function launch_configuration(
36-
fun::HIPFunction; shmem::Integer = 0, max_block_size::Integer = 0,
37-
)
32+
function launch_configuration(fun::HIPFunction; shmem::Integer = 0, max_block_size::Integer = 0)
3833
grid_size_ref, block_size_ref = Ref{Cint}(), Ref{Cint}()
3934
hipModuleOccupancyMaxPotentialBlockSize(
4035
grid_size_ref, block_size_ref, fun, shmem, max_block_size)

0 commit comments

Comments
 (0)