Skip to content

Commit 52a97dd

Browse files
authored
Merge pull request #555 from JuliaParallel/jps/fetch-all
Add fetch_all recursive helper
2 parents 2764b76 + bc042c9 commit 52a97dd

File tree

5 files changed

+48
-13
lines changed

5 files changed

+48
-13
lines changed

Diff for: Project.toml

+15-13
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ uuid = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
33
version = "0.18.12"
44

55
[deps]
6+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
67
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
78
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
89
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
@@ -24,7 +25,21 @@ TaskLocalValues = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34"
2425
TimespanLogging = "a526e669-04d3-4846-9525-c66122c55f63"
2526
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2627

28+
[weakdeps]
29+
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
30+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
31+
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"
32+
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
33+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
34+
35+
[extensions]
36+
GraphVizExt = "GraphViz"
37+
GraphVizSimpleExt = "Colors"
38+
JSON3Ext = "JSON3"
39+
PlotsExt = ["DataFrames", "Plots"]
40+
2741
[compat]
42+
Adapt = "4.0.4"
2843
Colors = "0.12"
2944
DataFrames = "1"
3045
DataStructures = "0.18"
@@ -43,22 +58,9 @@ TaskLocalValues = "0.1"
4358
TimespanLogging = "0.1"
4459
julia = "1.8"
4560

46-
[extensions]
47-
GraphVizExt = "GraphViz"
48-
GraphVizSimpleExt = "Colors"
49-
JSON3Ext = "JSON3"
50-
PlotsExt = ["DataFrames", "Plots"]
51-
5261
[extras]
5362
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
5463
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
5564
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"
5665
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
5766
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
58-
59-
[weakdeps]
60-
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
61-
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
62-
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"
63-
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
64-
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"

Diff for: src/Dagger.jl

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ end
2929
import TimespanLogging
3030
import TimespanLogging: timespan_start, timespan_finish
3131

32+
import Adapt
33+
3234
include("lib/util.jl")
3335
include("utils/dagdebug.jl")
3436

Diff for: src/chunks.jl

+2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ is_task_or_chunk(c::Chunk) = true
6565
Base.:(==)(c1::Chunk, c2::Chunk) = c1.handle == c2.handle
6666
Base.hash(c::Chunk, x::UInt64) = hash(c.handle, hash(Chunk, x))
6767

68+
Adapt.adapt_storage(::FetchAdaptor, x::Chunk) = fetch(x)
69+
6870
collect_remote(chunk::Chunk) =
6971
move(chunk.processor, OSProc(), poolget(chunk.handle))
7072

Diff for: src/thunk.jl

+14
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,20 @@ function spawn(f, args...; kwargs...)
486486
return task
487487
end
488488

489+
struct FetchAdaptor end
490+
Adapt.adapt_storage(::FetchAdaptor, x::DTask) = fetch(x)
491+
Adapt.adapt_structure(::FetchAdaptor, A::AbstractArray) =
492+
map(x->Adapt.adapt(FetchAdaptor(), x), A)
493+
494+
"""
495+
Dagger.fetch_all(x)
496+
497+
Recursively fetches all `DTask`s and `Chunk`s in `x`, returning an equivalent
498+
object. Useful for converting arbitrary Dagger-enabled objects into a
499+
non-Dagger form.
500+
"""
501+
fetch_all(x) = Adapt.adapt(FetchAdaptor(), x)
502+
489503
persist!(t::Thunk) = (t.persist=true; t)
490504
cache_result!(t::Thunk) = (t.cache=true; t)
491505

Diff for: test/thunk.jl

+15
Original file line numberDiff line numberDiff line change
@@ -378,4 +378,19 @@ end
378378
Dagger.@spawn error()
379379
end
380380
end
381+
@testset "fetch_all" begin
382+
ts = [Dagger.@spawn(1+1) for _ in 1:4]
383+
@test Dagger.fetch_all(ts) == [2, 2, 2, 2]
384+
cs = map(t->fetch(t; raw=true), ts)
385+
@test Dagger.fetch_all(cs) == [2, 2, 2, 2]
386+
387+
ts = Tuple(Dagger.@spawn(1+1) for _ in 1:4)
388+
@test Dagger.fetch_all(ts) == (2, 2, 2, 2)
389+
cs = fetch.(ts; raw=true)
390+
@test Dagger.fetch_all(cs) == (2, 2, 2, 2)
391+
392+
t = Dagger.@spawn 1+1
393+
@test Dagger.fetch_all(t) == 2
394+
@test Dagger.fetch_all(fetch(t; raw=true)) == 2
395+
end
381396
end

0 commit comments

Comments
 (0)