Skip to content

Commit 797e87b

Browse files
committed
Add gc use intrinisic
Now that our gc root placement pass is significantly more aggressive about dropping roots, we need to be a bit more careful about our use of pointer, etc. This adds a low-level intrinsic for annotating gc uses to keep objects alive even if they would be otherwise unreferenced. As an initial use case, we get rid of a number of uses of `pointer` in string, but creating a new `unsafe_load` that keeps the string alive.
1 parent 43ced35 commit 797e87b

File tree

7 files changed

+45
-17
lines changed

7 files changed

+45
-17
lines changed

base/boot.jl

+2
Original file line numberDiff line numberDiff line change
@@ -438,4 +438,6 @@ show(@nospecialize a) = show(STDOUT, a)
438438
print(@nospecialize a...) = print(STDOUT, a...)
439439
println(@nospecialize a...) = println(STDOUT, a...)
440440

441+
gcuse(@nospecialize a) = ccall(:jl_gc_use, Void, (Any,), a)
442+
441443
ccall(:jl_set_istopmod, Void, (Any, Bool), Core, true)

base/strings/string.jl

+18-15
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ String(s::Symbol) = unsafe_string(Cstring(s))
5959
pointer(s::String) = unsafe_convert(Ptr{UInt8}, s)
6060
pointer(s::String, i::Integer) = pointer(s)+(i-1)
6161

62+
function unsafe_load(s::String, i::Integer=1)
63+
ptr = pointer(s, i)
64+
r = unsafe_load(ptr)
65+
Core.gcuse(s)
66+
r
67+
end
68+
6269
sizeof(s::String) = Core.sizeof(s)
6370

6471
"""
@@ -73,7 +80,7 @@ codeunit(s::AbstractString, i::Integer)
7380
@boundscheck if (i < 1) | (i > sizeof(s))
7481
throw(BoundsError(s,i))
7582
end
76-
unsafe_load(pointer(s),i)
83+
unsafe_load(s, i)
7784
end
7885

7986
write(io::IO, s::String) = unsafe_write(io, pointer(s), reinterpret(UInt, sizeof(s)))
@@ -160,26 +167,24 @@ const utf8_trailing = [
160167
## required core functionality ##
161168

162169
function endof(s::String)
163-
p = pointer(s)
164170
i = sizeof(s)
165-
while i > 0 && is_valid_continuation(unsafe_load(p,i))
171+
while i > 0 && is_valid_continuation(unsafe_load(s, i))
166172
i -= 1
167173
end
168174
i
169175
end
170176

171177
function length(s::String)
172-
p = pointer(s)
173178
cnum = 0
174179
for i = 1:sizeof(s)
175-
cnum += !is_valid_continuation(unsafe_load(p,i))
180+
cnum += !is_valid_continuation(unsafe_load(s, i))
176181
end
177182
cnum
178183
end
179184

180-
@noinline function slow_utf8_next(p::Ptr{UInt8}, b::UInt8, i::Int, l::Int)
185+
@noinline function slow_utf8_next(s::String, b::UInt8, i::Int, l::Int)
181186
if is_valid_continuation(b)
182-
throw(UnicodeError(UTF_ERR_INVALID_INDEX, i, unsafe_load(p,i)))
187+
throw(UnicodeError(UTF_ERR_INVALID_INDEX, i, unsafe_load(s, i)))
183188
end
184189
trailing = utf8_trailing[b + 1]
185190
if l < i + trailing
@@ -188,7 +193,7 @@ end
188193
c::UInt32 = 0
189194
for j = 1:(trailing + 1)
190195
c <<= 6
191-
c += unsafe_load(p,i)
196+
c += unsafe_load(s, i)
192197
i += 1
193198
end
194199
c -= utf8_offset[trailing + 1]
@@ -206,12 +211,11 @@ done(s::String, state) = state > sizeof(s)
206211
@boundscheck if (i < 1) | (i > sizeof(s))
207212
throw(BoundsError(s,i))
208213
end
209-
p = pointer(s)
210-
b = unsafe_load(p, i)
214+
b = unsafe_load(s, i)
211215
if b < 0x80
212216
return Char(b), i + 1
213217
end
214-
return slow_utf8_next(p, b, i, sizeof(s))
218+
return slow_utf8_next(s, b, i, sizeof(s))
215219
end
216220

217221
function first_utf8_byte(ch::Char)
@@ -225,8 +229,7 @@ end
225229

226230
function reverseind(s::String, i::Integer)
227231
j = sizeof(s) + 1 - i
228-
p = pointer(s)
229-
while is_valid_continuation(unsafe_load(p,j))
232+
while is_valid_continuation(unsafe_load(s, j))
230233
j -= 1
231234
end
232235
return j
@@ -235,7 +238,7 @@ end
235238
## overload methods for efficiency ##
236239

237240
isvalid(s::String, i::Integer) =
238-
(1 <= i <= sizeof(s)) && !is_valid_continuation(unsafe_load(pointer(s),i))
241+
(1 <= i <= sizeof(s)) && !is_valid_continuation(unsafe_load(s, i))
239242

240243
function getindex(s::String, r::UnitRange{Int})
241244
isempty(r) && return ""
@@ -438,7 +441,7 @@ function repeat(s::String, r::Integer)
438441
n = sizeof(s)
439442
out = _string_n(n*r)
440443
if n == 1 # common case: repeating a single ASCII char
441-
ccall(:memset, Ptr{Void}, (Ptr{UInt8}, Cint, Csize_t), out, unsafe_load(pointer(s)), r)
444+
ccall(:memset, Ptr{Void}, (Ptr{UInt8}, Cint, Csize_t), out, unsafe_load(s), r)
442445
else
443446
for i=1:r
444447
unsafe_copy!(pointer(out, 1+(i-1)*n), pointer(s), n)

src/ccall.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -1687,6 +1687,13 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
16871687
emit_signal_fence(ctx);
16881688
return ghostValue(jl_void_type);
16891689
}
1690+
else if (is_libjulia_func(jl_gc_use)) {
1691+
assert(lrt == T_void);
1692+
assert(!isVa && !llvmcall && nargt == 1);
1693+
ctx.builder.CreateCall(prepare_call(gc_use_func), {decay_derived(boxed(ctx, argv[0]))});
1694+
JL_GC_POP();
1695+
return ghostValue(jl_void_type);
1696+
}
16901697
else if (_is_libjulia_func((uintptr_t)ptls_getter, "jl_get_ptls_states")) {
16911698
assert(lrt == T_pint8);
16921699
assert(!isVa && !llvmcall && nargt == 0);

src/codegen.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ static GlobalVariable *jlgetworld_global;
357357

358358
// placeholder functions
359359
static Function *gcroot_flush_func;
360+
static Function *gc_use_func;
360361
static Function *except_enter_func;
361362
static Function *pointer_from_objref_func;
362363

@@ -6489,6 +6490,12 @@ static void init_julia_llvm_env(Module *m)
64896490
"julia.gcroot_flush");
64906491
add_named_global(gcroot_flush_func, (void*)NULL, /*dllimport*/false);
64916492

6493+
gc_use_func = Function::Create(FunctionType::get(T_void,
6494+
ArrayRef<Type*>(PointerType::get(T_jlvalue, AddressSpace::Derived)), false),
6495+
Function::ExternalLinkage,
6496+
"julia.gc_use");
6497+
add_named_global(gc_use_func, (void*)NULL, /*dllimport*/false);
6498+
64926499
pointer_from_objref_func = Function::Create(FunctionType::get(T_pjlvalue,
64936500
ArrayRef<Type*>(PointerType::get(T_jlvalue, AddressSpace::Derived)), false),
64946501
Function::ExternalLinkage,

src/julia.h

+1
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,7 @@ JL_DLLEXPORT jl_value_t *jl_gc_alloc_1w(void);
648648
JL_DLLEXPORT jl_value_t *jl_gc_alloc_2w(void);
649649
JL_DLLEXPORT jl_value_t *jl_gc_alloc_3w(void);
650650
JL_DLLEXPORT jl_value_t *jl_gc_allocobj(size_t sz);
651+
JL_DLLEXPORT void jl_gc_use(jl_value_t *a);
651652

652653
JL_DLLEXPORT void jl_clear_malloc_data(void);
653654

src/llvm-late-gc-lowering.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ struct LateLowerGCFrame: public FunctionPass {
320320
MDNode *tbaa_tag;
321321
Function *ptls_getter;
322322
Function *gc_flush_func;
323+
Function *gc_use_func;
323324
Function *pointer_from_objref_func;
324325
Function *alloc_obj_func;
325326
Function *pool_alloc_func;
@@ -745,7 +746,8 @@ State LateLowerGCFrame::LocalScan(Function &F) {
745746
}
746747
if (auto callee = CI->getCalledFunction()) {
747748
// Known functions emitted in codegen that are not safepoints
748-
if (callee == pointer_from_objref_func || callee->getName() == "memcmp") {
749+
if (callee == pointer_from_objref_func || callee == gc_use_func ||
750+
callee->getName() == "memcmp") {
749751
continue;
750752
}
751753
}
@@ -1137,7 +1139,8 @@ bool LateLowerGCFrame::CleanupIR(Function &F) {
11371139
}
11381140
CallingConv::ID CC = CI->getCallingConv();
11391141
auto callee = CI->getCalledValue();
1140-
if (gc_flush_func != nullptr && callee == gc_flush_func) {
1142+
if ((gc_flush_func != nullptr && callee == gc_flush_func) ||
1143+
(gc_use_func != nullptr && callee == gc_use_func)) {
11411144
/* No replacement */
11421145
} else if (pointer_from_objref_func != nullptr && callee == pointer_from_objref_func) {
11431146
auto *ASCI = new AddrSpaceCastInst(CI->getOperand(0),
@@ -1405,6 +1408,7 @@ static void addRetNoAlias(Function *F)
14051408
bool LateLowerGCFrame::doInitialization(Module &M) {
14061409
ptls_getter = M.getFunction("jl_get_ptls_states");
14071410
gc_flush_func = M.getFunction("julia.gcroot_flush");
1411+
gc_use_func = M.getFunction("julia.gc_use");
14081412
pointer_from_objref_func = M.getFunction("julia.pointer_from_objref");
14091413
auto &ctx = M.getContext();
14101414
T_size = M.getDataLayout().getIntPtrType(ctx);

src/rtutils.c

+4
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,10 @@ JL_DLLEXPORT jl_value_t *jl_value_ptr(jl_value_t *a)
305305
{
306306
return a;
307307
}
308+
JL_DLLEXPORT void jl_gc_use(jl_value_t *a)
309+
{
310+
(void)a;
311+
}
308312

309313
// parsing --------------------------------------------------------------------
310314

0 commit comments

Comments
 (0)