Skip to content

Commit 5a4ccaa

Browse files
authored
Merge pull request #519 from JuliaParallel/stream-max-evals
Add support for limiting the evaluations of a streaming DAG
2 parents 1c1edab + b157f89 commit 5a4ccaa

File tree

4 files changed

+43
-10
lines changed

4 files changed

+43
-10
lines changed

Diff for: .buildkite/pipeline.yml

+11-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
sandbox_capable: "true"
66
os: linux
77
arch: x86_64
8-
command: "julia --project -e 'using Pkg; Pkg.develop(;path=\"lib/TimespanLogging\")'"
8+
command: "julia --project -e 'using Pkg; Pkg.develop(;path=\"lib/TimespanLogging\"); Pkg.add(; url=\"https://github.com/JuliaData/MemPool.jl\", rev=\"jps/migration-helper\")'"
9+
910
.bench: &bench
1011
if: build.message =~ /\[run benchmarks\]/
1112
agents:
@@ -14,6 +15,7 @@
1415
os: linux
1516
arch: x86_64
1617
num_cpus: 16
18+
1719
steps:
1820
- label: Julia 1.8
1921
timeout_in_minutes: 90
@@ -25,6 +27,7 @@ steps:
2527
julia_args: "--threads=1"
2628
- JuliaCI/julia-coverage#v1:
2729
codecov: true
30+
2831
- label: Julia 1.9
2932
timeout_in_minutes: 90
3033
<<: *test
@@ -35,6 +38,7 @@ steps:
3538
julia_args: "--threads=1"
3639
- JuliaCI/julia-coverage#v1:
3740
codecov: true
41+
3842
- label: Julia 1.10
3943
timeout_in_minutes: 90
4044
<<: *test
@@ -45,6 +49,7 @@ steps:
4549
julia_args: "--threads=1"
4650
- JuliaCI/julia-coverage#v1:
4751
codecov: true
52+
4853
- label: Julia nightly
4954
timeout_in_minutes: 90
5055
<<: *test
@@ -55,6 +60,7 @@ steps:
5560
julia_args: "--threads=1"
5661
- JuliaCI/julia-coverage#v1:
5762
codecov: true
63+
5864
- label: Julia 1.8 (macOS)
5965
timeout_in_minutes: 90
6066
<<: *test
@@ -69,6 +75,7 @@ steps:
6975
julia_args: "--threads=1"
7076
- JuliaCI/julia-coverage#v1:
7177
codecov: true
78+
7279
- label: Julia 1.8 - TimespanLogging
7380
timeout_in_minutes: 20
7481
<<: *test
@@ -78,6 +85,7 @@ steps:
7885
- JuliaCI/julia-coverage#v1:
7986
codecov: true
8087
command: "julia --project -e 'using Pkg; Pkg.instantiate(); Pkg.develop(;path=\"lib/TimespanLogging\"); Pkg.test(\"TimespanLogging\")'"
88+
8189
- label: Julia 1.8 - DaggerWebDash
8290
timeout_in_minutes: 20
8391
<<: *test
@@ -87,6 +95,7 @@ steps:
8795
- JuliaCI/julia-coverage#v1:
8896
codecov: true
8997
command: "julia -e 'using Pkg; Pkg.develop(;path=pwd()); Pkg.develop(;path=\"lib/TimespanLogging\"); Pkg.develop(;path=\"lib/DaggerWebDash\"); include(\"lib/DaggerWebDash/test/runtests.jl\")'"
98+
9099
- label: Benchmarks
91100
timeout_in_minutes: 120
92101
<<: *bench
@@ -103,6 +112,7 @@ steps:
103112
BENCHMARK_SCALE: "5:5:50"
104113
artifacts:
105114
- benchmarks/result*
115+
106116
- label: DTables.jl stability test
107117
timeout_in_minutes: 20
108118
plugins:

Diff for: src/sch/Sch.jl

+5-3
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,11 @@ end
253253
Combine `SchedulerOptions` and `ThunkOptions` into a new `ThunkOptions`.
254254
"""
255255
function Base.merge(sopts::SchedulerOptions, topts::ThunkOptions)
256-
single = topts.single !== nothing ? topts.single : sopts.single
257-
allow_errors = topts.allow_errors !== nothing ? topts.allow_errors : sopts.allow_errors
258-
proclist = topts.proclist !== nothing ? topts.proclist : sopts.proclist
256+
select_option = (sopt, topt) -> isnothing(topt) ? sopt : topt
257+
258+
single = select_option(sopts.single, topts.single)
259+
allow_errors = select_option(sopts.allow_errors, topts.allow_errors)
260+
proclist = select_option(sopts.proclist, topts.proclist)
259261
ThunkOptions(single,
260262
proclist,
261263
topts.time_util,

Diff for: src/stream.jl

+9-6
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,10 @@ function initialize_streaming!(self_streams, spec, task)
208208
end
209209
output_buffer = get(spec.options, :stream_output_buffer, ProcessRingBuffer)
210210
stream = Stream{T,output_buffer}(output_buffer_amount)
211-
spec.options = NamedTuple(filter(opt -> opt[1] != :stream_output_buffer &&
212-
opt[1] != :stream_output_buffer_amount,
213-
Base.pairs(spec.options)))
214211
self_streams[task.uid] = stream
215212

216-
spec.f = StreamingFunction(spec.f, stream)
213+
max_evals = get(spec.options, :stream_max_evals, -1)
214+
spec.f = StreamingFunction(spec.f, stream, max_evals)
217215
spec.options = merge(spec.options, (;occupancy=Dict(Any=>0)))
218216

219217
# Register Stream globally
@@ -256,6 +254,7 @@ const STREAM_THUNK_ID = TaskLocalValue{Int}(()->0)
256254
struct StreamingFunction{F, S}
257255
f::F
258256
stream::S
257+
max_evals::Int
259258
end
260259
chunktype(sf::StreamingFunction{F}) where F = F
261260
function (sf::StreamingFunction)(args...; kwargs...)
@@ -319,14 +318,17 @@ end
319318
function stream!(sf::StreamingFunction, uid,
320319
args::Tuple, kwarg_names::Tuple, kwarg_values::Tuple)
321320
f = move(thunk_processor(), sf.f)
322-
while true
321+
counter = 0
322+
323+
while sf.max_evals < 0 || counter < sf.max_evals
323324
# Get values from Stream args/kwargs
324325
stream_args = _stream_take_values!(args, uid)
325326
stream_kwarg_values = _stream_take_values!(kwarg_values, uid)
326327
stream_kwargs = _stream_namedtuple(kwarg_names, stream_kwarg_values)
327328

328329
# Run a single cycle of f
329330
stream_result = f(stream_args...; stream_kwargs...)
331+
counter += 1
330332

331333
# Exit streaming on graceful request
332334
if stream_result isa FinishStream
@@ -412,7 +414,8 @@ function finalize_streaming!(tasks::Vector{Pair{DTaskSpec,DTask}}, self_streams)
412414

413415
# Filter out all streaming options
414416
to_filter = (:stream_input_buffer, :stream_input_buffer_amount,
415-
:stream_output_buffer, :stream_output_buffer_amount)
417+
:stream_output_buffer, :stream_output_buffer_amount,
418+
:stream_max_evals)
416419
spec.options = NamedTuple(filter(opt -> !(opt[1] in to_filter),
417420
Base.pairs(spec.options)))
418421
if haskey(spec.options, :propagates)

Diff for: test/streaming.jl

+18
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
end
66
return x
77
end
8+
89
function catch_interrupt(f)
910
try
1011
f()
@@ -17,6 +18,7 @@ function catch_interrupt(f)
1718
rethrow(err)
1819
end
1920
end
21+
2022
function test_finishes(f, message::String; ignore_timeout=false)
2123
t = @eval Threads.@spawn @testset $message catch_interrupt($f)
2224
if timedwait(()->istaskdone(t), 10) == :timed_out
@@ -29,6 +31,7 @@ function test_finishes(f, message::String; ignore_timeout=false)
2931
end
3032
return true
3133
end
34+
3235
@testset "Basics" begin
3336
@test test_finishes("Single task") do
3437
local x
@@ -50,6 +53,21 @@ end
5053
fetch(x)
5154
end
5255

56+
@test test_finishes("Max evaluations") do
57+
counter = 0
58+
function incrementer()
59+
counter += 1
60+
end
61+
62+
x = Dagger.with_options(; stream_max_evals=10) do
63+
Dagger.spawn_streaming() do
64+
Dagger.@spawn incrementer()
65+
end
66+
end
67+
wait(x)
68+
@test counter == 10
69+
end
70+
5371
@test test_finishes("Two tasks (sequential)") do
5472
local x, y
5573
@warn "\n\n\nStart streaming\n\n\n"

0 commit comments

Comments
 (0)