diff --git a/benchmarks/frozenbatchnorm2d.py b/benchmarks/frozenbatchnorm2d.py new file mode 100644 index 00000000..4945de2f --- /dev/null +++ b/benchmarks/frozenbatchnorm2d.py @@ -0,0 +1,77 @@ +import torch +import nestedtensor +import utils +import torchvision + +import random + +random.seed(1010) +RAND_INTS = [random.randint(10, 30) for _ in range(2000)] +RAND_INTS = [random.randint(100, 300) for _ in range(20)] + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + # print(scale.size()) + # print(bias.size()) + # print(type(scale)) + # print(type(bias)) + # print(x.nested_size()) + return (x * scale + bias).squeeze(1) + +MODEL = FrozenBatchNorm2d(64).cuda() + +def gen_t_loop_frozenbatchnorm2d(): + tensors = [torch.rand(64, i, 256).cuda() for i in RAND_INTS] + + def t_loop(): + for t in tensors: + MODEL(t.unsqueeze(0)) + return t_loop + + +def gen_nt_frozenbatchnorm2d(): + nt0 = nestedtensor.nested_tensor( + [torch.rand(64, i, 256).cuda() for i in RAND_INTS]) + + def nt(): + MODEL(nt0) + return nt + + +if __name__ == "__main__": + print(utils.benchmark_fn(gen_nt_frozenbatchnorm2d())) + print(utils.benchmark_fn(gen_t_loop_frozenbatchnorm2d())) diff --git a/benchmarks/mha.py b/benchmarks/mha.py index d07dc516..8fbcbaf4 100644 --- a/benchmarks/mha.py +++ b/benchmarks/mha.py @@ -5,7 +5,6 @@ import random -# Performance tanks hard for lots of small Tensors as expected random.seed(1010) RAND_INTS = [random.randint(10, 30) for _ in range(2000)] RAND_INTS = [random.randint(100, 300) for _ in range(20)] @@ -14,7 +13,7 @@ MODEL0 = torch.nn.MultiheadAttention(256, 8, dropout=0.1).cuda() MODEL1 = nestedtensor.nn.MultiheadAttention(256, 8, dropout=0.1).cuda() -def gen_t_loop_segmentation(): +def gen_t_loop_mha(): tensors = [torch.rand(1, i, 256).cuda() for i in RAND_INTS] def t_loop(): @@ -23,7 +22,7 @@ def t_loop(): return t_loop -def gen_nt_segmentation(): +def gen_nt_mha(): nt0 = nestedtensor.nested_tensor( [torch.rand(i, 256).cuda() for i in RAND_INTS]) @@ -33,5 +32,5 @@ def nt(): if __name__ == "__main__": - print(utils.benchmark_fn(gen_nt_segmentation())) - print(utils.benchmark_fn(gen_t_loop_segmentation())) + print(utils.benchmark_fn(gen_nt_mha())) + print(utils.benchmark_fn(gen_t_loop_mha())) diff --git a/benchmarks/utils.py b/benchmarks/utils.py index d1268602..61e0fc5f 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -28,6 +28,7 @@ def benchmark_fn(fn, run_time = 5.0, use_cprofile=False, warmup=1.0, cuda=False) if use_cprofile: pr.enable() fn() + # import sys; sys.exit(1) if cuda: torch.cuda.synchronize() if use_cprofile: diff --git a/nestedtensor/csrc/BinaryOps.cpp b/nestedtensor/csrc/BinaryOps.cpp index c91a0587..b020dfe3 100644 --- a/nestedtensor/csrc/BinaryOps.cpp +++ b/nestedtensor/csrc/BinaryOps.cpp @@ -43,14 +43,18 @@ Tensor NestedTensor_binary(const Tensor& self, const Tensor& other) { return map_nested_tensor( [&self](Tensor other) { return func(self, other); }, other); } - if (is_packed(self) && (other.dim() == 0 || (other.dim() == 1 && other.numel() == 1))) { + if (is_packed(self)) { + auto self_structure = get_nested_tensor_structure(self); + auto self_impl = get_nested_tensor_impl(self); + if (other.dim() == 0 || (other.dim() == 1 && other.numel() == 1)) { #ifdef TRACEPACKED - std::cout << "calling packed binary " << typeid(func).name() << std::endl; + std::cout << "calling packed binary NT x T 0-dim / 1-dim 1-numel" + << typeid(func).name() << std::endl; #endif - auto self_structure = get_nested_tensor_structure(self); - return wrap_tensor_node(torch::nested_tensor::impl::build_structure( - func((*self_structure.buffer()), other), - get_nested_tensor_impl(self)->nested_size())); + return wrap_tensor_node(torch::nested_tensor::impl::build_structure( + func((*self_structure.buffer()), other), + get_nested_tensor_impl(self)->nested_size())); + } } return map_nested_tensor( [&other](Tensor self) { return func(self, other); }, self); diff --git a/nestedtensor/csrc/autograd_functions.cpp b/nestedtensor/csrc/autograd_functions.cpp index 5018dd82..7312acc1 100644 --- a/nestedtensor/csrc/autograd_functions.cpp +++ b/nestedtensor/csrc/autograd_functions.cpp @@ -58,7 +58,7 @@ struct NestedTensorFunction_batch_norm cudnn_enabled) .squeeze(0); }, - autograd_input); + autograd_input).contiguous(); ctx->saved_data["0"] = weight; ctx->saved_data["1"] = bias; ctx->saved_data["2"] = autograd_output; diff --git a/nestedtensor/csrc/functions.cpp b/nestedtensor/csrc/functions.cpp index aa6b5ddf..d15dd2c6 100644 --- a/nestedtensor/csrc/functions.cpp +++ b/nestedtensor/csrc/functions.cpp @@ -139,7 +139,7 @@ Tensor NestedTensor_layer_norm( [normalized_shape, &weight, &bias, eps](const at::Tensor t) { return at::layer_norm(t, normalized_shape, weight, bias, eps, true); }, - input); + input).contiguous(); } Tensor NestedTensor_all(const Tensor& self) { diff --git a/nestedtensor/nn/mha.py b/nestedtensor/nn/mha.py index 963a3444..84e6f7af 100644 --- a/nestedtensor/nn/mha.py +++ b/nestedtensor/nn/mha.py @@ -49,6 +49,9 @@ def multi_head_attention_forward(query, # type: Nested assert isinstance(value, nestedtensor.NestedTensor) assert torch.is_tensor(out_proj_weight) assert torch.is_tensor(out_proj_bias) + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() # TODO: Explicitly unsupported flags assert not use_separate_proj_weight diff --git a/nestedtensor/version.py b/nestedtensor/version.py index 49ba8adc..73166668 100644 --- a/nestedtensor/version.py +++ b/nestedtensor/version.py @@ -1,5 +1,5 @@ -__version__ = '0.0.1.dev20208254+8c381e2' -git_version = '8c381e21692b923eb2ba58b84c1fe5955ae207ad' +__version__ = '0.0.1.dev20208255+faee8a1' +git_version = 'faee8a1a2578f7ecb80098d2cb792ea7c22e61ab' from nestedtensor import _C if hasattr(_C, 'CUDA_VERSION'): cuda = _C.CUDA_VERSION