Skip to content

Commit 73fc2ec

Browse files
authored
Merge pull request #79 from gbaraldi/xfunc-shlib
Multiple function shared lib
2 parents d1dfd81 + 16e047e commit 73fc2ec

2 files changed

Lines changed: 100 additions & 1 deletion

File tree

src/StaticCompiler.jl

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,6 @@ function generate_shlib(f, tt, path::String = tempname(), name = GPUCompiler.saf
435435
path, name
436436
end
437437

438-
439438
function native_code_llvm(@nospecialize(func), @nospecialize(types); kwargs...)
440439
job, kwargs = native_job(func, types; kwargs...)
441440
GPUCompiler.code_llvm(stdout, job; kwargs...)
@@ -460,9 +459,83 @@ function native_code_native(@nospecialize(f), @nospecialize(tt), name = GPUCompi
460459
GPUCompiler.code_native(stdout, job; kwargs...)
461460
end
462461

462+
#Return an LLVM module for multiple functions
463+
function native_llvm_module(funcs::Array; demangle = false, kwargs...)
464+
f,tt = funcs[1]
465+
mod = native_llvm_module(f,tt, kwargs...)
466+
if length(funcs) > 1
467+
for func in funcs[2:end]
468+
@show f,tt = func
469+
tmod = native_llvm_module(f,tt, kwargs...)
470+
link!(mod,tmod)
471+
end
472+
end
473+
if demangle
474+
for func in functions(mod)
475+
fname = name(func)
476+
if fname[1:6] == "julia_"
477+
name!(func,fname[7:end])
478+
end
479+
end
480+
end
481+
LLVM.ModulePassManager() do pass_manager #remove duplicate functions
482+
LLVM.merge_functions!(pass_manager)
483+
LLVM.run!(pass_manager, mod)
484+
end
485+
return mod
486+
end
463487

488+
function generate_obj(funcs::Array, path::String = tempname(), filenamebase::String="obj";
489+
demangle =false,
490+
strip_llvm = false,
491+
strip_asm = true,
492+
opt_level=3,
493+
kwargs...)
494+
f,tt = funcs[1]
495+
mkpath(path)
496+
obj_path = joinpath(path, "$filenamebase.o")
497+
fakejob, kwargs = native_job(f,tt, kwargs...)
498+
mod = native_llvm_module(funcs; demangle = demangle, kwargs...)
499+
obj, _ = GPUCompiler.emit_asm(fakejob, mod; strip=strip_asm, validate=false, format=LLVM.API.LLVMObjectFile)
500+
open(obj_path, "w") do io
501+
write(io, obj)
502+
end
503+
path, obj_path
504+
end
464505

506+
function generate_shlib(funcs::Array, path::String = tempname(), filename::String="libfoo"; demangle=false, kwargs...)
465507

508+
lib_path = joinpath(path, "$filename.$(Libdl.dlext)")
509+
510+
_,obj_path = generate_obj(funcs, path, filename; demangle=demangle, kwargs...)
511+
# Pick a Clang
512+
cc = Sys.isapple() ? `cc` : clang()
513+
# Compile!
514+
run(`$cc -shared -o $lib_path $obj_path`)
515+
516+
path, name
517+
end
518+
519+
function compile_shlib(funcs::Array, path::String="./";
520+
filename="libfoo",
521+
demangle=false,
522+
kwargs...)
523+
for func in funcs
524+
f, types = func
525+
tt = Base.to_tuple_type(types)
526+
isconcretetype(tt) || error("input type signature $types is not concrete")
527+
528+
rt = only(native_code_typed(f, tt))[2]
529+
isconcretetype(rt) || error("$f$types did not infer to a concrete type. Got $rt")
530+
end
531+
532+
# Would be nice to use a compiler pass or something to check if there are any heap allocations or references to globals
533+
# Keep an eye on https://github.com/JuliaLang/julia/pull/43747 for this
534+
535+
generate_shlib(funcs, path, filename; demangle=demangle, kwargs...)
536+
537+
joinpath(abspath(path), filename * "." * Libdl.dlext)
538+
end
466539

467540

468541
end # module

test/testcore.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,3 +292,29 @@ end
292292
@test isa(r, Base.Process)
293293
@test r.exitcode == 0
294294
end
295+
296+
@noinline square(n) = n*n
297+
298+
function squaresquare(n)
299+
square(square(n))
300+
end
301+
302+
function squaresquaresquare(n)
303+
square(squaresquare(n))
304+
end
305+
306+
@testset "Multiple Function Dylibs" begin
307+
308+
309+
funcs = [(squaresquare,(Float64,)), (squaresquaresquare,(Float64,))]
310+
filepath = compile_shlib(funcs, demangle=true)
311+
312+
ptr = Libdl.dlopen(filepath, Libdl.RTLD_LOCAL)
313+
314+
fptr2 = Libdl.dlsym(ptr, "squaresquare")
315+
@test ccall(fptr2, Float64, (Float64,), 10.) == squaresquare(10.)
316+
317+
fptr = Libdl.dlsym(ptr, "squaresquaresquare")
318+
@test ccall(fptr, Float64, (Float64,), 10.) == squaresquaresquare(10.)
319+
#Compile dylib
320+
end

0 commit comments

Comments
 (0)