1
1
using DynamicPPL: LogDensityFunction
2
2
3
3
@testset " Automatic differentiation" begin
4
+ # Used as the ground truth that others are compared against.
5
+ ref_adtype = AutoForwardDiff ()
6
+ test_adtypes = [
7
+ AutoReverseDiff (; compile= false ),
8
+ AutoReverseDiff (; compile= true ),
9
+ AutoMooncake (; config= nothing ),
10
+ ]
11
+
4
12
@testset " Unsupported backends" begin
5
13
@model demo () = x ~ Normal ()
6
14
@test_logs (:warn , r" not officially supported" ) LogDensityFunction (
@@ -18,15 +26,10 @@ using DynamicPPL: LogDensityFunction
18
26
f = LogDensityFunction (m, varinfo)
19
27
x = DynamicPPL. getparams (f)
20
28
# Calculate reference logp + gradient of logp using ForwardDiff
21
- ref_adtype = ADTypes. AutoForwardDiff ()
22
29
ref_ldf = LogDensityFunction (m, varinfo; adtype= ref_adtype)
23
30
ref_logp, ref_grad = LogDensityProblems. logdensity_and_gradient (ref_ldf, x)
24
31
25
- @testset " $adtype " for adtype in [
26
- AutoReverseDiff (; compile= false ),
27
- AutoReverseDiff (; compile= true ),
28
- AutoMooncake (; config= nothing ),
29
- ]
32
+ @testset " $adtype " for adtype in test_adtypes
30
33
@info " Testing AD on: $(m. f) - $(short_varinfo_name (varinfo)) - $adtype "
31
34
32
35
# Put predicates here to avoid long lines
@@ -103,4 +106,66 @@ using DynamicPPL: LogDensityFunction
103
106
)
104
107
@test LogDensityProblems. logdensity_and_gradient (ldf, vi[:]) isa Any
105
108
end
109
+
110
+ # Test that various different ways of specifying array types as arguments work with all
111
+ # ADTypes.
112
+ @testset " Array argument types" begin
113
+ test_m = randn (2 , 3 )
114
+
115
+ function eval_logp_and_grad (model, m, adtype)
116
+ ldf = LogDensityFunction (model (); adtype= adtype)
117
+ return LogDensityProblems. logdensity_and_gradient (ldf, m[:])
118
+ end
119
+
120
+ @model function scalar_matrix_model (:: Type{T} = Float64) where {T<: Real }
121
+ m = Matrix {T} (undef, 2 , 3 )
122
+ return m ~ filldist (MvNormal (zeros (2 ), I), 3 )
123
+ end
124
+
125
+ scalar_matrix_model_reference = eval_logp_and_grad (
126
+ scalar_matrix_model, test_m, ref_adtype
127
+ )
128
+
129
+ @model function matrix_model (:: Type{T} = Matrix{Float64}) where {T}
130
+ m = T (undef, 2 , 3 )
131
+ return m ~ filldist (MvNormal (zeros (2 ), I), 3 )
132
+ end
133
+
134
+ matrix_model_reference = eval_logp_and_grad (matrix_model, test_m, ref_adtype)
135
+
136
+ @model function scalar_array_model (:: Type{T} = Float64) where {T<: Real }
137
+ m = Array {T} (undef, 2 , 3 )
138
+ return m ~ filldist (MvNormal (zeros (2 ), I), 3 )
139
+ end
140
+
141
+ scalar_array_model_reference = eval_logp_and_grad (
142
+ scalar_array_model, test_m, ref_adtype
143
+ )
144
+
145
+ @model function array_model (:: Type{T} = Array{Float64}) where {T}
146
+ m = T (undef, 2 , 3 )
147
+ return m ~ filldist (MvNormal (zeros (2 ), I), 3 )
148
+ end
149
+
150
+ array_model_reference = eval_logp_and_grad (array_model, test_m, ref_adtype)
151
+
152
+ @testset " $adtype " for adtype in test_adtypes
153
+ scalar_matrix_model_logp_and_grad = eval_logp_and_grad (
154
+ scalar_matrix_model, test_m, adtype
155
+ )
156
+ @test scalar_matrix_model_logp_and_grad[1 ] ≈ scalar_matrix_model_reference[1 ]
157
+ @test scalar_matrix_model_logp_and_grad[2 ] ≈ scalar_matrix_model_reference[2 ]
158
+ matrix_model_logp_and_grad = eval_logp_and_grad (matrix_model, test_m, adtype)
159
+ @test matrix_model_logp_and_grad[1 ] ≈ matrix_model_reference[1 ]
160
+ @test matrix_model_logp_and_grad[2 ] ≈ matrix_model_reference[2 ]
161
+ scalar_array_model_logp_and_grad = eval_logp_and_grad (
162
+ scalar_array_model, test_m, adtype
163
+ )
164
+ @test scalar_array_model_logp_and_grad[1 ] ≈ scalar_array_model_reference[1 ]
165
+ @test scalar_array_model_logp_and_grad[2 ] ≈ scalar_array_model_reference[2 ]
166
+ array_model_logp_and_grad = eval_logp_and_grad (array_model, test_m, adtype)
167
+ @test array_model_logp_and_grad[1 ] ≈ array_model_reference[1 ]
168
+ @test array_model_logp_and_grad[2 ] ≈ array_model_reference[2 ]
169
+ end
170
+ end
106
171
end
0 commit comments