Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bisection and brent's method for root finding #424

Merged
merged 3 commits into from
Jul 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion quantecon/optimize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
"""

from .scalar_maximization import brent_max
from .root_finding import newton, newton_halley, newton_secant
from .root_finding import newton, newton_halley, newton_secant, bisect, brentq
286 changes: 258 additions & 28 deletions quantecon/optimize/root_finding.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,32 @@
import numpy as np
from numba import jit, njit
from numba import njit
from collections import namedtuple

__all__ = ['newton', 'newton_halley', 'newton_secant']
__all__ = ['newton', 'newton_halley', 'newton_secant', 'bisect', 'brentq']

_ECONVERGED = 0
_ECONVERR = -1

results = namedtuple('results',
('root function_calls iterations converged'))
_iter = 100
_xtol = 2e-12
_rtol = 4*np.finfo(float).eps

results = namedtuple('results', 'root function_calls iterations converged')


@njit
def _results(r):
r"""Select from a tuple of(root, funccalls, iterations, flag)"""
x, funcalls, iterations, flag = r
return results(x, funcalls, iterations, flag == 0)


@njit
def newton(func, x0, fprime, args=(), tol=1.48e-8, maxiter=50,
disp=True):
"""
Find a zero from the Newton-Raphson method using the jitted version of
Scipy's newton for scalars. Note that this does not provide an alternative
Find a zero from the Newton-Raphson method using the jitted version of
Scipy's newton for scalars. Note that this does not provide an alternative
method such as secant. Thus, it is important that `fprime` can be provided.

Note that `func` and `fprime` must be jitted via Numba.
Expand All @@ -38,13 +43,13 @@ def newton(func, x0, fprime, args=(), tol=1.48e-8, maxiter=50,
actual zero.
fprime : callable and jitted
The derivative of the function (when available and convenient).
args : tuple, optional
args : tuple, optional(default=())
Extra arguments to be used in the function call.
tol : float, optional
tol : float, optional(default=1.48e-8)
The allowable error of the zero value.
maxiter : int, optional
maxiter : int, optional(default=50)
Maximum number of iterations.
disp : bool, optional
disp : bool, optional(default=True)
If True, raise a RuntimeError if the algorithm didn't converge

Returns
Expand Down Expand Up @@ -85,18 +90,19 @@ def newton(func, x0, fprime, args=(), tol=1.48e-8, maxiter=50,
break
newton_step = fval / fder
# Newton step
p = p0 - newton_step
p = p0 - newton_step
if abs(p - p0) < tol:
status = _ECONVERGED
break
p0 = p

if disp and status == _ECONVERR:
msg = "Failed to converge"
raise RuntimeError(msg)

return _results((p, funcalls, itr + 1, status))


@njit
def newton_halley(func, x0, fprime, fprime2, args=(), tol=1.48e-8,
maxiter=50, disp=True):
Expand All @@ -119,13 +125,13 @@ def newton_halley(func, x0, fprime, fprime2, args=(), tol=1.48e-8,
The derivative of the function (when available and convenient).
fprime2 : callable and jitted
The second order derivative of the function
args : tuple, optional
args : tuple, optional(default=())
Extra arguments to be used in the function call.
tol : float, optional
tol : float, optional(default=1.48e-8)
The allowable error of the zero value.
maxiter : int, optional
maxiter : int, optional(default=50)
Maximum number of iterations.
disp : bool, optional
disp : bool, optional(default=True)
If True, raise a RuntimeError if the algorithm didn't converge

Returns
Expand Down Expand Up @@ -179,15 +185,16 @@ def newton_halley(func, x0, fprime, fprime2, args=(), tol=1.48e-8,

return _results((p, funcalls, itr + 1, status))


@njit
def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50,
disp=True):
"""
Find a zero from the secant method using the jitted version of
Find a zero from the secant method using the jitted version of
Scipy's secant method.

Note that `func` must be jitted via Numba.

Parameters
----------
func : callable and jitted
Expand All @@ -197,13 +204,13 @@ def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50,
x0 : float
An initial estimate of the zero that should be somewhere near the
actual zero.
args : tuple, optional
args : tuple, optional(default=())
Extra arguments to be used in the function call.
tol : float, optional
tol : float, optional(default=1.48e-8)
The allowable error of the zero value.
maxiter : int, optional
maxiter : int, optional(default=50)
Maximum number of iterations.
disp : bool, optional
disp : bool, optional(default=True)
If True, raise a RuntimeError if the algorithm didn't converge.

Returns
Expand All @@ -214,17 +221,17 @@ def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50,
iterations - Number of iterations needed to find the root.
converged - True if the routine converged
"""

if tol <= 0:
raise ValueError("tol is too small <= 0")
if maxiter < 1:
raise ValueError("maxiter must be greater than 0")

# Convert to float (don't use float(x0); this works also for complex x0)
p0 = 1.0 * x0
funcalls = 0
status = _ECONVERR

# Secant method
if x0 >= 0:
p1 = x0 * (1 + 1e-4) + 1e-4
Expand All @@ -249,9 +256,232 @@ def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50,
p1 = p
q1 = func(p1, *args)
funcalls += 1

if disp and status == _ECONVERR:
msg = "Failed to converge"
raise RuntimeError(msg)

return _results((p, funcalls, itr + 1, status))
return _results((p, funcalls, itr + 1, status))


@njit
def _bisect_interval(a, b, fa, fb):
"""Conditional checks for intervals in methods involving bisection"""
if fa*fb > 0:
raise ValueError("f(a) and f(b) must have different signs")
root = 0.0
status = _ECONVERR

# Root found at either end of [a,b]
if fa == 0:
root = a
status = _ECONVERGED
if fb == 0:
root = b
status = _ECONVERGED

return root, status


@njit
def bisect(f, a, b, args=(), xtol=_xtol,
rtol=_rtol, maxiter=_iter, disp=True):
"""
Find root of a function within an interval adapted from Scipy's bisect.

Basic bisection routine to find a zero of the function `f` between the
arguments `a` and `b`. `f(a)` and `f(b)` cannot have the same signs.

`f` must be jitted via numba.

Parameters
----------
f : jitted and callable
Python function returning a number. `f` must be continuous.
a : number
One end of the bracketing interval [a,b].
b : number
The other end of the bracketing interval [a,b].
args : tuple, optional(default=())
Extra arguments to be used in the function call.
xtol : number, optional(default=2e-12)
The computed root ``x0`` will satisfy ``np.allclose(x, x0,
atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The
parameter must be nonnegative.
rtol : number, optional(default=4*np.finfo(float).eps)
The computed root ``x0`` will satisfy ``np.allclose(x, x0,
atol=xtol, rtol=rtol)``, where ``x`` is the exact root.
maxiter : number, optional(default=100)
Maximum number of iterations.
disp : bool, optional(default=True)
If True, raise a RuntimeError if the algorithm didn't converge.

Returns
-------
results : namedtuple

"""

if xtol <= 0:
raise ValueError("xtol is too small (<= 0)")

if maxiter < 1:
raise ValueError("maxiter must be greater than 0")

# Convert to float
xa = a * 1.0
xb = b * 1.0

fa = f(xa, *args)
fb = f(xb, *args)
funcalls = 2
root, status = _bisect_interval(xa, xb, fa, fb)

# Check for sign error and early termination
if status == _ECONVERGED:
itr = 0
else:
# Perform bisection
dm = xb - xa
for itr in range(maxiter):
dm *= 0.5
xm = xa + dm
fm = f(xm, *args)
funcalls += 1

if fm * fa >= 0:
xa = xm

if fm == 0 or abs(dm) < xtol + rtol * abs(xm):
root = xm
status = _ECONVERGED
itr += 1
break

if disp and status == _ECONVERR:
raise RuntimeError("Failed to converge")

return _results((root, funcalls, itr, status))


@njit
def brentq(f, a, b, args=(), xtol=_xtol,
rtol=_rtol, maxiter=_iter, disp=True):
"""
Find a root of a function in a bracketing interval using Brent's method
adapted from Scipy's brentq.

Uses the classic Brent's method to find a zero of the function `f` on
the sign changing interval [a , b].

`f` must be jitted via numba.

Parameters
----------
f : jitted and callable
Python function returning a number. `f` must be continuous.
a : number
One end of the bracketing interval [a,b].
b : number
The other end of the bracketing interval [a,b].
args : tuple, optional(default=())
Extra arguments to be used in the function call.
xtol : number, optional(default=2e-12)
The computed root ``x0`` will satisfy ``np.allclose(x, x0,
atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The
parameter must be nonnegative.
rtol : number, optional(default=4*np.finfo(float).eps)
The computed root ``x0`` will satisfy ``np.allclose(x, x0,
atol=xtol, rtol=rtol)``, where ``x`` is the exact root.
maxiter : number, optional(default=100)
Maximum number of iterations.
disp : bool, optional(default=True)
If True, raise a RuntimeError if the algorithm didn't converge.

Returns
-------
results : namedtuple

"""
if xtol <= 0:
raise ValueError("xtol is too small (<= 0)")
if maxiter < 1:
raise ValueError("maxiter must be greater than 0")

# Convert to float
xpre = a * 1.0
xcur = b * 1.0

fpre = f(xpre, *args)
fcur = f(xcur, *args)
funcalls = 2

root, status = _bisect_interval(xpre, xcur, fpre, fcur)

# Check for sign error and early termination
if status == _ECONVERGED:
itr = 0
else:
# Perform Brent's method
for itr in range(maxiter):

if fpre * fcur < 0:
xblk = xpre
fblk = fpre
spre = scur = xcur - xpre
if abs(fblk) < abs(fcur):
xpre = xcur
xcur = xblk
xblk = xpre

fpre = fcur
fcur = fblk
fblk = fpre

delta = (xtol + rtol * abs(xcur)) / 2
sbis = (xblk - xcur) / 2

# Root found
if fcur == 0 or abs(sbis) < delta:
status = _ECONVERGED
root = xcur
itr += 1
break

if abs(spre) > delta and abs(fcur) < abs(fpre):
if xpre == xblk:
# interpolate
stry = -fcur * (xcur - xpre) / (fcur - fpre)
else:
# extrapolate
dpre = (fpre - fcur) / (xpre - xcur)
dblk = (fblk - fcur) / (xblk - xcur)
stry = -fcur * (fblk * dblk - fpre * dpre) / \
(dblk * dpre * (fblk - fpre))

if (2 * abs(stry) < min(abs(spre), 3 * abs(sbis) - delta)):
# good short step
spre = scur
scur = stry
else:
# bisect
spre = sbis
scur = sbis
else:
# bisect
spre = sbis
scur = sbis

xpre = xcur
fpre = fcur
if (abs(scur) > delta):
xcur += scur
else:
xcur += (delta if sbis > 0 else -delta)
fcur = f(xcur, *args)
funcalls += 1

if disp and status == _ECONVERR:
raise RuntimeError("Failed to converge")

return _results((root, funcalls, itr, status))
Loading