Skip to content

Commit d71eddb

Browse files
authored
rft: make gamma_inc, gamma_inc_inv GPU-compatible (#514)
- replace `Vector`s by statically sized `Tuple`s - update type restriction in `chepolsum` - use Base.@nif to statically unroll findfirst into if/elseif/else chain at parse time - replace interpolated strings in errors with LazyString - manually inline single recursion step in auxgam; GPU compilers cannot statically prove termination - add tests that checks these functions are inferrable and do not allocation memory
1 parent 539d6f8 commit d71eddb

File tree

6 files changed

+74
-20
lines changed

6 files changed

+74
-20
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "SpecialFunctions"
22
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
3-
version = "2.7.0"
3+
version = "2.7.1"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/bessel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ struct AmosException <: Exception
77
end
88

99
function Base.showerror(io::IO, ex::AmosException)
10-
print(io, "AmosException with id $(ex.id): ")
10+
print(io, LazyString("AmosException with id ", ex.id, ": "))
1111
if ex.id == 0
1212
print(io, "normal return, computation complete.")
1313
elseif ex.id == 1

src/expint.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ macro E₁_cf64(x, n::Integer)
5252
end
5353

5454
function E₁_taylor_coefficients(::Type{T}, n::Integer) where {T<:Number}
55-
n < 0 && throw(ArgumentError("$n ≥ 0 is required"))
55+
n < 0 && throw(ArgumentError(LazyString(n, " ≥ 0 is required")))
5656
n == 0 && return T[]
5757
n == 1 && return T[-eulergamma]
5858
# iteratively compute the terms in the series, starting with k=1

src/gamma_inc.jl

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
11
using Base.MPFR: ROUNDING_MODE
22
#useful constants
3-
const acc0 = [5.0e-15, 5.0e-7, 5.0e-4] #accuracy options
4-
const big1 = [25.0, 14.0, 10.0]
5-
const e0 = [0.25e-3, 0.25e-1, 0.14]
6-
const x0 = [31.0, 17.0, 9.7]
3+
const acc0 = (5.0e-15, 5.0e-7, 5.0e-4) #accuracy options
4+
const big1 = (25.0, 14.0, 10.0)
5+
const e0 = (0.25e-3, 0.25e-1, 0.14)
6+
const x0 = (31.0, 17.0, 9.7)
77
const alog10 = log(10)
88
const rt2pin = Float64(invsqrt2π)
99
const rtpi = Float64(sqrtπ)
10-
const stirling_coef = [1.996379051590076518221, -0.17971032528832887213e-2, 0.131292857963846713e-4, -0.2340875228178749e-6, 0.72291210671127e-8, -0.3280997607821e-9, 0.198750709010e-10, -0.15092141830e-11, 0.1375340084e-12, -0.145728923e-13, 0.17532367e-14, -0.2351465e-15, 0.346551e-16, -0.55471e-17, 0.9548e-18, -0.1748e-18, 0.332e-19, -0.58e-20]
11-
const auxgam_coef = [-1.013609258009865776949, 0.784903531024782283535e-1, 0.67588668743258315530e-2, -0.12790434869623468120e-2, 0.462939838642739585e-4, 0.43381681744740352e-5, -0.5326872422618006e-6, 0.172233457410539e-7, 0.8300542107118e-9, -0.10553994239968e-9, 0.39415842851e-11, 0.362068537e-13, -0.107440229e-13, 0.5000413e-15, -0.62452e-17, -0.5185e-18, 0.347e-19, -0.9e-21]
10+
const stirling_coef = (1.996379051590076518221, -0.17971032528832887213e-2, 0.131292857963846713e-4, -0.2340875228178749e-6, 0.72291210671127e-8, -0.3280997607821e-9, 0.198750709010e-10, -0.15092141830e-11, 0.1375340084e-12, -0.145728923e-13, 0.17532367e-14, -0.2351465e-15, 0.346551e-16, -0.55471e-17, 0.9548e-18, -0.1748e-18, 0.332e-19, -0.58e-20)
11+
const auxgam_coef = (-1.013609258009865776949, 0.784903531024782283535e-1, 0.67588668743258315530e-2, -0.12790434869623468120e-2, 0.462939838642739585e-4, 0.43381681744740352e-5, -0.5326872422618006e-6, 0.172233457410539e-7, 0.8300542107118e-9, -0.10553994239968e-9, 0.39415842851e-11, 0.362068537e-13, -0.107440229e-13, 0.5000413e-15, -0.62452e-17, -0.5185e-18, 0.347e-19, -0.9e-21)
1212

1313
#----------------COEFFICIENTS FOR TEMME EXPANSION------------------
1414

1515
const d00 = -.333333333333333E+00
16-
const d0 = [.833333333333333E-01, -.148148148148148E-01, .115740740740741E-02, .352733686067019E-03, -.178755144032922E-03, .391926317852244E-04]
16+
const d0 = (.833333333333333E-01, -.148148148148148E-01, .115740740740741E-02, .352733686067019E-03, -.178755144032922E-03, .391926317852244E-04)
1717
const d10 = -.185185185185185E-02
18-
const d1 = [-.347222222222222E-02, .264550264550265E-02, -.990226337448560E-03, .205761316872428E-03]
18+
const d1 = (-.347222222222222E-02, .264550264550265E-02, -.990226337448560E-03, .205761316872428E-03)
1919
const d20 = .413359788359788E-02
20-
const d2 = [-.268132716049383E-02, .771604938271605E-03]
20+
const d2 = (-.268132716049383E-02, .771604938271605E-03)
2121
const d30 = .649434156378601E-03
22-
const d3 =[.229472093621399E-03, -.469189494395256E-03]
22+
const d3 = (.229472093621399E-03, -.469189494395256E-03)
2323
const d40 = -.861888290916712E-03
2424
const d4 = .784039221720067E-03
2525
const d50 = -.336798553366358E-03
@@ -88,11 +88,13 @@ Compute function `g` in ``1/\Gamma(x+1) = 1 + x (x-1) g(x)``, `-1 <= x <= 1`.
8888
"""
8989
function auxgam(x::Float64)
9090
@assert -1 <= x <= 1
91-
if x < 0
92-
return -(1.0 + (1.0 + x)*(1.0 + x)*auxgam(1.0 + x))/(1.0 - x)
91+
xp = ifelse(x < 0, 1 + x, x)
92+
t = 2xp - 1
93+
cheb = chepolsum(t, auxgam_coef)
94+
return if x < 0
95+
- (1 + xp * xp * cheb) / (1 - x)
9396
else
94-
t = 2*x - 1.0
95-
return chepolsum(t, auxgam_coef)
97+
cheb
9698
end
9799
end
98100

@@ -111,7 +113,7 @@ end
111113
112114
Computes a series of Chebyshev Polynomials given by: `a[1]/2 + a[2]T1(x) + .... + a[n]T{n-1}(X)`.
113115
"""
114-
function chepolsum(x::Float64, a::Array{Float64,1})
116+
function chepolsum(x::Float64, a::Tuple{Float64,Vararg{Float64}})
115117
n = length(a)
116118
if n == 1
117119
return a[1]/2.0
@@ -471,7 +473,8 @@ function gamma_inc_asym(a::Float64, x::Float64, ind::Integer)
471473
ts = cumprod(ntuple(i -> (a - i) / x, Val(21)))
472474

473475
# sum the smaller terms directly
474-
first_small_t = something(findfirst(x -> abs(x) < 1.0e-3, ts), 21)
476+
# Unrolled findfirst: finds the first index i in 1:21 where abs(ts[i]) < 1e-3, defaulting to 21
477+
first_small_t = Base.@nif 21 (i -> abs(ts[i]) < 1e-3) (i -> i) (i -> 21)
475478
sm = t = ts[first_small_t]
476479
amn = a - first_small_t
477480
while abs(t) acc
@@ -998,7 +1001,7 @@ end
9981001
# floating point numbers of the same type
9991002
function _gamma_inc_inv(a::T, p::T, q::T) where {T<:Real}
10001003
if p + q != 1
1001-
throw(ArgumentError("p + q must equal one but is $(p + q)"))
1004+
throw(ArgumentError(LazyString("p + q must equal one but is ", p + q)))
10021005
end
10031006

10041007
if iszero(p)

test/gamma_inc.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,53 @@ end
277277
@test loggamma(7, -300.2) log(gamma(7, -300.2))
278278
@test_throws DomainError loggamma(6, -3.2)
279279
end
280+
281+
@testset "GPU compatibility ($FT)" for FT in (Float64, Float32, Float16)
282+
# Note: This test is a proxy for GPU compatibility by checking that the functions
283+
# are type stable and do not allocate memory. It does not launch any GPU kernels.
284+
@testset "gamma_inc type stability" begin
285+
@test @inferred(gamma_inc(FT(30.0), FT(29.99999), 0)) isa Tuple{FT,FT}
286+
end
287+
@testset "gamma_inc_inv type stability" begin
288+
@test @inferred(gamma_inc_inv(FT(1.0), FT(0.01), FT(0.99))) isa FT
289+
end
290+
291+
@testset "gamma_inc allocations" begin
292+
# `@allocated` checks for allocations for specific code paths
293+
## a >= 1
294+
### gamma_inc_temme_1: simplified Temme expansion
295+
@test iszero((FT -> @allocated(gamma_inc(FT(30.0), FT(29.99999), 0)))(FT))
296+
### gamma_inc_minimax: minimax approximation
297+
@test iszero((FT -> @allocated(gamma_inc(FT(100.0), FT(80.0), 0)))(FT))
298+
### gamma_inc_temme: Temme expansion
299+
@test iszero((FT -> @allocated(gamma_inc(FT(100.0), FT(80.0), 1)))(FT))
300+
### gamma_inc_cf: Continued fraction
301+
@test iszero((FT -> @allocated(gamma_inc(FT(1.7), FT(2.5))))(FT))
302+
### gamma_inc_taylor: Taylor series
303+
@test iszero((FT -> @allocated(gamma_inc(FT(11.1), FT(0.001))))(FT))
304+
### gamma_inc_asym: Asymptotic expansion
305+
@test iszero((FT -> @allocated(gamma_inc(FT(10.0), FT(35.0))))(FT))
306+
### gamma_inc_fsum: Finite sums
307+
@test iszero((FT -> @allocated(gamma_inc(FT(24.0), FT(25))))(FT))
308+
## a==0.5
309+
### erfc
310+
@test iszero((FT -> @allocated(gamma_inc(FT(0.5), FT(0.5))))(FT))
311+
## x < 1.1
312+
### gamma_inc_taylor_x
313+
@test iszero((FT -> @allocated(gamma_inc(FT(0.9), FT(0.8))))(FT))
314+
## else
315+
### gamma_inc_cf
316+
@test iszero((FT -> @allocated(gamma_inc(FT(0.7), FT(2.5))))(FT))
317+
end
318+
319+
@testset "gamma_inc_inv allocations" begin
320+
# `@allocated` checks for allocations for specific code paths
321+
## x0 approximation paths
322+
### gamma_inc_inv_psmall
323+
@test iszero((FT -> @allocated(gamma_inc_inv(FT(1.0), FT(0.01), FT(0.99))))(FT))
324+
### gamma_inc_inv_qsmall
325+
@test iszero((FT -> @allocated(gamma_inc_inv(FT(5.0), FT(0.99), FT(0.01))))(FT))
326+
### gamma_inc_inv_alarge
327+
@test iszero((FT -> @allocated(gamma_inc_inv(FT(50.0), FT(0.3), FT(0.7))))(FT))
328+
end
329+
end

test/qa.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ end
4545
:ROUNDING_MODE, # Base.MPFR
4646
:_fact_table64, # Base
4747
:version, # Base.MPFR
48+
Symbol("@nif"), # Base
4849
(VERSION < v"1.11" ? (:depwarn,) : ())..., # Base
4950
),
5051
) === nothing

0 commit comments

Comments
 (0)