@@ -244,7 +244,11 @@ function forwardpass(x::ExprNode, expr_out)
244
244
for i in 2 : length (x. ex. args)
245
245
push! (values, forwardpass (x. ex. args[i], expr_out))
246
246
end
247
- fcall = Expr (:call , x. ex. args[1 ], values... )
247
+ if x. ex. args[1 ] == :(^ ) # Use NaNMath.pow instead of ^
248
+ fcall = Expr (:call , :pow , values... )
249
+ else
250
+ fcall = Expr (:call , x. ex. args[1 ], values... )
251
+ end
248
252
push! (expr_out. args, :( $ (x. value) = $ fcall ))
249
253
return x. value
250
254
elseif isexpr (x. ex, :curly )
@@ -284,11 +288,6 @@ forwardpass(x, expr_out) = :(forwardvalue($x, __placevalues, __placeindex_in))
284
288
forwardvalue (x:: Placeholder , placevalues, placeindex_in) = placevalues[placeindex_in[getplaceindex (x)]]
285
289
forwardvalue (x, placevalues, placeindex_in) = float (x)
286
290
287
- # better to return NaNs than throw DomainErrors.
288
- # sometimes the results aren't needed anyway,
289
- # because the code may compute derivatives wrt constants.
290
- log (x) = x <= 0 ? NaN : Base. log (x)
291
-
292
291
function revpass (x:: ExprNode , expr_out)
293
292
@assert isexpr (expr_out, :block )
294
293
# compute the partial drivative wrt. each expression down the graph
@@ -342,12 +341,12 @@ function revpass(x::ExprNode, expr_out)
342
341
if k == 2 # base
343
342
exponent = getvalue (p. ex. args[3 ])
344
343
push! (expr_out. args,
345
- :( $ (x. deriv) += $ (p. deriv)* $ exponent* $ (x. value)^ ( $ exponent- 1 ) ))
344
+ :( $ (x. deriv) += $ (p. deriv)* $ exponent* pow ( $ (x. value), $ exponent- 1 ) ))
346
345
else
347
346
@assert k == 3
348
347
base = getvalue (p. ex. args[2 ])
349
348
push! (expr_out. args,
350
- :( $ (x. deriv) += $ (p. deriv)* $ base^ ( $ (x. value))* log ($ base) ))
349
+ :( $ (x. deriv) += $ (p. deriv)* pow ( $ base, $ (x. value))* log ($ base) ))
351
350
end
352
351
elseif f == :(/ )
353
352
if k == 2 # numerator
@@ -358,7 +357,7 @@ function revpass(x::ExprNode, expr_out)
358
357
@assert k == 3 # denominator
359
358
numer = getvalue (p. ex. args[2 ])
360
359
push! (expr_out. args,
361
- :( $ (x. deriv) += - 1 * $ (p. deriv)* $ numer* ($ (x. value)) ^ ( - 2 ) ))
360
+ :( $ (x. deriv) += - 1 * $ (p. deriv)* $ numer* pow ($ (x. value), - 2 ) ))
362
361
end
363
362
else
364
363
# try one of the derivative rules
0 commit comments