You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The rules for accumulate and foldl don't compute anything for the init keyword. This can lead to silently wrong gradients, which is bad. Maybe this is a bigger problem than I realised.
It would be better to return @not_implemented. In fact, it's possible that all keywords everywhere should be that, or something like it, if possible. And even better to return the true answer, which IIRC these functions do know. Can this be done?
function f(K, xi, d)
x = xi
for i = 2:d
x = hcat(x, K*x[:, i-1])
end
return x
end
K = rand(3,3)
xi = rand(3,1)
f(K, xi, 50)
function f2(K, xi, d::Int)
xs = accumulate(1:d-1; init=xi) do x, i
K * x
end
hcat(xi, reduce(hcat, xs))
end
Gives this, Fill(1.0, 3, 1) is from hcat alone:
julia> using Zygote
julia> gradient(sum∘f, K, xi, 10)
([63.45016309970954 50.40609159573776 101.36588271461751; 23.572874731387856 18.379315265377535 35.224999619160954; 31.033286457367566 24.176359057416636 46.03455941092244], [48.455853839178204; 14.765466919614408; 18.845362109436827;;], nothing)
julia> gradient(sum∘f2, K, xi, 10) # NB the gradient for init=xi is missing!
([63.45016309970953 50.40609159573775 101.3658827146175; 23.572874731387852 18.379315265377535 35.22499961916095; 31.033286457367552 24.17635905741663 46.03455941092242], Fill(1.0, 3, 1), nothing)
The text was updated successfully, but these errors were encountered:
The rules for
accumulate
andfoldl
don't compute anything for theinit
keyword. This can lead to silently wrong gradients, which is bad. Maybe this is a bigger problem than I realised.It would be better to return
@not_implemented
. In fact, it's possible that all keywords everywhere should be that, or something like it, if possible. And even better to return the true answer, which IIRC these functions do know. Can this be done?Example from here: https://discourse.julialang.org/t/how-to-efficiently-build-ad-compatible-matrices-line-by-line/74632/17
Gives this,
Fill(1.0, 3, 1)
is fromhcat
alone:The text was updated successfully, but these errors were encountered: