Skip to content

Commit c705d9e

Browse files
Roy Lifacebook-github-bot
Roy Li
authored andcommitted
Introduce DeprecatedTypeProperties class (pytorch#17991)
Summary: Pull Request resolved: pytorch#17991 changes: -Breaks bc: Tensor::type() now returns DeprecatedTypeProperties& rather than Type&. -Added DeprecatedTypeProperties, it serves as a temporary replacement for Type as the return value of Tensor::type(). This contributes to making Type just for dispatch purposes so that we can make it dtype agnostic. -Tensor::dispatch_type() now returns Type& like Tensor::type() used to do. -Changed callsites of Tensor::type() appropriately. Reviewed By: ezyang Differential Revision: D14443117 fbshipit-source-id: 239ccb7a09626279a71d1a37f8f82e7f57bf7d9e
1 parent 095f88e commit c705d9e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+699
-569
lines changed

aten/src/ATen/DLConvertor.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ static DLDataType getDLDataType(const Tensor& t) {
5656
return dtype;
5757
}
5858

59-
static DLContext getDLContext(const Type& type, const int64_t& device_id) {
59+
static DLContext getDLContext(const Tensor& tensor, const int64_t& device_id) {
6060
DLContext ctx;
6161
ctx.device_id = device_id;
62-
if (type.is_cuda()) {
62+
if (tensor.is_cuda()) {
6363
ctx.device_type = DLDeviceType::kDLGPU;
6464
} else {
6565
ctx.device_type = DLDeviceType::kDLCPU;
@@ -161,7 +161,7 @@ DLManagedTensor* toDLPack(const Tensor& src) {
161161
if (src.is_cuda()) {
162162
device_id = src.get_device();
163163
}
164-
atDLMTensor->tensor.dl_tensor.ctx = getDLContext(src.type(), device_id);
164+
atDLMTensor->tensor.dl_tensor.ctx = getDLContext(src, device_id);
165165
atDLMTensor->tensor.dl_tensor.ndim = src.dim();
166166
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
167167
atDLMTensor->tensor.dl_tensor.shape =

aten/src/ATen/Dispatch.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ inline at::ScalarType scalar_type(at::ScalarType s) {
4141
return s;
4242
}
4343

44-
C10_DEPRECATED_MESSAGE("passing at::Type to an AT_DISPATCH macro is deprecated, " \
44+
C10_DEPRECATED_MESSAGE("passing at::DeprecatedTypeProperties to an AT_DISPATCH macro is deprecated, " \
4545
"pass an at::ScalarType instead")
46-
inline at::ScalarType scalar_type(const at::Type &t) {
46+
inline at::ScalarType scalar_type(const at::DeprecatedTypeProperties &t) {
4747
return t.scalarType();
4848
}
4949

aten/src/ATen/SparseTensorImpl.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ void SparseTensorImpl::set_indices_and_values_unsafe(const Tensor& indices, cons
8888
AT_CHECK(!indices.is_sparse(), "expected indices to be a dense tensor, but got indices of layout ", indices.layout());
8989
AT_CHECK(!values.is_sparse(), "expected values to be a dense tensor, but got values of layout ", values.layout());
9090

91-
AT_CHECK(values.type().toSparse() == legacyTensorType(*this), "values type must match sparse tensor type");
91+
AT_CHECK(values.device().type() == device().type(), "device type of values (", values.device().type(), ") must match device type of device().type()", device().type(), ")");
92+
AT_CHECK(values.scalar_type() == typeMetaToScalarType(dtype()), "dtype of values (", values.scalar_type(), ") must match dtype of sparse tensor (", typeMetaToScalarType(dtype()), ")");
9293
AT_CHECK(indices.scalar_type() == kLong, "indices must be an int64 tensor");
9394
AT_CHECK(indices.type().backend() == values.type().backend(), "backend of indices (", indices.type().backend(), ") must match backend of values (", values.type().backend(), ")");
9495
AT_CHECK(!indices.is_cuda() || indices.get_device() == values.get_device(), "device of indices (", indices.get_device(), ") must match device of values (", values.get_device(), ")");

aten/src/ATen/SparseTensorUtils.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ inline void alias_into_sparse(const SparseTensor& self, const LongTensor& indice
3131
// Take indices and values and makes a (data) copy of them to put into the sparse
3232
// indices/values. This used to be called THSTensor_(_set)
3333
inline void copy_into_sparse(const SparseTensor& self, const LongTensor& indices, const Tensor& values, bool non_blocking) {
34-
alias_into_sparse(self, self._indices().type().copy(indices, non_blocking), self._values().type().copy(values, non_blocking));
34+
alias_into_sparse(
35+
self,
36+
self._indices().dispatch_type().copy(indices, non_blocking),
37+
self._values().dispatch_type().copy(values, non_blocking));
3538
}
3639

3740
// TODO: put this into the public API
@@ -82,7 +85,7 @@ inline LongTensor flatten_indices(const Tensor& indices, IntArrayRef full_size,
8285
indices_mult_cpu_vec[i] = mult;
8386
mult *= full_size[i];
8487
}
85-
auto indices_mult_cpu = indices.type().cpu()
88+
auto indices_mult_cpu = indices.dispatch_type().cpu()
8689
.tensorFromBlob(indices_mult_cpu_vec.data(), /*size=*/{sparse_dim, 1});
8790
// NB: must be blocking because this blob may be freed after this closure,
8891
// and non_blocking copy will see garbage.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#pragma once
2+
3+
#include <c10/core/Backend.h>
4+
#include <c10/core/ScalarType.h>
5+
#include <c10/core/Layout.h>
6+
7+
8+
9+
namespace at {
10+
11+
// This class specifies a Backend and a ScalarType. Currently, it primarily
12+
// serves as a replacement return value for Tensor::type(). Previously,
13+
// Tensor::type() returned Type&, but we are changing Type to not be
14+
// dtype-specific.
15+
class DeprecatedTypeProperties {
16+
public:
17+
DeprecatedTypeProperties(Backend backend, ScalarType scalar_type)
18+
: backend_(backend), scalar_type_(scalar_type) {}
19+
20+
Backend backend() const {
21+
return backend_;
22+
}
23+
24+
bool is_sparse() const {
25+
return layout_from_backend(backend()) == kSparse;
26+
}
27+
28+
DeviceType device_type() const {
29+
return backendToDeviceType(backend_);
30+
}
31+
32+
bool is_cuda() const {
33+
return backendToDeviceType(backend_) == kCUDA;
34+
}
35+
36+
ScalarType scalarType() const {
37+
return scalar_type_;
38+
}
39+
40+
caffe2::TypeMeta typeMeta() const {
41+
return scalarTypeToTypeMeta(scalar_type_);
42+
}
43+
44+
bool is_defined() const {
45+
return backend_ != Backend::Undefined && scalar_type_ != ScalarType::Undefined;
46+
}
47+
48+
bool operator==(const DeprecatedTypeProperties& other) const {
49+
return backend_ == other.backend() && scalar_type_ == other.scalarType();
50+
}
51+
52+
bool operator!=(const DeprecatedTypeProperties& other) const {
53+
return !(*this == other);
54+
}
55+
56+
std::string toString() const {
57+
std::stringstream ss;
58+
ss << at::toString(backend()) << at::toString(scalarType()) << "Type";
59+
return ss.str();
60+
}
61+
62+
private:
63+
Backend backend_;
64+
ScalarType scalar_type_;
65+
};
66+
67+
} // namespace at
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#include <ATen/core/DeprecatedTypePropertiesRegistry.h>
2+
3+
namespace at {
4+
5+
// TODO: This could be bad juju if someone calls globalContext() in the
6+
// destructor of an object with static lifetime.
7+
DeprecatedTypePropertiesRegistry & globalDeprecatedTypePropertiesRegistry() {
8+
static DeprecatedTypePropertiesRegistry singleton;
9+
return singleton;
10+
}
11+
12+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#pragma once
2+
3+
// In order to preserve bc, we make DeprecatedTypeProperties instances unique
4+
// just like they are for Type.
5+
6+
#include <c10/core/Backend.h>
7+
#include <c10/core/ScalarType.h>
8+
#include <ATen/core/DeprecatedTypeProperties.h>
9+
10+
namespace at {
11+
12+
struct CAFFE2_API DeprecatedTypePropertiesDeleter {
13+
void operator()(DeprecatedTypeProperties * ptr) {
14+
delete ptr;
15+
}
16+
};
17+
18+
class CAFFE2_API DeprecatedTypePropertiesRegistry {
19+
public:
20+
using DeprecatedTypePropertiesUniquePtr =
21+
std::unique_ptr<DeprecatedTypeProperties, DeprecatedTypePropertiesDeleter>;
22+
23+
DeprecatedTypePropertiesRegistry() {
24+
for (int b = 0; b < static_cast<int>(Backend::NumOptions); ++b) {
25+
for (int s = 0; s < static_cast<int>(ScalarType::NumOptions); ++s) {
26+
registry[b][s] = DeprecatedTypePropertiesUniquePtr{
27+
new DeprecatedTypeProperties(static_cast<Backend>(b), static_cast<ScalarType>(s)),
28+
DeprecatedTypePropertiesDeleter()
29+
};
30+
}
31+
}
32+
}
33+
34+
DeprecatedTypeProperties& getDeprecatedTypeProperties(Backend p, ScalarType s) {
35+
return *registry[static_cast<int>(p)][static_cast<int>(s)];
36+
}
37+
38+
private:
39+
DeprecatedTypePropertiesUniquePtr registry
40+
[static_cast<int>(Backend::NumOptions)]
41+
[static_cast<int>(ScalarType::NumOptions)];
42+
};
43+
44+
CAFFE2_API DeprecatedTypePropertiesRegistry& globalDeprecatedTypePropertiesRegistry();
45+
46+
} // namespace at

aten/src/ATen/core/Formatting.cpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ std::ostream& operator<<(std::ostream & out, const Type& t) {
3737
return out << t.toString();
3838
}
3939

40+
std::ostream& operator<<(std::ostream & out, const DeprecatedTypeProperties& t) {
41+
return out << t.toString();
42+
}
43+
4044
static std::tuple<double, int64_t> __printFormat(std::ostream& stream, const Tensor& self) {
4145
auto size = self.numel();
4246
if(size == 0) {
@@ -238,8 +242,7 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi
238242
stream << "size:\n" << tensor_.sizes() << "\n";
239243
stream << "]";
240244
} else {
241-
Type& cpudouble = tensor_.type().toBackend(Backend::CPU).toScalarType(kDouble);
242-
Tensor tensor = tensor_.toType(cpudouble).contiguous();
245+
Tensor tensor = tensor_.to(kCPU, kDouble).contiguous();
243246
if(tensor.ndimension() == 0) {
244247
stream << defaultfloat << tensor.data<double>()[0] << std::endl;
245248
stream << "[ " << tensor_.toString() << "{} ]";

aten/src/ATen/core/Formatting.h

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ CAFFE2_API std::ostream& operator<<(std::ostream& out, Backend b);
1313
namespace at {
1414

1515
CAFFE2_API std::ostream& operator<<(std::ostream& out, const Type& t);
16+
CAFFE2_API std::ostream& operator<<(std::ostream& out, const DeprecatedTypeProperties& t);
1617
CAFFE2_API std::ostream& print(
1718
std::ostream& stream,
1819
const Tensor& tensor,

aten/src/ATen/core/LegacyTypeDispatch.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ namespace at {
77
/// Previously, in VariableType_*.cpp (generated by gen_variable_type.py), when
88
/// a function is using the 'use_derived' strategy, we call its implementation
99
/// on the base non-Variable type (`baseType`), passing unwrapped tensors to the
10-
/// call so that any `.type()` calls in the implementation can treat the passed
10+
/// call so that any `.dispatch_type()` calls in the implementation can treat the passed
1111
/// tensors as non-Variables and won't dispatch back to functions in VariableType.
1212
///
1313
/// However, after the Variable/Tensor merge, there is no concept of unwrapping
1414
/// a tensor anymore, and directly passing variables to the base type calls will
15-
/// cause the `.type()` dispatch in the implementation to treat the tensor as a
16-
/// variable, and any function dispatch based on `.type()` will dispatch back to
15+
/// cause the `.dispatch_type()` dispatch in the implementation to treat the tensor as a
16+
/// variable, and any function dispatch based on `.dispatch_type()` will dispatch back to
1717
/// VariableType, which is not what we want.
1818
///
1919
/// The solution to the above problem is to add `at::NonVariableTypeMode`, which

aten/src/ATen/core/Tensor.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@ void Tensor::enforce_invariants() {
3535

3636
void Tensor::print() const {
3737
if (defined()) {
38-
std::cerr << "[" << type().toString() << " " << sizes() << "]" << std::endl;
38+
std::cerr << "[" << dispatch_type().toString() << " " << sizes() << "]" << std::endl;
3939
} else {
4040
std::cerr << "[UndefinedTensor]" << std::endl;
4141
}
4242
}
4343

4444
const char * Tensor::toString() const {
45-
return type().toString();
45+
return dispatch_type().toString();
4646
}
4747

4848
} // namespace at

aten/src/ATen/core/Tensor.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <c10/util/Optional.h>
1414
#include <c10/core/Tensor.h>
1515
#include <ATen/core/LegacyTypeDispatch.h>
16+
#include <ATen/core/DeprecatedTypePropertiesRegistry.h>
1617

1718
namespace c10{
1819
struct TensorOptions;
@@ -196,7 +197,11 @@ class CAFFE2_API Tensor {
196197
return impl_->itemsize();
197198
}
198199

199-
Type & type() const {
200+
DeprecatedTypeProperties & type() const {
201+
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
202+
tensorTypeIdToBackend(type_id()), scalar_type());
203+
}
204+
Type & dispatch_type() const {
200205
return legacyTensorType(*impl_);
201206
}
202207
TensorTypeId type_id() const {

0 commit comments

Comments
 (0)