@@ -263,98 +263,98 @@ function compute_m(
263263 return m
264264end
265265
266- """
267- lambertw0(x)
266+ const MINARG = - inv (Base. MathConstants. e)
268267
269- Compute the principal branch (W₀) of the Lambert W function for x ∈ [-1/e, ∞).
268+ """
269+ _lambertw0_initial_guess(x::T) where {T<:AbstractFloat}
270270
271- The Lambert W function satisfies W(x)·exp(W(x)) = x. This implementation uses
272- Halley's method for fast convergence, typically requiring only 2-3 iterations.
271+ Provide a robust initial guess for the Lambert W₀ function for use in iterative solvers.
273272
274273# Arguments
275- - `x::Real `: Input value, must be ≥ -1/e ≈ -0.36788
274+ - `x::T `: Input value, should be ≥ -1/e
276275
277276# Returns
278- - `W::Float64`: Lambert W₀(x), the principal branch value
277+ - Initial guess for W₀(x)
279278
280279# Algorithm
281- Uses Halley's method with an appropriate initial guess:
282- - For x near -1/e: use series expansion
283- - For x ∈ [-1/e, -0.1]: use fitted approximation
284- - For x ∈ [-0.1, 10]: use log-based approximation
285- - For x > 10: use asymptotic expansion
286-
287- # References
288- Corless et al. (1996) "On the Lambert W function"
280+ - For x > 1: uses log(x) - log(log(x)) approximation
281+ - For x < -0.32 (near -1/e): uses series expansion for accurate convergence near branch point
282+ - For -0.32 ≤ x ≤ 1: uses max(x, -0.3) as a simple starting point
289283"""
290- function lambertw0 (x:: T ) where {T <: Real }
291- # Check domain
292- min_x = - one (T) / T (ℯ)
293- if x < min_x
294- throw (
295- DomainError (
296- x,
297- " Lambert W₀ is only defined for x ≥ -1/e ≈ -0.36788" ,
298- ),
299- )
300- end
301-
302- # Special cases
303- if x == zero (T)
304- return zero (T)
305- elseif abs (x - min_x) < T (1e-10 )
306- return - one (T)
307- end
308-
309- # Choose initial guess based on the region
310- if x < T (- 0.32 ) # Near the branch point -1/e
311- # Series expansion near -1/e
284+ @inline function _lambertw0_initial_guess (x:: T ) where {T <: AbstractFloat }
285+ if x > one (T)
286+ return log (x) - log (max (log (x), T (1e-6 )))
287+ elseif x < T (- 0.32 )
288+ # Near the branch point -1/e, use series expansion
289+ # This handles the singular behavior at x = -1/e where W(x) = -1
312290 p = sqrt (T (2 ) * (T (ℯ) * x + one (T)))
313- w = - one (T) + p - p^ 2 / T (3 ) + p^ 3 * T (11 ) / T (72 )
314- elseif x < zero (T)
315- # For x ∈ [-0.32, 0], use a rational approximation
316- w = x / (one (T) + x) # Simple approximation that's good enough for Halley
317- elseif x < T (2.5 )
318- # For small positive x, start with x as the guess (works well up to ~2.5)
319- w = x * (one (T) - x / T (3 )) # Slightly better than just x
320- elseif x < T (10 )
321- # Log-based approximation (safe since x >= 2.5)
322- l1 = log (x)
323- l2 = log (l1)
324- w = l1 - l2 + l2 / l1
291+ return - one (T) + p - p^ 2 / T (3 ) + p^ 3 * T (11 ) / T (72 )
325292 else
326- # Asymptotic expansion for large x
327- l1 = log (x)
328- l2 = log (l1)
329- w = l1 - l2 + l2 / l1 + l2 * (l2 - T (2 )) / (T (2 ) * l1 * l1)
293+ return max (x, T (- 0.3 ))
330294 end
295+ end
331296
332- # Halley's method refinement (typically converges in 2-3 iterations)
333- for _ in 1 : 10 # Maximum iterations
334- ew = exp (w)
335- wew = w * ew
336- f = wew - x
297+ """
298+ lambertw0(x::T; maxiter::Int = 16) where {T<:AbstractFloat}
337299
338- # Check convergence
339- if abs (f) < T (1e-14 ) * (one (T) + abs (x))
340- break
341- end
300+ Compute the principal branch (W₀) of the Lambert W function for x ∈ [-1/e, ∞).
342301
343- # Halley's method update
344- # w_new = w - f / (f' - f * f'' / (2 * f'))
345- # where f = w*exp(w) - x
346- # f' = exp(w) * (w + 1)
347- # f'' = exp(w) * (w + 2)
348- w1 = w + one (T)
349- denom = ew * w1 - f * (w + T (2 )) / (T (2 ) * w1)
302+ This is a GPU-device-friendly implementation using a fixed number of Halley iterations.
303+ The Lambert W function satisfies W(x)·exp(W(x)) = x.
350304
351- if abs (denom) < T ( 1e-20 )
352- break
353- end
305+ # Arguments
306+ - `x::T`: Input value, must be ≥ -1/e ≈ -0.36788
307+ - `maxiter::Int`: Maximum number of Halley iterations (default: 16)
354308
355- w = w - f / denom
356- end
309+ # Returns
310+ - `W::T`: Lambert W₀(x), the principal branch value, or NaN for invalid inputs
311+
312+ # Algorithm
313+ Uses Halley's method with a fixed number of iterations for GPU compatibility:
314+ - No dynamic memory allocation
315+ - No conditional breaks (runs all iterations)
316+ - Broadcastable for use with CuArrays: lambertw0.(cuarray)
317+
318+ # Device Compatibility
319+ This implementation is designed to work on both CPU and GPU:
320+ - All operations are scalar and supported on CUDA.jl
321+ - No array allocations or dynamic loops
322+ - Type-generic over AbstractFloat (Float32, Float64)
357323
324+ # References
325+ Corless et al. (1996) "On the Lambert W function"
326+ """
327+ @inline function lambertw0 (x:: T ; maxiter:: Int = 16 ) where {T <: AbstractFloat }
328+ if ! (isfinite (x)) || x < T (MINARG)
329+ return T (NaN )
330+ end
331+ w = _lambertw0_initial_guess (x)
332+ for i in 1 : maxiter
333+ ew = exp (w)
334+ f = w * ew - x
335+ # Halley denominator
336+ # Special case: when w ≈ -1, both numerator and denominator approach 0
337+ # This happens at the branch point x = -1/e, where W(-1/e) = -1
338+ w_plus_1 = w + one (T)
339+ if abs (w_plus_1) < eps (T)
340+ # Already at or very near the solution w = -1, no update needed
341+ Δ = zero (T)
342+ else
343+ two_w_plus_2 = T (2 ) * w_plus_1
344+ if abs (two_w_plus_2) < eps (T)
345+ # Near w = -1, use Newton's method instead of Halley
346+ Δ = f / (ew * w_plus_1)
347+ else
348+ denom = ew * w_plus_1 - (w + T (2 )) * f / two_w_plus_2
349+ if abs (denom) < eps (T)
350+ Δ = f / (ew * w_plus_1)
351+ else
352+ Δ = f / denom
353+ end
354+ end
355+ end
356+ w -= Δ
357+ end
358358 return w
359359end
360360
0 commit comments