Skip to content

Commit 3cbb1e5

Browse files
committed
remove Ref allocation on task switch
1 parent 86ee57c commit 3cbb1e5

File tree

7 files changed

+50
-13
lines changed

7 files changed

+50
-13
lines changed

base/task.jl

+14-9
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,8 @@ function yield()
615615
end
616616
end
617617

618+
@inline set_next_task(t::Task) = ccall(:jl_set_next_task, Cvoid, (Any,), t)
619+
618620
"""
619621
yield(t::Task, arg = nothing)
620622
@@ -624,7 +626,8 @@ immediately yields to `t` before calling the scheduler.
624626
function yield(t::Task, @nospecialize(x=nothing))
625627
t.result = x
626628
enq_work(current_task())
627-
return try_yieldto(ensure_rescheduled, Ref(t))
629+
set_next_task(t)
630+
return try_yieldto(ensure_rescheduled)
628631
end
629632

630633
"""
@@ -637,14 +640,15 @@ or scheduling in any way. Its use is discouraged.
637640
"""
638641
function yieldto(t::Task, @nospecialize(x=nothing))
639642
t.result = x
640-
return try_yieldto(identity, Ref(t))
643+
set_next_task(t)
644+
return try_yieldto(identity)
641645
end
642646

643-
function try_yieldto(undo, reftask::Ref{Task})
647+
function try_yieldto(undo)
644648
try
645-
ccall(:jl_switchto, Cvoid, (Any,), reftask)
649+
ccall(:jl_switch, Cvoid, ())
646650
catch
647-
undo(reftask[])
651+
undo(ccall(:jl_get_next_task, Ref{Task}, ()))
648652
rethrow()
649653
end
650654
ct = current_task()
@@ -696,18 +700,19 @@ function trypoptask(W::StickyWorkqueue)
696700
return t
697701
end
698702

699-
@noinline function poptaskref(W::StickyWorkqueue)
703+
@noinline function poptask(W::StickyWorkqueue)
700704
task = trypoptask(W)
701705
if !(task isa Task)
702706
task = ccall(:jl_task_get_next, Ref{Task}, (Any, Any), trypoptask, W)
703707
end
704-
return Ref(task)
708+
set_next_task(task)
709+
nothing
705710
end
706711

707712
function wait()
708713
W = Workqueues[Threads.threadid()]
709-
reftask = poptaskref(W)
710-
result = try_yieldto(ensure_rescheduled, reftask)
714+
poptask(W)
715+
result = try_yieldto(ensure_rescheduled)
711716
process_events()
712717
# return when we come out of the queue
713718
return result

src/ccall.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -1619,6 +1619,16 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
16191619
tbaa_decorate(tbaa_const, ctx.builder.CreateLoad(pct)),
16201620
retboxed, rt, unionall, static_rt);
16211621
}
1622+
else if (is_libjulia_func(jl_set_next_task)) {
1623+
assert(lrt == T_void);
1624+
assert(!isVa && !llvmcall && nccallargs == 1);
1625+
JL_GC_POP();
1626+
Value *ptls_pv = emit_bitcast(ctx, ctx.ptlsStates, T_ppjlvalue);
1627+
const int nt_offset = offsetof(jl_tls_states_t, next_task);
1628+
Value *pnt = ctx.builder.CreateGEP(ptls_pv, ConstantInt::get(T_size, nt_offset / sizeof(void*)));
1629+
ctx.builder.CreateStore(emit_pointer_from_objref(ctx, boxed(ctx, argv[0])), pnt);
1630+
return ghostValue(jl_nothing_type);
1631+
}
16221632
else if (is_libjulia_func(jl_sigatomic_begin)) {
16231633
assert(lrt == T_void);
16241634
assert(!isVa && !llvmcall && nccallargs == 0);

src/gc.c

+2
Original file line numberDiff line numberDiff line change
@@ -2645,6 +2645,8 @@ static void jl_gc_queue_thread_local(jl_gc_mark_cache_t *gc_cache, jl_gc_mark_sp
26452645
{
26462646
gc_mark_queue_obj(gc_cache, sp, ptls2->current_task);
26472647
gc_mark_queue_obj(gc_cache, sp, ptls2->root_task);
2648+
if (ptls2->next_task)
2649+
gc_mark_queue_obj(gc_cache, sp, ptls2->next_task);
26482650
if (ptls2->previous_exception)
26492651
gc_mark_queue_obj(gc_cache, sp, ptls2->previous_exception);
26502652
}

src/julia_internal.h

+1
Original file line numberDiff line numberDiff line change
@@ -1011,6 +1011,7 @@ JL_DLLEXPORT int jl_array_isassigned(jl_array_t *a, size_t i);
10111011

10121012
JL_DLLEXPORT uintptr_t jl_object_id_(jl_value_t *tv, jl_value_t *v) JL_NOTSAFEPOINT;
10131013
JL_DLLEXPORT jl_value_t *jl_get_current_task(void);
1014+
JL_DLLEXPORT void jl_set_next_task(jl_task_t *task);
10141015

10151016
// -- synchronization utilities -- //
10161017

src/julia_threads.h

+1
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ struct _jl_tls_states_t {
185185
uv_cond_t wake_signal;
186186
volatile sig_atomic_t defer_signal;
187187
struct _jl_task_t *current_task;
188+
struct _jl_task_t *next_task;
188189
#ifdef MIGRATE_TASKS
189190
struct _jl_task_t *previous_task;
190191
#endif

src/task.c

+21-4
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,21 @@ JL_DLLEXPORT void julia_init(JL_IMAGE_SEARCH rel)
248248
_julia_init(rel);
249249
}
250250

251+
JL_DLLEXPORT void jl_set_next_task(jl_task_t *task)
252+
{
253+
jl_get_ptls_states()->next_task = task;
254+
}
255+
256+
JL_DLLEXPORT jl_task_t *jl_get_next_task(void)
257+
{
258+
return jl_get_ptls_states()->next_task;
259+
}
260+
251261
void jl_release_task_stack(jl_ptls_t ptls, jl_task_t *task);
252262

253-
static void ctx_switch(jl_ptls_t ptls, jl_task_t **pt)
263+
static void ctx_switch(jl_ptls_t ptls)
254264
{
265+
jl_task_t **pt = &ptls->next_task;
255266
jl_task_t *t = *pt;
256267
assert(t != ptls->current_task);
257268
jl_task_t *lastt = ptls->current_task;
@@ -366,10 +377,10 @@ static jl_ptls_t NOINLINE refetch_ptls(void)
366377
return jl_get_ptls_states();
367378
}
368379

369-
JL_DLLEXPORT void jl_switchto(jl_task_t **pt)
380+
JL_DLLEXPORT void jl_switch(void)
370381
{
371382
jl_ptls_t ptls = jl_get_ptls_states();
372-
jl_task_t *t = *pt;
383+
jl_task_t *t = ptls->next_task;
373384
jl_task_t *ct = ptls->current_task;
374385
if (t == ct) {
375386
return;
@@ -401,7 +412,7 @@ JL_DLLEXPORT void jl_switchto(jl_task_t **pt)
401412
jl_timing_block_stop(blk);
402413
#endif
403414

404-
ctx_switch(ptls, pt);
415+
ctx_switch(ptls);
405416

406417
#ifdef MIGRATE_TASKS
407418
ptls = refetch_ptls();
@@ -432,6 +443,12 @@ JL_DLLEXPORT void jl_switchto(jl_task_t **pt)
432443
jl_sigint_safepoint(ptls);
433444
}
434445

446+
JL_DLLEXPORT void jl_switchto(jl_task_t **pt)
447+
{
448+
jl_set_next_task(*pt);
449+
jl_switch();
450+
}
451+
435452
JL_DLLEXPORT JL_NORETURN void jl_no_exc_handler(jl_value_t *e)
436453
{
437454
jl_printf(JL_STDERR, "fatal: error thrown and no exception handler available.\n");

src/threading.c

+1
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ void jl_init_threadtls(int16_t tid)
285285
ptls->bt_data = bt_data;
286286
ptls->sig_exception = NULL;
287287
ptls->previous_exception = NULL;
288+
ptls->next_task = NULL;
288289
#ifdef _OS_WINDOWS_
289290
ptls->needs_resetstkoflw = 0;
290291
#endif

0 commit comments

Comments
 (0)