Skip to content

Add Python integration #572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"

[extensions]
DistributionsExt = "Distributions"
GraphVizExt = "GraphViz"
GraphVizSimpleExt = "Colors"
JSON3Ext = "JSON3"
PlotsExt = ["DataFrames", "Plots"]
PythonExt = "PythonCall"

[compat]
Adapt = "4.0.4"
Expand Down Expand Up @@ -69,3 +71,4 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
3 changes: 3 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ makedocs(;
"Logging: Visualization" => "logging-visualization.md",
"Logging: Advanced" => "logging-advanced.md",
],
"External Languages" => [
"Python" => "external-languages/python.md",
],
"Checkpointing" => "checkpointing.md",
"Benchmarking" => "benchmarking.md",
"Dynamic Scheduler Control" => "dynamic.md",
Expand Down
49 changes: 49 additions & 0 deletions docs/src/external-languages/python.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Python Integration

If you're using Python as your main programming language, Dagger can be easily
integrated into your workflow. Dagger has built-in support for Python, which
you can easily access through the `pydaggerjl` library (accessible through
`pip`). This library provides a Pythonic interface to Dagger, allowing you to
spawn Dagger tasks that run Python functions on Python arguments.

Here's a simple example of interfacing between Python's `numpy` library and
Dagger:

```python
import numpy as np
from pydaggerjl import daggerjl

# Create a Dagger DTask to sum the elements of an array
task = daggerjl.spawn(np.sum, np.array([1, 2, 3]))

# Wait on task to finish
# This is purely educational, as fetch will wait for the task to finish
daggerjl.wait(task)

# Fetch the result
result = daggerjl.fetch(task)

print(f"The sum is: {result}")

# Create two numpy arrays
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])

# Element-wise sum of the two arrays
task = daggerjl.spawn(np.add, a, b)

# Fetch the result
result = daggerjl.fetch(task)

print(f"The element-wise sum is: {result}")

# Element-wise sum of last result with itself
task2 = daggerjl.spawn(np.add, task, task)

# Fetch the result
result2 = daggerjl.fetch(task2)

print(f"The element-wise sum of the last result with itself is: {result2}")
```

Keep an eye on Dagger and pydaggerjl - new features are soon to come!
46 changes: 46 additions & 0 deletions ext/PythonExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
module PythonExt

if isdefined(Base, :get_extension)
using PythonCall
else
using ..PythonCall
end

import Dagger
import Dagger: Processor, OSProc, ThreadProc, Chunk
import Distributed: myid

const CPUProc = Union{OSProc, ThreadProc}

struct PythonProcessor <: Processor
owner::Int
end

Dagger.root_worker_id(proc::PythonProcessor) = proc.owner
Dagger.get_parent(proc::PythonProcessor) = OSProc(proc.owner)
Dagger.default_enabled(::PythonProcessor) = true

Dagger.iscompatible_func(::ThreadProc, opts, ::Type{Py}) = false
Dagger.iscompatible_func(::PythonProcessor, opts, ::Type{Py}) = true
Dagger.iscompatible_arg(::PythonProcessor, opts, ::Type{Py}) = true
Dagger.iscompatible_arg(::PythonProcessor, opts, ::Type{<:PyArray}) = true

Dagger.move(from_proc::CPUProc, to_proc::PythonProcessor, x::Chunk) =
Dagger.move(from_proc, to_proc, Dagger.move(from_proc, Dagger.get_parent(to_proc), x))
Dagger.move(::CPUProc, ::PythonProcessor, x) = Py(x)
Dagger.move(::CPUProc, ::PythonProcessor, x::Py) = x
Dagger.move(::CPUProc, ::PythonProcessor, x::PyArray) = x
# FIXME: Conversion from Python to Julia

function Dagger.execute!(::PythonProcessor, f, args...; kwargs...)
@assert f isa Py "Function must be a Python object"
return f(args...; kwargs...)
end

function __init__()
Dagger.add_processor_callback!(:pythonproc) do
return PythonProcessor(myid())
end
end

end # module PythonExt
6 changes: 3 additions & 3 deletions src/sch/Sch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ function schedule!(ctx, state, procs=procs_to_use(ctx))

for proc in local_procs
gproc = get_parent(proc)
can_use, scope = can_use_proc(task, gproc, proc, opts, scope)
can_use, scope = can_use_proc(state, task, gproc, proc, opts, scope)
if can_use
has_cap, est_time_util, est_alloc_util, est_occupancy =
has_capacity(state, proc, gproc.pid, opts.time_util, opts.alloc_util, opts.occupancy, sig)
Expand Down Expand Up @@ -806,7 +806,7 @@ function schedule!(ctx, state, procs=procs_to_use(ctx))
cap, extra_util = nothing, nothing
procs_found = false
# N.B. if we only have one processor, we need to select it now
can_use, scope = can_use_proc(task, entry.gproc, entry.proc, opts, scope)
can_use, scope = can_use_proc(state, task, entry.gproc, entry.proc, opts, scope)
if can_use
has_cap, est_time_util, est_alloc_util, est_occupancy =
has_capacity(state, entry.proc, entry.gproc.pid, opts.time_util, opts.alloc_util, opts.occupancy, sig)
Expand All @@ -832,7 +832,7 @@ function schedule!(ctx, state, procs=procs_to_use(ctx))
@goto pop_task
end

can_use, scope = can_use_proc(task, entry.gproc, entry.proc, opts, scope)
can_use, scope = can_use_proc(state, task, entry.gproc, entry.proc, opts, scope)
if can_use
has_cap, est_time_util, est_alloc_util, est_occupancy =
has_capacity(state, entry.proc, entry.gproc.pid, opts.time_util, opts.alloc_util, opts.occupancy, sig)
Expand Down
20 changes: 19 additions & 1 deletion src/sch/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ function signature(f, args)
return sig
end

function can_use_proc(task, gproc, proc, opts, scope)
function can_use_proc(state, task, gproc, proc, opts, scope)
# Check against proclist
if opts.proclist !== nothing
@warn "The `proclist` option is deprecated, please use scopes instead\nSee https://juliaparallel.org/Dagger.jl/stable/scopes/ for details" maxlog=1
Expand Down Expand Up @@ -369,6 +369,24 @@ function can_use_proc(task, gproc, proc, opts, scope)
return false, scope
end

# Check against f/args
Tf = chunktype(task.f)
if !Dagger.iscompatible_func(proc, opts, Tf)
@dagdebug task :scope "Rejected $proc: Not compatible with function type ($Tf)"
return false, scope
end
for (_, arg) in task.inputs
arg = unwrap_weak_checked(arg)
if arg isa Thunk
arg = state.cache[arg]
end
Targ = chunktype(arg)
if !Dagger.iscompatible_arg(proc, opts, Targ)
@dagdebug task :scope "Rejected $proc: Not compatible with argument type ($Targ)"
return false, scope
end
end

@label accept

@dagdebug task :scope "Accepted $proc"
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -12,6 +13,7 @@ MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94"
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
SharedArrays = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
Expand Down
30 changes: 30 additions & 0 deletions test/extlang/python.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using PythonCall
using CondaPkg

CondaPkg.add("numpy")

np = pyimport("numpy")

# Restart scheduler to see new methods
Dagger.cancel!(;halt_sch=true)

@testset "spawn" begin
a = np.array([1, 2, 3])

t = Dagger.@spawn np.sum(a)
result = fetch(t)
@test result isa Py
@test pyconvert(Int, result) == 6

b = np.array([4, 5, 6])

t = Dagger.@spawn np.add(a, b)
result = fetch(t)
@test result isa Py
@test pyconvert(Array, result) == [5, 7, 9]

t2 = Dagger.@spawn np.add(t, b)
result = fetch(t2)
@test result isa Py
@test pyconvert(Array, result) == [9, 12, 15]
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ tests = [
("Caching", "cache.jl"),
("Disk Caching", "diskcaching.jl"),
("File IO", "file-io.jl"),
("External Languages - Python", "extlang/python.jl"),
#("Fault Tolerance", "fault-tolerance.jl"),
]
all_test_names = map(test -> replace(last(test), ".jl"=>""), tests)
Expand Down