Skip to content

Commit 9a712ae

Browse files
authored
remove Ref allocation on task switch (#35606)
1 parent 47087cb commit 9a712ae

File tree

7 files changed

+56
-16
lines changed

7 files changed

+56
-16
lines changed

base/task.jl

+14-9
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,8 @@ function yield()
615615
end
616616
end
617617

618+
@inline set_next_task(t::Task) = ccall(:jl_set_next_task, Cvoid, (Any,), t)
619+
618620
"""
619621
yield(t::Task, arg = nothing)
620622
@@ -624,7 +626,8 @@ immediately yields to `t` before calling the scheduler.
624626
function yield(t::Task, @nospecialize(x=nothing))
625627
t.result = x
626628
enq_work(current_task())
627-
return try_yieldto(ensure_rescheduled, Ref(t))
629+
set_next_task(t)
630+
return try_yieldto(ensure_rescheduled)
628631
end
629632

630633
"""
@@ -637,14 +640,15 @@ or scheduling in any way. Its use is discouraged.
637640
"""
638641
function yieldto(t::Task, @nospecialize(x=nothing))
639642
t.result = x
640-
return try_yieldto(identity, Ref(t))
643+
set_next_task(t)
644+
return try_yieldto(identity)
641645
end
642646

643-
function try_yieldto(undo, reftask::Ref{Task})
647+
function try_yieldto(undo)
644648
try
645-
ccall(:jl_switchto, Cvoid, (Any,), reftask)
649+
ccall(:jl_switch, Cvoid, ())
646650
catch
647-
undo(reftask[])
651+
undo(ccall(:jl_get_next_task, Ref{Task}, ()))
648652
rethrow()
649653
end
650654
ct = current_task()
@@ -696,18 +700,19 @@ function trypoptask(W::StickyWorkqueue)
696700
return t
697701
end
698702

699-
@noinline function poptaskref(W::StickyWorkqueue)
703+
@noinline function poptask(W::StickyWorkqueue)
700704
task = trypoptask(W)
701705
if !(task isa Task)
702706
task = ccall(:jl_task_get_next, Ref{Task}, (Any, Any), trypoptask, W)
703707
end
704-
return Ref(task)
708+
set_next_task(task)
709+
nothing
705710
end
706711

707712
function wait()
708713
W = Workqueues[Threads.threadid()]
709-
reftask = poptaskref(W)
710-
result = try_yieldto(ensure_rescheduled, reftask)
714+
poptask(W)
715+
result = try_yieldto(ensure_rescheduled)
711716
process_events()
712717
# return when we come out of the queue
713718
return result

src/ccall.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -1619,6 +1619,16 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
16191619
tbaa_decorate(tbaa_const, ctx.builder.CreateLoad(pct)),
16201620
retboxed, rt, unionall, static_rt);
16211621
}
1622+
else if (is_libjulia_func(jl_set_next_task)) {
1623+
assert(lrt == T_void);
1624+
assert(!isVa && !llvmcall && nccallargs == 1);
1625+
JL_GC_POP();
1626+
Value *ptls_pv = emit_bitcast(ctx, ctx.ptlsStates, T_ppjlvalue);
1627+
const int nt_offset = offsetof(jl_tls_states_t, next_task);
1628+
Value *pnt = ctx.builder.CreateGEP(ptls_pv, ConstantInt::get(T_size, nt_offset / sizeof(void*)));
1629+
ctx.builder.CreateStore(emit_pointer_from_objref(ctx, boxed(ctx, argv[0])), pnt);
1630+
return ghostValue(jl_nothing_type);
1631+
}
16221632
else if (is_libjulia_func(jl_sigatomic_begin)) {
16231633
assert(lrt == T_void);
16241634
assert(!isVa && !llvmcall && nccallargs == 0);

src/gc.c

+2
Original file line numberDiff line numberDiff line change
@@ -2645,6 +2645,8 @@ static void jl_gc_queue_thread_local(jl_gc_mark_cache_t *gc_cache, jl_gc_mark_sp
26452645
{
26462646
gc_mark_queue_obj(gc_cache, sp, ptls2->current_task);
26472647
gc_mark_queue_obj(gc_cache, sp, ptls2->root_task);
2648+
if (ptls2->next_task)
2649+
gc_mark_queue_obj(gc_cache, sp, ptls2->next_task);
26482650
if (ptls2->previous_exception)
26492651
gc_mark_queue_obj(gc_cache, sp, ptls2->previous_exception);
26502652
}

src/julia_internal.h

+1
Original file line numberDiff line numberDiff line change
@@ -1011,6 +1011,7 @@ JL_DLLEXPORT int jl_array_isassigned(jl_array_t *a, size_t i);
10111011

10121012
JL_DLLEXPORT uintptr_t jl_object_id_(jl_value_t *tv, jl_value_t *v) JL_NOTSAFEPOINT;
10131013
JL_DLLEXPORT jl_value_t *jl_get_current_task(void);
1014+
JL_DLLEXPORT void jl_set_next_task(jl_task_t *task);
10141015

10151016
// -- synchronization utilities -- //
10161017

src/julia_threads.h

+1
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ struct _jl_tls_states_t {
185185
uv_cond_t wake_signal;
186186
volatile sig_atomic_t defer_signal;
187187
struct _jl_task_t *current_task;
188+
struct _jl_task_t *next_task;
188189
#ifdef MIGRATE_TASKS
189190
struct _jl_task_t *previous_task;
190191
#endif

src/task.c

+27-7
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ static void NOINLINE save_stack(jl_ptls_t ptls, jl_task_t *lastt, jl_task_t **pt
125125
else {
126126
buf = lastt->stkbuf;
127127
}
128-
*pt = lastt; // clear the gc-root for the target task before copying the stack for saving
128+
*pt = NULL; // clear the gc-root for the target task before copying the stack for saving
129129
lastt->copy_stack = nb;
130130
lastt->sticky = 1;
131131
memcpy_a16((uint64_t*)buf, (uint64_t*)frame_addr, nb);
@@ -248,10 +248,24 @@ JL_DLLEXPORT void julia_init(JL_IMAGE_SEARCH rel)
248248
_julia_init(rel);
249249
}
250250

251+
JL_DLLEXPORT void jl_set_next_task(jl_task_t *task)
252+
{
253+
jl_get_ptls_states()->next_task = task;
254+
}
255+
256+
JL_DLLEXPORT jl_task_t *jl_get_next_task(void)
257+
{
258+
jl_ptls_t ptls = jl_get_ptls_states();
259+
if (ptls->next_task)
260+
return ptls->next_task;
261+
return ptls->current_task;
262+
}
263+
251264
void jl_release_task_stack(jl_ptls_t ptls, jl_task_t *task);
252265

253-
static void ctx_switch(jl_ptls_t ptls, jl_task_t **pt)
266+
static void ctx_switch(jl_ptls_t ptls)
254267
{
268+
jl_task_t **pt = &ptls->next_task;
255269
jl_task_t *t = *pt;
256270
assert(t != ptls->current_task);
257271
jl_task_t *lastt = ptls->current_task;
@@ -283,7 +297,7 @@ static void ctx_switch(jl_ptls_t ptls, jl_task_t **pt)
283297
}
284298

285299
if (killed) {
286-
*pt = lastt; // can't fail after here: clear the gc-root for the target task now
300+
*pt = NULL; // can't fail after here: clear the gc-root for the target task now
287301
lastt->gcstack = NULL;
288302
if (!lastt->copy_stack && lastt->stkbuf) {
289303
// early free of stkbuf back to the pool
@@ -302,7 +316,7 @@ static void ctx_switch(jl_ptls_t ptls, jl_task_t **pt)
302316
}
303317
else
304318
#endif
305-
*pt = lastt; // can't fail after here: clear the gc-root for the target task now
319+
*pt = NULL; // can't fail after here: clear the gc-root for the target task now
306320
lastt->gcstack = ptls->pgcstack;
307321
}
308322

@@ -366,10 +380,10 @@ static jl_ptls_t NOINLINE refetch_ptls(void)
366380
return jl_get_ptls_states();
367381
}
368382

369-
JL_DLLEXPORT void jl_switchto(jl_task_t **pt)
383+
JL_DLLEXPORT void jl_switch(void)
370384
{
371385
jl_ptls_t ptls = jl_get_ptls_states();
372-
jl_task_t *t = *pt;
386+
jl_task_t *t = ptls->next_task;
373387
jl_task_t *ct = ptls->current_task;
374388
if (t == ct) {
375389
return;
@@ -401,7 +415,7 @@ JL_DLLEXPORT void jl_switchto(jl_task_t **pt)
401415
jl_timing_block_stop(blk);
402416
#endif
403417

404-
ctx_switch(ptls, pt);
418+
ctx_switch(ptls);
405419

406420
#ifdef MIGRATE_TASKS
407421
ptls = refetch_ptls();
@@ -432,6 +446,12 @@ JL_DLLEXPORT void jl_switchto(jl_task_t **pt)
432446
jl_sigint_safepoint(ptls);
433447
}
434448

449+
JL_DLLEXPORT void jl_switchto(jl_task_t **pt)
450+
{
451+
jl_set_next_task(*pt);
452+
jl_switch();
453+
}
454+
435455
JL_DLLEXPORT JL_NORETURN void jl_no_exc_handler(jl_value_t *e)
436456
{
437457
jl_printf(JL_STDERR, "fatal: error thrown and no exception handler available.\n");

src/threading.c

+1
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ void jl_init_threadtls(int16_t tid)
285285
ptls->bt_data = bt_data;
286286
ptls->sig_exception = NULL;
287287
ptls->previous_exception = NULL;
288+
ptls->next_task = NULL;
288289
#ifdef _OS_WINDOWS_
289290
ptls->needs_resetstkoflw = 0;
290291
#endif

0 commit comments

Comments
 (0)