Skip to content

Commit aeb491a

Browse files
Merge pull request #502 from LilithHafner/lh/PyCall-extension
Add PyCall extension
2 parents 75b1925 + 931449c commit aeb491a

File tree

6 files changed

+92
-1
lines changed

6 files changed

+92
-1
lines changed

.github/workflows/CI.yml

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ jobs:
1818
group:
1919
- Core
2020
- Downstream
21+
- Python
2122
version:
2223
- '1'
2324
steps:

Project.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@ TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
3333
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3434

3535
[weakdeps]
36+
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
3637
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3738

3839
[extensions]
40+
PyCallExt = "PyCall"
3941
ZygoteExt = "Zygote"
4042

4143
[compat]
@@ -69,11 +71,12 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
6971
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
7072
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
7173
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
74+
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
7275
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
7376
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
7477
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
7578
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7679
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
7780

7881
[targets]
79-
test = ["Pkg", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote"]
82+
test = ["Pkg", "PyCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote"]

ext/PyCallExt.jl

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
module PyCallExt
2+
3+
using PyCall: PyCall, PyObject, PyAny, pyfunctionret, pyimport, hasproperty
4+
using SciMLBase: SciMLBase, solve
5+
6+
# SciML uses a function's arity (number of arguments) to determine if it operates in place.
7+
# PyCall does not preserve arity, so we inspect Python functions to find their arity.
8+
function SciMLBase.numargs(f::PyObject)
9+
inspect = pyimport("inspect")
10+
f2 = hasproperty(f, :py_func) ? f.py_func : f
11+
# if `f` is a bound method (i.e., `self.f`), `getfullargspec` includes
12+
# `self` in the `args` list. So, we subtract 1 in that case:
13+
length(first(inspect.getfullargspec(f2))) - inspect.ismethod(f2)
14+
end
15+
16+
# differential equation solutions can be converted to lists, this tells PyCall not
17+
# to perform that conversion automatically when a solution is returned from `solve`
18+
PyCall.PyObject(::typeof(solve)) = pyfunctionret(solve, Any, Vararg{PyAny})
19+
20+
end

test/python/Project.toml

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[deps]
2+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
3+
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
4+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
5+
6+
[compat]
7+
OrdinaryDiffEq = "6.33"
8+
PyCall = "1.96"
9+
SciMLBase = "2"

test/python/pycall.jl

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
using PyCall, SciMLBase, OrdinaryDiffEq
2+
3+
py""" # This is a mess because normal site-packages is not writeable in CI
4+
import subprocess, sys, site
5+
subprocess.run([sys.executable, '-m', 'pip', 'install', '--user', 'julia'])
6+
sys.path.append(site.getusersitepackages())
7+
"""
8+
9+
@testset "numargs" begin
10+
py"""
11+
def three_arg(a, b, c):
12+
return a + b + c
13+
14+
def four_arg(a, b, c, d):
15+
return a + b + c + d
16+
17+
class MyClass:
18+
def three_arg_method(self, a, b, c):
19+
return a + b + c
20+
21+
def four_arg_method(self, a, b, c, d):
22+
return a + b + c + d
23+
"""
24+
25+
@test SciMLBase.numargs(py"three_arg") === 3
26+
@test SciMLBase.numargs(py"four_arg") === 4
27+
x = py"MyClass()"
28+
@test SciMLBase.numargs(x.three_arg_method) === 3
29+
@test SciMLBase.numargs(x.four_arg_method) === 4
30+
end
31+
32+
@testset "solution handling" begin
33+
py"""
34+
from julia import OrdinaryDiffEq as ode
35+
36+
def f(u,p,t):
37+
return -u
38+
39+
u0 = 0.5
40+
tspan = (0., 1.)
41+
prob = ode.ODEProblem(f, u0, tspan)
42+
sol = ode.solve(prob, ode.Tsit5())
43+
"""
44+
@test py"sol" isa ODESolution
45+
end

test/runtests.jl

+13
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ function activate_downstream_env()
1515
Pkg.instantiate()
1616
end
1717

18+
function activate_python_env()
19+
Pkg.activate("python")
20+
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
21+
Pkg.instantiate()
22+
end
23+
1824
@time begin
1925
if GROUP == "Core" || GROUP == "All"
2026
@time @safetestset "Aqua" begin
@@ -93,4 +99,11 @@ end
9399
include("downstream/remake_autodiff.jl")
94100
end
95101
end
102+
103+
if !is_APPVEYOR && GROUP == "Python"
104+
activate_python_env()
105+
@time @safetestset "PyCall" begin
106+
include("python/pycall.jl")
107+
end
108+
end
96109
end

0 commit comments

Comments
 (0)