diff --git a/.circleci/unittest/linux/scripts/install.sh b/.circleci/unittest/linux/scripts/install.sh index 5f40403c..78f59f32 100755 --- a/.circleci/unittest/linux/scripts/install.sh +++ b/.circleci/unittest/linux/scripts/install.sh @@ -38,14 +38,14 @@ else PYVSHORT=cp${PYVSHORT}-cp${PYVSHORT}m fi -NIGHTLY_DATE=20220202 +NIGHTLY_DATE=20220224 if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install -q --pre torch==1.11.0dev${NIGHTLY_DATE} torchvision==0.12.0dev${NIGHTLY_DATE}+cpu -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + pip3 install -q --pre torch==1.12.0dev${NIGHTLY_DATE} torchvision==0.13.0dev${NIGHTLY_DATE}+cpu -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html conda install -y ninja PYTORCH_VERSION="$(python -c "import torch; print(torch.__version__)")" USE_NINJA=1 python setup.py develop bdist_wheel -d $WHEELS_FOLDER else - pip3 install -q --pre torch==1.11.0dev${NIGHTLY_DATE}+cu111 torchvision==0.12.0dev${NIGHTLY_DATE} -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html + pip3 install -q --pre torch==1.12.0dev${NIGHTLY_DATE}+cu111 torchvision==0.13.0dev${NIGHTLY_DATE} -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html conda install -y ninja PYTORCH_VERSION="$(python -c "import torch; print(torch.__version__)")" FORCE_CUDA=1 USE_NINJA=1 python setup.py develop bdist_wheel -d $WHEELS_FOLDER fi diff --git a/nestedtensor/csrc/BinaryOps.cpp b/nestedtensor/csrc/BinaryOps.cpp index 696c1030..fe8d7d8a 100644 --- a/nestedtensor/csrc/BinaryOps.cpp +++ b/nestedtensor/csrc/BinaryOps.cpp @@ -58,7 +58,7 @@ Tensor NestedTensor_add_Tensor( self.dtype() == c10::ScalarType::Half && other.dtype() == c10::ScalarType::Half) { other = other.contiguous(); - at::Tensor self_buffer = get_buffer(self); + const at::Tensor& self_buffer = get_buffer(self); Tensor nt_sizes_ = get_efficient_nested_size(self).sizes().to(torch::kInt32); Tensor nt_sizes_1 = at::native::narrow(nt_sizes_, 1, 1, 1); @@ -96,7 +96,7 @@ Tensor NestedTensor_add_Tensor( #endif if (self_opt_sizes[self_dim - 1] && other.dim() == 1 && (*(self_opt_sizes[self_dim - 1])) == other.size(0)) { - Tensor self_buffer = get_buffer(self); + const Tensor& self_buffer = get_buffer(self); Tensor result_buffer = at::add(self_buffer.reshape({-1, other.size(0)}), other) .reshape({-1}); @@ -256,7 +256,7 @@ Tensor NestedTensor_mul_Tensor(const Tensor& self_, const Tensor& other_) { self.dtype() == c10::ScalarType::Half && other.dtype() == c10::ScalarType::Half) { other = other.contiguous(); - at::Tensor self_buffer = get_buffer(self); + const at::Tensor& self_buffer = get_buffer(self); Tensor nt_sizes_ = get_efficient_nested_size(self).sizes().to(torch::kInt32); Tensor nt_sizes_1 = at::native::narrow(nt_sizes_, 1, 1, 1); @@ -375,7 +375,7 @@ Tensor NestedTensor_sub_Tensor( self.dtype() == c10::ScalarType::Half && other.dtype() == c10::ScalarType::Half) { other = other.contiguous(); - at::Tensor self_buffer = get_buffer(self); + const at::Tensor& self_buffer = get_buffer(self); Tensor nt_sizes_ = get_efficient_nested_size(self).sizes().to(torch::kInt32); Tensor nt_sizes_1 = at::native::narrow(nt_sizes_, 1, 1, 1); diff --git a/nestedtensor/csrc/EmbeddingBag.cpp b/nestedtensor/csrc/EmbeddingBag.cpp index 91be7070..2c6c7269 100644 --- a/nestedtensor/csrc/EmbeddingBag.cpp +++ b/nestedtensor/csrc/EmbeddingBag.cpp @@ -17,7 +17,7 @@ std::tuple NestedTensor_embedding_bag( bool sparse, const c10::optional& per_sample_weights, bool include_last_offset) { - at::Tensor indices = get_buffer(indices_).contiguous(); + const at::Tensor& indices = get_buffer(indices_).contiguous(); int64_t emb_dim = weight.size(1); SizeNode output_size = map( [&emb_dim](std::vector inp) { diff --git a/nestedtensor/csrc/activation.cpp b/nestedtensor/csrc/activation.cpp index cb7f5688..f810cfe9 100644 --- a/nestedtensor/csrc/activation.cpp +++ b/nestedtensor/csrc/activation.cpp @@ -8,7 +8,7 @@ namespace F = torch::nn::functional; namespace at { -Tensor NestedTensor_gelu(const Tensor& self) { +Tensor NestedTensor_gelu(const Tensor& self, const c10::string_view approximate) { if (is_nested_tensor_impl(self) && get_is_contiguous(self)) { return wrap_buffer( at::gelu(get_buffer(self)), @@ -16,7 +16,7 @@ Tensor NestedTensor_gelu(const Tensor& self) { get_efficient_nested_stride(self)); } return map_nested_tensor( - [](at::Tensor tensor) { return at::gelu(tensor); }, self); + [&approximate](at::Tensor tensor) { return at::gelu(tensor, approximate); }, self); } Tensor NestedTensor_elu(const Tensor& self, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale) { @@ -33,7 +33,6 @@ Tensor NestedTensor_elu(const Tensor& self, const Scalar& alpha, const Scalar& s // Registered below autograd Tensor NestedTensor_relu(const Tensor& self) { auto impl = get_nested_tensor_impl(self); - auto structure = get_nested_tensor_structure(self); if (get_is_contiguous(self)) { #ifdef TRACEPACKED std::cout << "calling packed relu" << std::endl; @@ -52,7 +51,7 @@ Tensor& NestedTensor_relu_(Tensor& self) { #ifdef TRACEPACKED std::cout << "calling packed relu_" << std::endl; #endif - Tensor buffer = get_buffer(self); + Tensor& buffer = get_buffer(self); at::relu_(buffer); return self; } diff --git a/nestedtensor/csrc/autograd_functions.cpp b/nestedtensor/csrc/autograd_functions.cpp index 6fa0cd65..cdcb1fd5 100644 --- a/nestedtensor/csrc/autograd_functions.cpp +++ b/nestedtensor/csrc/autograd_functions.cpp @@ -135,7 +135,7 @@ Tensor NestedTensor_batch_norm( c10::Half* running_var_ptr = running_var_cont.data_ptr(); if (get_is_contiguous(input, c10::MemoryFormat::ChannelsLast)) { - Tensor input_buffer = get_buffer(input); + const Tensor& input_buffer = get_buffer(input); int64_t num_channel = weight_cont.size(0); at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(); nested_tensor::cuda::batchnorm_inference_channels_last_kernelLauncher( @@ -149,17 +149,16 @@ Tensor NestedTensor_batch_norm( num_channel, input_buffer.numel(), defaultStream); - input_buffer = input_buffer.view(-1); - return wrap_buffer(std::move(input_buffer), get_efficient_nested_size(input), get_efficient_nested_stride(input)); + return wrap_buffer(input_buffer.view(-1), get_efficient_nested_size(input), get_efficient_nested_stride(input)); } - + Tensor output = input; output = NestedTensor_contiguous(output); - Tensor input_buffer = get_buffer(output); + const Tensor& input_buffer = get_buffer(output); // Tensor output_buffer = input_buffer.clone(); - + auto self_opt_sizes = get_opt_sizes(input); - + Tensor nt_sizes_ = get_efficient_nested_size(input).sizes(); // .to(torch::kInt32); Tensor nt_sizes_1 = at::native::narrow(nt_sizes_, 1, 1, 1); @@ -177,7 +176,7 @@ Tensor NestedTensor_batch_norm( } } Tensor nt_sizes = numbers_t.to(at::Device(kCUDA), torch::kInt32, true, true); - + at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(); nested_tensor::cuda::batchnorm_inference_kernelLauncher( input_buffer.data_ptr(), @@ -197,7 +196,7 @@ Tensor NestedTensor_batch_norm( nt_sizes.data_ptr(), defaultStream ); - return wrap_buffer(std::move(input_buffer), get_efficient_nested_size(output), get_efficient_nested_stride(output)); + return wrap_buffer(input_buffer, get_efficient_nested_size(output), get_efficient_nested_stride(output)); } #endif auto scalar_shape = make_scalar_shape(get_dim(input), n_input); diff --git a/nestedtensor/csrc/conv2d.cpp b/nestedtensor/csrc/conv2d.cpp index 218070b1..4da89a5f 100644 --- a/nestedtensor/csrc/conv2d.cpp +++ b/nestedtensor/csrc/conv2d.cpp @@ -41,9 +41,8 @@ Tensor NestedTensor_conv2d( get_is_cuda(input) ) { if (get_is_contiguous(input, c10::MemoryFormat::ChannelsLast)) { - Tensor input_buffer = get_buffer(input); - input_buffer = input_buffer.view({-1, weight.size(1)}); - at::Tensor result_buffer = at::matmul(input_buffer, + Tensor input_buffer = get_buffer(input).view({-1, weight.size(1)}); + at::Tensor result_buffer = at::matmul(input_buffer, weight.reshape({weight.size(0), weight.size(1)}).transpose(0, 1)); int64_t weight_size_0 = weight.size(0); auto new_sizes = map_efficient_size([&weight_size_0](int64_t* size_ptr, int64_t size) { @@ -60,9 +59,8 @@ Tensor NestedTensor_conv2d( } if (get_is_contiguous(input)) { input = transpose_nchw_nhwc(input); - Tensor input_buffer = get_buffer(input); - input_buffer = input_buffer.reshape({-1, weight.size(1)}); - at::Tensor result_buffer = at::matmul(input_buffer, + Tensor input_buffer = get_buffer(input).reshape({-1, weight.size(1)}); + at::Tensor result_buffer = at::matmul(input_buffer, weight.reshape({weight.size(0), weight.size(1)}).transpose(0, 1)); int64_t weight_size_0 = weight.size(0); auto new_sizes = map_efficient_size([&weight_size_0](int64_t* size_ptr, int64_t size) { @@ -130,9 +128,8 @@ Tensor NestedTensor_cudnn_convolution_relu( get_is_cuda(input) ) { if (get_is_contiguous(input, c10::MemoryFormat::ChannelsLast)) { - Tensor input_buffer = get_buffer(input); - input_buffer = input_buffer.view({-1, weight.size(1)}); - at::Tensor result_buffer = at::matmul(input_buffer, + Tensor input_buffer = get_buffer(input).view({-1, weight.size(1)}); + at::Tensor result_buffer = at::matmul(input_buffer, weight.reshape({weight.size(0), weight.size(1)}).transpose(0, 1)); int64_t weight_size_0 = weight.size(0); auto new_sizes = map_efficient_size([&weight_size_0](int64_t* size_ptr, int64_t size) { @@ -149,9 +146,8 @@ Tensor NestedTensor_cudnn_convolution_relu( } if (get_is_contiguous(input)) { input = transpose_nchw_nhwc(input); - Tensor input_buffer = get_buffer(input); - input_buffer = input_buffer.reshape({-1, weight.size(1)}); - at::Tensor result_buffer = at::matmul(input_buffer, + Tensor input_buffer = get_buffer(input).reshape({-1, weight.size(1)}); + at::Tensor result_buffer = at::matmul(input_buffer, weight.reshape({weight.size(0), weight.size(1)}).transpose(0, 1)); int64_t weight_size_0 = weight.size(0); auto new_sizes = map_efficient_size([&weight_size_0](int64_t* size_ptr, int64_t size) { diff --git a/nestedtensor/csrc/creation.cpp b/nestedtensor/csrc/creation.cpp index 39ee2068..681ca5e0 100644 --- a/nestedtensor/csrc/creation.cpp +++ b/nestedtensor/csrc/creation.cpp @@ -208,8 +208,7 @@ at::Tensor nested_tensor_impl( } } Tensor result = wrap_tensor_node(std::move(ivalue_structure)); - Tensor buffer = get_buffer(result); - buffer = buffer.to(device, dtype); + Tensor buffer = get_buffer(result).to(device, dtype); if (pin_memory) { buffer = buffer.pin_memory(); } diff --git a/nestedtensor/csrc/cuda/layernorm.cpp b/nestedtensor/csrc/cuda/layernorm.cpp index fd7c68fc..cc7866a2 100644 --- a/nestedtensor/csrc/cuda/layernorm.cpp +++ b/nestedtensor/csrc/cuda/layernorm.cpp @@ -24,10 +24,9 @@ Tensor NestedTensor_layer_norm( auto input_opt_sizes = get_opt_sizes(input); if (get_dim(input) == 3 && get_is_contiguous(input) && (*input_opt_sizes[2]) % 32 == 0) { - at::Tensor input_buffer = get_buffer(input); + const at::Tensor& input_buffer = get_buffer(input); int size2 = (int)(*input_opt_sizes[2]); int valid_word_num = (int)(input_buffer.numel() / size2); - at::Tensor zero_bias = torch::zeros({valid_word_num}, input.options()); at::Tensor output_buffer = torch::zeros_like(input_buffer); at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(); if (input_buffer.dtype() == torch::kFloat16) { diff --git a/nestedtensor/csrc/cuda/mha.cpp b/nestedtensor/csrc/cuda/mha.cpp index e9bc933b..14eed479 100644 --- a/nestedtensor/csrc/cuda/mha.cpp +++ b/nestedtensor/csrc/cuda/mha.cpp @@ -58,24 +58,23 @@ at::Tensor bt_min_mha( at::Tensor packed_padded = to_padded_tensor(packed, 0).contiguous(); std::vector packed_padded_chunks = packed_padded.chunk(3, -1); - at::Tensor query_buf = packed_padded_chunks[0]; - at::Tensor key_buf = packed_padded_chunks[1]; - at::Tensor val_buf = packed_padded_chunks[2]; + at::Tensor query_buf = std::move(packed_padded_chunks[0]); + at::Tensor key_buf = std::move(packed_padded_chunks[1]); + at::Tensor val_buf = std::move(packed_padded_chunks[2]); int64_t batch_size = query_buf.size(0); int64_t seq_len = query_buf.size(1); - query_buf = query_buf.reshape({batch_size, seq_len, head_num, size_per_head}).transpose(1, 2); - key_buf = key_buf.reshape({batch_size, seq_len, head_num, size_per_head}).transpose(1, 2); - val_buf = val_buf.reshape({batch_size, seq_len, head_num, size_per_head}).transpose(1, 2); + std::array qkv_dims = {batch_size, seq_len, head_num, size_per_head}; + query_buf = query_buf.reshape(qkv_dims).transpose(1, 2); + key_buf = key_buf.reshape(qkv_dims).transpose(1, 2); + val_buf = val_buf.reshape(qkv_dims).transpose(1, 2); key_buf = key_buf.transpose(2, 3); at::Tensor attn_output_weights = at::matmul(query_buf, key_buf).contiguous(); auto mask_options = torch::TensorOptions().dtype(query.dtype()).device(torch::kCUDA); - at::Tensor input_mask = to_mask(query, 2); - input_mask = input_mask.to(options); - at::Tensor attr_mask = input_mask.view({-1, 1, 1, seq_len}).to(mask_options); + at::Tensor attr_mask = to_mask(query, 2).view({-1, 1, 1, seq_len}).to(mask_options, /* non-blocking = */ true); attr_mask = attr_mask * attr_mask.transpose(2, 3); if (query.dtype() == torch::kFloat16) { @@ -103,7 +102,7 @@ at::Tensor bt_min_mha( auto attn_output = at::matmul(attn_output_weights, val_buf); attn_output = attn_output.transpose(1, 2).reshape({batch_size, seq_len, embedding_dim}).contiguous(); at::Tensor attr_out = from_padded_tensor(attn_output, get_efficient_nested_size(query)); - return at::matmul(attr_out, out_proj_weight.t()); + return at::matmul(attr_out, out_proj_weight.t()) + out_proj_bias; } TORCH_LIBRARY_FRAGMENT(nestedtensor, m) { diff --git a/nestedtensor/csrc/functions.cpp b/nestedtensor/csrc/functions.cpp index 5dd2d8c6..401f10ee 100644 --- a/nestedtensor/csrc/functions.cpp +++ b/nestedtensor/csrc/functions.cpp @@ -28,31 +28,25 @@ Tensor NestedTensor_embedding( } if (is_nested_tensor_impl(indices) && !is_nested_tensor_impl(weight) && - get_dim(indices) == 1 && + get_dim(indices) < 3 && get_dim(weight) == 2 && get_is_contiguous(indices) && get_is_contiguous(weight)) { - Tensor indices_buffer = get_buffer(indices); + const Tensor& indices_buffer = get_buffer(indices); Tensor result_buffer = at::embedding( weight, indices_buffer, padding_idx, scale_grad_by_freq, sparse); - EfficientSizeNode new_nested_size = get_efficient_nested_size(indices); - EfficientSizeNode new_nested_stride = get_efficient_nested_stride(indices); + const EfficientSizeNode& new_nested_size = get_efficient_nested_size(indices); auto new_nested_size_sizes = new_nested_size.sizes(); - auto new_nested_stride_sizes = new_nested_stride.sizes(); auto tmp = torch::empty( {new_nested_size_sizes.size(0)}, new_nested_size_sizes.options()); tmp.fill_(weight.size(1)); tmp = tmp.reshape({new_nested_size_sizes.size(0), 1}); new_nested_size_sizes = at::cat({new_nested_size_sizes, tmp}, 1); - new_nested_stride_sizes = at::cat({tmp, new_nested_stride_sizes}, 1); return wrap_buffer( std::move(result_buffer), EfficientSizeNode( new_nested_size.structure(), - new_nested_size_sizes), - EfficientSizeNode( - new_nested_stride.structure(), - new_nested_stride_sizes)); + new_nested_size_sizes)); } return map_nested_tensor( [&](at::Tensor i) { diff --git a/nestedtensor/csrc/masking.cpp b/nestedtensor/csrc/masking.cpp index 13ec62c3..91663a27 100644 --- a/nestedtensor/csrc/masking.cpp +++ b/nestedtensor/csrc/masking.cpp @@ -2,6 +2,7 @@ #include #ifdef WITH_CUDA #include +#include #include #include #endif @@ -80,10 +81,10 @@ std::vector _get_max_size(const SizeNode& size_node) { return result; } -std::vector get_max_size_from_efficient_size(EfficientSizeNode esize) { +static std::vector get_max_size_from_efficient_size(const EfficientSizeNode& esize) { auto nt_opt_sizes = esize.opt_sizes(); if (nt_opt_sizes.size() > 0 && *nt_opt_sizes[0] > 0) { - auto sizes = esize.sizes(); + const auto& sizes = esize.sizes(); int64_t* sizes_ptr = sizes.data_ptr(); int64_t sizes_size_0 = sizes.size(0); int64_t sizes_size_1 = sizes.size(1); @@ -107,14 +108,15 @@ std::vector get_max_size(const Tensor& nt) { } -Tensor batch_offsets_from_efficient_size(EfficientSizeNode ef) { - Tensor ef_sizes = ef.sizes(); +Tensor batch_offsets_from_efficient_size(const EfficientSizeNode& ef, int64_t extra_elements) { + const Tensor& ef_sizes = ef.sizes(); int64_t* nt_sizes_ptr = ef_sizes.data_ptr(); - Tensor offsets = torch::empty({1 + ef_sizes.size(0)}, torch::kInt64); - int64_t* offsets_ptr = offsets.data_ptr(); + Tensor offsets = torch::empty({1 + ef_sizes.size(0) + extra_elements}, torch::kInt32); + int32_t* offsets_ptr = offsets.data_ptr(); offsets_ptr[0] = 0; int64_t ef_sizes_size_1 = ef_sizes.size(1); - for (int64_t i = 0; i < ef_sizes.size(0); i++) { + const auto ef_sizes_size_0 = ef_sizes.size(0); + for (int64_t i = 0; i < ef_sizes_size_0; i++) { int64_t prod = 1; for (int64_t j = 0; j < ef_sizes_size_1; j++) { prod = prod * nt_sizes_ptr[i * ef_sizes_size_1 + j]; @@ -124,15 +126,11 @@ Tensor batch_offsets_from_efficient_size(EfficientSizeNode ef) { return offsets; } -std::vector padded_size_from_efficient_size(EfficientSizeNode ef_size) { - Tensor nt_sizes = ef_size.sizes(); +static std::vector padded_size_from_efficient_size(const EfficientSizeNode& ef_size) { + const Tensor& nt_sizes = ef_size.sizes(); auto max_size = get_max_size_from_efficient_size(ef_size); - std::vector new_size; - new_size.push_back(nt_sizes.size(0)); - for (int64_t i = 0; i < max_size.size(); i++) { - new_size.push_back(max_size[i]); - } - return new_size; + max_size.insert(max_size.begin(), nt_sizes.size(0)); + return max_size; } std::tuple pad_nt(Tensor nt, std::vector shape) { @@ -254,8 +252,8 @@ std::tuple to_tensor_mask( #ifdef WITH_CUDA if (get_dim(nt) == 3 && get_is_contiguous(nt) && mask_dim && *mask_dim == 2) { auto nt_opt_size = get_opt_sizes(nt); - Tensor nt_buffer = get_buffer(nt); - if (nt_opt_size[2] && nt_buffer.is_cuda()) { + auto nt_buffer = c10::MaybeOwned::borrowed(get_buffer(nt)); + if (nt_opt_size[2] && nt_buffer->is_cuda()) { Tensor nt_sizes_ = get_efficient_nested_size(nt).sizes().to(torch::kInt32); TORCH_CHECK(nt_sizes_.dim() == 2, "NestedTensor metadata of unexpected dimension.") @@ -265,19 +263,19 @@ std::tuple to_tensor_mask( at::cumsum(nt_sizes, 0).to(torch::kInt32).reshape({-1}); nt_sizes = at::cat({torch::tensor({0}, torch::kInt32), nt_sizes}); Tensor output = torch::zeros( - {*nt_opt_size[0], max_size_1, *nt_opt_size[2]}, nt_buffer.options()); + {*nt_opt_size[0], max_size_1, *nt_opt_size[2]}, nt_buffer->options()); nt_sizes = nt_sizes.to(torch::kCUDA); Tensor output_mask = torch::zeros( - {*nt_opt_size[0], max_size_1}, nt_buffer.options()); + {*nt_opt_size[0], max_size_1}, nt_buffer->options()); output_mask = output_mask.to(torch::kInt32); at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(); if (nt.dtype() == torch::kFloat16) { - nt_buffer = nt_buffer.to(torch::kFloat); + nt_buffer = c10::MaybeOwned::owned(nt_buffer->to(torch::kFloat)); output = output.to(torch::kFloat); } - if (nt_buffer.dtype() == torch::kFloat) { + if (nt_buffer->dtype() == torch::kFloat) { nested_tensor::cuda::add_padding_mask_kernelLauncher( - nt_buffer.data_ptr(), + nt_buffer->data_ptr(), output.data_ptr(), output_mask.data_ptr(), nt_sizes.data_ptr(), @@ -305,11 +303,9 @@ std::tuple to_tensor_mask( auto opt_sizes = get_opt_sizes(nt); if (opt_sizes.size() == 1 && *opt_sizes[0] == 1) { nt = NestedTensor_contiguous(nt); - Tensor nt_buffer = get_buffer(nt); - nt_buffer = nt_buffer.reshape({-1}); Tensor result_mask = !mask_dim || *mask_dim == 0 ? torch::tensor(true) : torch::tensor({true}); - return std::make_tuple(nt_buffer, result_mask); + return std::make_tuple(get_buffer(nt).reshape({-1}), result_mask); } auto max_size = get_max_size(nt); @@ -434,8 +430,8 @@ Tensor to_mask( max_size.push_back(tmp_max_size[i - 1]); } if (*mask_dim == 2 && get_dim(nt) == 3) { - auto nt_size = get_efficient_nested_size(nt); - auto esizes = nt_size.sizes(); + const auto& nt_size = get_efficient_nested_size(nt); + const auto& esizes = nt_size.sizes(); auto options = torch::TensorOptions().dtype(torch::kByte); auto result = torch::zeros({*opt_sizes[0], tmp_max_size[0]}, options); @@ -456,31 +452,30 @@ Tensor to_mask( return merge_mask(res_mask, mask_dim); } -Tensor from_padded_tensor(Tensor padded, EfficientSizeNode target_size) { +Tensor from_padded_tensor(const Tensor& padded, const EfficientSizeNode& target_size) { TORCH_CHECK(padded.dim() == target_size.dim(), "Target size has different dimension as input padded Tensor."); #ifdef WITH_CUDA if (padded.dim() > 1 && padded.dim() < 5 && get_is_contiguous(padded) && padded.is_cuda()) { - Tensor target_offsets = batch_offsets_from_efficient_size(target_size); - std::vector padded_sizes = padded.sizes().vec(); - Tensor padded_sizes_tensor = torch::tensor(padded_sizes); + Tensor target_offsets = batch_offsets_from_efficient_size(target_size, 0); + Tensor padded_sizes_tensor = torch::tensor(padded.sizes()); Tensor output = torch::empty({target_size.numel()}, padded.options()); - Tensor target_size_sizes = target_size.sizes(); + Tensor target_size_sizes = target_size.sizes().reshape(-1); - at::Tensor metadata = at::cat({target_size_sizes.reshape(-1), padded_sizes_tensor, target_offsets}); + at::Tensor metadata = at::cat({target_size_sizes, padded_sizes_tensor, target_offsets}); metadata = metadata.to(at::Device(kCUDA), torch::kInt32, true, true); - std::vector split_sizes; - split_sizes.push_back(target_size_sizes.numel()); - split_sizes.push_back(padded_sizes_tensor.numel()); - split_sizes.push_back(target_offsets.numel()); + std::array split_sizes = { + target_size_sizes.numel(), + padded_sizes_tensor.numel(), + target_offsets.numel()}; - std::vector split = at::split_with_sizes(metadata, IntArrayRef(split_sizes), 0); + std::vector split = at::split_with_sizes(metadata, split_sizes, 0); - target_size_sizes = split[0]; - padded_sizes_tensor = split[1]; - target_offsets = split[2]; + target_size_sizes = std::move(split[0]); + padded_sizes_tensor = std::move(split[1]); + target_offsets = std::move(split[2]); at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(); if (padded.dtype() == torch::kFloat16) { @@ -538,7 +533,7 @@ Tensor _collapse_two_dims_3(Tensor input, int64_t dim1, int64_t dim2) { TORCH_CHECK(dim2 - 1 == dim1, "dim2 must be one more than dim1.") TORCH_CHECK(dim1 == 1, "dim1 must be 1.") TORCH_CHECK(get_dim(input) == 3, "Expected input to be 3 dim."); - auto input_esizes = get_efficient_nested_size(input); + const auto& input_esizes = get_efficient_nested_size(input); Tensor nt_sizes = input_esizes.sizes(); Tensor sizes_dim1 = at::native::narrow(nt_sizes, 1, 0, 1); @@ -555,45 +550,40 @@ Tensor _collapse_two_dims_3(Tensor input, int64_t dim1, int64_t dim2) { return result; } -Tensor to_padded_tensor(Tensor nt, double padding) { +Tensor to_padded_tensor(const Tensor& t, double padding) { #ifdef WITH_CUDA - if ((get_dim(nt) >= 2 && get_dim(nt) <= 4)) { - nt = NestedTensor_contiguous(nt, c10::MemoryFormat::Contiguous); + if ((get_dim(t) >= 2 && get_dim(t) <= 4)) { + auto nt = NestedTensor_contiguous(t, c10::MemoryFormat::Contiguous); auto nt_opt_size = get_opt_sizes(nt); auto orig_nt_dim = get_dim(nt); - Tensor nt_buffer = get_buffer(nt); + const Tensor& nt_buffer = get_buffer(nt); if (nt_buffer.is_cuda()) { if (get_dim(nt) == 3 && nt_opt_size[2]) { nt = _collapse_two_dims_3(nt, 1, 2); } - auto esize = get_efficient_nested_size(nt); - at::Tensor nt_sizes = esize.sizes(); - Tensor offsets = batch_offsets_from_efficient_size(esize); + const auto& esize = get_efficient_nested_size(nt); + const at::Tensor& nt_sizes = esize.sizes(); + Tensor offsets = batch_offsets_from_efficient_size(esize, nt_sizes.numel()); std::vector new_size = padded_size_from_efficient_size(esize); at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(); Tensor output = at::empty(IntArrayRef(new_size), nt_buffer.options()); int64_t input_dim = nt_sizes.size(1); int64_t batch_size = nt_sizes.size(0); - at::Tensor metadata = at::cat({offsets, nt_sizes.reshape(-1)}); - metadata = metadata.to(at::Device(kCUDA), torch::kInt32, true, true); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(offsets.numel() == nt_sizes.numel() * 2 + 1, "error in metadata size logic"); + at::native::narrow(offsets, 0, nt_sizes.numel() + 1, nt_sizes.numel()).copy_(nt_sizes.reshape({-1})); + at::Tensor metadata = offsets.to(at::Device(kCUDA), torch::kInt32, true, true); - std::vector split_sizes; - split_sizes.push_back(offsets.numel()); - split_sizes.push_back(nt_sizes.numel()); - - std::vector split = at::split_with_sizes(metadata, IntArrayRef(split_sizes), 0); - - offsets = split[0]; - nt_sizes = split[1]; + const auto offsets_ptr = metadata.data_ptr(); + const auto nt_sizes_ptr = offsets_ptr + nt_sizes.numel() + 1; if (nt_buffer.dtype() == torch::kFloat16) { nested_tensor::cuda::add_padding_kernelLauncher( nt_buffer.data_ptr(), output.data_ptr(), (c10::Half)(padding), - offsets.data_ptr(), - nt_sizes.data_ptr(), + offsets_ptr, + nt_sizes_ptr, input_dim, new_size, batch_size, @@ -608,8 +598,8 @@ Tensor to_padded_tensor(Tensor nt, double padding) { nt_buffer.data_ptr(), output.data_ptr(), (float)(padding), - offsets.data_ptr(), - nt_sizes.data_ptr(), + offsets_ptr, + nt_sizes_ptr, input_dim, new_size, batch_size, @@ -624,13 +614,12 @@ Tensor to_padded_tensor(Tensor nt, double padding) { } } #endif - auto opt_sizes = get_opt_sizes(nt); + auto opt_sizes = get_opt_sizes(t); if (opt_sizes.size() == 1 && *opt_sizes[0] == 1) { - nt = NestedTensor_contiguous(nt); - return get_buffer(nt); + return get_buffer(NestedTensor_contiguous(t)); } - auto max_size = get_max_size(nt); - TensorNode structure = get_nested_tensor_structure(nt); + auto max_size = get_max_size(t); + TensorNode structure = get_nested_tensor_structure(t); if (structure.degree() == 0) { return torch::tensor({padding}); } diff --git a/nestedtensor/csrc/masking.h b/nestedtensor/csrc/masking.h index e851b393..fb521f00 100644 --- a/nestedtensor/csrc/masking.h +++ b/nestedtensor/csrc/masking.h @@ -18,17 +18,12 @@ at::Tensor to_mask( c10::optional mask_dim); at::Tensor to_padded_tensor( - at::Tensor nt, + const at::Tensor& nt, double padding); at::Tensor from_padded_tensor( - at::Tensor nt, - torch::nested_tensor::EfficientSizeNode target_size, - torch::nested_tensor::EfficientSizeNode target_stride); - -at::Tensor from_padded_tensor( - at::Tensor nt, - torch::nested_tensor::EfficientSizeNode target_size); + const at::Tensor& nt, + const torch::nested_tensor::EfficientSizeNode& target_size); c10::optional nt_from_tensor_mask( at::Tensor tensor, diff --git a/nestedtensor/csrc/matmul.cpp b/nestedtensor/csrc/matmul.cpp index ed1ada9a..aaf1fc36 100644 --- a/nestedtensor/csrc/matmul.cpp +++ b/nestedtensor/csrc/matmul.cpp @@ -15,7 +15,7 @@ Tensor NestedTensor_matmul(const Tensor& self, const Tensor& other) { auto self_opt_sizes = get_opt_sizes(self); if (self_opt_sizes[2]) { if (*self_opt_sizes[2] == other.size(0)) { - Tensor self_buffer = get_buffer(self); + const Tensor& self_buffer = get_buffer(self); Tensor result_buffer = at::matmul(self_buffer.reshape({-1, other.size(0)}), other); result_buffer = result_buffer.reshape({-1}); @@ -37,7 +37,7 @@ Tensor NestedTensor_matmul(const Tensor& self, const Tensor& other) { new_nested_size, new_nested_stride); return wrap_buffer( - std::move(result_buffer), new_nested_size, new_nested_stride); + std::move(result_buffer), std::move(new_nested_size), std::move(new_nested_stride)); } } } diff --git a/nestedtensor/csrc/nested_tensor_impl.cpp b/nestedtensor/csrc/nested_tensor_impl.cpp index 5bbcef11..4a7c21a0 100644 --- a/nestedtensor/csrc/nested_tensor_impl.cpp +++ b/nestedtensor/csrc/nested_tensor_impl.cpp @@ -28,34 +28,38 @@ TensorNode _unbind_tensors(TensorNode structure) { return TensorNode(std::move(result_nodes)); } +// We cannot delegate from the second constructor that uses this macro +// to the first one because we would need to both move from +// nested_size and use it. +#define NESTED_TENSOR_IMPL_CONSTRUCTOR_BODY(stride) \ + : TensorImpl( \ + c10::DispatchKeySet({NestedTensorKey}), \ + buffer.dtype(), \ + buffer.device()), \ + _buffer(std::move(buffer)), \ + _nested_size(std::move(nested_size)), \ + _nested_stride(stride), \ + _is_pinned(_buffer.is_pinned()), \ + _is_contiguous(torch::nested_tensor::impl::storage_is_contiguous( \ + _buffer, \ + _nested_size, \ + _nested_stride)), \ + _is_contiguous_channels_last(torch::nested_tensor::impl::storage_is_contiguous_channels_last( \ + _buffer, \ + _nested_size, \ + _nested_stride)) { \ + remove_autograd_key(); \ + key_set_ = key_set_ - c10::DispatchKeySet({c10::DispatchKey::ADInplaceOrView}); \ + } + NestedTensorImpl::NestedTensorImpl(at::Tensor&& buffer, EfficientSizeNode nested_size, EfficientSizeNode nested_stride) - : TensorImpl( - c10::DispatchKeySet({NestedTensorKey}), - buffer.dtype(), - buffer.device()), - _buffer(buffer), - _nested_size(nested_size), - _nested_stride(nested_stride), - _is_pinned(_buffer.is_pinned()), - _is_contiguous(torch::nested_tensor::impl::storage_is_contiguous( - _buffer, - _nested_size, - _nested_stride)), - _is_contiguous_channels_last(torch::nested_tensor::impl::storage_is_contiguous_channels_last( - _buffer, - _nested_size, - _nested_stride)) { - remove_autograd_key(); - key_set_ = key_set_ - c10::DispatchKeySet({c10::DispatchKey::ADInplaceOrView}); -} +NESTED_TENSOR_IMPL_CONSTRUCTOR_BODY(std::move(nested_stride)) NestedTensorImpl::NestedTensorImpl(at::Tensor&& buffer, EfficientSizeNode nested_size) - : NestedTensorImpl(std::move(buffer), - nested_size, - torch::nested_tensor::impl::_cont_stride(nested_size)) {} +NESTED_TENSOR_IMPL_CONSTRUCTOR_BODY(torch::nested_tensor::impl::_cont_stride(nested_size)) NestedTensorImpl::NestedTensorImpl(at::Tensor&& buffer, SizeNode nested_size, @@ -116,7 +120,7 @@ std::vector wrap_tensor_node(std::vector input) { return result; } -at::Tensor wrap_buffer(at::Tensor&& buffer, SizeNode nested_size) { +at::Tensor wrap_buffer(at::Tensor buffer, SizeNode nested_size) { TORCH_CHECK(buffer.is_contiguous(), "Given buffer must be contiguous."); if (nested_size.is_leaf()) { return buffer.reshape(IntArrayRef(nested_size.payload())); @@ -126,7 +130,7 @@ at::Tensor wrap_buffer(at::Tensor&& buffer, SizeNode nested_size) { } at::Tensor wrap_buffer( - at::Tensor&& buffer, + at::Tensor buffer, EfficientSizeNode efficient_nested_size, EfficientSizeNode efficient_nested_stride) { TORCH_CHECK(buffer.is_contiguous(), "Given buffer must be contiguous."); @@ -138,12 +142,12 @@ at::Tensor wrap_buffer( "Internal error: expected nested_size of non-zero height."); return at::detail::make_tensor( std::move(buffer), - efficient_nested_size, - efficient_nested_stride); + std::move(efficient_nested_size), + std::move(efficient_nested_stride)); } at::Tensor wrap_buffer( - at::Tensor&& buffer, + at::Tensor buffer, EfficientSizeNode efficient_nested_size) { TORCH_CHECK(buffer.is_contiguous(), "Given buffer must be contiguous."); TORCH_CHECK( @@ -151,7 +155,7 @@ at::Tensor wrap_buffer( "Internal error: expected nested_size of non-zero height."); return at::detail::make_tensor( std::move(buffer), - efficient_nested_size); + std::move(efficient_nested_size)); } Tensor NestedTensor_contiguous(const Tensor& self, MemoryFormat memory_format) { diff --git a/nestedtensor/csrc/nested_tensor_impl.h b/nestedtensor/csrc/nested_tensor_impl.h index 004afbed..28a58dd9 100644 --- a/nestedtensor/csrc/nested_tensor_impl.h +++ b/nestedtensor/csrc/nested_tensor_impl.h @@ -75,10 +75,10 @@ struct NestedTensorImpl : public c10::TensorImpl { _nested_size, _nested_stride)); } - EfficientSizeNode get_nested_size() { + const EfficientSizeNode& get_nested_size() { return _nested_size; } - EfficientSizeNode get_nested_stride() { + const EfficientSizeNode& get_nested_stride() { return _nested_stride; } int64_t nested_dim() const { @@ -198,7 +198,11 @@ inline TensorNode get_nested_tensor_structure(at::Tensor tensor) { return get_nested_tensor_impl(tensor)->get_structure(); } -inline at::Tensor get_buffer(const at::Tensor& tensor) { +inline const at::Tensor& get_buffer(const at::Tensor& tensor) { + return get_nested_tensor_impl(tensor)->get_buffer(); +} + +inline at::Tensor& get_buffer(at::Tensor& tensor) { return get_nested_tensor_impl(tensor)->get_buffer(); } @@ -209,13 +213,13 @@ inline const std::vector> get_opt_sizes( return get_nested_tensor_impl(tensor)->opt_sizes(); } -inline const EfficientSizeNode get_efficient_nested_size(const at::Tensor& tensor) { +inline const EfficientSizeNode& get_efficient_nested_size(const at::Tensor& tensor) { TORCH_CHECK( is_nested_tensor_impl(tensor), "Given tensor must be NestedTensor."); return get_nested_tensor_impl(tensor)->get_nested_size(); } -inline const EfficientSizeNode get_efficient_nested_stride(const at::Tensor& tensor) { +inline const EfficientSizeNode& get_efficient_nested_stride(const at::Tensor& tensor) { TORCH_CHECK( is_nested_tensor_impl(tensor), "Given tensor must be NestedTensor."); return get_nested_tensor_impl(tensor)->get_nested_stride(); @@ -282,13 +286,13 @@ inline int64_t get_nested_dim(const at::Tensor& tensor) { at::Tensor wrap_tensor_node(NestedTensorImpl); at::Tensor wrap_tensor_node(TensorNode&&); std::vector wrap_tensor_node(std::vector); -at::Tensor wrap_buffer(at::Tensor&&, SizeNode nested_size); +at::Tensor wrap_buffer(at::Tensor, SizeNode nested_size); at::Tensor wrap_buffer( - at::Tensor&&, + at::Tensor, EfficientSizeNode efficient_nested_size, EfficientSizeNode efficient_nested_stride); at::Tensor wrap_buffer( - at::Tensor&&, + at::Tensor, EfficientSizeNode efficient_nested_size); template diff --git a/nestedtensor/csrc/shape.cpp b/nestedtensor/csrc/shape.cpp index d6b42385..04ca9749 100644 --- a/nestedtensor/csrc/shape.cpp +++ b/nestedtensor/csrc/shape.cpp @@ -61,8 +61,8 @@ Tensor NestedTensor_transpose(const Tensor& self, int64_t dim0, int64_t dim1) { TORCH_CHECK( dim0 >= nested_dim && dim1 >= nested_dim, "Transposition of nested dimensions is not implemented yet."); - EfficientSizeNode ef_sizes = get_efficient_nested_size(self); - EfficientSizeNode ef_strides = get_efficient_nested_stride(self); + const EfficientSizeNode& ef_sizes = get_efficient_nested_size(self); + const EfficientSizeNode& ef_strides = get_efficient_nested_stride(self); auto new_ef_sizes = map_efficient_size( [dim0, dim1, nested_dim](int64_t* size_ptr, int64_t size) { int64_t tmp = size_ptr[dim0 - nested_dim]; diff --git a/nestedtensor/csrc/storage/EfficientSizeNode.h b/nestedtensor/csrc/storage/EfficientSizeNode.h index 9c5be06f..e57651bf 100644 --- a/nestedtensor/csrc/storage/EfficientSizeNode.h +++ b/nestedtensor/csrc/storage/EfficientSizeNode.h @@ -104,7 +104,7 @@ struct EfficientSizeNode { const at::Tensor& sizes() const { return _sizes; } - const int64_t structure() const { + int64_t structure() const { return _structure; } EfficientSizeNode clone() const { diff --git a/nestedtensor/csrc/storage/Packed.h b/nestedtensor/csrc/storage/Packed.h index 04a5b6f6..35a37742 100644 --- a/nestedtensor/csrc/storage/Packed.h +++ b/nestedtensor/csrc/storage/Packed.h @@ -88,7 +88,7 @@ inline at::Tensor pack(const TensorNode& structure) { at::Tensor result_buffer = empty({full_numel}, tensors[0].options()); int64_t index = 0; for (size_t i = 0; i < tensors.size(); i++) { - at::Tensor narrowed_result_buffer = + at::Tensor narrowed_result_buffer = result_buffer.narrow(0, index, tensors[i].numel()); narrowed_result_buffer = narrowed_result_buffer.reshape(tensors[i].sizes()); narrowed_result_buffer.copy_(tensors[i], true); @@ -111,11 +111,14 @@ inline bool storage_is_contiguous( const at::Tensor& strides_sizes = nested_stride.sizes(); int64_t* sizes_sizes_ptr = sizes_sizes.data_ptr(); int64_t* strides_sizes_ptr = strides_sizes.data_ptr(); - for (int64_t i = 0; i < sizes_sizes.size(0); i++) { + const auto sizes_sizes_0 = sizes_sizes.size(0); + const auto sizes_sizes_1 = sizes_sizes.size(1); + const auto strides_sizes_1 = strides_sizes.size(1); + for (int64_t i = 0; i < sizes_sizes_0; i++) { if (!_is_cont_stride( - sizes_sizes_ptr + i * sizes_sizes.size(1), - strides_sizes_ptr + i * strides_sizes.size(1), - sizes_sizes.size(1))) { + sizes_sizes_ptr + i * sizes_sizes_1, + strides_sizes_ptr + i * strides_sizes_1, + sizes_sizes_1)) { return false; } } diff --git a/nestedtensor/csrc/transpose.cpp b/nestedtensor/csrc/transpose.cpp index 41cc3cf8..07ec5c2a 100644 --- a/nestedtensor/csrc/transpose.cpp +++ b/nestedtensor/csrc/transpose.cpp @@ -21,7 +21,7 @@ Tensor _collapse_two_dims(Tensor input, int64_t dim1, int64_t dim2) { TORCH_CHECK(dim2 - 1 == dim1, "dim2 must be one more than dim1.") TORCH_CHECK(dim1 == 1 || dim1 == 2, "dim1 must be 1 or 2.") TORCH_CHECK(get_dim(input) == 4, "Expected input to be 4 dim."); - auto input_esizes = get_efficient_nested_size(input); + const auto& input_esizes = get_efficient_nested_size(input); Tensor nt_sizes = input_esizes.sizes(); Tensor sizes_dim1 = at::native::narrow(nt_sizes, 1, 0, 1).contiguous(); @@ -95,8 +95,8 @@ Tensor _transpose_nchw_nhwc(Tensor input, Tensor output) { Tensor block_offsets; std::tie(offsets, block_offsets) = _create_offsets<32>(collapsed_input); at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(); - Tensor input_buffer = get_buffer(input); - Tensor output_buffer = get_buffer(output); + const Tensor& input_buffer = get_buffer(input); + const Tensor& output_buffer = get_buffer(output); TORCH_CHECK(input_buffer.is_cuda(), "Expected input_buffer to be CUDA."); TORCH_CHECK(output_buffer.is_cuda(), "Expected output_buffer to be CUDA."); int* block_offsets_ptr = block_offsets.data_ptr(); @@ -123,7 +123,7 @@ Tensor transpose_nchw_nhwc(Tensor input) { TORCH_CHECK(get_is_contiguous(input), "transpose_nchw_nhwc input needs to be contiguous."); auto input_opt_sizes = get_opt_sizes(input); TORCH_CHECK(input_opt_sizes[1], "Expected first dimension to be regular."); - Tensor input_buffer = get_buffer(input); + const Tensor& input_buffer = get_buffer(input); auto new_sizes = map_efficient_size([](int64_t* size_ptr, int64_t size) { int64_t tmp = size_ptr[0]; size_ptr[0] = size_ptr[2]; @@ -153,8 +153,8 @@ Tensor _transpose_nhwc_nchw(Tensor input, Tensor output) { Tensor block_offsets; std::tie(offsets, block_offsets) = _create_offsets<32>(collapsed_input); at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(); - Tensor input_buffer = get_buffer(input); - Tensor output_buffer = get_buffer(output); + const Tensor& input_buffer = get_buffer(input); + const Tensor& output_buffer = get_buffer(output); int* block_offsets_ptr = block_offsets.data_ptr(); int batch_size = sizes_dim3.numel(); int block_numel = block_offsets_ptr[batch_size]; @@ -179,7 +179,7 @@ Tensor transpose_nhwc_nchw(Tensor input) { TORCH_CHECK(get_is_contiguous(input), "transpose_nhwc_nchw input needs to be contiguous."); auto input_opt_sizes = get_opt_sizes(input); TORCH_CHECK(input_opt_sizes[3], "Expected last dimension to be regular."); - Tensor input_buffer = get_buffer(input); + const Tensor& input_buffer = get_buffer(input); auto new_sizes = map_efficient_size([](int64_t* size_ptr, int64_t size) { // nhwc int64_t tmp = size_ptr[0]; diff --git a/nestedtensor/csrc/utils/nested_node.h b/nestedtensor/csrc/utils/nested_node.h index 89c47375..fcc28a6b 100644 --- a/nestedtensor/csrc/utils/nested_node.h +++ b/nestedtensor/csrc/utils/nested_node.h @@ -100,44 +100,35 @@ class _map> { const NestedNode&... nested_node) { size_t degree = 0; bool all_leaf = true; - c10::guts::tuple_map( - std::forward_as_tuple(nested_node...), [&all_leaf, °ree](auto n) { - all_leaf = all_leaf && (n.is_leaf()); - if (degree > 1 && n.degree() > 1) { - TORCH_CHECK(degree == n.degree(), "NestedNodes don't broadcast."); - } - if (n.degree() > degree) { - degree = n.degree(); - } - return nullptr; - }); + auto find_max_degree = [&all_leaf, °ree](auto n) { + all_leaf = all_leaf && (n.is_leaf()); + if (degree > 1 && n.degree() > 1) { + TORCH_CHECK(degree == n.degree(), "NestedNodes don't broadcast."); + } + if (n.degree() > degree) { + degree = n.degree(); + } + }; + std::initializer_list unused = {(find_max_degree(nested_node), 0)...}; if (all_leaf) { return NestedNode(std::forward(fn)(nested_node.payload()...)); } std::vector> result; for (size_t i = 0; i < degree; i++) { - std::tuple...> children = c10::guts::tuple_map( - std::forward_as_tuple(nested_node...), [&i](auto a) { - static_assert( - c10::guts::is_instantiation_of::value, + auto get_child = [&i](auto a) { + static_assert( + c10::guts::is_instantiation_of::value, "Internal error."); - if (a.is_leaf()) { - return a; - } - if (a.degree() == 1 && a.height() > 0) { - return a.children(0); - } - TORCH_CHECK(a.degree() > 0, "Internal assert."); - return a.children(i); - }); - // TODO: Due to the experiences with to_vector and the inversion I'm a bit - // wary of apply but I haven't been able to reproduce the argument - // inversion behavior in other contexts. - c10::guts::apply( - [&result, &fn](NestedNode... filtered) { - result.emplace_back(function(std::forward(fn), filtered...)); - }, - std::move(children)); + if (a.is_leaf()) { + return a; + } + if (a.degree() == 1 && a.height() > 0) { + return a.children(0); + } + TORCH_CHECK(a.degree() > 0, "Internal assert."); + return a.children(i); + }; + result.emplace_back(function(std::forward(fn), get_child(nested_node)...)); } return NestedNode(std::move(result)); } @@ -307,8 +298,7 @@ class _apply> { static void function(F&& fn, NestedNode... nested_node) { size_t degree = 0; bool all_leaf = true; - c10::guts::tuple_map( - std::forward_as_tuple(nested_node...), [&all_leaf, °ree](auto n) { + auto find_degree = [&all_leaf, °ree](auto n) { all_leaf = all_leaf && (n.is_leaf()); if (degree == 0 && n.degree() > 0) { degree = n.degree(); @@ -317,31 +307,27 @@ class _apply> { TORCH_CHECK(degree == n.degree(), "NestedNodes don't broadcast."); } return nullptr; - }); + }; + std::initializer_list unused = {(find_degree(nested_node), 0)...}; if (all_leaf) { std::forward(fn)(nested_node.payload()...); } else { for (size_t i = 0; i < degree; i++) { - std::tuple...> children = c10::guts::tuple_map( - std::forward_as_tuple(nested_node...), [&i](auto a) { - static_assert( - c10::guts::is_instantiation_of:: - value, - "Internal error."); - if (a.is_leaf()) { - return a; - } - if (a.degree() == 1) { - return a.children(0); - } - TORCH_CHECK(a.degree() > 0, "Internal assert."); - return a.children(i); - }); - c10::guts::apply( - [&fn](NestedNode... filtered) { - function(std::forward(fn), filtered...); - }, - std::move(children)); + auto get_child = [&i](auto a) { + static_assert( + c10::guts::is_instantiation_of:: + value, + "Internal error."); + if (a.is_leaf()) { + return a; + } + if (a.degree() == 1) { + return a.children(0); + } + TORCH_CHECK(a.degree() > 0, "Internal assert."); + return a.children(i); + }; + function(std::forward(fn), get_child(nested_node)...); } } }