@@ -138,6 +138,61 @@ struct C10_API AutogradMetaInterface {
138
138
virtual ~AutogradMetaInterface ();
139
139
};
140
140
141
+ // NOTE [ Version Counter Sharing ]
142
+ //
143
+ // Every Tensor has a version counter. Version counters are incremented whenever the
144
+ // data or size of a tensor changes through in-place Variable operations. Version
145
+ // counters are used to detect modifications to saved variables which would result in
146
+ // incorrect gradient calculations. Version counters may be shared between Variables:
147
+ //
148
+ // 1. A view shares the version counter of the base Variable,
149
+ // 2. `x.detach()` shares the version counter of `x`,
150
+ // 3. Unpacked saved variables share the version counter of the source.
151
+ //
152
+ // Version counters are not shared in these scenarios:
153
+ //
154
+ // 1. When we replace a `Variable`'s underlying `Tensor` by calling `set_data(...)`,
155
+ // 2. `x.data` does not share the version counter of `x`. (See discussion at
156
+ // https://github.com/pytorch/pytorch/issues/5396)
157
+ //
158
+ // Question: Why do we put the version counter in TensorImpl instead of AutogradMeta?
159
+ //
160
+ // Answer: After the Variable/Tensor merge, a tensor will not have AutogradMeta when
161
+ // its `requires_grad_` is false, but when we use this tensor in the forward pass of
162
+ // a function that requires saving this tensor for backward, we need to keep track of
163
+ // this tensor's version to make sure it's always valid in the autograd graph.
164
+ //
165
+ // To achieve this goal, we put the version counter in TensorImpl instead of AutogradMeta,
166
+ // and have it always be available. This allows us to have the optimization of not
167
+ // carrying AutogradMeta when a tensor doesn't require gradient.
168
+ //
169
+ // A hypothetical alternative way to achieve this goal is to initialize AutogradMeta and
170
+ // create the version counter for the non-requires-grad tensor only when it's saved for
171
+ // backward. However, since saving a tensor for backward happens in the forward pass, and
172
+ // our invariant is that forward pass needs to be thread-safe, lazy-initializing AutogradMeta
173
+ // when saving a tensor can introduce race conditions when we are running the forward
174
+ // pass in multi-thread scenarios, thus making the forward pass not thread-safe anymore,
175
+ // which breaks the invariant.
176
+ struct C10_API VariableVersion {
177
+ public:
178
+ // NOTE: As of C++11 and 14, default-constructing a std::atomic variable
179
+ // leaves it in a persistently undefined state. See
180
+ // https://cplusplus.github.io/LWG/issue2334.
181
+ VariableVersion (uint32_t version = 0 )
182
+ : version_block_(std::make_shared<std::atomic<uint32_t >>(version)) {}
183
+
184
+ void bump () noexcept {
185
+ version_block_->fetch_add (1 );
186
+ }
187
+
188
+ uint32_t current_version () const noexcept {
189
+ return version_block_->load ();
190
+ }
191
+
192
+ private:
193
+ std::shared_ptr<std::atomic<uint32_t >> version_block_;
194
+ };
195
+
141
196
/* *
142
197
* The low-level representation of a tensor, which contains a pointer
143
198
* to a storage (which contains the actual data) and metadata (e.g., sizes and
@@ -845,13 +900,18 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
845
900
return std::move (autograd_meta_);
846
901
}
847
902
848
- // NOTE: `shallow_copy_and_detach()` does not copy the AutogradMeta pointer
849
- // because it is unique for each Variable.
903
+ // NOTE: `shallow_copy_and_detach()` does not copy the following TensorImpl fields:
904
+ // 1. the AutogradMeta pointer, because it is unique for each Variable.
905
+ // 2. the version counter, because although it lives in TensorImpl, the version counter is managed
906
+ // by autograd, and the call sites of `shallow_copy_and_detach()` (from autograd) should decide what
907
+ // the version counter should be for each new TensorImpl. See NOTE [ Version Counter Sharing ] for details.
908
+ //
850
909
// NOTE: We don't set `allow_tensor_metadata_change_` to false here, because there are call sites
851
910
// to this function that need to change the shallow copy's size or storage afterwards, and setting
852
911
// `allow_tensor_metadata_change_` to false would prevent those changes from happening and is
853
912
// undesirable.
854
913
virtual c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach () const {
914
+ AT_ASSERT (!is_variable ()); // TODO: remove this when Variable and Tensor are merged
855
915
auto impl = c10::make_intrusive<TensorImpl>(Storage (storage ()), type_id ());
856
916
impl->set_sizes_and_strides (sizes (), strides ());
857
917
impl->storage_offset_ = storage_offset_;
@@ -862,6 +922,19 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
862
922
return impl;
863
923
}
864
924
925
+ void set_version_counter (
926
+ const c10::VariableVersion& version_counter) noexcept {
927
+ version_counter_ = version_counter;
928
+ }
929
+
930
+ const c10::VariableVersion& version_counter () const noexcept {
931
+ return version_counter_;
932
+ }
933
+
934
+ void bump_version () noexcept {
935
+ version_counter_.bump ();
936
+ }
937
+
865
938
inline void set_pyobj (PyObject* pyobj) noexcept {
866
939
pyobj_ = pyobj;
867
940
}
@@ -1384,6 +1457,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
1384
1457
// at a time).
1385
1458
std::unique_ptr<c10::AutogradMetaInterface> autograd_meta_ = nullptr ;
1386
1459
1460
+ c10::VariableVersion version_counter_;
1461
+
1387
1462
PyObject* pyobj_ = nullptr ; // weak reference
1388
1463
1389
1464
// We could save a word or two by combining the SmallVector structs,
@@ -1470,6 +1545,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
1470
1545
// weak refcount
1471
1546
// storage pointer
1472
1547
// autograd metadata pointer
1548
+ // version counter (word 0)
1549
+ // version counter (word 1)
1473
1550
// PyObject pointer
1474
1551
// sizes SmallVector (begin)
1475
1552
// sizes SmallVector (end)
@@ -1494,7 +1571,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
1494
1571
// miscellaneous bitfield
1495
1572
//
1496
1573
static_assert (sizeof (void *) != sizeof (int64_t ) || // if 64-bit...
1497
- sizeof (TensorImpl) == sizeof (int64_t ) * 27 ,
1574
+ sizeof (TensorImpl) == sizeof (int64_t ) * 29 ,
1498
1575
" You changed the size of TensorImpl on 64-bit arch."
1499
1576
" See Note [TensorImpl size constraints] on how to proceed." );
1500
1577
0 commit comments