Skip to content

Commit 687378e

Browse files
authored
allow default_addprocs_params to be specialized on ClusterManager (#38570)
I made this change to allow the set to be expanded for my own package, but I noticed this also helps unify #38353 and existing ssh-only options.
1 parent 87bfa51 commit 687378e

File tree

3 files changed

+30
-12
lines changed

3 files changed

+30
-12
lines changed

stdlib/Distributed/docs/src/index.md

+1
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,5 @@ Distributed.connect(::ClusterManager, ::Int, ::WorkerConfig)
6565
Distributed.init_worker
6666
Distributed.start_worker
6767
Distributed.process_messages
68+
Distributed.default_addprocs_params
6869
```

stdlib/Distributed/src/cluster.jl

+9-5
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ function addprocs(manager::ClusterManager; kwargs...)
448448
end
449449

450450
function addprocs_locked(manager::ClusterManager; kwargs...)
451-
params = merge(default_addprocs_params(), Dict{Symbol,Any}(kwargs))
451+
params = merge(default_addprocs_params(manager), Dict{Symbol,Any}(kwargs))
452452
topology(Symbol(params[:topology]))
453453

454454
if PGRP.topology !== :all_to_all
@@ -513,12 +513,16 @@ function set_valid_processes(plist::Array{Int})
513513
end
514514
end
515515

516+
"""
517+
default_addprocs_params(mgr::ClusterManager) -> Dict{Symbol, Any}
518+
519+
Implemented by cluster managers. The default keyword parameters passed when calling
520+
`addprocs(mgr)`. The minimal set of options is available by calling
521+
`default_addprocs_params()`
522+
"""
523+
default_addprocs_params(::ClusterManager) = default_addprocs_params()
516524
default_addprocs_params() = Dict{Symbol,Any}(
517525
:topology => :all_to_all,
518-
:ssh => "ssh",
519-
:shell => :posix,
520-
:cmdline_cookie => false,
521-
:env => [],
522526
:dir => pwd(),
523527
:exename => joinpath(Sys.BINDIR::String, julia_exename()),
524528
:exeflags => ``,

stdlib/Distributed/src/managers.jl

+20-7
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ struct SSHManager <: ClusterManager
3434
end
3535

3636

37-
function check_addprocs_args(kwargs)
38-
valid_kw_names = collect(keys(default_addprocs_params()))
37+
function check_addprocs_args(manager, kwargs)
38+
valid_kw_names = keys(default_addprocs_params(manager))
3939
for keyname in keys(kwargs)
4040
!(keyname in valid_kw_names) && throw(ArgumentError("Invalid keyword argument $(keyname)"))
4141
end
@@ -137,11 +137,23 @@ This timeout can be controlled via environment variable `JULIA_WORKER_TIMEOUT`.
137137
The value of `JULIA_WORKER_TIMEOUT` on the master process specifies the number of seconds a
138138
newly launched worker waits for connection establishment.
139139
"""
140-
function addprocs(machines::AbstractVector; tunnel=false, multiplex=false, sshflags=``, max_parallel=10, kwargs...)
141-
check_addprocs_args(kwargs)
142-
addprocs(SSHManager(machines); tunnel=tunnel, multiplex=multiplex, sshflags=sshflags, max_parallel=max_parallel, kwargs...)
140+
function addprocs(machines::AbstractVector; kwargs...)
141+
manager = SSHManager(machines)
142+
check_addprocs_args(manager, kwargs)
143+
addprocs(manager; kwargs...)
143144
end
144145

146+
default_addprocs_params(::SSHManager) =
147+
merge(default_addprocs_params(),
148+
Dict{Symbol,Any}(
149+
:ssh => "ssh",
150+
:sshflags => ``,
151+
:shell => :posix,
152+
:cmdline_cookie => false,
153+
:env => [],
154+
:tunnel => false,
155+
:multiplex => false,
156+
:max_parallel => 10))
145157

146158
function launch(manager::SSHManager, params::Dict, launched::Array, launch_ntfy::Condition)
147159
# Launch one worker on each unique host in parallel. Additional workers are launched later.
@@ -426,8 +438,9 @@ processes on the local machine. If `restrict` is `true`, binding is restricted t
426438
`enable_threaded_blas` have the same effect as documented for `addprocs(machines)`.
427439
"""
428440
function addprocs(np::Integer; restrict=true, kwargs...)
429-
check_addprocs_args(kwargs)
430-
addprocs(LocalManager(np, restrict); kwargs...)
441+
manager = LocalManager(np, restrict)
442+
check_addprocs_args(manager, kwargs)
443+
addprocs(manager; kwargs...)
431444
end
432445

433446
Base.show(io::IO, manager::LocalManager) = print(io, "LocalManager()")

0 commit comments

Comments
 (0)