Skip to content

Commit 48dec2c

Browse files
avik-paloxinaboxhyrodiummateuszbaran
authored
Bring over tests from ChainRulesCore (#1231)
* Bring over tests from ChainRulesCore * add missing begin Co-authored-by: Yuto Horikawa <[email protected]> * use axes to store the axes and not size * Update ext/StaticArraysChainRulesCoreExt.jl Co-authored-by: Frames White <[email protected]> --------- Co-authored-by: Frames White <[email protected]> Co-authored-by: Frames White <[email protected]> Co-authored-by: Yuto Horikawa <[email protected]> Co-authored-by: Mateusz Baran <[email protected]>
1 parent f06af93 commit 48dec2c

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

ext/StaticArraysChainRulesCoreExt.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@ end
1515

1616
# Project SArray to SArray
1717
function ProjectTo(x::SArray{S, T}) where {S, T}
18-
return ProjectTo{SArray}(; element = CRC._eltype_projectto(T), axes = Size(x))
18+
# We have a axes field because it is expected by other ProjectTo's like the one for Transpose
19+
return ProjectTo{SArray}(; element = CRC._eltype_projectto(T), axes = axes(x),
20+
size = Size(x))
1921
end
2022

2123
@inline _sarray_from_array(::Size{T}, dx::AbstractArray) where {T} = SArray{Tuple{T...}}(dx)
2224

23-
(project::ProjectTo{SArray})(dx::AbstractArray) = _sarray_from_array(project.axes, dx)
25+
(project::ProjectTo{SArray})(dx::AbstractArray) = _sarray_from_array(project.size, dx)
2426

2527
# Adjoint for SArray constructor
2628
function rrule(::Type{T}, x::Tuple) where {T <: SArray}

test/chainrules.jl

+19-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,24 @@
1-
using StaticArrays, ChainRulesCore, ChainRulesTestUtils, JLArrays, Test
1+
using StaticArrays, ChainRulesCore, ChainRulesTestUtils, JLArrays, LinearAlgebra, Test
22

3-
@testset "Chain Rules Integration" begin
3+
@testset "ChainRules Integration" begin
44
@testset "Projection" begin
5+
# There is no code for this, but when argument isa StaticArray, axes(x) === axes(dx)
6+
# implies a check, and reshape will wrap a Vector into a static SizedVector:
7+
pstat = ProjectTo(SA[1, 2, 3])
8+
@test axes(pstat(rand(3))) === (SOneTo(3),)
9+
10+
# This recurses into structured arrays:
11+
pst = ProjectTo(transpose(SA[1, 2, 3]))
12+
@test axes(pst(rand(1,3))) === (SOneTo(1), SOneTo(3))
13+
@test pst(rand(1,3)) isa Transpose
14+
15+
# When the argument is an ordinary Array, static gradients are allowed to pass,
16+
# like FillArrays. Collecting to an Array would cost a copy.
17+
pvec3 = ProjectTo([1, 2, 3])
18+
@test pvec3(SA[1, 2, 3]) isa StaticArray
19+
end
20+
21+
@testset "Constructor rrules" begin
522
test_rrule(SMatrix{1, 4}, (1.0, 1.0, 1.0, 1.0))
623
test_rrule(SMatrix{4, 1}, (1.0, 1.0, 1.0, 1.0))
724
test_rrule(SMatrix{2, 2}, (1.0, 1.0, 1.0, 1.0))

0 commit comments

Comments
 (0)