|
1 |
| -using StaticArrays, ChainRulesCore, ChainRulesTestUtils, JLArrays, Test |
| 1 | +using StaticArrays, ChainRulesCore, ChainRulesTestUtils, JLArrays, LinearAlgebra, Test |
2 | 2 |
|
3 |
| -@testset "Chain Rules Integration" begin |
| 3 | +@testset "ChainRules Integration" begin |
4 | 4 | @testset "Projection" begin
|
| 5 | + # There is no code for this, but when argument isa StaticArray, axes(x) === axes(dx) |
| 6 | + # implies a check, and reshape will wrap a Vector into a static SizedVector: |
| 7 | + pstat = ProjectTo(SA[1, 2, 3]) |
| 8 | + @test axes(pstat(rand(3))) === (SOneTo(3),) |
| 9 | + |
| 10 | + # This recurses into structured arrays: |
| 11 | + pst = ProjectTo(transpose(SA[1, 2, 3])) |
| 12 | + @test axes(pst(rand(1,3))) === (SOneTo(1), SOneTo(3)) |
| 13 | + @test pst(rand(1,3)) isa Transpose |
| 14 | + |
| 15 | + # When the argument is an ordinary Array, static gradients are allowed to pass, |
| 16 | + # like FillArrays. Collecting to an Array would cost a copy. |
| 17 | + pvec3 = ProjectTo([1, 2, 3]) |
| 18 | + @test pvec3(SA[1, 2, 3]) isa StaticArray |
| 19 | + end |
| 20 | + |
| 21 | + @testset "Constructor rrules" begin |
5 | 22 | test_rrule(SMatrix{1, 4}, (1.0, 1.0, 1.0, 1.0))
|
6 | 23 | test_rrule(SMatrix{4, 1}, (1.0, 1.0, 1.0, 1.0))
|
7 | 24 | test_rrule(SMatrix{2, 2}, (1.0, 1.0, 1.0, 1.0))
|
|
0 commit comments