Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

fbsync import 20220226 #470

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .circleci/unittest/linux/scripts/install.sh
Original file line number Diff line number Diff line change
@@ -38,14 +38,14 @@ else
PYVSHORT=cp${PYVSHORT}-cp${PYVSHORT}m
fi

NIGHTLY_DATE=20220202
NIGHTLY_DATE=20220224

if [ "${CU_VERSION:-}" == cpu ] ; then
pip3 install -q --pre torch==1.11.0dev${NIGHTLY_DATE} torchvision==0.12.0dev${NIGHTLY_DATE}+cpu -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip3 install -q --pre torch==1.12.0dev${NIGHTLY_DATE} torchvision==0.13.0dev${NIGHTLY_DATE}+cpu -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
conda install -y ninja
PYTORCH_VERSION="$(python -c "import torch; print(torch.__version__)")" USE_NINJA=1 python setup.py develop bdist_wheel -d $WHEELS_FOLDER
else
pip3 install -q --pre torch==1.11.0dev${NIGHTLY_DATE}+cu111 torchvision==0.12.0dev${NIGHTLY_DATE} -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html
pip3 install -q --pre torch==1.12.0dev${NIGHTLY_DATE}+cu111 torchvision==0.13.0dev${NIGHTLY_DATE} -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html
conda install -y ninja
PYTORCH_VERSION="$(python -c "import torch; print(torch.__version__)")" FORCE_CUDA=1 USE_NINJA=1 python setup.py develop bdist_wheel -d $WHEELS_FOLDER
fi
8 changes: 4 additions & 4 deletions nestedtensor/csrc/BinaryOps.cpp
Original file line number Diff line number Diff line change
@@ -58,7 +58,7 @@ Tensor NestedTensor_add_Tensor(
self.dtype() == c10::ScalarType::Half &&
other.dtype() == c10::ScalarType::Half) {
other = other.contiguous();
at::Tensor self_buffer = get_buffer(self);
const at::Tensor& self_buffer = get_buffer(self);
Tensor nt_sizes_ =
get_efficient_nested_size(self).sizes().to(torch::kInt32);
Tensor nt_sizes_1 = at::native::narrow(nt_sizes_, 1, 1, 1);
@@ -96,7 +96,7 @@ Tensor NestedTensor_add_Tensor(
#endif
if (self_opt_sizes[self_dim - 1] && other.dim() == 1 &&
(*(self_opt_sizes[self_dim - 1])) == other.size(0)) {
Tensor self_buffer = get_buffer(self);
const Tensor& self_buffer = get_buffer(self);
Tensor result_buffer =
at::add(self_buffer.reshape({-1, other.size(0)}), other)
.reshape({-1});
@@ -256,7 +256,7 @@ Tensor NestedTensor_mul_Tensor(const Tensor& self_, const Tensor& other_) {
self.dtype() == c10::ScalarType::Half &&
other.dtype() == c10::ScalarType::Half) {
other = other.contiguous();
at::Tensor self_buffer = get_buffer(self);
const at::Tensor& self_buffer = get_buffer(self);
Tensor nt_sizes_ =
get_efficient_nested_size(self).sizes().to(torch::kInt32);
Tensor nt_sizes_1 = at::native::narrow(nt_sizes_, 1, 1, 1);
@@ -375,7 +375,7 @@ Tensor NestedTensor_sub_Tensor(
self.dtype() == c10::ScalarType::Half &&
other.dtype() == c10::ScalarType::Half) {
other = other.contiguous();
at::Tensor self_buffer = get_buffer(self);
const at::Tensor& self_buffer = get_buffer(self);
Tensor nt_sizes_ =
get_efficient_nested_size(self).sizes().to(torch::kInt32);
Tensor nt_sizes_1 = at::native::narrow(nt_sizes_, 1, 1, 1);
2 changes: 1 addition & 1 deletion nestedtensor/csrc/EmbeddingBag.cpp
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> NestedTensor_embedding_bag(
bool sparse,
const c10::optional<Tensor>& per_sample_weights,
bool include_last_offset) {
at::Tensor indices = get_buffer(indices_).contiguous();
const at::Tensor& indices = get_buffer(indices_).contiguous();
int64_t emb_dim = weight.size(1);
SizeNode output_size = map(
[&emb_dim](std::vector<int64_t> inp) {
7 changes: 3 additions & 4 deletions nestedtensor/csrc/activation.cpp
Original file line number Diff line number Diff line change
@@ -8,15 +8,15 @@ namespace F = torch::nn::functional;

namespace at {

Tensor NestedTensor_gelu(const Tensor& self) {
Tensor NestedTensor_gelu(const Tensor& self, const c10::string_view approximate) {
if (is_nested_tensor_impl(self) && get_is_contiguous(self)) {
return wrap_buffer(
at::gelu(get_buffer(self)),
get_efficient_nested_size(self),
get_efficient_nested_stride(self));
}
return map_nested_tensor(
[](at::Tensor tensor) { return at::gelu(tensor); }, self);
[&approximate](at::Tensor tensor) { return at::gelu(tensor, approximate); }, self);
}

Tensor NestedTensor_elu(const Tensor& self, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale) {
@@ -33,7 +33,6 @@ Tensor NestedTensor_elu(const Tensor& self, const Scalar& alpha, const Scalar& s
// Registered below autograd
Tensor NestedTensor_relu(const Tensor& self) {
auto impl = get_nested_tensor_impl(self);
auto structure = get_nested_tensor_structure(self);
if (get_is_contiguous(self)) {
#ifdef TRACEPACKED
std::cout << "calling packed relu" << std::endl;
@@ -52,7 +51,7 @@ Tensor& NestedTensor_relu_(Tensor& self) {
#ifdef TRACEPACKED
std::cout << "calling packed relu_" << std::endl;
#endif
Tensor buffer = get_buffer(self);
Tensor& buffer = get_buffer(self);
at::relu_(buffer);
return self;
}
17 changes: 8 additions & 9 deletions nestedtensor/csrc/autograd_functions.cpp
Original file line number Diff line number Diff line change
@@ -135,7 +135,7 @@ Tensor NestedTensor_batch_norm(
c10::Half* running_var_ptr = running_var_cont.data_ptr<c10::Half>();

if (get_is_contiguous(input, c10::MemoryFormat::ChannelsLast)) {
Tensor input_buffer = get_buffer(input);
const Tensor& input_buffer = get_buffer(input);
int64_t num_channel = weight_cont.size(0);
at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
nested_tensor::cuda::batchnorm_inference_channels_last_kernelLauncher(
@@ -149,17 +149,16 @@ Tensor NestedTensor_batch_norm(
num_channel,
input_buffer.numel(),
defaultStream);
input_buffer = input_buffer.view(-1);
return wrap_buffer(std::move(input_buffer), get_efficient_nested_size(input), get_efficient_nested_stride(input));
return wrap_buffer(input_buffer.view(-1), get_efficient_nested_size(input), get_efficient_nested_stride(input));
}

Tensor output = input;
output = NestedTensor_contiguous(output);
Tensor input_buffer = get_buffer(output);
const Tensor& input_buffer = get_buffer(output);
// Tensor output_buffer = input_buffer.clone();

auto self_opt_sizes = get_opt_sizes(input);

Tensor nt_sizes_ =
get_efficient_nested_size(input).sizes(); // .to(torch::kInt32);
Tensor nt_sizes_1 = at::native::narrow(nt_sizes_, 1, 1, 1);
@@ -177,7 +176,7 @@ Tensor NestedTensor_batch_norm(
}
}
Tensor nt_sizes = numbers_t.to(at::Device(kCUDA), torch::kInt32, true, true);

at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
nested_tensor::cuda::batchnorm_inference_kernelLauncher(
input_buffer.data_ptr<c10::Half>(),
@@ -197,7 +196,7 @@ Tensor NestedTensor_batch_norm(
nt_sizes.data_ptr<int>(),
defaultStream
);
return wrap_buffer(std::move(input_buffer), get_efficient_nested_size(output), get_efficient_nested_stride(output));
return wrap_buffer(input_buffer, get_efficient_nested_size(output), get_efficient_nested_stride(output));
}
#endif
auto scalar_shape = make_scalar_shape(get_dim(input), n_input);
20 changes: 8 additions & 12 deletions nestedtensor/csrc/conv2d.cpp
Original file line number Diff line number Diff line change
@@ -41,9 +41,8 @@ Tensor NestedTensor_conv2d(
get_is_cuda(input)
) {
if (get_is_contiguous(input, c10::MemoryFormat::ChannelsLast)) {
Tensor input_buffer = get_buffer(input);
input_buffer = input_buffer.view({-1, weight.size(1)});
at::Tensor result_buffer = at::matmul(input_buffer,
Tensor input_buffer = get_buffer(input).view({-1, weight.size(1)});
at::Tensor result_buffer = at::matmul(input_buffer,
weight.reshape({weight.size(0), weight.size(1)}).transpose(0, 1));
int64_t weight_size_0 = weight.size(0);
auto new_sizes = map_efficient_size([&weight_size_0](int64_t* size_ptr, int64_t size) {
@@ -60,9 +59,8 @@ Tensor NestedTensor_conv2d(
}
if (get_is_contiguous(input)) {
input = transpose_nchw_nhwc(input);
Tensor input_buffer = get_buffer(input);
input_buffer = input_buffer.reshape({-1, weight.size(1)});
at::Tensor result_buffer = at::matmul(input_buffer,
Tensor input_buffer = get_buffer(input).reshape({-1, weight.size(1)});
at::Tensor result_buffer = at::matmul(input_buffer,
weight.reshape({weight.size(0), weight.size(1)}).transpose(0, 1));
int64_t weight_size_0 = weight.size(0);
auto new_sizes = map_efficient_size([&weight_size_0](int64_t* size_ptr, int64_t size) {
@@ -130,9 +128,8 @@ Tensor NestedTensor_cudnn_convolution_relu(
get_is_cuda(input)
) {
if (get_is_contiguous(input, c10::MemoryFormat::ChannelsLast)) {
Tensor input_buffer = get_buffer(input);
input_buffer = input_buffer.view({-1, weight.size(1)});
at::Tensor result_buffer = at::matmul(input_buffer,
Tensor input_buffer = get_buffer(input).view({-1, weight.size(1)});
at::Tensor result_buffer = at::matmul(input_buffer,
weight.reshape({weight.size(0), weight.size(1)}).transpose(0, 1));
int64_t weight_size_0 = weight.size(0);
auto new_sizes = map_efficient_size([&weight_size_0](int64_t* size_ptr, int64_t size) {
@@ -149,9 +146,8 @@ Tensor NestedTensor_cudnn_convolution_relu(
}
if (get_is_contiguous(input)) {
input = transpose_nchw_nhwc(input);
Tensor input_buffer = get_buffer(input);
input_buffer = input_buffer.reshape({-1, weight.size(1)});
at::Tensor result_buffer = at::matmul(input_buffer,
Tensor input_buffer = get_buffer(input).reshape({-1, weight.size(1)});
at::Tensor result_buffer = at::matmul(input_buffer,
weight.reshape({weight.size(0), weight.size(1)}).transpose(0, 1));
int64_t weight_size_0 = weight.size(0);
auto new_sizes = map_efficient_size([&weight_size_0](int64_t* size_ptr, int64_t size) {
3 changes: 1 addition & 2 deletions nestedtensor/csrc/creation.cpp
Original file line number Diff line number Diff line change
@@ -208,8 +208,7 @@ at::Tensor nested_tensor_impl(
}
}
Tensor result = wrap_tensor_node(std::move(ivalue_structure));
Tensor buffer = get_buffer(result);
buffer = buffer.to(device, dtype);
Tensor buffer = get_buffer(result).to(device, dtype);
if (pin_memory) {
buffer = buffer.pin_memory();
}
3 changes: 1 addition & 2 deletions nestedtensor/csrc/cuda/layernorm.cpp
Original file line number Diff line number Diff line change
@@ -24,10 +24,9 @@ Tensor NestedTensor_layer_norm(
auto input_opt_sizes = get_opt_sizes(input);
if (get_dim(input) == 3 && get_is_contiguous(input) &&
(*input_opt_sizes[2]) % 32 == 0) {
at::Tensor input_buffer = get_buffer(input);
const at::Tensor& input_buffer = get_buffer(input);
int size2 = (int)(*input_opt_sizes[2]);
int valid_word_num = (int)(input_buffer.numel() / size2);
at::Tensor zero_bias = torch::zeros({valid_word_num}, input.options());
at::Tensor output_buffer = torch::zeros_like(input_buffer);
at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
if (input_buffer.dtype() == torch::kFloat16) {
19 changes: 9 additions & 10 deletions nestedtensor/csrc/cuda/mha.cpp
Original file line number Diff line number Diff line change
@@ -58,24 +58,23 @@ at::Tensor bt_min_mha(

at::Tensor packed_padded = to_padded_tensor(packed, 0).contiguous();
std::vector<at::Tensor> packed_padded_chunks = packed_padded.chunk(3, -1);
at::Tensor query_buf = packed_padded_chunks[0];
at::Tensor key_buf = packed_padded_chunks[1];
at::Tensor val_buf = packed_padded_chunks[2];
at::Tensor query_buf = std::move(packed_padded_chunks[0]);
at::Tensor key_buf = std::move(packed_padded_chunks[1]);
at::Tensor val_buf = std::move(packed_padded_chunks[2]);
int64_t batch_size = query_buf.size(0);
int64_t seq_len = query_buf.size(1);

query_buf = query_buf.reshape({batch_size, seq_len, head_num, size_per_head}).transpose(1, 2);
key_buf = key_buf.reshape({batch_size, seq_len, head_num, size_per_head}).transpose(1, 2);
val_buf = val_buf.reshape({batch_size, seq_len, head_num, size_per_head}).transpose(1, 2);
std::array<int64_t, 4> qkv_dims = {batch_size, seq_len, head_num, size_per_head};
query_buf = query_buf.reshape(qkv_dims).transpose(1, 2);
key_buf = key_buf.reshape(qkv_dims).transpose(1, 2);
val_buf = val_buf.reshape(qkv_dims).transpose(1, 2);

key_buf = key_buf.transpose(2, 3);
at::Tensor attn_output_weights = at::matmul(query_buf, key_buf).contiguous();

auto mask_options =
torch::TensorOptions().dtype(query.dtype()).device(torch::kCUDA);
at::Tensor input_mask = to_mask(query, 2);
input_mask = input_mask.to(options);
at::Tensor attr_mask = input_mask.view({-1, 1, 1, seq_len}).to(mask_options);
at::Tensor attr_mask = to_mask(query, 2).view({-1, 1, 1, seq_len}).to(mask_options, /* non-blocking = */ true);
attr_mask = attr_mask * attr_mask.transpose(2, 3);

if (query.dtype() == torch::kFloat16) {
@@ -103,7 +102,7 @@ at::Tensor bt_min_mha(
auto attn_output = at::matmul(attn_output_weights, val_buf);
attn_output = attn_output.transpose(1, 2).reshape({batch_size, seq_len, embedding_dim}).contiguous();
at::Tensor attr_out = from_padded_tensor(attn_output, get_efficient_nested_size(query));
return at::matmul(attr_out, out_proj_weight.t());
return at::matmul(attr_out, out_proj_weight.t()) + out_proj_bias;
}

TORCH_LIBRARY_FRAGMENT(nestedtensor, m) {
14 changes: 4 additions & 10 deletions nestedtensor/csrc/functions.cpp
Original file line number Diff line number Diff line change
@@ -28,31 +28,25 @@ Tensor NestedTensor_embedding(
}
if (is_nested_tensor_impl(indices) &&
!is_nested_tensor_impl(weight) &&
get_dim(indices) == 1 &&
get_dim(indices) < 3 &&
get_dim(weight) == 2 &&
get_is_contiguous(indices) &&
get_is_contiguous(weight)) {
Tensor indices_buffer = get_buffer(indices);
const Tensor& indices_buffer = get_buffer(indices);
Tensor result_buffer = at::embedding(
weight, indices_buffer, padding_idx, scale_grad_by_freq, sparse);
EfficientSizeNode new_nested_size = get_efficient_nested_size(indices);
EfficientSizeNode new_nested_stride = get_efficient_nested_stride(indices);
const EfficientSizeNode& new_nested_size = get_efficient_nested_size(indices);
auto new_nested_size_sizes = new_nested_size.sizes();
auto new_nested_stride_sizes = new_nested_stride.sizes();
auto tmp = torch::empty(
{new_nested_size_sizes.size(0)}, new_nested_size_sizes.options());
tmp.fill_(weight.size(1));
tmp = tmp.reshape({new_nested_size_sizes.size(0), 1});
new_nested_size_sizes = at::cat({new_nested_size_sizes, tmp}, 1);
new_nested_stride_sizes = at::cat({tmp, new_nested_stride_sizes}, 1);
return wrap_buffer(
std::move(result_buffer),
EfficientSizeNode(
new_nested_size.structure(),
new_nested_size_sizes),
EfficientSizeNode(
new_nested_stride.structure(),
new_nested_stride_sizes));
new_nested_size_sizes));
}
return map_nested_tensor(
[&](at::Tensor i) {
Loading