Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradients from keyword arguments dropped #567

Open
mcabbott opened this issue Jan 15, 2022 · 0 comments · May be fixed by #569
Open

Gradients from keyword arguments dropped #567

mcabbott opened this issue Jan 15, 2022 · 0 comments · May be fixed by #569
Labels
bug Something isn't working

Comments

@mcabbott
Copy link
Member

The rules for accumulate and foldl don't compute anything for the init keyword. This can lead to silently wrong gradients, which is bad. Maybe this is a bigger problem than I realised.

It would be better to return @not_implemented. In fact, it's possible that all keywords everywhere should be that, or something like it, if possible. And even better to return the true answer, which IIRC these functions do know. Can this be done?

Example from here: https://discourse.julialang.org/t/how-to-efficiently-build-ad-compatible-matrices-line-by-line/74632/17

function f(K, xi, d)
    x = xi
    for i = 2:d
        x = hcat(x, K*x[:, i-1])
    end
    return x
end

K = rand(3,3)
xi = rand(3,1)
f(K, xi, 50)

function f2(K, xi, d::Int)
    xs = accumulate(1:d-1; init=xi) do x, i
        K * x
    end
    hcat(xi, reduce(hcat, xs))
end

Gives this, Fill(1.0, 3, 1) is from hcat alone:

julia> using Zygote

julia> gradient(sum∘f, K, xi, 10)
([63.45016309970954 50.40609159573776 101.36588271461751; 23.572874731387856 18.379315265377535 35.224999619160954; 31.033286457367566 24.176359057416636 46.03455941092244], [48.455853839178204; 14.765466919614408; 18.845362109436827;;], nothing)

julia> gradient(sum∘f2, K, xi, 10) # NB the gradient for init=xi is missing!
([63.45016309970953 50.40609159573775 101.3658827146175; 23.572874731387852 18.379315265377535 35.22499961916095; 31.033286457367552 24.17635905741663 46.03455941092242], Fill(1.0, 3, 1), nothing)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant