Skip to content

Commit ab0c261

Browse files
authored
Merge pull request #450 from QBatista/invalid_inputs_brent_max
ENH: Add errors for invalid inputs for `brent_max`
2 parents 083e003 + cde2f7b commit ab0c261

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

quantecon/optimize/scalar_maximization.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def brent_max(func, a, b, args=(), xtol=1e-5, maxiter=500):
3434
info : tuple
3535
A tuple of the form (status_flag, num_iter). Here status_flag
3636
indicates whether or not the maximum number of function calls was
37-
attained. A value of 0 implies that the maximum was not hit.
37+
attained. A value of 0 implies that the maximum was not hit.
3838
The value `num_iter` is the number of function calls.
3939
4040
Example
@@ -49,7 +49,15 @@ def f(x):
4949
```
5050
5151
"""
52-
52+
if not np.isfinite(a):
53+
raise ValueError("a must be finite.")
54+
55+
if not np.isfinite(b):
56+
raise ValueError("b must be finite.")
57+
58+
if not a < b:
59+
raise ValueError("a must be less than b.")
60+
5361
maxfun = maxiter
5462
status_flag = 0
5563

quantecon/optimize/tests/test_scalar_max.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -4,38 +4,43 @@
44
"""
55
import numpy as np
66
from numpy.testing import assert_almost_equal
7+
from nose.tools import raises
78
from numba import njit
89

910
from quantecon.optimize import brent_max
1011

12+
1113
@njit
1214
def f(x):
1315
"""
1416
A function for testing on.
1517
"""
1618
return -(x + 2.0)**2 + 1.0
1719

20+
1821
def test_brent_max():
1922
"""
20-
Uses the function f defined above to test the scalar maximization
23+
Uses the function f defined above to test the scalar maximization
2124
routine.
2225
"""
2326
true_fval = 1.0
2427
true_xf = -2.0
2528
xf, fval, info = brent_max(f, -2, 2)
2629
assert_almost_equal(true_fval, fval, decimal=4)
2730
assert_almost_equal(true_xf, xf, decimal=4)
28-
31+
32+
2933
@njit
3034
def g(x, y):
3135
"""
3236
A multivariate function for testing on.
3337
"""
3438
return -x**2 + y
35-
39+
40+
3641
def test_brent_max():
3742
"""
38-
Uses the function f defined above to test the scalar maximization
43+
Uses the function f defined above to test the scalar maximization
3944
routine.
4045
"""
4146
y = 5
@@ -46,6 +51,21 @@ def test_brent_max():
4651
assert_almost_equal(true_xf, xf, decimal=4)
4752

4853

54+
@raises(ValueError)
55+
def test_invalid_a_brent_max():
56+
brent_max(f, -np.inf, 2)
57+
58+
59+
@raises(ValueError)
60+
def test_invalid_b_brent_max():
61+
brent_max(f, -2, np.inf)
62+
63+
64+
@raises(ValueError)
65+
def test_invalid_a_b_brent_max():
66+
brent_max(f, 1, 0)
67+
68+
4969
if __name__ == '__main__':
5070
import sys
5171
import nose
@@ -54,5 +74,3 @@ def test_brent_max():
5474
argv.append('--verbose')
5575
argv.append('--nocapture')
5676
nose.main(argv=argv, defaultTest=__file__)
57-
58-

0 commit comments

Comments
 (0)