Skip to content

Commit 000f667

Browse files
committed
add checkpoint function, fixes HIPS#182 (thanks @j-towns)
1 parent e3101c7 commit 000f667

File tree

3 files changed

+57
-3
lines changed

3 files changed

+57
-3
lines changed

Diff for: autograd/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
from . import container_types
44
from .convenience_wrappers import (grad, multigrad, multigrad_dict, elementwise_grad,
55
value_and_grad, grad_and_aux, hessian_vector_product,
6-
hessian, jacobian, vector_jacobian_product, grad_named)
6+
hessian, jacobian, vector_jacobian_product, grad_named,
7+
checkpoint)

Diff for: autograd/convenience_wrappers.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from __future__ import absolute_import
33
from functools import partial
44
import autograd.numpy as np
5-
from autograd.core import make_vjp, getval, isnode, vspace
5+
from autograd.core import make_vjp, getval, isnode, vspace, primitive
66
from .errors import add_error_hints
77
from collections import OrderedDict
88
from inspect import getargspec
@@ -180,6 +180,18 @@ def gradfun(*args, **kwargs):
180180

181181
return gradfun
182182

183+
def checkpoint(fun):
184+
"""Returns a checkpointed version of `fun`, where intermediate values
185+
computed during the forward pass of `fun` are discarded and then recomputed
186+
for the backward pass. Useful to save memory, effectively trading off time
187+
and memory. See e.g. arxiv.org/abs/1604.06174.
188+
"""
189+
def wrapped_grad(argnum, g, ans, vs, gvs, args, kwargs):
190+
return make_vjp(fun, argnum)(*args, **kwargs)[0](g)
191+
wrapped = primitive(fun)
192+
wrapped.vjp = wrapped_grad
193+
return wrapped
194+
183195
def attach_name_and_doc(fun, argnum, opname):
184196
namestr = "{op}_{fun}_wrt_argnum_{argnum}".format(
185197
op=opname.lower(), fun=getattr(fun, '__name__', '[unknown name]'), argnum=argnum)

Diff for: tests/test_wrappers.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from autograd.util import *
77
from autograd import (grad, elementwise_grad, jacobian, value_and_grad,
88
grad_and_aux, hessian_vector_product, hessian, multigrad,
9-
jacobian, vector_jacobian_product, primitive)
9+
jacobian, vector_jacobian_product, primitive, checkpoint)
1010
from builtins import range
1111

1212
npr.seed(1)
@@ -161,3 +161,44 @@ def f(x):
161161

162162
y = np.random.randn(10, 10).astype(np.float16)
163163
assert grad(f)(y).dtype.type is np.float16
164+
165+
def test_checkpoint_correctness():
166+
bar = lambda x, y: 2*x + y + 5
167+
checkpointed_bar = checkpoint(bar)
168+
foo = lambda x: bar(x, x/3.) + bar(x, x**2)
169+
foo2 = lambda x: checkpointed_bar(x, x/3.) + checkpointed_bar(x, x**2)
170+
assert np.allclose(foo(3.), foo2(3.))
171+
assert np.allclose(grad(foo)(3.), grad(foo2)(3.))
172+
173+
baz = lambda *args: sum(args)
174+
checkpointed_baz = checkpoint(baz)
175+
foobaz = lambda x: baz(x, x/3.)
176+
foobaz2 = lambda x: checkpointed_baz(x, x/3.)
177+
assert np.allclose(foobaz(3.), foobaz2(3.))
178+
assert np.allclose(grad(foobaz)(3.), grad(foobaz2)(3.))
179+
180+
def checkpoint_memory():
181+
'''This test is meant to be run manually, since it depends on
182+
memory_profiler and its behavior may vary.'''
183+
try:
184+
from memory_profiler import memory_usage
185+
except ImportError:
186+
return
187+
188+
def f(a):
189+
for _ in range(10):
190+
a = np.sin(a**2 + 1)
191+
return a
192+
checkpointed_f = checkpoint(f)
193+
194+
def testfun(f, x):
195+
for _ in range(5):
196+
x = f(x)
197+
return np.sum(x)
198+
gradfun = grad(testfun, 1)
199+
200+
A = npr.RandomState(0).randn(100000)
201+
max_usage = max(memory_usage((gradfun, (f, A))))
202+
max_checkpointed_usage = max(memory_usage((gradfun, (checkpointed_f, A))))
203+
204+
assert max_checkpointed_usage < max_usage / 2.

0 commit comments

Comments
 (0)