Skip to content

Commit 6cc898b

Browse files
committed
Initial commit for override
1 parent d33e25c commit 6cc898b

5 files changed

Lines changed: 219 additions & 32 deletions

File tree

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
99
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1010
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1111
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
12+
StaticTools = "86c06d3c-3f03-46de-9781-57580aa96d0a"
1213

1314
[compat]
1415
GPUCompiler = "0.16"

src/StaticCompiler.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@ using Libdl: Libdl, dlsym, dlopen
88
using Base: RefValue
99
using Serialization: serialize, deserialize
1010
using Clang_jll: clang
11+
using StaticTools: @symbolcall, @c_str
1112

1213
export compile, load_function, compile_shlib, compile_executable
1314
export native_code_llvm, native_code_typed, native_llvm_module, native_code_native
1415

16+
include("utils.jl")
17+
include("quirks.jl")
1518
include("target.jl")
1619
include("pointer_patching.jl")
1720
include("code_loading.jl")
@@ -95,8 +98,13 @@ function compile(f, _tt, path::String = tempname(); name = GPUCompiler.safe_nam
9598
isconcretetype(rt) || error("$f on $_tt did not infer to a concrete type. Got $rt")
9699

97100
f_wrap!(out::Ref, args::Ref{<:Tuple}) = (out[] = f(args[]...); nothing)
98-
_, _, table = generate_obj(f_wrap!, Tuple{RefValue{rt}, RefValue{tt}}, path, name; opt_level, strip_llvm, strip_asm, filename, kwargs...)
99-
101+
@eval GPUCompiler.method_table(@nospecialize(job::GPUCompiler.CompilerJob{<:Any,StaticCompilerParams})) = nothing
102+
@eval GPUCompiler.method_table(@nospecialize(job::GPUCompiler.CompilerJob{NativeCompilerTarget})) = nothing
103+
@eval GPUCompiler.method_table(@nospecialize(job::GPUCompiler.CompilerJob{NativeCompilerTarget, StaticCompilerParams})) = nothing
104+
_, _, table = generate_obj(f_wrap!, Tuple{RefValue{rt}, RefValue{tt}}, path, name; opt_level, strip_llvm, strip_asm, filename, ext = false, kwargs...)
105+
@eval GPUCompiler.method_table(@nospecialize(job::GPUCompiler.CompilerJob{<:Any,StaticCompilerParams})) = method_table
106+
@eval GPUCompiler.method_table(@nospecialize(job::GPUCompiler.CompilerJob{NativeCompilerTarget})) = method_table
107+
@eval GPUCompiler.method_table(@nospecialize(job::GPUCompiler.CompilerJob{NativeCompilerTarget, StaticCompilerParams})) = method_table
100108
lf = LazyStaticCompiledFunction{rt, tt}(Symbol(f), path, name, filename, table)
101109
cjl_path = joinpath(path, "$filename.cjl")
102110
serialize(cjl_path, lf)
@@ -106,7 +114,7 @@ end
106114

107115
"""
108116
```julia
109-
generate_obj(f, tt, path::String = tempname(), name = GPUCompiler.safe_name(repr(f)), filenamebase::String="obj";
117+
(f, tt, path::String = tempname(), name = GPUCompiler.safe_name(repr(f)), filenamebase::String="obj";
110118
\tstrip_llvm = false,
111119
\tstrip_asm = true,
112120
\topt_level=3,
@@ -135,11 +143,12 @@ function generate_obj(f, tt, path::String = tempname(), name = GPUCompiler.safe_
135143
strip_llvm = false,
136144
strip_asm = true,
137145
opt_level=3,
146+
ext = true,
138147
kwargs...)
139148
mkpath(path)
140149
obj_path = joinpath(path, "$filenamebase.o")
141-
tm = GPUCompiler.llvm_machine(NativeCompilerTarget())
142-
job, kwargs = native_job(f, tt; name, kwargs...)
150+
job, kwargs = native_job(f, tt; name,ext = ext, kwargs...)
151+
tm = GPUCompiler.llvm_machine(job.target)
143152
#Get LLVM to generated a module of code for us. We don't want GPUCompiler's optimization passes.
144153
mod, meta = GPUCompiler.JuliaContext() do context
145154
GPUCompiler.codegen(:llvm, job; strip=strip_llvm, only_entry=false, validate=false, optimize=false, ctx=context)
@@ -500,11 +509,12 @@ function generate_obj(funcs::Array, path::String = tempname(), filenamebase::Str
500509
strip_llvm = false,
501510
strip_asm = true,
502511
opt_level=3,
512+
ext = true,
503513
kwargs...)
504514
f,tt = funcs[1]
505515
mkpath(path)
506516
obj_path = joinpath(path, "$filenamebase.o")
507-
fakejob, kwargs = native_job(f,tt, kwargs...)
517+
fakejob, kwargs = native_job(f,tt, ext = true, kwargs...)
508518
mod = native_llvm_module(funcs; demangle = demangle, kwargs...)
509519
obj, _ = GPUCompiler.emit_asm(fakejob, mod; strip=strip_asm, validate=false, format=LLVM.API.LLVMObjectFile)
510520
open(obj_path, "w") do io

src/quirks.jl

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
macro print_and_throw(err)
2+
quote
3+
println(err)
4+
libcexit(Int32(1))
5+
end
6+
end
7+
8+
# math.jl
9+
@device_override @noinline Base.Math.throw_complex_domainerror(f::Symbol, x) =
10+
@print_and_throw c"This operation requires a complex input to return a complex result"
11+
@device_override @noinline Base.Math.throw_exp_domainerror(f::Symbol, x) =
12+
@print_and_throw c"Exponentiation yielding a complex result requires a complex argument"
13+
14+
# intfuncs.jl
15+
@device_override @noinline Base.throw_domerr_powbysq(::Any, p) =
16+
@print_and_throw c"Cannot raise an integer to a negative power"
17+
@device_override @noinline Base.throw_domerr_powbysq(::Integer, p) =
18+
@print_and_throw c"Cannot raise an integer to a negative power"
19+
@device_override @noinline Base.throw_domerr_powbysq(::AbstractMatrix, p) =
20+
@print_and_throw c"Cannot raise an integer to a negative power"
21+
@device_override @noinline Base.__throw_gcd_overflow(a, b) =
22+
@print_and_throw c"gcd overflow"
23+
24+
# checked.jl
25+
@device_override @noinline Base.Checked.throw_overflowerr_binaryop(op, x, y) =
26+
@print_and_throw c"Binary operation overflowed"
27+
@device_override @noinline Base.Checked.throw_overflowerr_negation(op, x, y) =
28+
@print_and_throw c"Negation overflowed"
29+
@device_override function Base.Checked.checked_abs(x::Base.Checked.SignedInt)
30+
r = ifelse(x<0, -x, x)
31+
r<0 && @print_and_throw(c"checked arithmetic: cannot compute |x|")
32+
r
33+
end
34+
35+
# boot.jl
36+
@device_override @noinline Core.throw_inexacterror(f::Symbol, ::Type{T}, val) where {T} =
37+
@print_and_throw c"Inexact conversion"
38+
39+
# abstractarray.jl
40+
@device_override @noinline Base.throw_boundserror(A, I) =
41+
@print_and_throw c"Out-of-bounds array access"
42+
43+
# trig.jl
44+
@device_override @noinline Base.Math.sincos_domain_error(x) =
45+
@print_and_throw c"sincos(x) is only defined for finite x."
46+
47+
48+
# range.jl
49+
@static if VERSION >= v"1.7-"
50+
@eval begin
51+
@device_override function Base.StepRangeLen{T,R,S,L}(ref::R, step::S, len::Integer,
52+
offset::Integer=1) where {T,R,S,L}
53+
if T <: Integer && !isinteger(ref + step)
54+
@print_and_throw(c"StepRangeLen{<:Integer} cannot have non-integer step")
55+
end
56+
len = convert(L, len)
57+
len >= zero(len) || @print_and_throw(c"StepRangeLen length cannot be negative")
58+
offset = convert(L, offset)
59+
L1 = oneunit(typeof(len))
60+
L1 <= offset <= max(L1, len) || @print_and_throw(c"StepRangeLen: offset must be in [1,...]")
61+
$(
62+
Expr(:new, :(StepRangeLen{T,R,S,L}), :ref, :step, :len, :offset)
63+
)
64+
end
65+
end
66+
else
67+
@device_override function Base.StepRangeLen{T,R,S}(ref::R, step::S, len::Integer,
68+
offset::Integer=1) where {T,R,S}
69+
if T <: Integer && !isinteger(ref + step)
70+
@print_and_throw(c"StepRangeLen{<:Integer} cannot have non-integer step")
71+
end
72+
len >= 0 || @print_and_throw(c"StepRangeLen length cannot be negative")
73+
1 <= offset <= max(1,len) || @print_and_throw(c"StepRangeLen: offset must be in [1,...]")
74+
new(ref, step, len, offset)
75+
end
76+
end
77+
78+
79+
# fastmath.jl
80+
@static if VERSION <= v"1.7-"
81+
## prevent fallbacks to libm
82+
for f in (:acosh, :asinh, :atanh, :cbrt, :cosh, :exp2, :expm1, :log1p, :sinh, :tanh)
83+
f_fast = Base.FastMath.fast_op[f]
84+
@eval begin
85+
@device_override Base.FastMath.$f_fast(x::Float32) = $f(x)
86+
@device_override Base.FastMath.$f_fast(x::Float64) = $f(x)
87+
end
88+
end
89+
end

src/target.jl

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,44 +3,62 @@ Base.@kwdef struct NativeCompilerTarget <: GPUCompiler.AbstractCompilerTarget
33
features::String=(LLVM.version() < v"8") ? "" : unsafe_string(LLVM.API.LLVMGetHostCPUFeatures())
44
end
55

6-
GPUCompiler.llvm_triple(::NativeCompilerTarget) = Sys.MACHINE
6+
Base.@kwdef struct ExternalNativeCompilerTarget <: GPUCompiler.AbstractCompilerTarget
7+
cpu::String=(LLVM.version() < v"8") ? "" : unsafe_string(LLVM.API.LLVMGetHostCPUName())
8+
features::String=(LLVM.version() < v"8") ? "" : unsafe_string(LLVM.API.LLVMGetHostCPUFeatures())
9+
end
710

8-
function GPUCompiler.llvm_machine(target::NativeCompilerTarget)
9-
triple = GPUCompiler.llvm_triple(target)
11+
module StaticRuntime
12+
# the runtime library
13+
signal_exception() = return
14+
malloc(sz) = ccall("extern malloc", llvmcall, Csize_t, (Csize_t,), sz)
15+
report_oom(sz) = return
16+
report_exception(ex) = return
17+
report_exception_name(ex) = return
18+
report_exception_frame(idx, func, file, line) = return
19+
end
1020

11-
t = LLVM.Target(triple=triple)
21+
struct StaticCompilerParams <: GPUCompiler.AbstractCompilerParams end
1222

13-
tm = LLVM.TargetMachine(t, triple, target.cpu, target.features, reloc=LLVM.API.LLVMRelocPIC)
14-
GPUCompiler.asm_verbosity!(tm, true)
23+
for target in (:NativeCompilerTarget, :ExternalNativeCompilerTarget)
24+
@eval begin
25+
GPUCompiler.llvm_triple(::$target) = Sys.MACHINE
1526

16-
return tm
17-
end
27+
function GPUCompiler.llvm_machine(target::$target)
28+
triple = GPUCompiler.llvm_triple(target)
1829

19-
GPUCompiler.runtime_slug(job::GPUCompiler.CompilerJob{NativeCompilerTarget}) = "native_$(job.target.cpu)-$(hash(job.target.features))"
30+
t = LLVM.Target(triple=triple)
2031

21-
module StaticRuntime
22-
# the runtime library
23-
signal_exception() = return
24-
malloc(sz) = ccall("extern malloc", llvmcall, Csize_t, (Csize_t,), sz)
25-
report_oom(sz) = return
26-
report_exception(ex) = return
27-
report_exception_name(ex) = return
28-
report_exception_frame(idx, func, file, line) = return
29-
end
32+
tm = LLVM.TargetMachine(t, triple, target.cpu, target.features, reloc=LLVM.API.LLVMRelocPIC)
33+
GPUCompiler.asm_verbosity!(tm, true)
3034

31-
struct StaticCompilerParams <: GPUCompiler.AbstractCompilerParams end
35+
return tm
36+
end
3237

33-
GPUCompiler.runtime_module(::GPUCompiler.CompilerJob{<:Any,StaticCompilerParams}) = StaticRuntime
34-
GPUCompiler.runtime_module(::GPUCompiler.CompilerJob{NativeCompilerTarget}) = StaticRuntime
35-
GPUCompiler.runtime_module(::GPUCompiler.CompilerJob{NativeCompilerTarget, StaticCompilerParams}) = StaticRuntime
38+
GPUCompiler.runtime_slug(job::GPUCompiler.CompilerJob{$target}) = "native_$(job.target.cpu)-$(hash(job.target.features))"
39+
40+
GPUCompiler.runtime_module(::GPUCompiler.CompilerJob{$target}) = StaticRuntime
41+
GPUCompiler.runtime_module(::GPUCompiler.CompilerJob{$target, StaticCompilerParams}) = StaticRuntime
3642

43+
44+
GPUCompiler.can_throw(job::GPUCompiler.CompilerJob{$target, StaticCompilerParams}) = true
45+
GPUCompiler.can_throw(job::GPUCompiler.CompilerJob{$target}) = true
46+
end
47+
end
48+
49+
GPUCompiler.runtime_module(::GPUCompiler.CompilerJob{<:Any,StaticCompilerParams}) = StaticRuntime
3750
GPUCompiler.can_throw(job::GPUCompiler.CompilerJob{<:Any,StaticCompilerParams}) = true
38-
GPUCompiler.can_throw(job::GPUCompiler.CompilerJob{NativeCompilerTarget, StaticCompilerParams}) = true
39-
GPUCompiler.can_throw(job::GPUCompiler.CompilerJob{NativeCompilerTarget}) = true
4051

41-
function native_job(@nospecialize(func), @nospecialize(types); kernel::Bool=false, name=GPUCompiler.safe_name(repr(func)), kwargs...)
52+
# GPUCompiler.method_table(@nospecialize(job::GPUCompiler.CompilerJob{<:Any,StaticCompilerParams})) = nothing
53+
# GPUCompiler.method_table(@nospecialize(job::GPUCompiler.CompilerJob{NativeCompilerTarget})) = nothing
54+
# GPUCompiler.method_table(@nospecialize(job::GPUCompiler.CompilerJob{NativeCompilerTarget, StaticCompilerParams})) = nothing
55+
56+
GPUCompiler.method_table(@nospecialize(job::GPUCompiler.CompilerJob{ExternalNativeCompilerTarget})) = method_table
57+
GPUCompiler.method_table(@nospecialize(job::GPUCompiler.CompilerJob{ExternalNativeCompilerTarget, StaticCompilerParams})) = method_table
58+
59+
function native_job(@nospecialize(func), @nospecialize(types); kernel::Bool=false, name=GPUCompiler.safe_name(repr(func)), ext = true, kwargs...)
4260
source = GPUCompiler.FunctionSpec(func, Base.to_tuple_type(types), kernel, name)
43-
target = NativeCompilerTarget()
61+
target = ext ? ExternalNativeCompilerTarget() : NativeCompilerTarget()
4462
params = StaticCompilerParams()
4563
GPUCompiler.CompilerJob(target, source, params), kwargs
4664
end

src/utils.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
2+
3+
# local method table for device functions
4+
@static if isdefined(Base.Experimental, Symbol("@overlay"))
5+
Base.Experimental.@MethodTable(method_table)
6+
else
7+
const method_table = nothing
8+
end
9+
10+
# list of overrides (only for Julia 1.6)
11+
const overrides = Expr[]
12+
13+
macro device_override(ex)
14+
ex = macroexpand(__module__, ex)
15+
if Meta.isexpr(ex, :call)
16+
@show ex = eval(ex)
17+
error()
18+
end
19+
code = quote
20+
$GPUCompiler.@override(StaticCompiler.method_table, $ex)
21+
end
22+
if isdefined(Base.Experimental, Symbol("@overlay"))
23+
return esc(code)
24+
else
25+
push!(overrides, code)
26+
return
27+
end
28+
end
29+
30+
macro device_function(ex)
31+
ex = macroexpand(__module__, ex)
32+
def = splitdef(ex)
33+
34+
# generate a function that errors
35+
def[:body] = quote
36+
error("This function is not intended for use on the CPU")
37+
end
38+
39+
esc(quote
40+
$(combinedef(def))
41+
@device_override $ex
42+
end)
43+
end
44+
45+
macro device_functions(ex)
46+
ex = macroexpand(__module__, ex)
47+
48+
# recursively prepend `@device_function` to all function definitions
49+
function rewrite(block)
50+
out = Expr(:block)
51+
for arg in block.args
52+
if Meta.isexpr(arg, :block)
53+
# descend in blocks
54+
push!(out.args, rewrite(arg))
55+
elseif Meta.isexpr(arg, [:function, :(=)])
56+
# rewrite function definitions
57+
push!(out.args, :(@device_function $arg))
58+
else
59+
# preserve all the rest
60+
push!(out.args, arg)
61+
end
62+
end
63+
out
64+
end
65+
66+
esc(rewrite(ex))
67+
end
68+
69+
libcexit(x::Int32) = @symbolcall exit(x::Int32)::Nothing

0 commit comments

Comments
 (0)