|
1 | 1 | #include <torch/csrc/autograd/python_saved_variable_hooks.h>
|
| 2 | +#include <ATen/SavedTensorHooks.h> |
2 | 3 |
|
3 | 4 | #include <torch/csrc/THP.h>
|
4 | 5 |
|
@@ -45,39 +46,37 @@ namespace torch { namespace autograd {
|
45 | 46 | }
|
46 | 47 | }
|
47 | 48 |
|
48 |
| - std::mutex PyDefaultSavedVariableHooks::mutex_; |
49 |
| - PyObject* PyDefaultSavedVariableHooks::pack_hook_(nullptr); |
50 |
| - PyObject* PyDefaultSavedVariableHooks::unpack_hook_(nullptr); |
51 |
| - |
52 | 49 | void PyDefaultSavedVariableHooks::set_hooks(py::function &pack_hook, py::function &unpack_hook) {
|
53 |
| - std::lock_guard<std::mutex> lock(mutex_); |
| 50 | + PyObject *pack_hook_(nullptr), *unpack_hook_(nullptr); |
| 51 | + std::tie(pack_hook_, unpack_hook_) = at::SavedTensorDefaultHooks::get_hooks(); |
54 | 52 | TORCH_CHECK(!pack_hook_ && !unpack_hook_,
|
55 | 53 | "Setting default hooks but they have already been set. "
|
56 | 54 | "Hint: only one pair of hooks is allowed at a time.");
|
57 |
| - pack_hook_ = pack_hook.release().ptr(); |
58 |
| - unpack_hook_ = unpack_hook.release().ptr(); |
| 55 | + at::SavedTensorDefaultHooks::enable(); |
| 56 | + at::SavedTensorDefaultHooks::set_hooks(pack_hook.release().ptr(), unpack_hook.release().ptr()); |
59 | 57 | }
|
60 | 58 |
|
61 | 59 | void PyDefaultSavedVariableHooks::reset_hooks() {
|
62 |
| - std::lock_guard<std::mutex> lock(mutex_); |
| 60 | + PyObject *pack_hook(nullptr), *unpack_hook(nullptr); |
| 61 | + std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks(); |
63 | 62 | if (Py_IsInitialized()) {
|
64 | 63 | py::gil_scoped_acquire gil;
|
65 |
| - Py_XDECREF(pack_hook_); |
66 |
| - Py_XDECREF(unpack_hook_); |
| 64 | + Py_XDECREF(pack_hook); |
| 65 | + Py_XDECREF(unpack_hook); |
67 | 66 | }
|
68 |
| - pack_hook_ = nullptr; |
69 |
| - unpack_hook_ = nullptr; |
| 67 | + at::SavedTensorDefaultHooks::set_hooks(nullptr, nullptr); |
70 | 68 | }
|
71 | 69 |
|
72 | 70 | std::unique_ptr<SavedVariableHooks> PyDefaultSavedVariableHooks::get_hooks() {
|
73 |
| - if (!pack_hook_ || !unpack_hook_) { |
| 71 | + PyObject *pack_hook(nullptr), *unpack_hook(nullptr); |
| 72 | + std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks(); |
| 73 | + if (!pack_hook || !unpack_hook) { |
74 | 74 | return nullptr;
|
75 | 75 | }
|
76 |
| - std::lock_guard<std::mutex> lock(mutex_); |
77 | 76 | py::gil_scoped_acquire gil;
|
78 |
| - py::function pack_hook = py::reinterpret_borrow<py::function>(pack_hook_); |
79 |
| - py::function unpack_hook = py::reinterpret_borrow<py::function>(unpack_hook_); |
80 |
| - return std::make_unique<PySavedVariableHooks>(pack_hook, unpack_hook); |
| 77 | + py::function pack_hook_ = py::reinterpret_borrow<py::function>(pack_hook); |
| 78 | + py::function unpack_hook_ = py::reinterpret_borrow<py::function>(unpack_hook); |
| 79 | + return std::make_unique<PySavedVariableHooks>(pack_hook_, unpack_hook_); |
81 | 80 | }
|
82 | 81 |
|
83 | 82 | }}
|
0 commit comments