Skip to content

Add keyword argument support #394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,22 @@ steps:
julia_args: "--threads=1"
- JuliaCI/julia-coverage#v1:
codecov: true
- label: Julia 1.9
timeout_in_minutes: 60
<<: *test
plugins:
- JuliaCI/julia#v1:
version: "1.9"
- JuliaCI/julia-test#v1:
julia_args: "--threads=1"
- JuliaCI/julia-coverage#v1:
codecov: true
- label: Julia nightly
timeout_in_minutes: 60
<<: *test
plugins:
- JuliaCI/julia#v1:
version: "1.9-nightly"
version: "1.10-nightly"
- JuliaCI/julia-test#v1:
julia_args: "--threads=1"
- JuliaCI/julia-coverage#v1:
Expand Down
6 changes: 3 additions & 3 deletions docs/src/checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,17 @@ Let's see how we'd modify the above example to use checkpointing:

```julia
using Serialization

X = compute(randn(Blocks(128,128), 1024, 1024))
Y = [delayed(sum; options=Dagger.Sch.ThunkOptions(;
checkpoint=(thunk,result)->begin
Y = [delayed(sum; checkpoint=(thunk,result)->begin
open("checkpoint-$idx.bin", "w") do io
serialize(io, collect(result))
end
end, restore=(thunk)->begin
open("checkpoint-$idx.bin", "r") do io
Dagger.tochunk(deserialize(io))
end
end))(chunk) for (idx,chunk) in enumerate(X.chunks)]
end)(chunk) for (idx,chunk) in enumerate(X.chunks)]
inner(x...) = sqrt(sum(x))
Z = delayed(inner)(Y...)
z = collect(Z)
Expand Down
45 changes: 23 additions & 22 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,34 @@

## Usage

The main function for using Dagger is `spawn`:
The main entrypoint to Dagger is `@spawn`:

`Dagger.spawn(f, args...; options...)`
`Dagger.@spawn [option=value]... f(args...; kwargs...)`

or `@spawn` for the more convenient macro form:
or `spawn` if it's more convenient:

`Dagger.@spawn [option=value]... f(args...)`
`Dagger.spawn(f, Dagger.Options(options), args...; kwargs...)`

When called, it creates an `EagerThunk` (also known as a "thunk" or "task")
object representing a call to function `f` with the arguments `args`. If it is
called with other thunks as inputs, such as in `Dagger.@spawn f(Dagger.@spawn
g())`, then the function `f` gets passed the results of those input thunks. If
those thunks aren't yet finished executing, then the execution of `f` waits on
all of its input thunks to complete before executing.
object representing a call to function `f` with the arguments `args` and
keyword arguments `kwargs`. If it is called with other thunks as args/kwargs,
such as in `Dagger.@spawn f(Dagger.@spawn g())`, then the function `f` gets
passed the results of those input thunks, once they're available. If those
thunks aren't yet finished executing, then the execution of `f` waits on all of
its input thunks to complete before executing.

The key point is that, for each argument to a thunk, if the argument is an
`EagerThunk`, it'll be executed before this node and its result will be passed
into the function `f`. If the argument is *not* an `EagerThunk` (instead, some
other type of Julia object), it'll be passed as-is to the function `f`.

Thunks don't accept regular keyword arguments for the function `f`. Instead,
the `options` kwargs are passed to the scheduler to control its behavior:
The `Options` struct in the second argument position is optional; if provided,
it is passed to the scheduler to control its behavior. `Options` contains a
`NamedTuple` of option key-value pairs, which can be any of:
- Any field in `Dagger.Sch.ThunkOptions` (see [Scheduler and Thunk options](@ref))
- `meta::Bool` -- Pass the input `Chunk` objects themselves to `f` and not the value contained in them

There are also some extra kwargs that can be passed, although they're considered advanced options to be used only by developers or library authors:
There are also some extra optionss that can be passed, although they're considered advanced options to be used only by developers or library authors:
- `get_result::Bool` -- return the actual result to the scheduler instead of `Chunk` objects. Used when `f` explicitly constructs a Chunk or when return value is small (e.g. in case of reduce)
- `persist::Bool` -- the result of this Thunk should not be released after it becomes unused in the DAG
- `cache::Bool` -- cache the result of this Thunk such that if the thunk is evaluated again, one can just reuse the cached value. If it’s been removed from cache, recompute the value.
Expand Down Expand Up @@ -133,18 +135,18 @@ via `@par` or `delayed`. The above computation can be executed with the lazy
API by substituting `@spawn` with `@par` and `fetch` with `collect`:

```julia
p = @par add1(4)
q = @par add2(p)
r = @par add1(3)
s = @par combine(p, q, r)
p = Dagger.@par add1(4)
q = Dagger.@par add2(p)
r = Dagger.@par add1(3)
s = Dagger.@par combine(p, q, r)

@assert collect(s) == 16
```

or similarly, in block form:

```julia
s = @par begin
s = Dagger.@par begin
p = add1(4)
q = add2(p)
r = add1(3)
Expand All @@ -159,7 +161,7 @@ operation, you can call `compute` on the thunk. This will return a `Chunk`
object which references the result (see [Chunks](@ref) for more details):

```julia
x = @par 1+2
x = Dagger.@par 1+2
cx = compute(x)
cx::Chunk
@assert collect(cx) == 3
Expand Down Expand Up @@ -207,7 +209,7 @@ Scheduler options can be constructed and passed to `collect()` or `compute()`
as the keyword argument `options` for lazy API usage:

```julia
t = @par 1+2
t = Dagger.@par 1+2
opts = Dagger.Sch.SchedulerOptions(;single=1) # Execute on worker 1

compute(t; options=opts)
Expand All @@ -221,10 +223,9 @@ Thunk options can be passed to `@spawn/spawn`, `@par`, and `delayed` similarly:
# Execute on worker 1

Dagger.@spawn single=1 1+2
Dagger.spawn(+, 1, 2; single=1)
Dagger.spawn(+, Dagger.Options(;single=1), 1, 2)

opts = Dagger.Sch.ThunkOptions(;single=1)
delayed(+)(1, 2; options=opts)
delayed(+; single=1)(1, 2)
```

### Core vs. Worker Schedulers
Expand Down
2 changes: 1 addition & 1 deletion docs/src/propagation.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Option Propagation

Most options passed to Dagger are passed via `delayed` or `Dagger.@spawn`
Most options passed to Dagger are passed via `@spawn/spawn` or `delayed`
directly. This works well when an option only needs to be set for a single
thunk, but is cumbersome when the same option needs to be set on multiple
thunks, or set recursively on thunks spawned within other thunks. Thankfully,
Expand Down
2 changes: 1 addition & 1 deletion src/array/darray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ function thunkize(ctx::Context, c::DArray; persist=true)
if persist
foreach(persist!, thunks)
end
Thunk(thunks...; meta=true) do results...
Thunk(map(thunk->nothing=>thunk, thunks)...; meta=true) do results...
t = eltype(results[1])
DArray(t, dmn, dmnchunks,
reshape(Union{Chunk,Thunk}[results...], sz))
Expand Down
10 changes: 5 additions & 5 deletions src/array/map-reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function stage(ctx::Context, node::Map)
f = node.f
for i=eachindex(domains)
inps = map(x->chunks(x)[i], inputs)
thunks[i] = Thunk((args...) -> map(f, args...), inps...)
thunks[i] = Thunk((args...) -> map(f, args...), map(inp->nothing=>inp, inps)...)
end
DArray(Any, domain(primary), domainchunks(primary), thunks)
end
Expand All @@ -40,8 +40,8 @@ end

function stage(ctx::Context, r::ReduceBlock)
inp = stage(ctx, r.input)
reduced_parts = map(x -> Thunk(r.op, x; get_result=r.get_result), chunks(inp))
Thunk((xs...) -> r.op_master(xs), reduced_parts...; meta=true)
reduced_parts = map(x -> Thunk(r.op, nothing=>x; get_result=r.get_result), chunks(inp))
Thunk((xs...) -> r.op_master(xs), map(part->nothing=>part, reduced_parts)...; meta=true)
end

reduceblock_async(f, x::ArrayOp; get_result=true) = ReduceBlock(f, f, x, get_result)
Expand Down Expand Up @@ -126,10 +126,10 @@ function stage(ctx::Context, r::Reducedim)
inp = cached_stage(ctx, r.input)
thunks = let op = r.op, dims=r.dims
# do reducedim on each block
tmp = map(p->Thunk(b->reduce(op,b,dims=dims), p), chunks(inp))
tmp = map(p->Thunk(b->reduce(op,b,dims=dims), nothing=>p), chunks(inp))
# combine the results in tree fashion
treereducedim(tmp, r.dims) do x,y
Thunk(op, x,y)
Thunk(op, nothing=>x, nothing=>y)
end
end
c = domainchunks(inp)
Expand Down
10 changes: 5 additions & 5 deletions src/array/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ function size(x::Transpose)
end

transpose(x::ArrayOp) = Transpose(transpose, x)
transpose(x::Union{Chunk, Thunk}) = Thunk(transpose, x)
transpose(x::Union{Chunk, Thunk}) = Thunk(transpose, nothing=>x)

adjoint(x::ArrayOp) = Transpose(adjoint, x)
adjoint(x::Union{Chunk, Thunk}) = Thunk(adjoint, x)
adjoint(x::Union{Chunk, Thunk}) = Thunk(adjoint, nothing=>x)

function adjoint(x::ArrayDomain{2})
d = indexes(x)
Expand Down Expand Up @@ -91,8 +91,8 @@ function (+)(a::ArrayDomain, b::ArrayDomain)
a
end

(*)(a::Union{Chunk, Thunk}, b::Union{Chunk, Thunk}) = Thunk(*, a,b)
(+)(a::Union{Chunk, Thunk}, b::Union{Chunk, Thunk}) = Thunk(+, a,b)
(*)(a::Union{Chunk, Thunk}, b::Union{Chunk, Thunk}) = Thunk(*, nothing=>a, nothing=>b)
(+)(a::Union{Chunk, Thunk}, b::Union{Chunk, Thunk}) = Thunk(+, nothing=>a, nothing=>b)

# we define our own matmat and matvec multiply
# for computing the new domains and thunks.
Expand Down Expand Up @@ -211,7 +211,7 @@ end
function _scale(l, r)
res = similar(r, Any)
for i=1:length(l)
res[i,:] = map(x->Thunk((a,b) -> Diagonal(a)*b, l[i], x), r[i,:])
res[i,:] = map(x->Thunk((a,b) -> Diagonal(a)*b, nothing=>l[i], nothing=>x), r[i,:])
end
res
end
Expand Down
2 changes: 1 addition & 1 deletion src/array/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ Base.@deprecate mappart(args...) mapchunk(args...)
function stage(ctx::Context, node::MapChunk)
inputs = map(x->cached_stage(ctx, x), node.input)
thunks = map(map(chunks, inputs)...) do ps...
Thunk(node.f, ps...)
Thunk(node.f, map(p->nothing=>p, ps)...)
end

DArray(Any, domain(inputs[1]), domainchunks(inputs[1]), thunks)
Expand Down
2 changes: 1 addition & 1 deletion src/array/setindex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ function stage(ctx::Context, sidx::SetIndex)
local_dmn = ArrayDomain(map(x->x[2], idx_and_dmn))
s = subdmns[idx...]
part_to_set = sidx.val
ps[idx...] = Thunk(ps[idx...]) do p
ps[idx...] = Thunk(nothing=>ps[idx...]) do p
q = copy(p)
q[indexes(project(s, local_dmn))...] .= part_to_set
q
Expand Down
4 changes: 2 additions & 2 deletions src/compute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ function dependents(node::Thunk)
if !haskey(deps, next)
deps[next] = Set{Thunk}()
end
for inp in inputs(next)
for (_, inp) in next.inputs
if istask(inp) || (inp isa Chunk)
s = get!(()->Set{Thunk}(), deps, inp)
push!(s, next)
Expand Down Expand Up @@ -165,7 +165,7 @@ function order(node::Thunk, ndeps)
haskey(output, next) && continue
s += 1
output[next] = s
parents = filter(istask, inputs(next))
parents = filter(istask, map(last, next.inputs))
if !isempty(parents)
# If parents is empty, sort! should be a no-op, but raises an ambiguity error
# when InlineStrings.jl is loaded (at least, version 1.1.0), because InlineStrings
Expand Down
12 changes: 6 additions & 6 deletions src/processor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ function delete_processor_callback!(name::Symbol)
end

"""
execute!(proc::Processor, f, args...) -> Any
execute!(proc::Processor, f, args...; kwargs...) -> Any

Executes the function `f` with arguments `args` on processor `proc`. This
function can be overloaded by `Processor` subtypes to allow executing function
calls differently than normal Julia.
Executes the function `f` with arguments `args` and keyword arguments `kwargs`
on processor `proc`. This function can be overloaded by `Processor` subtypes to
allow executing function calls differently than normal Julia.
"""
function execute! end

Expand Down Expand Up @@ -154,12 +154,12 @@ end
iscompatible(proc::ThreadProc, opts, f, args...) = true
iscompatible_func(proc::ThreadProc, opts, f) = true
iscompatible_arg(proc::ThreadProc, opts, x) = true
function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...))
function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...); @nospecialize(kwargs...))
tls = get_tls()
task = Task() do
set_tls!(tls)
TimespanLogging.prof_task_put!(tls.sch_handle.thunk_id.id)
@invokelatest f(args...)
@invokelatest f(args...; kwargs...)
end
task.sticky = true
ret = ccall(:jl_set_task_tid, Cint, (Any, Cint), task, proc.tid-1)
Expand Down
Loading