Skip to content

Commit 5abeac3

Browse files
Varal7facebook-github-bot
authored andcommitted
Make saved tensors default hooks thread local (pytorch#62909)
Summary: Pull Request resolved: pytorch#62909 This PR makes saved tensors default hooks thread local. This allows using default hooks in a multithreaded context. Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D30165416 Pulled By: Varal7 fbshipit-source-id: 10a7d580661d3d94bdaf398c4e076b7bea11c16b
1 parent cb23976 commit 5abeac3

8 files changed

+125
-25
lines changed

aten/src/ATen/SavedTensorHooks.cpp

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include <ATen/SavedTensorHooks.h>
2+
#include <c10/util/Exception.h>
3+
4+
namespace at {
5+
6+
namespace {
7+
// PyObject is defined in c10/util/python_stub.h
8+
// Reference counting is handled by the caller of `set_hooks`.
9+
thread_local PyObject* pack_hook_(nullptr);
10+
thread_local PyObject* unpack_hook_(nullptr);
11+
12+
// This flag is set to true the first time default hooks are registered
13+
// and left at true for the rest of the execution.
14+
// It's an optimization so that users who never use default hooks don't need to
15+
// read the thread_local variables pack_hook_ and unpack_hook_.
16+
static bool is_enabled(false);
17+
}
18+
19+
void SavedTensorDefaultHooks::enable() {
20+
is_enabled = true;
21+
}
22+
23+
void SavedTensorDefaultHooks::set_hooks(PyObject* pack_hook, PyObject* unpack_hook) {
24+
if (!is_enabled) {
25+
TORCH_INTERNAL_ASSERT(pack_hook == nullptr && unpack_hook == nullptr);
26+
return;
27+
}
28+
pack_hook_ = pack_hook;
29+
unpack_hook_ = unpack_hook;
30+
}
31+
32+
std::pair<PyObject*, PyObject*> SavedTensorDefaultHooks::get_hooks() {
33+
if (!is_enabled) {
34+
return std::make_pair(nullptr, nullptr);
35+
}
36+
return std::make_pair(pack_hook_, unpack_hook_);
37+
}
38+
39+
}

aten/src/ATen/SavedTensorHooks.h

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include <c10/macros/Export.h>
4+
#include <c10/util/python_stub.h>
5+
6+
#include <utility>
7+
8+
namespace at {
9+
10+
struct TORCH_API SavedTensorDefaultHooks {
11+
static void set_hooks(PyObject* pack_hook, PyObject* unpack_hook);
12+
static std::pair<PyObject*, PyObject*> get_hooks();
13+
static void enable();
14+
};
15+
16+
} // namespace at

aten/src/ATen/ThreadLocalState.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#endif
66

77
#include <ATen/record_function.h>
8+
#include <ATen/SavedTensorHooks.h>
89

910
namespace at {
1011

@@ -13,6 +14,7 @@ ThreadLocalState::ThreadLocalState(bool keep_grad_mode)
1314
debug_info_(c10::ThreadLocalDebugInfo::current()),
1415
inference_mode_enabled_(c10::InferenceMode::is_enabled()) {
1516
rf_tls_ = at::get_record_function_tls_();
17+
saved_tensors_default_hooks_ = SavedTensorDefaultHooks::get_hooks();
1618

1719
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
1820
keep_grad_mode_ = keep_grad_mode;
@@ -34,6 +36,10 @@ void ThreadLocalState::setThreadLocalState(
3436

3537
at::set_record_function_tls_(state.rf_tls_);
3638

39+
SavedTensorDefaultHooks::set_hooks(
40+
state.saved_tensors_default_hooks_.first,
41+
state.saved_tensors_default_hooks_.second);
42+
3743
c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_);
3844

3945
c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_);

aten/src/ATen/ThreadLocalState.h

+3
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ class TORCH_API ThreadLocalState {
4343
// TLS for InferenceMode
4444
bool inference_mode_enabled_;
4545

46+
// TLS for saved tensors default hooks
47+
std::pair<PyObject*, PyObject*> saved_tensors_default_hooks_;
48+
4649
// Whether pre-sampling RecordFunction optimization was enabled
4750
bool bumped_record_all_functions_ = false;
4851

test/test_autograd.py

+44
Original file line numberDiff line numberDiff line change
@@ -9341,6 +9341,50 @@ def train_fn_grad(x):
93419341
# be accumulate to the same place and should be the same
93429342
self._run_py_multithread_fn(train_fn_grad, (x,))
93439343

9344+
def test_multithread_saved_tensors_hooks(self):
9345+
def pack(x):
9346+
warnings.warn("pack")
9347+
return x
9348+
9349+
def registers_hooks_for_each_thread():
9350+
with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x):
9351+
x = torch.ones(5, 5, requires_grad=True)
9352+
with warnings.catch_warnings(record=True) as w:
9353+
y = x * x
9354+
# should raise two warnings from x being saved twice
9355+
self.assertEqual(len(w), 2)
9356+
y.sum().backward()
9357+
9358+
def test_dataparallel_saved_tensors_hooks(self):
9359+
def pack(x):
9360+
warnings.warn("pack")
9361+
return x
9362+
9363+
_self = self
9364+
9365+
class Model(torch.nn.Module):
9366+
def forward(self, x):
9367+
with warnings.catch_warnings(record=True) as w:
9368+
y = x * x
9369+
if torch.cuda.device_count() >= 2:
9370+
# DataParallel is calling the forward in different threads
9371+
# without progating TLS, so hooks should not be called here
9372+
_self.assertEqual(len(w), 0)
9373+
else:
9374+
# DataParallel only uses one thread
9375+
# so hooks should be called here
9376+
_self.assertEqual(len(w), 2)
9377+
9378+
x = torch.ones(5, 5, requires_grad=True)
9379+
model = torch.nn.DataParallel(Model())
9380+
9381+
with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x):
9382+
model(x)
9383+
with warnings.catch_warnings(record=True) as w:
9384+
y = x * x
9385+
# hooks should be called here
9386+
_self.assertEqual(len(w), 2)
9387+
93449388
def test_python_thread_in_middle(self):
93459389
# User might write a network that starts on one CPU thread, then runs its second half
93469390
# concurrently with other threads (either via python threading or fork/join calls),

tools/build_variables.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,7 @@ aten_cpu_source_non_codegen_list = [
848848
"aten/src/ATen/native/mkldnn/Utils.cpp",
849849
"aten/src/ATen/native/quantized/cpu/init_qnnpack.cpp",
850850
"aten/src/ATen/record_function.cpp",
851+
"aten/src/ATen/SavedTensorHooks.cpp",
851852
"aten/src/ATen/vulkan/Context.cpp",
852853
]
853854

torch/csrc/autograd/python_saved_variable_hooks.cpp

+16-17
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <torch/csrc/autograd/python_saved_variable_hooks.h>
2+
#include <ATen/SavedTensorHooks.h>
23

34
#include <torch/csrc/THP.h>
45

@@ -45,39 +46,37 @@ namespace torch { namespace autograd {
4546
}
4647
}
4748

48-
std::mutex PyDefaultSavedVariableHooks::mutex_;
49-
PyObject* PyDefaultSavedVariableHooks::pack_hook_(nullptr);
50-
PyObject* PyDefaultSavedVariableHooks::unpack_hook_(nullptr);
51-
5249
void PyDefaultSavedVariableHooks::set_hooks(py::function &pack_hook, py::function &unpack_hook) {
53-
std::lock_guard<std::mutex> lock(mutex_);
50+
PyObject *pack_hook_(nullptr), *unpack_hook_(nullptr);
51+
std::tie(pack_hook_, unpack_hook_) = at::SavedTensorDefaultHooks::get_hooks();
5452
TORCH_CHECK(!pack_hook_ && !unpack_hook_,
5553
"Setting default hooks but they have already been set. "
5654
"Hint: only one pair of hooks is allowed at a time.");
57-
pack_hook_ = pack_hook.release().ptr();
58-
unpack_hook_ = unpack_hook.release().ptr();
55+
at::SavedTensorDefaultHooks::enable();
56+
at::SavedTensorDefaultHooks::set_hooks(pack_hook.release().ptr(), unpack_hook.release().ptr());
5957
}
6058

6159
void PyDefaultSavedVariableHooks::reset_hooks() {
62-
std::lock_guard<std::mutex> lock(mutex_);
60+
PyObject *pack_hook(nullptr), *unpack_hook(nullptr);
61+
std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks();
6362
if (Py_IsInitialized()) {
6463
py::gil_scoped_acquire gil;
65-
Py_XDECREF(pack_hook_);
66-
Py_XDECREF(unpack_hook_);
64+
Py_XDECREF(pack_hook);
65+
Py_XDECREF(unpack_hook);
6766
}
68-
pack_hook_ = nullptr;
69-
unpack_hook_ = nullptr;
67+
at::SavedTensorDefaultHooks::set_hooks(nullptr, nullptr);
7068
}
7169

7270
std::unique_ptr<SavedVariableHooks> PyDefaultSavedVariableHooks::get_hooks() {
73-
if (!pack_hook_ || !unpack_hook_) {
71+
PyObject *pack_hook(nullptr), *unpack_hook(nullptr);
72+
std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks();
73+
if (!pack_hook || !unpack_hook) {
7474
return nullptr;
7575
}
76-
std::lock_guard<std::mutex> lock(mutex_);
7776
py::gil_scoped_acquire gil;
78-
py::function pack_hook = py::reinterpret_borrow<py::function>(pack_hook_);
79-
py::function unpack_hook = py::reinterpret_borrow<py::function>(unpack_hook_);
80-
return std::make_unique<PySavedVariableHooks>(pack_hook, unpack_hook);
77+
py::function pack_hook_ = py::reinterpret_borrow<py::function>(pack_hook);
78+
py::function unpack_hook_ = py::reinterpret_borrow<py::function>(unpack_hook);
79+
return std::make_unique<PySavedVariableHooks>(pack_hook_, unpack_hook_);
8180
}
8281

8382
}}

torch/csrc/autograd/python_saved_variable_hooks.h

-8
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,6 @@ struct PyDefaultSavedVariableHooks {
2727
static void set_hooks(py::function &pack_hook, py::function &unpack_hook);
2828
static void reset_hooks();
2929
static std::unique_ptr<SavedVariableHooks> get_hooks();
30-
31-
private:
32-
static PyObject* pack_hook_;
33-
static PyObject* unpack_hook_;
34-
35-
// Mutex to ensure that concurrent operations that modify default pack_hook_ and
36-
// unpack_hook_ are thread-safe.
37-
static std::mutex mutex_;
3830
};
3931

4032
}}

0 commit comments

Comments
 (0)