Skip to content

Commit 4ae59e4

Browse files
Will Fengfacebook-github-bot
Will Feng
authored andcommitted
Move version_counter_ to TensorImpl (pytorch#18223)
Summary: According to pytorch#13638 (comment), after the Variable/Tensor merge, we may capture variables without autograd metadata inside an autograd function, and we need a working version counter in these cases. This PR makes it possible by moving `version_counter_` out of autograd metadata and into TensorImpl, so that variables without autograd metadata still have version counters. Pull Request resolved: pytorch#18223 Differential Revision: D14735123 Pulled By: yf225 fbshipit-source-id: 15f690311393ffd5a53522a226da82f5abb6c65b
1 parent 507fe66 commit 4ae59e4

File tree

7 files changed

+111
-63
lines changed

7 files changed

+111
-63
lines changed

aten/src/ATen/OpaqueTensorImpl.h

+6-2
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,12 @@ struct CAFFE2_API OpaqueTensorImpl : public TensorImpl {
7676
AT_ERROR("opaque tensors do not have storage");
7777
}
7878

79-
// NOTE: `shallow_copy_and_detach()` does not copy the AutogradMeta pointer
80-
// because it is unique for each Variable.
79+
// NOTE: `shallow_copy_and_detach()` does not copy the following TensorImpl fields:
80+
// 1. the AutogradMeta pointer, because it is unique for each Variable.
81+
// 2. the version counter, because although it lives in TensorImpl, the version counter is managed
82+
// by autograd, and the call sites of `shallow_copy_and_detach()` (from autograd) should decide what
83+
// the version counter should be for each new TensorImpl. See NOTE [ Version Counter Sharing ] for details.
84+
//
8185
// NOTE: We don't set `allow_tensor_metadata_change_` to false here, because there are call sites
8286
// to this function that need to change the shallow copy's size or storage afterwards, and setting
8387
// `allow_tensor_metadata_change_` to false would prevent those changes from happening and is

aten/src/ATen/SparseTensorImpl.h

+6-2
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,12 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl {
183183
// make it happen
184184
void set_indices_and_values_unsafe(const Tensor& indices, const Tensor& values);
185185

186-
// NOTE: `shallow_copy_and_detach()` does not copy the AutogradMeta pointer
187-
// because it is unique for each Variable.
186+
// NOTE: `shallow_copy_and_detach()` does not copy the following TensorImpl fields:
187+
// 1. the AutogradMeta pointer, because it is unique for each Variable.
188+
// 2. the version counter, because although it lives in TensorImpl, the version counter is managed
189+
// by autograd, and the call sites of `shallow_copy_and_detach()` (from autograd) should decide what
190+
// the version counter should be for each new TensorImpl. See NOTE [ Version Counter Sharing ] for details.
191+
//
188192
// NOTE: We don't set `allow_tensor_metadata_change_` to false here, because there are call sites
189193
// to this function that need to change the shallow copy's size or storage afterwards, and setting
190194
// `allow_tensor_metadata_change_` to false would prevent those changes from happening and is

c10/core/TensorImpl.h

+80-3
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,61 @@ struct C10_API AutogradMetaInterface {
138138
virtual ~AutogradMetaInterface();
139139
};
140140

141+
// NOTE [ Version Counter Sharing ]
142+
//
143+
// Every Tensor has a version counter. Version counters are incremented whenever the
144+
// data or size of a tensor changes through in-place Variable operations. Version
145+
// counters are used to detect modifications to saved variables which would result in
146+
// incorrect gradient calculations. Version counters may be shared between Variables:
147+
//
148+
// 1. A view shares the version counter of the base Variable,
149+
// 2. `x.detach()` shares the version counter of `x`,
150+
// 3. Unpacked saved variables share the version counter of the source.
151+
//
152+
// Version counters are not shared in these scenarios:
153+
//
154+
// 1. When we replace a `Variable`'s underlying `Tensor` by calling `set_data(...)`,
155+
// 2. `x.data` does not share the version counter of `x`. (See discussion at
156+
// https://github.com/pytorch/pytorch/issues/5396)
157+
//
158+
// Question: Why do we put the version counter in TensorImpl instead of AutogradMeta?
159+
//
160+
// Answer: After the Variable/Tensor merge, a tensor will not have AutogradMeta when
161+
// its `requires_grad_` is false, but when we use this tensor in the forward pass of
162+
// a function that requires saving this tensor for backward, we need to keep track of
163+
// this tensor's version to make sure it's always valid in the autograd graph.
164+
//
165+
// To achieve this goal, we put the version counter in TensorImpl instead of AutogradMeta,
166+
// and have it always be available. This allows us to have the optimization of not
167+
// carrying AutogradMeta when a tensor doesn't require gradient.
168+
//
169+
// A hypothetical alternative way to achieve this goal is to initialize AutogradMeta and
170+
// create the version counter for the non-requires-grad tensor only when it's saved for
171+
// backward. However, since saving a tensor for backward happens in the forward pass, and
172+
// our invariant is that forward pass needs to be thread-safe, lazy-initializing AutogradMeta
173+
// when saving a tensor can introduce race conditions when we are running the forward
174+
// pass in multi-thread scenarios, thus making the forward pass not thread-safe anymore,
175+
// which breaks the invariant.
176+
struct C10_API VariableVersion {
177+
public:
178+
// NOTE: As of C++11 and 14, default-constructing a std::atomic variable
179+
// leaves it in a persistently undefined state. See
180+
// https://cplusplus.github.io/LWG/issue2334.
181+
VariableVersion(uint32_t version = 0)
182+
: version_block_(std::make_shared<std::atomic<uint32_t>>(version)) {}
183+
184+
void bump() noexcept {
185+
version_block_->fetch_add(1);
186+
}
187+
188+
uint32_t current_version() const noexcept {
189+
return version_block_->load();
190+
}
191+
192+
private:
193+
std::shared_ptr<std::atomic<uint32_t>> version_block_;
194+
};
195+
141196
/**
142197
* The low-level representation of a tensor, which contains a pointer
143198
* to a storage (which contains the actual data) and metadata (e.g., sizes and
@@ -845,13 +900,18 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
845900
return std::move(autograd_meta_);
846901
}
847902

848-
// NOTE: `shallow_copy_and_detach()` does not copy the AutogradMeta pointer
849-
// because it is unique for each Variable.
903+
// NOTE: `shallow_copy_and_detach()` does not copy the following TensorImpl fields:
904+
// 1. the AutogradMeta pointer, because it is unique for each Variable.
905+
// 2. the version counter, because although it lives in TensorImpl, the version counter is managed
906+
// by autograd, and the call sites of `shallow_copy_and_detach()` (from autograd) should decide what
907+
// the version counter should be for each new TensorImpl. See NOTE [ Version Counter Sharing ] for details.
908+
//
850909
// NOTE: We don't set `allow_tensor_metadata_change_` to false here, because there are call sites
851910
// to this function that need to change the shallow copy's size or storage afterwards, and setting
852911
// `allow_tensor_metadata_change_` to false would prevent those changes from happening and is
853912
// undesirable.
854913
virtual c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach() const {
914+
AT_ASSERT(!is_variable()); // TODO: remove this when Variable and Tensor are merged
855915
auto impl = c10::make_intrusive<TensorImpl>(Storage(storage()), type_id());
856916
impl->set_sizes_and_strides(sizes(), strides());
857917
impl->storage_offset_ = storage_offset_;
@@ -862,6 +922,19 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
862922
return impl;
863923
}
864924

925+
void set_version_counter(
926+
const c10::VariableVersion& version_counter) noexcept {
927+
version_counter_ = version_counter;
928+
}
929+
930+
const c10::VariableVersion& version_counter() const noexcept {
931+
return version_counter_;
932+
}
933+
934+
void bump_version() noexcept {
935+
version_counter_.bump();
936+
}
937+
865938
inline void set_pyobj(PyObject* pyobj) noexcept {
866939
pyobj_ = pyobj;
867940
}
@@ -1384,6 +1457,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
13841457
// at a time).
13851458
std::unique_ptr<c10::AutogradMetaInterface> autograd_meta_ = nullptr;
13861459

1460+
c10::VariableVersion version_counter_;
1461+
13871462
PyObject* pyobj_ = nullptr; // weak reference
13881463

13891464
// We could save a word or two by combining the SmallVector structs,
@@ -1470,6 +1545,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
14701545
// weak refcount
14711546
// storage pointer
14721547
// autograd metadata pointer
1548+
// version counter (word 0)
1549+
// version counter (word 1)
14731550
// PyObject pointer
14741551
// sizes SmallVector (begin)
14751552
// sizes SmallVector (end)
@@ -1494,7 +1571,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
14941571
// miscellaneous bitfield
14951572
//
14961573
static_assert(sizeof(void*) != sizeof(int64_t) || // if 64-bit...
1497-
sizeof(TensorImpl) == sizeof(int64_t) * 27,
1574+
sizeof(TensorImpl) == sizeof(int64_t) * 29,
14981575
"You changed the size of TensorImpl on 64-bit arch."
14991576
"See Note [TensorImpl size constraints] on how to proceed.");
15001577

torch/csrc/autograd/saved_variable.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#pragma once
22

33
#include <torch/csrc/WindowsTorchApiMacro.h>
4-
#include <torch/csrc/autograd/variable_version.h>
54

65
#include <ATen/ATen.h>
76

@@ -47,7 +46,7 @@ class TORCH_API SavedVariable {
4746
// passed in to the unpack function when reconstructing the Variable.
4847
std::shared_ptr<Function> grad_fn_;
4948
std::weak_ptr<Function> grad_accumulator_;
50-
VariableVersion version_counter_;
49+
c10::VariableVersion version_counter_;
5150

5251
uint32_t saved_version_ = 0;
5352
uint32_t output_nr_ = 0;

torch/csrc/autograd/variable.cpp

+10-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include <torch/csrc/autograd/functions/tensor.h>
88
#include <torch/csrc/autograd/generated/Functions.h>
99
#include <torch/csrc/autograd/generated/VariableType.h>
10-
#include <torch/csrc/autograd/variable_version.h>
1110

1211
#include <ATen/ATen.h>
1312
#include <c10/util/Exception.h>
@@ -171,8 +170,13 @@ void Variable::Impl::set_data(const at::Tensor &new_data) {
171170
device_opt_ = new_data.device();
172171
type_id_ = new_data.dispatch_type().type_id();
173172

174-
auto new_data_copy = at::Tensor(new_data.getIntrusivePtr()->shallow_copy_and_detach());
175-
data_ = std::move(new_data_copy);
173+
auto new_data_impl_copy = new_data.getIntrusivePtr()->shallow_copy_and_detach();
174+
// Version counter is not shared when we replace a `Variable`'s underlying `Tensor`
175+
// by calling `set_data(...)`. The original version of the `Variable` is always preserved.
176+
// See NOTE [ Version Counter Sharing ] for details.
177+
auto saved_version_ = data_.unsafeGetTensorImpl()->version_counter().current_version();
178+
new_data_impl_copy->set_version_counter(saved_version_);
179+
data_ = std::move(at::Tensor(new_data_impl_copy));
176180
}
177181

178182
void Variable::Impl::release_resources() {
@@ -189,8 +193,8 @@ Variable::DifferentiableViewImpl::DifferentiableViewImpl(Variable base, at::Tens
189193
diff_view_meta->base_ = diff_view_meta->base_.base();
190194
}
191195
diff_view_meta->is_view_ = true;
192-
diff_view_meta->version_counter_ = diff_view_meta->base_.version_counter();
193-
diff_view_meta->attr_version = diff_view_meta->version_counter_.current_version();
196+
data_.unsafeGetTensorImpl()->set_version_counter(diff_view_meta->base_.version_counter());
197+
diff_view_meta->attr_version = data_.unsafeGetTensorImpl()->version_counter().current_version();
194198
}
195199

196200
const std::shared_ptr<Function>& Variable::grad_fn() const {
@@ -200,7 +204,7 @@ const std::shared_ptr<Function>& Variable::grad_fn() const {
200204
if (!diff_view_meta->grad_fn_ && !diff_view_meta->base_.requires_grad()) {
201205
return diff_view_meta->grad_fn_;
202206
}
203-
auto current_version = diff_view_meta->version_counter_.current_version();
207+
auto current_version = this->current_version();
204208
if (diff_view_meta->attr_version != current_version) {
205209
AT_ASSERT(diff_view_meta->output_nr_ == 0);
206210
auto fn = std::make_shared<generated::AsStridedBackward>();

torch/csrc/autograd/variable.h

+8-10
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include <torch/csrc/WindowsTorchApiMacro.h>
66
#include <torch/csrc/autograd/edge.h>
77
#include <torch/csrc/autograd/function_hook.h>
8-
#include <torch/csrc/autograd/variable_version.h>
98

109
#include <ATen/ATen.h>
1110
#include <c10/util/Exception.h>
@@ -257,10 +256,10 @@ struct TORCH_API Variable : public at::Tensor {
257256

258257
/// Increments the version count of this `Variable`.
259258
void bump_version() noexcept;
260-
void set_version_counter(const VariableVersion& version_counter) noexcept;
259+
void set_version_counter(const c10::VariableVersion& version_counter) noexcept;
261260

262261
/// Retrieves this `Variable`s version counter.
263-
const VariableVersion& version_counter() const noexcept;
262+
const c10::VariableVersion& version_counter() const noexcept;
264263

265264
/// Retrieves the current value of the `Variable`'s version counter.
266265
/// Equivalent to calling `version_counter().current_version()`.
@@ -335,7 +334,6 @@ struct TORCH_API Variable::AutogradMeta : public c10::AutogradMetaInterface {
335334
std::shared_ptr<Function> grad_fn_;
336335
std::weak_ptr<Function> grad_accumulator_;
337336

338-
VariableVersion version_counter_;
339337
std::vector<std::shared_ptr<FunctionPreHook>> hooks_;
340338

341339
// Only meaningful on leaf variables (must be false otherwise)
@@ -692,20 +690,20 @@ inline bool Variable::is_leaf() const noexcept {
692690
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
693691

694692
inline void Variable::set_version_counter(
695-
const VariableVersion& version_counter) noexcept {
696-
get_autograd_meta()->version_counter_ = version_counter;
693+
const c10::VariableVersion& version_counter) noexcept {
694+
data().unsafeGetTensorImpl()->set_version_counter(version_counter);
697695
}
698696

699697
inline void Variable::bump_version() noexcept {
700-
get_autograd_meta()->version_counter_.bump();
698+
data().unsafeGetTensorImpl()->bump_version();
701699
}
702700

703701
inline uint32_t Variable::current_version() const noexcept {
704-
return get_autograd_meta()->version_counter_.current_version();
702+
return data().unsafeGetTensorImpl()->version_counter().current_version();
705703
}
706704

707-
inline const VariableVersion& Variable::version_counter() const noexcept {
708-
return get_autograd_meta()->version_counter_;
705+
inline const c10::VariableVersion& Variable::version_counter() const noexcept {
706+
return data().unsafeGetTensorImpl()->version_counter();
709707
}
710708

711709
// Hooks

torch/csrc/autograd/variable_version.h

-38
This file was deleted.

0 commit comments

Comments
 (0)