Skip to content

Commit 9b1a65b

Browse files
mruberryfacebook-github-bot
authored andcommitted
Extends type and shape tracing with device (pytorch#9796)
Summary: This PR extends the existing type and shape metadata tracing and verification done in autograd with device information. This expansion of tracing is required for pytorch#8354, is likely useful in other scenarios, and is a healthy sanity check, just like type and shape tracing. The precise changes are: - TypeAndShape -> InputMetadata, now includes device() - Creating InputMetadata is simplified to just require a tensor, and callers were updated to use this simpler invocation wherever possible - The gradient accumulator of a variable is now reset when set_data() is called if either the type or device changes, and this reset now locks to avoid contention with acquiring the gradient accumulator - Mismatched devices during backward() will throw a runtime error, just like mismatched type and shape - (Bonus!) Two uninitialized pointers in THCReduce are now initialized (to nullptr) to prevent build warnings fyi colesbury Pull Request resolved: pytorch#9796 Reviewed By: goldsborough Differential Revision: D9119325 Pulled By: ezyang fbshipit-source-id: 76d1861b8d4f74db0575ff1f3bd965e18f9463de
1 parent 2993c42 commit 9b1a65b

12 files changed

+95
-56
lines changed

aten/src/THC/THCReduce.cuh

+3-3
Original file line numberDiff line numberDiff line change
@@ -517,9 +517,9 @@ bool THC_reduceDim(THCState* state,
517517
(TYPE) outElements, init, modifyOp, reduceOp, finalizeOp); \
518518
} \
519519
else \
520-
{ \
521-
void* stagingData; \
522-
void* semaphores; \
520+
{ \
521+
void* stagingData = nullptr; \
522+
void* semaphores = nullptr; \
523523
\
524524
if(grid.y > 1) \
525525
{ \

tools/autograd/templates/VariableType.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ static void throw_error_out_requires_grad(const char* name) {
343343

344344
static void rebase_history(Variable& var, std::shared_ptr<Function> grad_fn) {
345345
if (grad_fn && var.defined()) {
346-
grad_fn->add_input_metadata(var.type(), var.sizes());
346+
grad_fn->add_input_metadata(var);
347347
var.rebase_history({std::move(grad_fn), 0});
348348
}
349349
}
@@ -353,7 +353,7 @@ static void rebase_history(ArrayRef<Variable> vars, std::shared_ptr<Function> gr
353353
for (auto& var : vars) {
354354
if (var.defined()) {
355355
// TODO: eliminate const_cast
356-
auto output_nr = grad_fn->add_input_metadata(var.type(), var.sizes());
356+
auto output_nr = grad_fn->add_input_metadata(var);
357357
const_cast<Variable&>(var).rebase_history({grad_fn, output_nr});
358358
} else {
359359
grad_fn->add_input_metadata(Function::undefined_input());

torch/csrc/autograd/engine.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,13 @@ static void validate_outputs(const edge_list& edges, variable_list& grads, const
338338
ss << metadata.type() << " but got " << grads[i].type();
339339
throw std::runtime_error(format_error(ss.str()));
340340
}
341+
const auto output_device = output.is_cuda() ? output.get_device() : -1;
342+
if (output_device != metadata.device()) {
343+
std::stringstream ss;
344+
ss << "invalid gradient at index " << i << " - expected device ";
345+
ss << metadata.device() << " but got " << output_device;
346+
throw std::runtime_error(format_error(ss.str()));
347+
}
341348
}
342349
}
343350

torch/csrc/autograd/function.h

+15-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include "torch/csrc/autograd/anomaly_mode.h"
66
#include "torch/csrc/autograd/profiler.h"
77
#include "torch/csrc/autograd/saved_variable.h"
8-
#include "torch/csrc/autograd/type_and_shape.h"
8+
#include "torch/csrc/autograd/input_metadata.h"
99
#include "torch/csrc/autograd/variable.h"
1010
#include "torch/csrc/utils/python_stub.h"
1111
#include "torch/csrc/utils/variadic.h"
@@ -128,9 +128,18 @@ struct TORCH_API Function : std::enable_shared_from_this<Function> {
128128

129129
/// Adds the type and shape metadata for a new input. Returns the index of
130130
/// of the new input.
131-
uint32_t add_input_metadata(const at::Type& type, at::IntList shape) noexcept {
131+
uint32_t add_input_metadata(
132+
const at::Type& type
133+
, at::IntList shape
134+
, const int64_t device) noexcept {
132135
uint32_t input_nr = input_metadata_.size();
133-
input_metadata_.emplace_back(type, shape);
136+
input_metadata_.emplace_back(type, shape, device);
137+
return input_nr;
138+
}
139+
140+
uint32_t add_input_metadata(const at::Tensor& t) noexcept {
141+
uint32_t input_nr = input_metadata_.size();
142+
input_metadata_.emplace_back(t);
134143
return input_nr;
135144
}
136145

@@ -145,7 +154,7 @@ struct TORCH_API Function : std::enable_shared_from_this<Function> {
145154
return input_metadata_.size();
146155
}
147156

148-
const TypeAndShape& input_metadata(size_t index) const {
157+
const InputMetadata& input_metadata(size_t index) const {
149158
return input_metadata_[index];
150159
}
151160

@@ -322,7 +331,7 @@ struct TORCH_API Function : std::enable_shared_from_this<Function> {
322331
std::unique_ptr<AnomalyMetadata> anomaly_metadata_ = nullptr;
323332
std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_;
324333
std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;
325-
at::SmallVector<TypeAndShape, 2> input_metadata_;
334+
at::SmallVector<InputMetadata, 2> input_metadata_;
326335
};
327336

328337
/// See Function::is_traceable() for definition.
@@ -367,7 +376,7 @@ inline void create_gradient_edge(
367376
Variable& variable,
368377
std::shared_ptr<Function> function) {
369378
// Copy before move.
370-
const auto input_nr = function->add_input_metadata(variable.type(), variable.sizes());
379+
const auto input_nr = function->add_input_metadata(variable);
371380
variable.set_gradient_edge({std::move(function), input_nr});
372381
}
373382

torch/csrc/autograd/functions/accumulate_grad.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace torch { namespace autograd {
1919
AccumulateGrad::AccumulateGrad(Variable variable_)
2020
: Function(/*sequence_nr=*/UINT64_MAX)
2121
, variable(std::move(variable_)) {
22-
add_input_metadata(variable.type(), variable.sizes());
22+
add_input_metadata(variable);
2323
}
2424

2525
auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {

torch/csrc/autograd/functions/tensor.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ CopySlices::CopySlices(
4343
fn(std::move(fn_)) {
4444
// Take the next_edges of fn as our own, except for index 0 which goes
4545
// to base instead of the view.
46-
add_input_metadata(base_var.type(), base_var.sizes());
46+
add_input_metadata(base_var);
4747
const auto num_outputs = fn->num_outputs();
4848
next_edges_.reserve(num_outputs);
4949
add_next_edge(base_var.gradient_edge());

torch/csrc/autograd/functions/utils.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ inline void set_history(
5454
if (grad_fn) {
5555
if (variable.defined()) {
5656
auto output_nr =
57-
grad_fn->add_input_metadata(variable.type(), variable.sizes());
57+
grad_fn->add_input_metadata(variable);
5858
as_variable_ref(variable).set_gradient_edge({grad_fn, output_nr});
5959
} else {
6060
grad_fn->add_input_metadata(Function::undefined_input());

torch/csrc/autograd/input_metadata.h

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
5+
#include <cstdint>
6+
7+
namespace torch { namespace autograd {
8+
9+
/// A tensor's type and shape. Each Function records the required type and
10+
/// shape of its inputs. If is_valid() is false, then the corresponding input
11+
/// is not used and may be an undefined tensor.
12+
struct InputMetadata {
13+
InputMetadata() = default;
14+
15+
InputMetadata(const at::Type& type, at::IntList shape, const int64_t device)
16+
: type_{&type} , shape_{shape}, device_{device} { }
17+
18+
InputMetadata(const at::Tensor& t)
19+
: InputMetadata(t.type(), t.sizes(), t.is_cuda() ? t.get_device() : - 1) { }
20+
21+
bool is_valid() const {
22+
return type_ != nullptr;
23+
}
24+
25+
const at::Type& type() const {
26+
AT_ASSERT(type_);
27+
return *type_;
28+
}
29+
30+
at::IntList shape() const {
31+
return shape_;
32+
}
33+
34+
int64_t device() const {
35+
return device_;
36+
}
37+
38+
private:
39+
const at::Type* type_ = nullptr;
40+
at::DimVector shape_;
41+
const int64_t device_ = -1;
42+
};
43+
44+
}}

torch/csrc/autograd/python_function.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ static void _wrap_outputs(THPFunction *self,
433433
// to set_history wins.
434434
auto var = as_variable(obj, i);
435435
if (cdata) {
436-
auto output_nr = cdata->add_input_metadata(var.type(), var.sizes());
436+
auto output_nr = cdata->add_input_metadata(var);
437437
AT_ASSERT(i == (int)output_nr);
438438
}
439439
set_history(var, i, is_input, is_modified, is_differentiable);

torch/csrc/autograd/python_legacy_variable.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ static PyObject *THPVariable_pynew(PyTypeObject* type, PyObject *args, PyObject
5757
Variable var;
5858
if (grad_fn) {
5959
auto grad_fn_ = THPFunction_asFunction((THPFunction*)grad_fn);
60-
Edge edge(grad_fn_, grad_fn_->add_input_metadata(tensor.type(), tensor.sizes()));
60+
Edge edge(grad_fn_, grad_fn_->add_input_metadata(tensor));
6161
var = make_variable(std::move(tensor), std::move(edge));
6262
} else {
6363
var = make_variable(std::move(tensor), requires_grad);

torch/csrc/autograd/type_and_shape.h

-33
Original file line numberDiff line numberDiff line change
@@ -1,33 +0,0 @@
1-
#pragma once
2-
3-
#include <ATen/ATen.h>
4-
5-
namespace torch { namespace autograd {
6-
7-
/// A tensor's type and shape. Each Function records the required type and
8-
/// shape of its inputs. If is_valid() is false, then the corresponding input
9-
/// is not used and may be an undefined tensor.
10-
struct TypeAndShape {
11-
TypeAndShape() : type_(nullptr) {}
12-
13-
TypeAndShape(const at::Type& type, at::IntList shape)
14-
: type_(&type) , shape_(shape) {}
15-
16-
bool is_valid() const {
17-
return type_ != nullptr;
18-
}
19-
20-
const at::Type& type() const {
21-
AT_ASSERT(type_);
22-
return *type_;
23-
}
24-
25-
at::IntList shape() const {
26-
return shape_;
27-
}
28-
29-
const at::Type* type_;
30-
at::DimVector shape_;
31-
};
32-
33-
}}

torch/csrc/autograd/variable.cpp

+19-7
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,22 @@ void Variable::Impl::backward(
117117
}
118118

119119
void Variable::Impl::set_data(Tensor new_data) {
120-
if (new_data.type() != data_.type()) {
121-
scalar_type_ = new_data.type().scalarType();
122-
backend_ = new_data.type().backend();
123-
is_variable_ = true;
124-
// Clear grad_accumulator if it exists, since it stores the old type info.
125-
grad_accumulator_.reset();
120+
// Resets gradient accumulator if metadata is out of date
121+
std::lock_guard<std::mutex> lock(mutex_);
122+
auto prior_accumulator = grad_accumulator_.lock();
123+
if (prior_accumulator) {
124+
const auto prior_device = prior_accumulator->input_metadata(0).device();
125+
const auto new_device = new_data.is_cuda() ? new_data.get_device() : -1;
126+
127+
if (new_data.type() != data_.type() || prior_device != new_device) {
128+
grad_accumulator_.reset();
129+
}
126130
}
131+
132+
// Updates metadata
133+
scalar_type_ = new_data.type().scalarType();
134+
backend_ = new_data.type().backend();
135+
is_variable_ = true;
127136
data_ = std::move(new_data);
128137
}
129138

@@ -160,7 +169,10 @@ std::shared_ptr<Function>& Variable::ViewImpl::get_grad_fn() {
160169
fn->stride = strides().vec();
161170
fn->storage_offset = data_.storage_offset();
162171
fn->set_next_edges(collect_next_edges(base_));
163-
fn->add_input_metadata(base_.type(), sizes());
172+
fn->add_input_metadata(
173+
base_.type()
174+
, sizes() // Note: sizes(), not base_.sizes(), is intentional
175+
, base_.is_cuda() ? base_.get_device() : -1);
164176
grad_fn_ = std::move(fn);
165177
attr_version = current_version;
166178
}

0 commit comments

Comments
 (0)