Skip to content

Commit 67fd265

Browse files
committed
Merge pull request #7 from mlubin/nanmath
use NaNMath package to return NaNs instead of throwing DomainErrors
2 parents 5cc7d9a + 2818354 commit 67fd265

File tree

4 files changed

+22
-9
lines changed

4 files changed

+22
-9
lines changed

REQUIRE

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
julia 0.3
22
DualNumbers
3+
NaNMath
34
Graphs

src/ReverseDiffSparse.jl

+6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@ module ReverseDiffSparse
33
import Calculus
44
using DualNumbers
55
using Base.Meta
6+
# Override basic math functions to return NaN instead of throwing errors.
7+
# This is what NLP solvers expect, and
8+
# sometimes the results aren't needed anyway,
9+
# because the code may compute derivatives wrt constants.
10+
import NaNMath: sin, cos, tan, asin, acos, acosh, atanh, log, log2, log10, lgamma, log1p, pow
11+
612
if isdir(Pkg.dir("ArrayViews"))
713
eval(Expr(:import,:ArrayViews))
814
const subarr = ArrayViews.view

src/revmode.jl

+8-9
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,11 @@ function forwardpass(x::ExprNode, expr_out)
244244
for i in 2:length(x.ex.args)
245245
push!(values, forwardpass(x.ex.args[i], expr_out))
246246
end
247-
fcall = Expr(:call, x.ex.args[1], values...)
247+
if x.ex.args[1] == :(^) # Use NaNMath.pow instead of ^
248+
fcall = Expr(:call, :pow, values...)
249+
else
250+
fcall = Expr(:call, x.ex.args[1], values...)
251+
end
248252
push!(expr_out.args, :( $(x.value) = $fcall ))
249253
return x.value
250254
elseif isexpr(x.ex, :curly)
@@ -284,11 +288,6 @@ forwardpass(x, expr_out) = :(forwardvalue($x, __placevalues, __placeindex_in))
284288
forwardvalue(x::Placeholder, placevalues, placeindex_in) = placevalues[placeindex_in[getplaceindex(x)]]
285289
forwardvalue(x, placevalues, placeindex_in) = float(x)
286290

287-
# better to return NaNs than throw DomainErrors.
288-
# sometimes the results aren't needed anyway,
289-
# because the code may compute derivatives wrt constants.
290-
log(x) = x <= 0 ? NaN : Base.log(x)
291-
292291
function revpass(x::ExprNode, expr_out)
293292
@assert isexpr(expr_out, :block)
294293
# compute the partial drivative wrt. each expression down the graph
@@ -342,12 +341,12 @@ function revpass(x::ExprNode, expr_out)
342341
if k == 2 # base
343342
exponent = getvalue(p.ex.args[3])
344343
push!(expr_out.args,
345-
:( $(x.deriv) += $(p.deriv)*$exponent*$(x.value)^($exponent-1) ))
344+
:( $(x.deriv) += $(p.deriv)*$exponent*pow($(x.value),$exponent-1) ))
346345
else
347346
@assert k == 3
348347
base = getvalue(p.ex.args[2])
349348
push!(expr_out.args,
350-
:( $(x.deriv) += $(p.deriv)*$base^($(x.value))*log($base) ))
349+
:( $(x.deriv) += $(p.deriv)*pow($base,$(x.value))*log($base) ))
351350
end
352351
elseif f == :(/)
353352
if k == 2 # numerator
@@ -358,7 +357,7 @@ function revpass(x::ExprNode, expr_out)
358357
@assert k == 3 # denominator
359358
numer = getvalue(p.ex.args[2])
360359
push!(expr_out.args,
361-
:( $(x.deriv) += -1*$(p.deriv)*$numer*($(x.value))^(-2) ))
360+
:( $(x.deriv) += -1*$(p.deriv)*$numer*pow($(x.value),-2) ))
362361
end
363362
else
364363
# try one of the derivative rules

test/test_grad.jl

+7
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,13 @@ fval = fg([-3.0],out)
205205
@test_approx_eq fval (-3)^2
206206
@test_approx_eq out[1] 2*-3
207207

208+
y = -2
209+
ex = @processNLExpr y^x[1]
210+
fg = genfgrad_simple(ex)
211+
fval = fg([0.3],out)
212+
@test isnan(fval)
213+
214+
208215
# zeros in products
209216
ex = @processNLExpr prod{ x[i], i = 1:2 }
210217
fg = genfgrad_simple(ex)

0 commit comments

Comments
 (0)