Skip to content

Commit db14903

Browse files
committed
remove threadedregion and move jl_threading_run to julia
1 parent a1165b8 commit db14903

File tree

5 files changed

+26
-111
lines changed

5 files changed

+26
-111
lines changed

base/task.jl

+8
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,14 @@ function __preinit_threads__()
431431
nothing
432432
end
433433

434+
function _run_on(t::Task, tid)
435+
@assert !istaskstarted(t)
436+
t.sticky = true
437+
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid-1)
438+
schedule(t)
439+
return t
440+
end
441+
434442
function enq_work(t::Task)
435443
(t.state == :runnable && t.queue === nothing) || error("schedule: Task not runnable")
436444
tid = Threads.threadid(t)

base/threadingconstructs.jl

+13-21
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,18 @@ on `threadid()`.
1818
"""
1919
nthreads() = Int(unsafe_load(cglobal(:jl_n_threads, Cint)))
2020

21-
# Only read/written by the main thread
22-
const in_threaded_loop = Ref(false)
23-
2421
function _threadsfor(iter,lbody)
2522
lidx = iter.args[1] # index
2623
range = iter.args[2]
2724
quote
2825
local threadsfor_fun
2926
let range = $(esc(range))
30-
function threadsfor_fun(onethread=false)
27+
function threadsfor_fun()
3128
r = range # Load into local variable
3229
lenr = length(r)
3330
# divide loop iterations among threads
34-
if onethread
35-
tid = 1
36-
len, rem = lenr, 0
37-
else
38-
tid = threadid()
39-
len, rem = divrem(lenr, nthreads())
40-
end
31+
tid = threadid()
32+
len, rem = divrem(lenr, nthreads())
4133
# not enough iterations for all the threads?
4234
if len == 0
4335
if tid > rem
@@ -65,20 +57,20 @@ function _threadsfor(iter,lbody)
6557
end
6658
end
6759
end
68-
# Hack to make nested threaded loops kinda work
69-
if threadid() != 1 || in_threaded_loop[]
70-
# We are in a nested threaded loop
71-
Base.invokelatest(threadsfor_fun, true)
72-
else
73-
in_threaded_loop[] = true
74-
# the ccall is not expected to throw
75-
ccall(:jl_threading_run, Cvoid, (Any,), threadsfor_fun)
76-
in_threaded_loop[] = false
77-
end
60+
threading_run(threadsfor_fun)
7861
nothing
7962
end
8063
end
8164

65+
function threading_run(func)
66+
tasks = Vector{Task}(undef, nthreads())
67+
for tid = 1:nthreads()
68+
tasks[tid] = Base._run_on(Task(func), tid)
69+
end
70+
foreach(wait, tasks)
71+
return nothing
72+
end
73+
8274
"""
8375
Threads.@threads
8476

src/jl_uv.c

+2-4
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,10 @@ JL_DLLEXPORT void jl_uv_req_set_data(uv_req_t *req, void *data) { req->data = da
201201
JL_DLLEXPORT void *jl_uv_handle_data(uv_handle_t *handle) { return handle->data; }
202202
JL_DLLEXPORT void *jl_uv_write_handle(uv_write_t *req) { return req->handle; }
203203

204-
extern volatile unsigned _threadedregion;
205-
206204
JL_DLLEXPORT int jl_run_once(uv_loop_t *loop)
207205
{
208206
jl_ptls_t ptls = jl_get_ptls_states();
209-
if (loop && (_threadedregion || ptls->tid == 0)) {
207+
if (loop) {
210208
jl_gc_safepoint_(ptls);
211209
JL_UV_LOCK();
212210
loop->stop_flag = 0;
@@ -220,7 +218,7 @@ JL_DLLEXPORT int jl_run_once(uv_loop_t *loop)
220218
JL_DLLEXPORT int jl_process_events(uv_loop_t *loop)
221219
{
222220
jl_ptls_t ptls = jl_get_ptls_states();
223-
if (loop && (_threadedregion || ptls->tid == 0)) {
221+
if (loop) {
224222
jl_gc_safepoint_(ptls);
225223
if (jl_mutex_trylock(&jl_uv_mutex)) {
226224
loop->stop_flag = 0;

src/partr.c

+3-14
Original file line numberDiff line numberDiff line change
@@ -411,8 +411,6 @@ static int may_sleep(jl_ptls_t ptls)
411411
return jl_atomic_load(&sleep_check_state) == sleeping && jl_atomic_load(&ptls->sleep_check_state) == sleeping;
412412
}
413413

414-
extern volatile unsigned _threadedregion;
415-
416414
JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *getsticky)
417415
{
418416
jl_ptls_t ptls = jl_get_ptls_states();
@@ -434,7 +432,7 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *getsticky)
434432
#endif
435433

436434
jl_cpu_pause();
437-
if (sleep_check_after_threshold(&start_cycles) || (!_threadedregion && ptls->tid == 0)) {
435+
if (sleep_check_after_threshold(&start_cycles)) {
438436
if (!sleep_check_now(ptls->tid))
439437
continue;
440438
jl_atomic_store(&ptls->sleep_check_state, sleeping); // acquire sleep-check lock
@@ -446,14 +444,7 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *getsticky)
446444
// although none are allowed to create new ones
447445
// outside of threaded regions, all IO is permitted,
448446
// but only on thread 1
449-
int uvlock = 0;
450-
if (_threadedregion) {
451-
uvlock = jl_mutex_trylock(&jl_uv_mutex);
452-
}
453-
else if (ptls->tid == 0) {
454-
uvlock = 1;
455-
JL_UV_LOCK();
456-
}
447+
int uvlock = jl_mutex_trylock(&jl_uv_mutex);
457448
if (uvlock) {
458449
int active = 1;
459450
if (jl_atomic_load(&jl_uv_n_waiters) != 0) {
@@ -483,9 +474,7 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *getsticky)
483474
// to the last thread to do an explicit operation,
484475
// which may starve other threads of critical work
485476
}
486-
if (!_threadedregion && active && ptls->tid == 0) {
487-
// thread 0 is the only thread permitted to run the event loop
488-
// so it needs to stay alive
477+
if (active) {
489478
start_cycles = 0;
490479
continue;
491480
}

src/threading.c

-72
Original file line numberDiff line numberDiff line change
@@ -473,78 +473,6 @@ void jl_start_threads(void)
473473

474474
#endif
475475

476-
unsigned volatile _threadedregion; // HACK: keep track of whether it is safe to do IO
477-
478-
// simple fork/join mode code
479-
JL_DLLEXPORT void jl_threading_run(jl_value_t *func)
480-
{
481-
jl_ptls_t ptls = jl_get_ptls_states();
482-
int8_t gc_state = jl_gc_unsafe_enter(ptls);
483-
size_t world = jl_world_counter;
484-
jl_method_instance_t *mfunc = jl_lookup_generic(&func, 1, jl_int32hash_fast(jl_return_address()), world);
485-
// Ignore constant return value for now.
486-
jl_code_instance_t *fptr = jl_compile_method_internal(mfunc, world);
487-
if (fptr->invoke == jl_fptr_const_return)
488-
return;
489-
490-
size_t nthreads = jl_n_threads;
491-
jl_svec_t *ts = jl_alloc_svec(nthreads);
492-
JL_GC_PUSH1(&ts);
493-
jl_value_t *wait_func = jl_get_global(jl_base_module, jl_symbol("wait"));
494-
jl_value_t *schd_func = jl_get_global(jl_base_module, jl_symbol("schedule"));
495-
// create and schedule all tasks
496-
_threadedregion += 1;
497-
for (int i = 0; i < nthreads; i++) {
498-
jl_value_t *args2[2];
499-
args2[0] = (jl_value_t*)jl_task_type;
500-
args2[1] = func;
501-
jl_task_t *t = (jl_task_t*)jl_apply(args2, 2);
502-
jl_svecset(ts, i, t);
503-
t->sticky = 1;
504-
t->tid = i;
505-
args2[0] = schd_func;
506-
args2[1] = (jl_value_t*)t;
507-
jl_apply(args2, 2);
508-
#ifdef JULIA_ENABLE_THREADING
509-
if (i == 1) {
510-
// let threads know work is coming (optimistic)
511-
jl_wakeup_thread(-1);
512-
}
513-
#endif
514-
}
515-
#ifdef JULIA_ENABLE_THREADING
516-
if (nthreads > 2) {
517-
// let threads know work is ready (guaranteed)
518-
jl_wakeup_thread(-1);
519-
}
520-
#endif
521-
// join with all tasks
522-
JL_TRY {
523-
for (int i = 0; i < nthreads; i++) {
524-
jl_value_t *t = jl_svecref(ts, i);
525-
jl_value_t *args[2] = { wait_func, t };
526-
jl_apply(args, 2);
527-
}
528-
}
529-
JL_CATCH {
530-
_threadedregion -= 1;
531-
jl_wake_libuv();
532-
JL_UV_LOCK();
533-
JL_UV_UNLOCK();
534-
jl_rethrow();
535-
}
536-
// make sure no threads are sitting in the event loop
537-
_threadedregion -= 1;
538-
jl_wake_libuv();
539-
// make sure no more callbacks will run while user code continues
540-
// outside thread region and might touch an I/O object.
541-
JL_UV_LOCK();
542-
JL_UV_UNLOCK();
543-
JL_GC_POP();
544-
jl_gc_unsafe_leave(ptls, gc_state);
545-
}
546-
547-
548476
#ifndef JULIA_ENABLE_THREADING
549477

550478
void jl_init_threading(void)

0 commit comments

Comments
 (0)