Skip to content

Commit ab78449

Browse files
Roy Lifacebook-github-bot
Roy Li
authored andcommitted
Add ScalarType argument to Type::options() (pytorch#19270)
Summary: Pull Request resolved: pytorch#19270 ghimport-source-id: a5ade61 Differential Revision: D14938707 Pulled By: li-roy fbshipit-source-id: 018fb3f01706531a06515d6d861e5683a455a705
1 parent a044ba1 commit ab78449

18 files changed

+76
-73
lines changed

aten/src/ATen/Context.h

+5
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,11 @@ CAFFE2_API TypeExtendedInterface& getType(const Tensor&);
176176

177177
CAFFE2_API Allocator* getCPUAllocator();
178178

179+
static inline DeprecatedTypeProperties& getNonVariableDeprecatedTypeProperties(Backend p, ScalarType s) {
180+
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
181+
p, s, /*is_variable*/false);
182+
}
183+
179184
static inline DeprecatedTypeProperties& CPU(ScalarType s) {
180185
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
181186
Backend::CPU, s, /*is_variable*/false);

aten/src/ATen/core/Type.h

+5-10
Original file line numberDiff line numberDiff line change
@@ -176,30 +176,25 @@ struct CAFFE2_API Type {
176176
return this != &other;
177177
}
178178

179-
/// Constructs the `TensorOptions` from a type and a `device_index`.
180-
TensorOptions options(int16_t device_index = -1) const {
181-
return TensorOptions().dtype(typeMeta())
179+
TensorOptions options(ScalarType s, int16_t device_index = -1) const {
180+
return TensorOptions().dtype(s)
182181
.device(device_type(), device_index)
183182
.layout(layout())
184183
.is_variable(is_variable());
185184
}
186185

187186
/// Constructs the `TensorOptions` from a type and a Device. Asserts that
188187
/// the device type matches the device type of the type.
189-
TensorOptions options(c10::optional<Device> device_opt) const {
188+
TensorOptions options(ScalarType s, c10::optional<Device> device_opt) const {
190189
if (!device_opt.has_value()) {
191-
return options(-1);
190+
return options(s, -1);
192191
} else {
193192
Device device = device_opt.value();
194193
AT_ASSERT(device.type() == device_type());
195-
return options(device.index());
194+
return options(s, device.index());
196195
}
197196
}
198197

199-
operator TensorOptions() const {
200-
return options();
201-
}
202-
203198
// example
204199
// virtual Tensor * add(Tensor & a, Tensor & b) = 0;
205200
virtual Tensor abs(const Tensor & self) const = 0;

aten/src/ATen/function_wrapper.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1604,7 +1604,8 @@ def emit_body(env, option, scalar_type_cases):
16041604
# e.g. x.sum(0) and x.sum() return the same type. We explicitly cast to the
16051605
# ScalarType before constructing the scalar_tensor to avoid overflow checking.
16061606
elif ret['type'] == 'accreal' or ret['type'] == 'real':
1607-
return_scalar = 'return at::scalar_tensor(convert<${ScalarType}>(${call}), options());'
1607+
return_scalar = ('return at::scalar_tensor(convert<${ScalarType}>(${call}), '
1608+
'options(ScalarType::${ScalarName}));')
16081609
case_body.append(CodeTemplate(return_scalar).substitute(case_env, call=call))
16091610
else:
16101611
# we using int64_t for long in the API, so correct it here...

aten/src/ATen/templates/Type.h

+5-10
Original file line numberDiff line numberDiff line change
@@ -119,30 +119,25 @@ struct CAFFE2_API Type {
119119
return this != &other;
120120
}
121121

122-
/// Constructs the `TensorOptions` from a type and a `device_index`.
123-
TensorOptions options(int16_t device_index = -1) const {
124-
return TensorOptions().dtype(typeMeta())
122+
TensorOptions options(ScalarType s, int16_t device_index = -1) const {
123+
return TensorOptions().dtype(s)
125124
.device(device_type(), device_index)
126125
.layout(layout())
127126
.is_variable(is_variable());
128127
}
129128

130129
/// Constructs the `TensorOptions` from a type and a Device. Asserts that
131130
/// the device type matches the device type of the type.
132-
TensorOptions options(c10::optional<Device> device_opt) const {
131+
TensorOptions options(ScalarType s, c10::optional<Device> device_opt) const {
133132
if (!device_opt.has_value()) {
134-
return options(-1);
133+
return options(s, -1);
135134
} else {
136135
Device device = device_opt.value();
137136
AT_ASSERT(device.type() == device_type());
138-
return options(device.index());
137+
return options(s, device.index());
139138
}
140139
}
141140

142-
operator TensorOptions() const {
143-
return options();
144-
}
145-
146141
// example
147142
// virtual Tensor * add(Tensor & a, Tensor & b) = 0;
148143
${pure_virtual_type_method_declarations}

test/cpp/api/tensor_options.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,18 @@ TEST(TensorOptionsTest, ConstructsWellFromCPUTypes) {
6666
options = TensorOptions(kInt);
6767
REQUIRE_OPTIONS(kCPU, -1, kInt, kStrided);
6868

69-
options = TensorOptions(getNonVariableType(Backend::SparseCPU, kFloat));
69+
options = TensorOptions(getNonVariableDeprecatedTypeProperties(Backend::SparseCPU, kFloat));
7070
REQUIRE_OPTIONS(kCPU, -1, kFloat, kSparse);
7171

72-
options = TensorOptions(getNonVariableType(Backend::SparseCPU, kByte));
72+
options = TensorOptions(getNonVariableDeprecatedTypeProperties(Backend::SparseCPU, kByte));
7373
REQUIRE_OPTIONS(kCPU, -1, kByte, kSparse);
7474
}
7575

7676
TEST(TensorOptionsTest, ConstructsWellFromCPUTensors) {
7777
auto options = empty(5, kDouble).options();
7878
REQUIRE_OPTIONS(kCPU, -1, kDouble, kStrided);
7979

80-
options = empty(5, getNonVariableType(Backend::SparseCPU, kByte)).options();
80+
options = empty(5, getNonVariableDeprecatedTypeProperties(Backend::SparseCPU, kByte)).options();
8181
REQUIRE_OPTIONS(kCPU, -1, kByte, kSparse);
8282
}
8383

test/cpp/api/tensor_options_cuda.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -42,25 +42,25 @@ TEST(TensorOptionsTest, ConstructsWellFromCUDATypes_CUDA) {
4242
options = CUDA(kInt).options();
4343
REQUIRE_OPTIONS(kCUDA, -1, kInt, kStrided);
4444

45-
options = getNonVariableType(Backend::SparseCUDA, kFloat).options();
45+
options = getNonVariableDeprecatedTypeProperties(Backend::SparseCUDA, kFloat).options();
4646
REQUIRE_OPTIONS(kCUDA, -1, kFloat, kSparse);
4747

48-
options = getNonVariableType(Backend::SparseCUDA, kByte).options();
48+
options = getNonVariableDeprecatedTypeProperties(Backend::SparseCUDA, kByte).options();
4949
REQUIRE_OPTIONS(kCUDA, -1, kByte, kSparse);
5050

5151
options = CUDA(kFloat).options(/*device=*/5);
5252
REQUIRE_OPTIONS(kCUDA, 5, kFloat, kStrided);
5353

5454
options =
55-
getNonVariableType(Backend::SparseCUDA, kFloat).options(/*device=*/5);
55+
getNonVariableDeprecatedTypeProperties(Backend::SparseCUDA, kFloat).options(/*device=*/5);
5656
REQUIRE_OPTIONS(kCUDA, 5, kFloat, kSparse);
5757
}
5858

5959
TEST(TensorOptionsTest, ConstructsWellFromCUDATensors_MultiCUDA) {
6060
auto options = empty(5, device(kCUDA).dtype(kDouble)).options();
6161
REQUIRE_OPTIONS(kCUDA, 0, kDouble, kStrided);
6262

63-
options = empty(5, getNonVariableType(Backend::SparseCUDA, kByte)).options();
63+
options = empty(5, getNonVariableDeprecatedTypeProperties(Backend::SparseCUDA, kByte)).options();
6464
REQUIRE_OPTIONS(kCUDA, 0, kByte, kSparse);
6565

6666
if (torch::cuda::device_count() > 1) {

tools/autograd/templates/Functions.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ struct TypeAndSize {
3535
/* implicit */
3636
TypeAndSize(const Tensor & t)
3737
: sizes(t.sizes().vec())
38-
, type(&t.dispatch_type()) {}
38+
, type(&t.type()) {}
3939

4040
Tensor zeros() { return at::zeros(sizes, *type); }
4141

4242
private:
4343
std::vector<int64_t> sizes;
44-
Type* type;
44+
at::DeprecatedTypeProperties* type;
4545
};
4646

4747
${autograd_function_declarations}

torch/csrc/autograd/engine.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ static variable_list call_post_hooks(Function& fn, variable_list outputs, const
334334
return outputs;
335335
}
336336

337-
static bool is_compatible_type(const at::Type& expected, const at::Type& actual) {
337+
static bool is_compatible_type(const at::DeprecatedTypeProperties& expected, const at::DeprecatedTypeProperties& actual) {
338338
// Types are compatible if they exactly match or if the gradient is a sparse
339339
// version of the expected type.
340340
return expected == actual || (actual.is_sparse() &&
@@ -372,7 +372,7 @@ static void validate_outputs(const edge_list& edges, variable_list& grads, const
372372
}
373373
grads[i] = at::sum_to(std::move(grads[i]), metadata.shape());
374374
}
375-
if (!is_compatible_type(metadata.type(), grads[i].dispatch_type())) {
375+
if (!is_compatible_type(metadata.type(), grads[i].type())) {
376376
std::stringstream ss;
377377
ss << "invalid gradient at index " << i << " - expected type ";
378378
ss << metadata.type() << " but got " << grads[i].type();

torch/csrc/autograd/function.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ struct TORCH_API Function : std::enable_shared_from_this<Function> {
130130
/// Adds the type and shape metadata for a new input. Returns the index of
131131
/// of the new input.
132132
uint32_t add_input_metadata(
133-
const at::Type& type
133+
const at::DeprecatedTypeProperties& type
134134
, at::IntArrayRef shape
135135
, at::Device device) noexcept {
136136
uint32_t input_nr = input_metadata_.size();

torch/csrc/autograd/input_metadata.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@ namespace torch { namespace autograd {
1212
struct InputMetadata {
1313
InputMetadata() = default;
1414

15-
InputMetadata(const at::Type& type, at::IntArrayRef shape, at::Device device)
15+
InputMetadata(const at::DeprecatedTypeProperties& type, at::IntArrayRef shape, at::Device device)
1616
: type_{&type} , shape_{shape}, device_{device} { }
1717

1818
InputMetadata(const at::Tensor& t)
19-
: InputMetadata(t.dispatch_type(), t.sizes(), t.device()) { }
19+
: InputMetadata(t.type(), t.sizes(), t.device()) { }
2020

2121
bool is_valid() const {
2222
return type_ != nullptr;
2323
}
2424

25-
const at::Type& type() const {
25+
const at::DeprecatedTypeProperties& type() const {
2626
AT_ASSERT(type_);
2727
return *type_;
2828
}
@@ -40,7 +40,7 @@ struct InputMetadata {
4040
}
4141

4242
private:
43-
const at::Type* type_ = nullptr;
43+
const at::DeprecatedTypeProperties* type_ = nullptr;
4444
at::DimVector shape_;
4545
at::Device device_ = at::kCPU;
4646
};

torch/csrc/autograd/python_function.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,15 @@ namespace torch { namespace autograd {
4646
VariableInfo::VariableInfo(const Variable& var)
4747
: type(&var.dispatch_type())
4848
, device(var.device())
49+
, scalar_type(var.scalar_type())
4950
, size(var.sizes().vec())
5051
, requires_grad(var.requires_grad()) {
5152
}
5253

5354
Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const {
5455
// NB: This will NOT work if we ever get mixed device gradients
5556
device_guard.reset_device(device);
56-
return at::zeros(size, type->options());
57+
return at::zeros(size, type->options(scalar_type));
5758
}
5859

5960
auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list {

torch/csrc/autograd/python_function.h

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ struct VariableInfo {
2525

2626
at::Type* type;
2727
at::Device device = at::kCPU;
28+
at::ScalarType scalar_type = at::kFloat;
2829
std::vector<int64_t> size;
2930
bool requires_grad;
3031
};

torch/csrc/autograd/python_legacy_variable.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ static PyObject *THPVariable_pynew(PyTypeObject* type, PyObject *args, PyObject
4646
if (!data || data == Py_None) {
4747
// For legacy serialization code, create an empty tensor. This is also used
4848
// by nn.Parameter() with no arguments.
49-
auto var = at::empty({0}, torch::tensors::get_default_tensor_type().options());
49+
auto scalar_type = torch::tensors::get_default_scalar_type();
50+
auto var = at::empty({0}, torch::tensors::get_default_tensor_type().options(scalar_type));
5051
tensor = static_cast<Variable&>(var).data();
5152
} else if (THPVariable_Check(data)) {
5253
tensor = ((THPVariable*)data)->cdata.data();

torch/csrc/autograd/python_variable_indexing.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,15 @@ static Variable sequenceToVariable(const at::Type& type, PyObject* seq) {
110110
return torch::utils::indexing_tensor_from_data(idx_type, kLong, c10::nullopt, seq);
111111
}
112112

113-
static Variable valueToTensor(const at::Type & type, PyObject* value) {
113+
static Variable valueToTensor(const at::Type & type, const ScalarType scalar_type, PyObject* value) {
114114
if (THPVariable_Check(value)) {
115115
return reinterpret_cast<THPVariable*>(value)->cdata;
116116
}
117117
if (THPUtils_checkLong(value) || PyBool_Check(value)) {
118-
return at::scalar_tensor(Scalar(THPUtils_unpackLong(value)), type.options());
118+
return at::scalar_tensor(Scalar(THPUtils_unpackLong(value)), type.options(scalar_type));
119119
}
120120
if (PyFloat_Check(value)) {
121-
return at::scalar_tensor(Scalar(THPUtils_unpackDouble(value)), type.options());
121+
return at::scalar_tensor(Scalar(THPUtils_unpackDouble(value)), type.options(scalar_type));
122122
}
123123
throw TypeError("can't assign a %s to a %s", Py_TYPE(value)->tp_name, type.toString());
124124
}
@@ -334,7 +334,7 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
334334

335335
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
336336
OptionalDeviceGuard device_guard(device_of(self_));
337-
auto value = valueToTensor(self_.dispatch_type(), py_value);
337+
auto value = valueToTensor(self_.dispatch_type(), self_.scalar_type(), py_value);
338338

339339
// handle simple types: integers, slices, ellipsis, bool
340340
if (index == Py_False) { // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)

torch/csrc/autograd/variable.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ const std::shared_ptr<Function>& Variable::grad_fn() const {
214214
fn->storage_offset = data().storage_offset();
215215
fn->set_next_edges(collect_next_edges(diff_view_meta->base_));
216216
fn->add_input_metadata(
217-
diff_view_meta->base_.dispatch_type()
217+
diff_view_meta->base_.type()
218218
, sizes() // Note: sizes(), not base_.sizes(), is intentional
219219
, diff_view_meta->base_.device());
220220
diff_view_meta->grad_fn_ = std::move(fn);

torch/csrc/cuda/comm.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntArrayRef devices) {
5959
tensors.push_back(tensor);
6060
for (auto device : devices.slice(1)) {
6161
_device_guard.set_index(device);
62-
tensors.push_back(at::empty(tensor.sizes(), type.options()));
62+
tensors.push_back(at::empty(tensor.sizes(), type.options(tensor.scalar_type())));
6363
}
6464
nccl::broadcast(tensors);
6565
} else {

torch/csrc/jit/passes/shape_analysis.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,9 @@ class ShapePropagator {
157157
return *iv;
158158
}
159159
if (CompleteTensorTypePtr type = type_->cast<CompleteTensorType>()) {
160-
auto backend =
161-
type->device().is_cpu() ? at::Backend::CPU : at::Backend::CUDA;
160+
auto attype = type->device().is_cpu() ?
161+
at::CPU(type->scalarType()) : at::CUDA(type->scalarType());
162162
at::DeviceGuard device_guard(type->device());
163-
auto& attype = at::getNonVariableType(backend, type->scalarType());
164163
auto t =
165164
at::empty_strided(type->sizes(), type->strides(), attype.options())
166165
.zero_();

0 commit comments

Comments
 (0)