From 05be02d6aa1a7febeb10a8e91dcd840c303f2e38 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Mon, 12 Feb 2024 13:41:26 -0700 Subject: [PATCH] Add processor helpers Add all_processors helper for fetching all known processors Add compatible_processors for fetching all processors matching a scope Reorganize some processor, context, and task TLS files --- src/Dagger.jl | 5 ++ src/context.jl | 102 ++++++++++++++++++++++ src/processor.jl | 181 ---------------------------------------- src/task-tls.jl | 39 +++++++++ src/threadproc.jl | 38 +++++++++ src/utils/processors.jl | 17 ++++ src/utils/scopes.jl | 24 ++++++ test/processors.jl | 8 ++ test/scopes.jl | 24 ++++++ 9 files changed, 257 insertions(+), 181 deletions(-) create mode 100644 src/context.jl create mode 100644 src/task-tls.jl create mode 100644 src/threadproc.jl create mode 100644 src/utils/processors.jl create mode 100644 src/utils/scopes.jl diff --git a/src/Dagger.jl b/src/Dagger.jl index 826f3509f..d4dd58002 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -28,7 +28,12 @@ include("utils/locked-object.jl") include("utils/tasks.jl") include("options.jl") include("processor.jl") +include("threadproc.jl") +include("context.jl") +include("utils/processors.jl") +include("task-tls.jl") include("scopes.jl") +include("utils/scopes.jl") include("eager_thunk.jl") include("queue.jl") include("thunk.jl") diff --git a/src/context.jl b/src/context.jl new file mode 100644 index 000000000..6ad3bfc7f --- /dev/null +++ b/src/context.jl @@ -0,0 +1,102 @@ +""" + Context(xs::Vector{OSProc}) -> Context + Context(xs::Vector{Int}) -> Context + +Create a Context, by default adding each available worker. + +It is also possible to create a Context from a vector of [`OSProc`](@ref), +or equivalently the underlying process ids can also be passed directly +as a `Vector{Int}`. + +Special fields include: +- 'log_sink': A log sink object to use, if any. +- `log_file::Union{String,Nothing}`: Path to logfile. If specified, at +scheduler termination, logs will be collected, combined with input thunks, and +written out in DOT format to this location. +- `profile::Bool`: Whether or not to perform profiling with Profile stdlib. +""" +mutable struct Context + procs::Vector{Processor} + proc_lock::ReentrantLock + proc_notify::Threads.Condition + log_sink::Any + log_file::Union{String,Nothing} + profile::Bool + options +end + +Context(procs::Vector{P}=Processor[OSProc(w) for w in procs()]; + proc_lock=ReentrantLock(), proc_notify=Threads.Condition(), + log_sink=TimespanLogging.NoOpLog(), log_file=nothing, profile=false, + options=nothing) where {P<:Processor} = + Context(procs, proc_lock, proc_notify, log_sink, log_file, + profile, options) +Context(xs::Vector{Int}; kwargs...) = Context(map(OSProc, xs); kwargs...) +Context(ctx::Context, xs::Vector=copy(procs(ctx))) = # make a copy + Context(xs; log_sink=ctx.log_sink, log_file=ctx.log_file, + profile=ctx.profile, options=ctx.options) + +const GLOBAL_CONTEXT = Ref{Context}() +function global_context() + if !isassigned(GLOBAL_CONTEXT) + GLOBAL_CONTEXT[] = Context() + end + return GLOBAL_CONTEXT[] +end + +""" + lock(f, ctx::Context) + +Acquire `ctx.proc_lock`, execute `f` with the lock held, and release the lock +when `f` returns. +""" +Base.lock(f, ctx::Context) = lock(f, ctx.proc_lock) + +""" + procs(ctx::Context) + +Fetch the list of procs currently known to `ctx`. +""" +procs(ctx::Context) = lock(ctx) do + copy(ctx.procs) +end + +""" + addprocs!(ctx::Context, xs) + +Add new workers `xs` to `ctx`. + +Workers will typically be assigned new tasks in the next scheduling iteration +if scheduling is ongoing. + +Workers can be either `Processor`s or the underlying process IDs as `Integer`s. +""" +addprocs!(ctx::Context, xs::AbstractVector{<:Integer}) = addprocs!(ctx, map(OSProc, xs)) +function addprocs!(ctx::Context, xs::AbstractVector{<:OSProc}) + lock(ctx) do + append!(ctx.procs, xs) + end + lock(ctx.proc_notify) do + notify(ctx.proc_notify) + end +end + +""" + rmprocs!(ctx::Context, xs) + +Remove the specified workers `xs` from `ctx`. + +Workers will typically finish all their assigned tasks if scheduling is ongoing +but will not be assigned new tasks after removal. + +Workers can be either `Processor`s or the underlying process IDs as `Integer`s. +""" +rmprocs!(ctx::Context, xs::AbstractVector{<:Integer}) = rmprocs!(ctx, map(OSProc, xs)) +function rmprocs!(ctx::Context, xs::AbstractVector{<:OSProc}) + lock(ctx) do + filter!(p -> (p ∉ xs), ctx.procs) + end + lock(ctx.proc_notify) do + notify(ctx.proc_notify) + end +end diff --git a/src/processor.jl b/src/processor.jl index c77453f62..d54eb249c 100644 --- a/src/processor.jl +++ b/src/processor.jl @@ -146,184 +146,3 @@ iscompatible_arg(proc::OSProc, opts, args...) = any(child-> all(arg->iscompatible_arg(child, opts, arg), args), children(proc)) - -""" - ThreadProc <: Processor - -Julia CPU (OS) thread, identified by Julia thread ID. -""" -struct ThreadProc <: Processor - owner::Int - tid::Int -end -iscompatible(proc::ThreadProc, opts, f, args...) = true -iscompatible_func(proc::ThreadProc, opts, f) = true -iscompatible_arg(proc::ThreadProc, opts, x) = true -function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...); @nospecialize(kwargs...)) - tls = get_tls() - task = Task() do - set_tls!(tls) - TimespanLogging.prof_task_put!(tls.sch_handle.thunk_id.id) - @invokelatest f(args...; kwargs...) - end - set_task_tid!(task, proc.tid) - schedule(task) - try - fetch(task) - catch err - @static if VERSION < v"1.7-rc1" - stk = Base.catch_stack(task) - else - stk = Base.current_exceptions(task) - end - err, frames = stk[1] - rethrow(CapturedException(err, frames)) - end -end -get_parent(proc::ThreadProc) = OSProc(proc.owner) -default_enabled(proc::ThreadProc) = true - -# TODO: ThreadGroupProc? - -""" - Context(xs::Vector{OSProc}) -> Context - Context(xs::Vector{Int}) -> Context - -Create a Context, by default adding each available worker. - -It is also possible to create a Context from a vector of [`OSProc`](@ref), -or equivalently the underlying process ids can also be passed directly -as a `Vector{Int}`. - -Special fields include: -- 'log_sink': A log sink object to use, if any. -- `log_file::Union{String,Nothing}`: Path to logfile. If specified, at -scheduler termination, logs will be collected, combined with input thunks, and -written out in DOT format to this location. -- `profile::Bool`: Whether or not to perform profiling with Profile stdlib. -""" -mutable struct Context - procs::Vector{Processor} - proc_lock::ReentrantLock - proc_notify::Threads.Condition - log_sink::Any - log_file::Union{String,Nothing} - profile::Bool - options -end - -Context(procs::Vector{P}=Processor[OSProc(w) for w in procs()]; - proc_lock=ReentrantLock(), proc_notify=Threads.Condition(), - log_sink=TimespanLogging.NoOpLog(), log_file=nothing, profile=false, - options=nothing) where {P<:Processor} = - Context(procs, proc_lock, proc_notify, log_sink, log_file, - profile, options) -Context(xs::Vector{Int}; kwargs...) = Context(map(OSProc, xs); kwargs...) -Context(ctx::Context, xs::Vector=copy(procs(ctx))) = # make a copy - Context(xs; log_sink=ctx.log_sink, log_file=ctx.log_file, - profile=ctx.profile, options=ctx.options) - -const GLOBAL_CONTEXT = Ref{Context}() -function global_context() - if !isassigned(GLOBAL_CONTEXT) - GLOBAL_CONTEXT[] = Context() - end - return GLOBAL_CONTEXT[] -end - -""" - lock(f, ctx::Context) - -Acquire `ctx.proc_lock`, execute `f` with the lock held, and release the lock -when `f` returns. -""" -Base.lock(f, ctx::Context) = lock(f, ctx.proc_lock) - -""" - procs(ctx::Context) - -Fetch the list of procs currently known to `ctx`. -""" -procs(ctx::Context) = lock(ctx) do - copy(ctx.procs) -end - -""" - addprocs!(ctx::Context, xs) - -Add new workers `xs` to `ctx`. - -Workers will typically be assigned new tasks in the next scheduling iteration -if scheduling is ongoing. - -Workers can be either `Processor`s or the underlying process IDs as `Integer`s. -""" -addprocs!(ctx::Context, xs::AbstractVector{<:Integer}) = addprocs!(ctx, map(OSProc, xs)) -function addprocs!(ctx::Context, xs::AbstractVector{<:OSProc}) - lock(ctx) do - append!(ctx.procs, xs) - end - lock(ctx.proc_notify) do - notify(ctx.proc_notify) - end -end - -""" - rmprocs!(ctx::Context, xs) - -Remove the specified workers `xs` from `ctx`. - -Workers will typically finish all their assigned tasks if scheduling is ongoing -but will not be assigned new tasks after removal. - -Workers can be either `Processor`s or the underlying process IDs as `Integer`s. -""" -rmprocs!(ctx::Context, xs::AbstractVector{<:Integer}) = rmprocs!(ctx, map(OSProc, xs)) -function rmprocs!(ctx::Context, xs::AbstractVector{<:OSProc}) - lock(ctx) do - filter!(p -> (p ∉ xs), ctx.procs) - end - lock(ctx.proc_notify) do - notify(ctx.proc_notify) - end -end - -# In-Thunk Helpers - -""" - thunk_processor() - -Get the current processor executing the current thunk. -""" -thunk_processor() = task_local_storage(:_dagger_processor)::Processor - -""" - in_thunk() - -Returns `true` if currently in a [`Thunk`](@ref) process, else `false`. -""" -in_thunk() = haskey(task_local_storage(), :_dagger_sch_uid) - -""" - get_tls() - -Gets all Dagger TLS variable as a `NamedTuple`. -""" -get_tls() = ( - sch_uid=task_local_storage(:_dagger_sch_uid), - sch_handle=task_local_storage(:_dagger_sch_handle), - processor=thunk_processor(), - task_spec=task_local_storage(:_dagger_task_spec), -) - -""" - set_tls!(tls) - -Sets all Dagger TLS variables from the `NamedTuple` `tls`. -""" -function set_tls!(tls) - task_local_storage(:_dagger_sch_uid, tls.sch_uid) - task_local_storage(:_dagger_sch_handle, tls.sch_handle) - task_local_storage(:_dagger_processor, tls.processor) - task_local_storage(:_dagger_task_spec, tls.task_spec) -end diff --git a/src/task-tls.jl b/src/task-tls.jl new file mode 100644 index 000000000..fb42dfbc9 --- /dev/null +++ b/src/task-tls.jl @@ -0,0 +1,39 @@ +# In-Thunk Helpers + +""" + thunk_processor() + +Get the current processor executing the current thunk. +""" +thunk_processor() = task_local_storage(:_dagger_processor)::Processor + +""" + in_thunk() + +Returns `true` if currently in a [`Thunk`](@ref) process, else `false`. +""" +in_thunk() = haskey(task_local_storage(), :_dagger_sch_uid) + +""" + get_tls() + +Gets all Dagger TLS variable as a `NamedTuple`. +""" +get_tls() = ( + sch_uid=task_local_storage(:_dagger_sch_uid), + sch_handle=task_local_storage(:_dagger_sch_handle), + processor=thunk_processor(), + task_spec=task_local_storage(:_dagger_task_spec), +) + +""" + set_tls!(tls) + +Sets all Dagger TLS variables from the `NamedTuple` `tls`. +""" +function set_tls!(tls) + task_local_storage(:_dagger_sch_uid, tls.sch_uid) + task_local_storage(:_dagger_sch_handle, tls.sch_handle) + task_local_storage(:_dagger_processor, tls.processor) + task_local_storage(:_dagger_task_spec, tls.task_spec) +end diff --git a/src/threadproc.jl b/src/threadproc.jl new file mode 100644 index 000000000..d0e897586 --- /dev/null +++ b/src/threadproc.jl @@ -0,0 +1,38 @@ +""" + ThreadProc <: Processor + +Julia CPU (OS) thread, identified by Julia thread ID. +""" +struct ThreadProc <: Processor + owner::Int + tid::Int +end +iscompatible(proc::ThreadProc, opts, f, args...) = true +iscompatible_func(proc::ThreadProc, opts, f) = true +iscompatible_arg(proc::ThreadProc, opts, x) = true +function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...); @nospecialize(kwargs...)) + tls = get_tls() + task = Task() do + set_tls!(tls) + TimespanLogging.prof_task_put!(tls.sch_handle.thunk_id.id) + @invokelatest f(args...; kwargs...) + end + set_task_tid!(task, proc.tid) + schedule(task) + try + fetch(task) + catch err + @static if VERSION < v"1.7-rc1" + stk = Base.catch_stack(task) + else + stk = Base.current_exceptions(task) + end + err, frames = stk[1] + rethrow(CapturedException(err, frames)) + end +end +get_parent(proc::ThreadProc) = OSProc(proc.owner) +default_enabled(proc::ThreadProc) = true + +# TODO: ThreadGroupProc? + diff --git a/src/utils/processors.jl b/src/utils/processors.jl new file mode 100644 index 000000000..a7dad2494 --- /dev/null +++ b/src/utils/processors.jl @@ -0,0 +1,17 @@ +# Processor utilities + +""" + all_processors(ctx::Context=Sch.eager_context()) -> Set{Processor} + +Returns the set of all processors available to the scheduler, across all +Distributed workers. +""" +function all_processors(ctx::Context=Sch.eager_context()) + all_procs = Set{Processor}() + for gproc in procs(ctx) + for proc in get_processors(gproc) + push!(all_procs, proc) + end + end + return all_procs +end diff --git a/src/utils/scopes.jl b/src/utils/scopes.jl new file mode 100644 index 000000000..27f970af6 --- /dev/null +++ b/src/utils/scopes.jl @@ -0,0 +1,24 @@ +# Scope-Processor helpers + +""" + compatible_processors(scope::AbstractScope, ctx::Context=Sch.eager_context()) -> Set{Processor} + +Returns the set of all processors (across all Distributed workers) that are +compatible with the given scope. +""" +function compatible_processors(scope::AbstractScope, ctx::Context=Sch.eager_context()) + compat_procs = Set{Processor}() + for gproc in procs(ctx) + # Fast-path in case entire process is incompatible + gproc_scope = ProcessScope(gproc) + if !isa(constrain(scope, gproc_scope), InvalidScope) + for proc in get_processors(gproc) + proc_scope = ExactScope(proc) + if !isa(constrain(scope, proc_scope), InvalidScope) + push!(compat_procs, proc) + end + end + end + end + return compat_procs +end diff --git a/test/processors.jl b/test/processors.jl index 9efcc1a4b..b960da68f 100644 --- a/test/processors.jl +++ b/test/processors.jl @@ -93,4 +93,12 @@ end end @test collect(delayed(mythunk)(1)) === ThreadProc end + + @testset "all_processors" begin + all_procs = Dagger.all_processors() + for w in procs() + w_procs = Dagger.get_processors(OSProc(w)) + @test all(proc->proc in all_procs, w_procs) + end + end end diff --git a/test/scopes.jl b/test/scopes.jl index d5d7e0a6b..5388273df 100644 --- a/test/scopes.jl +++ b/test/scopes.jl @@ -238,5 +238,29 @@ end end + @testset "compatible_processors" begin + scope = Dagger.scope(workers=[]) + comp_procs = Dagger.compatible_processors(scope) + @test !any(proc->proc in comp_procs, Dagger.get_processors(OSProc(wid1))) + @test !any(proc->proc in comp_procs, Dagger.get_processors(OSProc(wid2))) + + scope = Dagger.scope(worker=wid1) + comp_procs = Dagger.compatible_processors(scope) + @test all(proc->proc in comp_procs, Dagger.get_processors(OSProc(wid1))) + @test !any(proc->proc in comp_procs, Dagger.get_processors(OSProc(wid2))) + + scope = Dagger.scope(worker=wid1, thread=2) + comp_procs = Dagger.compatible_processors(scope) + @test length(comp_procs) == 1 + @test !all(proc->proc in comp_procs, Dagger.get_processors(OSProc(wid1))) + @test !all(proc->proc in comp_procs, Dagger.get_processors(OSProc(wid2))) + @test Dagger.ThreadProc(wid1, 2) in comp_procs + + scope = Dagger.scope(workers=[wid1, wid2]) + comp_procs = Dagger.compatible_processors(scope) + @test all(proc->proc in comp_procs, Dagger.get_processors(OSProc(wid1))) + @test all(proc->proc in comp_procs, Dagger.get_processors(OSProc(wid2))) + end + rmprocs([wid1, wid2]) end