Skip to content

Commit d1a98c6

Browse files
mhaurupenelopeysm
andauthored
Add a couple of tests being removed from Turing.jl (#840)
* Add a couple of tests being removed from Turing.jl * Simplify a test Co-authored-by: Penelope Yong <[email protected]> * Simplify ad tests --------- Co-authored-by: Penelope Yong <[email protected]>
1 parent 4494438 commit d1a98c6

File tree

2 files changed

+85
-6
lines changed

2 files changed

+85
-6
lines changed

test/ad.jl

+71-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
using DynamicPPL: LogDensityFunction
22

33
@testset "Automatic differentiation" begin
4+
# Used as the ground truth that others are compared against.
5+
ref_adtype = AutoForwardDiff()
6+
test_adtypes = [
7+
AutoReverseDiff(; compile=false),
8+
AutoReverseDiff(; compile=true),
9+
AutoMooncake(; config=nothing),
10+
]
11+
412
@testset "Unsupported backends" begin
513
@model demo() = x ~ Normal()
614
@test_logs (:warn, r"not officially supported") LogDensityFunction(
@@ -18,15 +26,10 @@ using DynamicPPL: LogDensityFunction
1826
f = LogDensityFunction(m, varinfo)
1927
x = DynamicPPL.getparams(f)
2028
# Calculate reference logp + gradient of logp using ForwardDiff
21-
ref_adtype = ADTypes.AutoForwardDiff()
2229
ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype)
2330
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)
2431

25-
@testset "$adtype" for adtype in [
26-
AutoReverseDiff(; compile=false),
27-
AutoReverseDiff(; compile=true),
28-
AutoMooncake(; config=nothing),
29-
]
32+
@testset "$adtype" for adtype in test_adtypes
3033
@info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype"
3134

3235
# Put predicates here to avoid long lines
@@ -103,4 +106,66 @@ using DynamicPPL: LogDensityFunction
103106
)
104107
@test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any
105108
end
109+
110+
# Test that various different ways of specifying array types as arguments work with all
111+
# ADTypes.
112+
@testset "Array argument types" begin
113+
test_m = randn(2, 3)
114+
115+
function eval_logp_and_grad(model, m, adtype)
116+
ldf = LogDensityFunction(model(); adtype=adtype)
117+
return LogDensityProblems.logdensity_and_gradient(ldf, m[:])
118+
end
119+
120+
@model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real}
121+
m = Matrix{T}(undef, 2, 3)
122+
return m ~ filldist(MvNormal(zeros(2), I), 3)
123+
end
124+
125+
scalar_matrix_model_reference = eval_logp_and_grad(
126+
scalar_matrix_model, test_m, ref_adtype
127+
)
128+
129+
@model function matrix_model(::Type{T}=Matrix{Float64}) where {T}
130+
m = T(undef, 2, 3)
131+
return m ~ filldist(MvNormal(zeros(2), I), 3)
132+
end
133+
134+
matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype)
135+
136+
@model function scalar_array_model(::Type{T}=Float64) where {T<:Real}
137+
m = Array{T}(undef, 2, 3)
138+
return m ~ filldist(MvNormal(zeros(2), I), 3)
139+
end
140+
141+
scalar_array_model_reference = eval_logp_and_grad(
142+
scalar_array_model, test_m, ref_adtype
143+
)
144+
145+
@model function array_model(::Type{T}=Array{Float64}) where {T}
146+
m = T(undef, 2, 3)
147+
return m ~ filldist(MvNormal(zeros(2), I), 3)
148+
end
149+
150+
array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype)
151+
152+
@testset "$adtype" for adtype in test_adtypes
153+
scalar_matrix_model_logp_and_grad = eval_logp_and_grad(
154+
scalar_matrix_model, test_m, adtype
155+
)
156+
@test scalar_matrix_model_logp_and_grad[1] scalar_matrix_model_reference[1]
157+
@test scalar_matrix_model_logp_and_grad[2] scalar_matrix_model_reference[2]
158+
matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype)
159+
@test matrix_model_logp_and_grad[1] matrix_model_reference[1]
160+
@test matrix_model_logp_and_grad[2] matrix_model_reference[2]
161+
scalar_array_model_logp_and_grad = eval_logp_and_grad(
162+
scalar_array_model, test_m, adtype
163+
)
164+
@test scalar_array_model_logp_and_grad[1] scalar_array_model_reference[1]
165+
@test scalar_array_model_logp_and_grad[2] scalar_array_model_reference[2]
166+
array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype)
167+
@test array_model_logp_and_grad[1] array_model_reference[1]
168+
@test array_model_logp_and_grad[2] array_model_reference[2]
169+
end
170+
end
106171
end

test/compiler.jl

+14
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,20 @@ module Issue537 end
289289
@test all((isassigned(x, i) for i in eachindex(x)))
290290
end
291291

292+
# Test that that using @. to stop unwanted broadcasting on the RHS works.
293+
@testset "@. ~ with interpolation" begin
294+
@model function at_dot_with_interpolation()
295+
x = Vector{Float64}(undef, 2)
296+
# Without the interpolation the RHS would turn into `Normal.(sum.([1.0, 2.0]))`,
297+
# which would crash.
298+
@. x ~ $(Normal(sum([1.0, 2.0])))
299+
end
300+
301+
# The main check is just that calling at_dot_with_interpolation() doesn't crash,
302+
# the check of the keys is not very important.
303+
@show keys(VarInfo(at_dot_with_interpolation())) == [@varname(x[1]), @varname(x[2])]
304+
end
305+
292306
# A couple of uses of .~ that are no longer valid as of v0.35.
293307
@testset "old .~ syntax" begin
294308
@model function multivariate_dot_tilde()

0 commit comments

Comments
 (0)