Skip to content

Commit d39a270

Browse files
authored
Merge pull request #53 from MasonProtter/master
New `compile` interface
2 parents 6e20094 + 4d348e5 commit d39a270

7 files changed

Lines changed: 305 additions & 83 deletions

File tree

Manifest.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# This file is machine-generated - editing it directly is not advised
22

3+
julia_version = "1.7.1"
34
manifest_format = "2.0"
45

56
[[deps.ArgTools]]
@@ -16,6 +17,12 @@ git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9"
1617
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
1718
version = "0.4.1"
1819

20+
[[deps.Clang_jll]]
21+
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll", "libLLVM_jll"]
22+
git-tree-sha1 = "8cf7e67e264dedc5d321ec87e78525e958aea057"
23+
uuid = "0ee61d77-7f21-5576-8119-9fcc46b10100"
24+
version = "12.0.1+3"
25+
1926
[[deps.Dates]]
2027
deps = ["Printf"]
2128
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
@@ -149,6 +156,10 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
149156
deps = ["Libdl"]
150157
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
151158

159+
[[deps.libLLVM_jll]]
160+
deps = ["Artifacts", "Libdl"]
161+
uuid = "8f36deef-c2a5-5394-99ed-8e07531fb29a"
162+
152163
[[deps.nghttp2_jll]]
153164
deps = ["Artifacts", "Libdl"]
154165
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
name = "StaticCompiler"
22
uuid = "81625895-6c0f-48fc-b932-11a18313743c"
33
authors = ["Tom Short"]
4-
version = "0.2.0"
4+
version = "0.3.0"
55

66
[deps]
7+
Clang_jll = "0ee61d77-7f21-5576-8119-9fcc46b10100"
78
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
89
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
910
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
11+
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
1012

1113
[compat]
12-
julia = "1.7"
1314
GPUCompiler = "0.13"
1415
LLVM = "4"
16+
julia = "1.7"
1517

1618
[extras]
1719
Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0"

README.md

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,29 @@ Pkg.add(PackageSpec( url = "https://github.com/tshort/StaticCompiler.jl", rev =
1616
```
1717

1818
```julia
19-
using StaticCompiler
20-
f(x) = 2x
19+
julia> using StaticCompiler
2120

22-
# compile `f` and return an LLVM module
23-
m = compile(f, (Int,))
21+
julia> fib(n) = n <= 1 ? n : fib(n - 1) + fib(n - 2)
22+
fib (generic function with 1 method)
2423

25-
# compile `f` and write to a shared library ("f.so" or "f.dll")
26-
generate_shlib(f, (Int,), "libf")
27-
# find a function pointer for this shared library
28-
fptr = generate_shlib_fptr("libf", "f")
29-
@ccall $fptr(2::Int)::Int
24+
julia> fib_compiled, path = compile(fib, Tuple{Int}, "fib")
25+
(f = fib(::Int64) :: Int64, path = "fib")
3026

31-
# do this in one step (this time with a temporary shared library)
32-
fptr = generate_shlib_fptr(f, (Int,))
33-
@ccall $fptr(2::Int)::Int
27+
julia> fib_compiled(10)
28+
55
29+
```
30+
Now we can quit this session and load a new one where `fib` is not defined:
31+
```julia
32+
julia> using StaticCompiler
33+
34+
julia> fib
35+
ERROR: UndefVarError: fib not defined
36+
37+
julia> fib_compiled = load_function("fib")
38+
fib(::Int64) :: Int64
39+
40+
julia> fib_compiled(10)
41+
55
3442
```
3543

3644
## Approach
@@ -39,8 +47,7 @@ This package uses the [GPUCompiler package](https://github.com/JuliaGPU/GPUCompi
3947

4048
## Limitations
4149

42-
* This package currently requires that you have `gcc` installed and in your system's `PATH`. This is probably pretty easy to fix, we only use `gcc` for linking. In theory Clang_jll or LLVM_full_jll should be able to do this, and be managed through Julia's package manager.
4350
* No heap allocations (e.g. creating an array or a string) are allowed inside a statically compiled function body. If you try to run such a function, you will get a segfault.
44-
** It's sometimes possible you won't get a segfault if you define and run the function in the same session, but trying to call the compiled function in a new julia session will definitely segfault.
45-
* Lots of other limitations too. E.g. there's an example in tests/runtests.jl where summing a vector of `Complex{Float32}` is fine, but segfaults on `Complex{Float64}`.
46-
* Doesn't currently work on Windows
51+
** It's sometimes possible you won't get a segfault if you define and run the function in the same session, but trying to call the compiled function in a new julia session will definitely segfault if you allocate memory.
52+
* Doesn't currently work on Windows
53+
* If you find any other limitations, let us know. There's probably lots.

src/StaticCompiler.jl

Lines changed: 219 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,136 @@ module StaticCompiler
33
using GPUCompiler: GPUCompiler
44
using LLVM: LLVM
55
using Libdl: Libdl
6+
using Base: RefValue
7+
using Serialization: serialize, deserialize
8+
using Clang_jll: clang
69

10+
export compile, load_function
11+
export native_code_llvm, native_code_typed, native_llvm_module
712

8-
export generate_shlib, generate_shlib_fptr, compile, native_code_llvm, native_code_typed, native_llvm_module
13+
"""
14+
compile(f, types, path::String = tempname()) --> (compiled_f, path)
15+
16+
!!! Warning: this will fail on programs that heap allocate any memory, or have dynamic dispatch !!!
17+
18+
Statically compile the method of a function `f` specialized to arguments of the type given by `types`.
19+
20+
This will create a directory at the specified path (or in a temporary directory if you exclude that argument)
21+
that contains the files needed for your static compiled function. `compile` will return a
22+
`StaticCompiledFunction` object and `obj_path` which is the absolute path of the directory containing the
23+
compilation artifacts. The `StaticCompiledFunction` can be treated as if it is a function with a single
24+
method corresponding to the types you specified when it was compiled.
25+
26+
To deserialize and instantiate a previously compiled function, simply execute `load_function(path)`, which
27+
returns a callable `StaticCompiledFunction`.
28+
29+
### Example:
30+
31+
Define and compile a `fib` function:
32+
```julia
33+
julia> using StaticCompiler
34+
35+
julia> fib(n) = n <= 1 ? n : fib(n - 1) + fib(n - 2)
36+
fib (generic function with 1 method)
37+
38+
julia> fib_compiled, path = compile(fib, Tuple{Int}, "fib")
39+
(f = fib(::Int64) :: Int64, path = "fib")
40+
41+
julia> fib_compiled(10)
42+
55
43+
```
44+
Now we can quit this session and load a new one where `fib` is not defined:
45+
```julia
46+
julia> fib
47+
ERROR: UndefVarError: fib not defined
48+
49+
julia> using StaticCompiler
50+
51+
julia> fib_compiled = load_function("fib.cjl")
52+
fib(::Int64) :: Int64
53+
54+
julia> fib_compiled(10)
55+
55
56+
```
57+
Tada!
58+
59+
### Details:
60+
61+
Here is the structure of the directory created by `compile` in the above example:
62+
```julia
63+
shell> tree fib
64+
path
65+
├── obj.cjl
66+
├── obj.o
67+
└── obj.so
68+
69+
0 directories, 3 files
70+
````
71+
* `obj.so` (or `.dylib` on MacOS) is a shared object file that can be linked to in order to execute your
72+
compiled julia function.
73+
* `obj.cjl` is a serialized `LazyStaticCompiledFunction` object which will be deserialized and instantiated
74+
with `load_function(path)`. `LazyStaticcompiledfunction`s contain the requisite information needed to link to the
75+
`obj.so` inside a julia session. Once it is instantiated in a julia session (i.e. by
76+
`instantiate(::LazyStaticCompiledFunction)`, this happens automatically in `load_function`), it will be of type
77+
`StaticCompiledFunction` and may be called with arguments of type `types` as if it were a function with a
78+
single method (the method determined by `types`).
79+
"""
80+
function compile(f, _tt, path::String = tempname(); name = GPUCompiler.safe_name(repr(f)), kwargs...)
81+
tt = Base.to_tuple_type(_tt)
82+
isconcretetype(tt) || error("input type signature $_tt is not concrete")
83+
84+
rt = only(native_code_typed(f, tt))[2]
85+
isconcretetype(rt) || error("$f on $_tt did not infer to a concrete type. Got $rt")
86+
87+
# Would be nice to use a compiler pass or something to check if there are any heap allocations or references to globals
88+
# Keep an eye on https://github.com/JuliaLang/julia/pull/43747 for this
89+
90+
f_wrap!(out::Ref, args::Ref{<:Tuple}) = (out[] = f(args[]...); nothing)
91+
92+
generate_shlib(f_wrap!, Tuple{RefValue{rt}, RefValue{tt}}, path, name; kwargs...)
93+
94+
lf = LazyStaticCompiledFunction{rt, tt}(Symbol(f), path, name)
95+
cjl_path = joinpath(path, "obj.cjl")
96+
serialize(cjl_path, lf)
97+
(; f = instantiate(lf), path=abspath(path))
98+
end
99+
100+
"""
101+
load_function(path) --> compiled_f
102+
103+
load a `StaticCompiledFunction` from a given path. This object is callable.
104+
"""
105+
load_function(path) = instantiate(deserialize(joinpath(path, "obj.cjl")) :: LazyStaticCompiledFunction)
106+
107+
struct LazyStaticCompiledFunction{rt, tt}
108+
f::Symbol
109+
path::String
110+
name::String
111+
end
112+
113+
function instantiate(p::LazyStaticCompiledFunction{rt, tt}) where {rt, tt}
114+
StaticCompiledFunction{rt, tt}(p.f, generate_shlib_fptr(p.path::String, p.name))
115+
end
116+
117+
struct StaticCompiledFunction{rt, tt}
118+
f::Symbol
119+
ptr::Ptr{Nothing}
120+
end
121+
122+
function Base.show(io::IO, f::StaticCompiledFunction{rt, tt}) where {rt, tt}
123+
types = [tt.parameters...]
124+
print(io, String(f.f), "(", join(("::$T" for T tt.parameters), ',') ,") :: $rt")
125+
end
126+
127+
function (f::StaticCompiledFunction{rt, tt})(args...) where {rt, tt}
128+
Tuple{typeof.(args)...} == tt || error("Input types don't match compiled target $((tt.parameters...,)). Got arguments of type $(typeof.(args))")
129+
out = RefValue{rt}()
130+
refargs = Ref(args)
131+
ccall(f.ptr, Nothing, (Ref{rt}, Ref{tt}), out, refargs)
132+
out[]
133+
end
134+
135+
instantiate(f::StaticCompiledFunction) = f
9136

10137
module TestRuntime
11138
# dummy methods
@@ -30,39 +157,121 @@ function native_job(@nospecialize(func), @nospecialize(types); kernel::Bool=fals
30157
GPUCompiler.CompilerJob(target, source, params), kwargs
31158
end
32159

160+
161+
"""
162+
```julia
163+
generate_shlib(f, tt, path::String, name::String; kwargs...)
164+
```
165+
Low level interface for compiling a shared object / dynamically loaded library
166+
(`.so` / `.dylib`) for function `f` given a tuple type `tt` characterizing
167+
the types of the arguments for which the function will be compiled.
168+
169+
See also `StaticCompiler.generate_shlib_fptr`.
170+
171+
### Examples
172+
```julia
173+
julia> function test(n)
174+
r = 0.0
175+
for i=1:n
176+
r += log(sqrt(i))
177+
end
178+
return r/n
179+
end
180+
test (generic function with 1 method)
181+
182+
julia> path, name = StaticCompiler.generate_shlib(test, Tuple{Int64}, "./test")
183+
("./test", "test")
184+
185+
shell> tree \$path
186+
./test
187+
|-- obj.o
188+
`-- obj.so
189+
190+
0 directories, 2 files
191+
192+
julia> test(100_000)
193+
5.256496109495593
194+
195+
julia> ccall(StaticCompiler.generate_shlib_fptr(path, name), Float64, (Int64,), 100_000)
196+
5.256496109495593
197+
```
198+
"""
33199
function generate_shlib(f, tt, path::String = tempname(), name = GPUCompiler.safe_name(repr(f)); kwargs...)
34-
open(path, "w") do io
200+
mkpath(path)
201+
obj_path = joinpath(path, "obj.o")
202+
lib_path = joinpath(path, "obj.$(Libdl.dlext)")
203+
open(obj_path, "w") do io
35204
job, kwargs = native_job(f, tt; name, kwargs...)
36205
obj, _ = GPUCompiler.codegen(:obj, job; strip=true, only_entry=false, validate=false)
37-
206+
38207
write(io, obj)
39208
flush(io)
40-
run(`gcc -shared -o $path.$(Libdl.dlext) $path`)
41-
rm(path)
209+
210+
# Pick a Clang
211+
cc = Sys.isapple() ? `cc` : clang()
212+
# Compile!
213+
run(`$cc -shared -o $lib_path $obj_path`)
42214
end
43215
path, name
44216
end
45217

218+
46219
function generate_shlib_fptr(f, tt, path::String=tempname(), name = GPUCompiler.safe_name(repr(f)); temp::Bool=true, kwargs...)
47220
generate_shlib(f, tt, path, name; kwargs...)
48-
ptr = Libdl.dlopen("$(abspath(path)).$(Libdl.dlext)", Libdl.RTLD_LOCAL)
221+
lib_path = joinpath(abspath(path), "obj.$(Libdl.dlext)")
222+
ptr = Libdl.dlopen(lib_path, Libdl.RTLD_LOCAL)
49223
fptr = Libdl.dlsym(ptr, "julia_$name")
50224
@assert fptr != C_NULL
51225
if temp
52-
atexit(()->rm("$path.$(Libdl.dlext)"))
226+
atexit(()->rm(path; recursive=true))
53227
end
54228
fptr
55229
end
56230

231+
"""
232+
```julia
233+
generate_shlib_fptr(path::String, name)
234+
```
235+
Low level interface for obtaining a function pointer by `dlopen`ing a shared
236+
library given the `path` and `name` of a `.so`/`.dylib` already compiled by
237+
`generate_shlib`.
238+
239+
See also `StaticCompiler.enerate_shlib`.
240+
241+
### Examples
242+
```julia
243+
julia> function test(n)
244+
r = 0.0
245+
for i=1:n
246+
r += log(sqrt(i))
247+
end
248+
return r/n
249+
end
250+
test (generic function with 1 method)
251+
252+
julia> path, name = StaticCompiler.generate_shlib(test, Tuple{Int64}, "./test");
253+
254+
julia> test_ptr = StaticCompiler.generate_shlib_fptr(path, name)
255+
Ptr{Nothing} @0x000000015209f600
256+
257+
julia> ccall(test_ptr, Float64, (Int64,), 100_000)
258+
5.256496109495593
259+
260+
julia> @ccall \$test_ptr(100_000::Int64)::Float64 # Equivalently
261+
5.256496109495593
262+
263+
julia> test(100_000)
264+
5.256496109495593
265+
```
266+
"""
57267
function generate_shlib_fptr(path::String, name)
58-
ptr = Libdl.dlopen("$(abspath(path)).$(Libdl.dlext)", Libdl.RTLD_LOCAL)
268+
lib_path = joinpath(abspath(path), "obj.$(Libdl.dlext)")
269+
ptr = Libdl.dlopen(lib_path, Libdl.RTLD_LOCAL)
59270
fptr = Libdl.dlsym(ptr, "julia_$name")
60271
@assert fptr != C_NULL
61272
fptr
62273
end
63274

64-
65-
66275
function native_code_llvm(@nospecialize(func), @nospecialize(types); kwargs...)
67276
job, kwargs = native_job(func, types; kwargs...)
68277
GPUCompiler.code_llvm(stdout, job; kwargs...)
@@ -80,5 +289,4 @@ function native_llvm_module(f, tt, name = GPUCompiler.safe_name(repr(f)); kwargs
80289
return m
81290
end
82291

83-
84292
end # module

test/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
66
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
77
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
88
ManualMemory = "d125e4d3-2237-4719-b19c-fa641b8a4667"
9-
StrideArraysCore = "7792a7ef-975c-4747-a70f-980b88e8d1da"
9+
StrideArraysCore = "7792a7ef-975c-4747-a70f-980b88e8d1da"
10+
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"

0 commit comments

Comments
 (0)