Skip to content

Commit fb62a9b

Browse files
authored
Merge pull request #32 from slimgroup/integrate-rrule
ChainRules
2 parents 2d203c2 + 3357c9d commit fb62a9b

File tree

10 files changed

+56
-28
lines changed

10 files changed

+56
-28
lines changed

Diff for: .github/workflows/ci-joli.yaml

+1-8
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,7 @@ jobs:
1919
fail-fast: false
2020

2121
matrix:
22-
version:
23-
- '1.1'
24-
- '1.2'
25-
- '1.3'
26-
- '1.4'
27-
- '1.5'
28-
- '1.6'
29-
22+
version: ['1.6', '1.7']
3023
os:
3124
- ubuntu-latest
3225
- macos-latest

Diff for: .travis.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ os:
44
- linux
55
- osx
66
julia:
7-
- 1.3
7+
- 1.6
88
# - nightly
99
branches:
1010
only:

Diff for: Project.toml

+11-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "JOLI"
22
uuid = "bb331ad6-a1cf-11e9-23da-9bcb53c69f6f"
33
authors = ["Henryk Modzelewski <[email protected]>"]
4-
version = "0.7.16"
4+
version = "0.8.0"
55

66
[deps]
77
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
@@ -20,17 +20,24 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2020
SharedArrays = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
2121
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2222
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
23-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2423
Wavelets = "29a6e085-ba6d-5f35-a997-948ac2efa89a"
2524

2625
[compat]
2726
DistributedArrays = "0.5, 0.6"
2827
FFTW = "1"
28+
Flux = "0.12"
2929
InplaceOps = "0.3.0"
3030
IterativeSolvers = "0.8, 0.9"
31-
NFFT = "0.4, 0.5, 0.6"
31+
NFFT = "0.6 - 0.12"
3232
Nullables = "1"
3333
PyCall = "1.18, 1.90, 1.91, 1.62"
34-
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1, 2"
34+
SpecialFunctions = "1.2, 2"
3535
Wavelets = "0.8, 0.9"
3636
julia = "1"
37+
38+
[extras]
39+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
40+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
41+
42+
[targets]
43+
test = ["Test", "Flux"]

Diff for: src/JOLI.jl

+8
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ using NFFT
5656
using Wavelets
5757
using PyCall
5858
using SpecialFunctions
59+
using SpecialFunctions.ChainRulesCore
5960

6061
# what's imported from Base
6162
import Base.eltype
@@ -84,6 +85,9 @@ import DistributedArrays.SPMD: scatter
8485
# what's imported from IterativeSolvers
8586
import IterativeSolvers.Adivtype
8687

88+
# what's imported from ChainRulesCore
89+
import SpecialFunctions.ChainRulesCore.rrule
90+
8791
# extra exported methods
8892
export deltype, reltype
8993
export elements, hasinverse, issquare, istall, iswide, iscomplex, islinear, isadjoint, isequiv
@@ -134,4 +138,8 @@ include("joLinearFunctionConstructors.jl")
134138
include("joLinearOperatorConstructors.jl")
135139
include("joMixedConstructors.jl")
136140

141+
142+
# ChainRules
143+
include("rrule.jl")
144+
137145
end # module

Diff for: src/joLinearFunctionConstructors/joNFFT.jl

+13-13
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,27 @@ module joNFFT_etc
77
using JOLI: jo_convert
88
function apply_nfft_centered(pln,n,v::Vector{vdt},rdt::DataType) where vdt<:Union{AbstractFloat,Complex}
99
iv=jo_convert(ComplexF64,v,false)
10-
rv=nfft(pln,iv)/sqrt(n)
10+
rv= (pln*iv)/sqrt(n)
1111
rv=fftshift(rv)
1212
rv=jo_convert(rdt,rv,false)
1313
return rv
1414
end
1515
function apply_infft_centered(pln,n,v::Vector{vdt},rdt::DataType) where vdt<:Union{AbstractFloat,Complex}
1616
iv=jo_convert(ComplexF64,v,false)
1717
iv=ifftshift(iv)
18-
rv=nfft_adjoint(pln,iv)/sqrt(n)
18+
rv=(adjoint(pln)*iv)/sqrt(n)
1919
rv=jo_convert(rdt,rv,false)
2020
return rv
2121
end
2222
function apply_nfft(pln,n,v::Vector{vdt},rdt::DataType) where vdt<:Union{AbstractFloat,Complex}
2323
iv=jo_convert(ComplexF64,v,false)
24-
rv=nfft(pln,iv)/sqrt(n)
24+
rv= (pln*iv)/sqrt(n)
2525
rv=jo_convert(rdt,rv,false)
2626
return rv
2727
end
2828
function apply_infft(pln,n,v::Vector{vdt},rdt::DataType) where vdt<:Union{AbstractFloat,Complex}
2929
iv=jo_convert(ComplexF64,v,false)
30-
rv=nfft_adjoint(pln,iv)/sqrt(n)
30+
rv=(adjoint(pln)*iv)/sqrt(n)
3131
rv=jo_convert(rdt,rv,false)
3232
return rv
3333
end
@@ -36,21 +36,21 @@ using .joNFFT_etc
3636

3737
export joNFFT
3838
"""
39-
julia> op = joNFFT(N,nodes[,m=...][,sigma=...][,window=...][,K=...];
39+
julia> op = joNFFT(N,pos[,m=...][,sigma=...][,window=...];
4040
[centered=...,][DDT=...,][RDT=...,][name=...])
4141
4242
1D NFFT transform over fast dimension (wrapper to https://github.com/tknopp/NFFT.jl)
4343
4444
# Signature
4545
4646
function joNFFT(N::Integer,pos::Vector{joFloat},
47-
m=4,sigma=2.0,window=:kaiser_bessel,K=2000; centered::Bool=false,
47+
m=4,sigma=2.0,window=:kaiser_bessel; centered::Bool=false,
4848
DDT::DataType=joComplex,RDT::DataType=DDT,name::String="joNFFT")
4949
5050
# Arguments
5151
5252
- `N`: size
53-
- `nodes`: nodes' positions
53+
- `pos`: nodes positions
5454
- optional
5555
- see https://github.com/tknopp/NFFT.jl for info about optional parameters to NFFTplan: `m`, `sigma`, `window`, and `K`
5656
- keywords
@@ -67,24 +67,24 @@ export joNFFT
6767
6868
1D NFFT
6969
70-
joNFFT(N,nodes)
70+
joNFFT(N,pos)
7171
7272
centered coefficients
7373
74-
joNFFT(N,nodes; centered=true)
74+
joNFFT(N,pos; centered=true)
7575
7676
examples with DDT/RDT
7777
78-
% joNFFT(N,nodes; DDT=ComplexF32)
79-
% joNFFT(N,nodes; DDT=ComplexF32,RDT=ComplexF64)
78+
% joNFFT(N,pos; DDT=ComplexF32)
79+
% joNFFT(N,pos; DDT=ComplexF32,RDT=ComplexF64)
8080
8181
"""
8282
function joNFFT(N::Integer,pos::Vector{joFloat},
83-
m=4,sigma=2.0,window=:kaiser_bessel,K=2000; centered::Bool=false,
83+
m=4,sigma=2.0,window=:kaiser_bessel; centered::Bool=false,
8484
DDT::DataType=joComplex,RDT::DataType=DDT,name::String="joNFFT")
8585

8686
M = length(pos)
87-
p = try NFFTPlan(pos,N,m,sigma,window,K) catch; plan_nfft(pos,N,m,sigma,window,K) end
87+
p = plan_nfft(pos, N; m=m, σ=sigma, window=window)
8888
if centered
8989
return joLinearFunctionFwd_A(M,N,
9090
v1->joNFFT_etc.apply_nfft_centered(p,N,v1,RDT),

Diff for: src/joMixedConstructors/joGaussian.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ examples with DDT/RDT
122122
"""
123123
function joGaussian(M::Integer,N::Integer=M;
124124
implicit::Bool=false,normalized::Bool=false,orthonormal::Bool=false,
125-
RNG::AbstractRNG=Random.seed!(),
125+
RNG::AbstractRNG=MersenneTwister(),
126126
DDT::DataType=joFloat,RDT::DataType=DDT,
127127
name::String="joGaussian")
128128

Diff for: src/rrule.jl

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
function rrule(::typeof(*), A::joAbstractLinearOperator{ADDT,ARDT}, v) where {ADDT,ARDT}
2+
Y = A*v
3+
function time_pullback(dy)
4+
DY = unthunk(Vector(dy))
5+
return NoTangent(), NoTangent(), @thunk(A' * DY)
6+
end
7+
return Y, time_pullback
8+
end

Diff for: test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ include("test_joCurvelet2DnoFFT.jl")
4646
include("test_joExtend.jl")
4747
include("test_joGaussian.jl")
4848
include("test_joOuterProd.jl")
49+
include("test_rrules.jl")
4950
etime=time()
5051
dtime=etime-stime
5152
println("\nTest Total elapsed time: ",round(dtime,digits=1),"s")

Diff for: test/test_joNFFT.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ tsname="joNFFT"
44
for t=1:T # start test loop
55
m=4^t
66
n=3^t
7-
nodes=sort(rand(joFloat,n))
7+
nodes=sort(rand(joFloat,n)) .- joFloat(.5)
88

99
verbose && println("$tsname ($m[,$m]) - not centered")
1010
@testset "$m [x $m]" begin

Diff for: test/test_rrules.jl

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using Flux
2+
3+
x = rand(Float32, 3)
4+
5+
W = randn(Float32, 3, 3)
6+
W_JOLI = joMatrix(W)
7+
8+
g = gradient(x -> sum(W*x), x)[1]
9+
g_JOLI = gradient(x -> sum(W_JOLI*x), x)[1]
10+
11+
@test isapprox(sum((g - g_JOLI).^2)/sum(g_JOLI.^2), 0f0; atol=1f-8)

0 commit comments

Comments
 (0)