Skip to content

Commit 1155267

Browse files
authored
Merge pull request #417 from chrishyland/root-finding
ENH: Root finding
2 parents 2a612cc + 876c797 commit 1155267

File tree

3 files changed

+381
-1
lines changed

3 files changed

+381
-1
lines changed

quantecon/optimize/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
"""
44

55
from .scalar_maximization import brent_max
6-
6+
from .root_finding import newton, newton_halley, newton_secant

quantecon/optimize/root_finding.py

+257
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
import numpy as np
2+
from numba import jit, njit
3+
from collections import namedtuple
4+
5+
__all__ = ['newton', 'newton_halley', 'newton_secant']
6+
7+
_ECONVERGED = 0
8+
_ECONVERR = -1
9+
10+
results = namedtuple('results',
11+
('root function_calls iterations converged'))
12+
13+
@njit
14+
def _results(r):
15+
r"""Select from a tuple of(root, funccalls, iterations, flag)"""
16+
x, funcalls, iterations, flag = r
17+
return results(x, funcalls, iterations, flag == 0)
18+
19+
@njit
20+
def newton(func, x0, fprime, args=(), tol=1.48e-8, maxiter=50,
21+
disp=True):
22+
"""
23+
Find a zero from the Newton-Raphson method using the jitted version of
24+
Scipy's newton for scalars. Note that this does not provide an alternative
25+
method such as secant. Thus, it is important that `fprime` can be provided.
26+
27+
Note that `func` and `fprime` must be jitted via Numba.
28+
They are recommended to be `njit` for performance.
29+
30+
Parameters
31+
----------
32+
func : callable and jitted
33+
The function whose zero is wanted. It must be a function of a
34+
single variable of the form f(x,a,b,c...), where a,b,c... are extra
35+
arguments that can be passed in the `args` parameter.
36+
x0 : float
37+
An initial estimate of the zero that should be somewhere near the
38+
actual zero.
39+
fprime : callable and jitted
40+
The derivative of the function (when available and convenient).
41+
args : tuple, optional
42+
Extra arguments to be used in the function call.
43+
tol : float, optional
44+
The allowable error of the zero value.
45+
maxiter : int, optional
46+
Maximum number of iterations.
47+
disp : bool, optional
48+
If True, raise a RuntimeError if the algorithm didn't converge
49+
50+
Returns
51+
-------
52+
results : namedtuple
53+
root - Estimated location where function is zero.
54+
function_calls - Number of times the function was called.
55+
iterations - Number of iterations needed to find the root.
56+
converged - True if the routine converged
57+
"""
58+
59+
if tol <= 0:
60+
raise ValueError("tol is too small <= 0")
61+
if maxiter < 1:
62+
raise ValueError("maxiter must be greater than 0")
63+
64+
# Convert to float (don't use float(x0); this works also for complex x0)
65+
p0 = 1.0 * x0
66+
funcalls = 0
67+
status = _ECONVERR
68+
69+
# Newton-Raphson method
70+
for itr in range(maxiter):
71+
# first evaluate fval
72+
fval = func(p0, *args)
73+
funcalls += 1
74+
# If fval is 0, a root has been found, then terminate
75+
if fval == 0:
76+
status = _ECONVERGED
77+
p = p0
78+
itr -= 1
79+
break
80+
fder = fprime(p0, *args)
81+
funcalls += 1
82+
# derivative is zero, not converged
83+
if fder == 0:
84+
p = p0
85+
break
86+
newton_step = fval / fder
87+
# Newton step
88+
p = p0 - newton_step
89+
if abs(p - p0) < tol:
90+
status = _ECONVERGED
91+
break
92+
p0 = p
93+
94+
if disp and status == _ECONVERR:
95+
msg = "Failed to converge"
96+
raise RuntimeError(msg)
97+
98+
return _results((p, funcalls, itr + 1, status))
99+
100+
@njit
101+
def newton_halley(func, x0, fprime, fprime2, args=(), tol=1.48e-8,
102+
maxiter=50, disp=True):
103+
"""
104+
Find a zero from Halley's method using the jitted version of
105+
Scipy's.
106+
107+
`func`, `fprime`, `fprime2` must be jitted via Numba.
108+
109+
Parameters
110+
----------
111+
func : callable and jitted
112+
The function whose zero is wanted. It must be a function of a
113+
single variable of the form f(x,a,b,c...), where a,b,c... are extra
114+
arguments that can be passed in the `args` parameter.
115+
x0 : float
116+
An initial estimate of the zero that should be somewhere near the
117+
actual zero.
118+
fprime : callable and jitted
119+
The derivative of the function (when available and convenient).
120+
fprime2 : callable and jitted
121+
The second order derivative of the function
122+
args : tuple, optional
123+
Extra arguments to be used in the function call.
124+
tol : float, optional
125+
The allowable error of the zero value.
126+
maxiter : int, optional
127+
Maximum number of iterations.
128+
disp : bool, optional
129+
If True, raise a RuntimeError if the algorithm didn't converge
130+
131+
Returns
132+
-------
133+
results : namedtuple
134+
root - Estimated location where function is zero.
135+
function_calls - Number of times the function was called.
136+
iterations - Number of iterations needed to find the root.
137+
converged - True if the routine converged
138+
"""
139+
140+
if tol <= 0:
141+
raise ValueError("tol is too small <= 0")
142+
if maxiter < 1:
143+
raise ValueError("maxiter must be greater than 0")
144+
145+
# Convert to float (don't use float(x0); this works also for complex x0)
146+
p0 = 1.0 * x0
147+
funcalls = 0
148+
status = _ECONVERR
149+
150+
# Halley Method
151+
for itr in range(maxiter):
152+
# first evaluate fval
153+
fval = func(p0, *args)
154+
funcalls += 1
155+
# If fval is 0, a root has been found, then terminate
156+
if fval == 0:
157+
status = _ECONVERGED
158+
p = p0
159+
itr -= 1
160+
break
161+
fder = fprime(p0, *args)
162+
funcalls += 1
163+
# derivative is zero, not converged
164+
if fder == 0:
165+
p = p0
166+
break
167+
newton_step = fval / fder
168+
# Halley's variant
169+
fder2 = fprime2(p0, *args)
170+
p = p0 - newton_step / (1.0 - 0.5 * newton_step * fder2 / fder)
171+
if abs(p - p0) < tol:
172+
status = _ECONVERGED
173+
break
174+
p0 = p
175+
176+
if disp and status == _ECONVERR:
177+
msg = "Failed to converge"
178+
raise RuntimeError(msg)
179+
180+
return _results((p, funcalls, itr + 1, status))
181+
182+
@njit
183+
def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50,
184+
disp=True):
185+
"""
186+
Find a zero from the secant method using the jitted version of
187+
Scipy's secant method.
188+
189+
Note that `func` must be jitted via Numba.
190+
191+
Parameters
192+
----------
193+
func : callable and jitted
194+
The function whose zero is wanted. It must be a function of a
195+
single variable of the form f(x,a,b,c...), where a,b,c... are extra
196+
arguments that can be passed in the `args` parameter.
197+
x0 : float
198+
An initial estimate of the zero that should be somewhere near the
199+
actual zero.
200+
args : tuple, optional
201+
Extra arguments to be used in the function call.
202+
tol : float, optional
203+
The allowable error of the zero value.
204+
maxiter : int, optional
205+
Maximum number of iterations.
206+
disp : bool, optional
207+
If True, raise a RuntimeError if the algorithm didn't converge.
208+
209+
Returns
210+
-------
211+
results : namedtuple
212+
root - Estimated location where function is zero.
213+
function_calls - Number of times the function was called.
214+
iterations - Number of iterations needed to find the root.
215+
converged - True if the routine converged
216+
"""
217+
218+
if tol <= 0:
219+
raise ValueError("tol is too small <= 0")
220+
if maxiter < 1:
221+
raise ValueError("maxiter must be greater than 0")
222+
223+
# Convert to float (don't use float(x0); this works also for complex x0)
224+
p0 = 1.0 * x0
225+
funcalls = 0
226+
status = _ECONVERR
227+
228+
# Secant method
229+
if x0 >= 0:
230+
p1 = x0 * (1 + 1e-4) + 1e-4
231+
else:
232+
p1 = x0 * (1 + 1e-4) - 1e-4
233+
q0 = func(p0, *args)
234+
funcalls += 1
235+
q1 = func(p1, *args)
236+
funcalls += 1
237+
for itr in range(maxiter):
238+
if q1 == q0:
239+
p = (p1 + p0) / 2.0
240+
status = _ECONVERGED
241+
break
242+
else:
243+
p = p1 - q1 * (p1 - p0) / (q1 - q0)
244+
if np.abs(p - p1) < tol:
245+
status = _ECONVERGED
246+
break
247+
p0 = p1
248+
q0 = q1
249+
p1 = p
250+
q1 = func(p1, *args)
251+
funcalls += 1
252+
253+
if disp and status == _ECONVERR:
254+
msg = "Failed to converge"
255+
raise RuntimeError(msg)
256+
257+
return _results((p, funcalls, itr + 1, status))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import numpy as np
2+
from numpy.testing import assert_almost_equal, assert_allclose
3+
from numba import njit
4+
5+
from quantecon.optimize import newton, newton_halley, newton_secant
6+
7+
@njit
8+
def func(x):
9+
"""
10+
Function for testing on.
11+
"""
12+
return (x**3 - 1)
13+
14+
15+
@njit
16+
def func_prime(x):
17+
"""
18+
Derivative for func.
19+
"""
20+
return (3*x**2)
21+
22+
@njit
23+
def func_prime2(x):
24+
"""
25+
Second order derivative for func.
26+
"""
27+
return 6*x
28+
29+
@njit
30+
def func_two(x):
31+
"""
32+
Harder function for testing on.
33+
"""
34+
return np.sin(4 * (x - 1/4)) + x + x**20 - 1
35+
36+
37+
@njit
38+
def func_two_prime(x):
39+
"""
40+
Derivative for func_two.
41+
"""
42+
return 4*np.cos(4*(x - 1/4)) + 20*x**19 + 1
43+
44+
@njit
45+
def func_two_prime2(x):
46+
"""
47+
Second order derivative for func_two
48+
"""
49+
return 380*x**18 - 16*np.sin(4*(x - 1/4))
50+
51+
52+
def test_newton_basic():
53+
"""
54+
Uses the function f defined above to test the scalar maximization
55+
routine.
56+
"""
57+
true_fval = 1.0
58+
fval = newton(func, 5, func_prime)
59+
assert_almost_equal(true_fval, fval.root, decimal=4)
60+
61+
62+
def test_newton_basic_two():
63+
"""
64+
Uses the function f defined above to test the scalar maximization
65+
routine.
66+
"""
67+
true_fval = 1.0
68+
fval = newton(func, 5, func_prime)
69+
assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0)
70+
71+
72+
def test_newton_hard():
73+
"""
74+
Harder test for convergence.
75+
"""
76+
true_fval = 0.408
77+
fval = newton(func_two, 0.4, func_two_prime)
78+
assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.01)
79+
80+
def test_halley_basic():
81+
"""
82+
Basic test for halley method
83+
"""
84+
true_fval = 1.0
85+
fval = newton_halley(func, 5, func_prime, func_prime2)
86+
assert_almost_equal(true_fval, fval.root, decimal=4)
87+
88+
def test_halley_hard():
89+
"""
90+
Harder test for halley method
91+
"""
92+
true_fval = 0.408
93+
fval = newton_halley(func_two, 0.4, func_two_prime, func_two_prime2)
94+
assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.01)
95+
96+
def test_secant_basic():
97+
"""
98+
Basic test for secant option.
99+
"""
100+
true_fval = 1.0
101+
fval = newton_secant(func, 5)
102+
assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.001)
103+
104+
105+
def test_secant_hard():
106+
"""
107+
Harder test for convergence for secant function.
108+
"""
109+
true_fval = 0.408
110+
fval = newton_secant(func_two, 0.4)
111+
assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.01)
112+
113+
114+
# executing testcases.
115+
116+
if __name__ == '__main__':
117+
import sys
118+
import nose
119+
120+
argv = sys.argv[:]
121+
argv.append('--verbose')
122+
argv.append('--nocapture')
123+
nose.main(argv=argv, defaultTest=__file__)

0 commit comments

Comments
 (0)