Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic tangent generation for simple types #43

Merged
merged 20 commits into from
Jun 29, 2020
Merged

Conversation

willtebbutt
Copy link
Member

First steps of an implementation to automatically construct appropriate tangent vectors. @oxinabox @sethaxen let me know if you think these are okay as a starting point / if you think that there are any other types that are crucial to have this implemented for.

I would very much like to cover general structs with a generated function, but I've not had a chance to sort that out yet.

@willtebbutt willtebbutt changed the title Implements for simple types Automatic tangent generation for simple types Jun 25, 2020
@oxinabox
Copy link
Member

I would very much like to cover general structs with a generated function, but I've not had a chance to sort that out yet.

I donb't think it should need a generated function.
It should optimize just fine with a normal function using fieldnames etc.
And even if it doesn't it is just not that performance critical

@oxinabox
Copy link
Member

I think this needs a docstring that describes x as an instance of the primal value

@willtebbutt
Copy link
Member Author

@oxinabox I've got a weird test failure. Would you mind taking a look? It's in some of your more complex code, so might be obvious to you what's going on.

@oxinabox
Copy link
Member

Which error?

https://travis-ci.com/github/JuliaDiff/ChainRulesTestUtils.jl/jobs/354027267#L263-L264
Int + DoesNotExist() erroring seems wrong.
Answer should be an Int
code for that is here: https://github.com/JuliaDiff/ChainRulesCore.jl/blob/master/src/differential_arithmetic.jl#L20


Errors on 1.3 and earlier about length are because you are broadcasting a RNG which is not supported until 1.4
JuliaLang/julia#33213

Could port that change into Compat.jl
or could just add randn.(Ref(rng), x)

@willtebbutt
Copy link
Member Author

willtebbutt commented Jun 26, 2020

Nah, this error:

generate_tangent: Error During Test at /Users/willtebbutt/.julia/dev/ChainRulesTestUtils/test/generate_tangent.jl:25
  Test threw exception
  Expression: x + rand_tangent(rng, x) isa typeof(x)
  MethodError: Cannot `convert` an object of type Int64 to an object of type DoesNotExist
  Closest candidates are:
    convert(::Type{T}, !Matched::T) where T at essentials.jl:171
  Stacktrace:
   [1] convert(::Type{Tuple{DoesNotExist}}, ::Tuple{Int64}) at ./essentials.jl:310 (repeats 2 times)
   [2] Tuple{Float64,DoesNotExist}(::Tuple{Float64,Int64}) at ./tuple.jl:225
   [3] NamedTuple{(:a, :b),Tuple{Float64,DoesNotExist}}(::Tuple{Float64,Int64}) at ./namedtuple.jl:72
   [4] macro expansion at /Users/willtebbutt/.julia/packages/ChainRulesCore/Q5Nrj/src/differentials/composite.jl:0 [inlined]
   [5] elementwise_add(::NamedTuple{(:a, :b),Tuple{Float64,Int64}}, ::NamedTuple{(:a, :b),Tuple{Float64,DoesNotExist}}) at /Users/willtebbutt/.julia/packages/ChainRulesCore/Q5Nrj/src/differentials/composite.jl:191
   [6] +(::NamedTuple{(:a, :b),Tuple{Float64,Int64}}, ::Composite{NamedTuple{(:a, :b),Tuple{Float64,Int64}},NamedTuple{(:a, :b),Tuple{Float64,DoesNotExist}}}) at /Users/willtebbutt/.julia/packages/ChainRulesCore/Q5Nrj/src/differential_arithmetic.jl:88
   [7] (::var"#3#4"{MersenneTwister})(::Tuple{NamedTuple{(:a, :b),Tuple{Float64,Int64}},UnionAll}) at /Users/willtebbutt/.julia/dev/ChainRulesTestUtils/test/generate_tangent.jl:25
   [8] foreach(::var"#3#4"{MersenneTwister}, ::Array{Tuple{Any,Type},1}) at ./abstractarray.jl:1919
   [9] top-level scope at /Users/willtebbutt/.julia/dev/ChainRulesTestUtils/test/generate_tangent.jl:6
   [10] top-level scope at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.4/Test/src/Test.jl:1113
   [11] top-level scope at /Users/willtebbutt/.julia/dev/ChainRulesTestUtils/test/generate_tangent.jl:4

Appears to stem from the Composite code. 4 + DoesNotExist() works fine.

@oxinabox
Copy link
Member

AH yes that is a bug in ChainRulesCore. I will fix it

@willtebbutt
Copy link
Member Author

@oxinabox your fix in ChainRulesCore appears to have fixed the issue. Thanks for the quick fix.

willtebbutt and others added 2 commits June 26, 2020 23:48
Co-authored-by: Seth Axen <[email protected]>
Co-authored-by: Seth Axen <[email protected]>
@sethaxen
Copy link
Member

Unless there's more you want to add, I think all this is missing is rand_tangent(::AbstractRNG, ::Bool).

@willtebbutt
Copy link
Member Author

Unless there's more you want to add, I think all this is missing is rand_tangent(::AbstractRNG, ::Bool).

I think Bool should be covered by Integer. I'll add a test though.

@willtebbutt
Copy link
Member Author

I also still definitely want to add stuff that can handle Functions, and more general structs. That should happen today at some point.

@sethaxen
Copy link
Member

Unless there's more you want to add, I think all this is missing is rand_tangent(::AbstractRNG, ::Bool).

I think Bool should be covered by Integer. I'll add a test though.

TIL that Bool <: Integer

@willtebbutt
Copy link
Member Author

@sethaxen @oxinabox I'm happy with this PR now. Let me know if there's anything else you want.

@willtebbutt
Copy link
Member Author

Thanks for the comments, they're now incorporated. Ready for another round of reviews @sethaxen @oxinabox

@willtebbutt
Copy link
Member Author

@oxinabox default_rng wasn't a thing on 1.0, and Compat doesn't appear to have support for it. Do you have any suggestions, or shall I revert to GLOBAL_RNG?

Copy link
Member

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, once the rng part is cleared up.

@oxinabox
Copy link
Member

Revert to GLOBAL_RNG i think.

@willtebbutt willtebbutt merged commit 23c78a6 into master Jun 29, 2020
@willtebbutt willtebbutt deleted the wct/auto-tangent branch June 29, 2020 23:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants