Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Packed FrozenBatchNorm2d #213

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions benchmarks/frozenbatchnorm2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch
import nestedtensor
import utils
import torchvision

import random

random.seed(1010)
RAND_INTS = [random.randint(10, 30) for _ in range(2000)]
RAND_INTS = [random.randint(100, 300) for _ in range(20)]

class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.

Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other models than torchvision.models.resnet[18,34,50,101]
produce nans.
"""

def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]

super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)

def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
# print(scale.size())
# print(bias.size())
# print(type(scale))
# print(type(bias))
# print(x.nested_size())
return (x * scale + bias).squeeze(1)

MODEL = FrozenBatchNorm2d(64).cuda()

def gen_t_loop_frozenbatchnorm2d():
tensors = [torch.rand(64, i, 256).cuda() for i in RAND_INTS]

def t_loop():
for t in tensors:
MODEL(t.unsqueeze(0))
return t_loop


def gen_nt_frozenbatchnorm2d():
nt0 = nestedtensor.nested_tensor(
[torch.rand(64, i, 256).cuda() for i in RAND_INTS])

def nt():
MODEL(nt0)
return nt


if __name__ == "__main__":
print(utils.benchmark_fn(gen_nt_frozenbatchnorm2d()))
print(utils.benchmark_fn(gen_t_loop_frozenbatchnorm2d()))
9 changes: 4 additions & 5 deletions benchmarks/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import random

# Performance tanks hard for lots of small Tensors as expected
random.seed(1010)
RAND_INTS = [random.randint(10, 30) for _ in range(2000)]
RAND_INTS = [random.randint(100, 300) for _ in range(20)]
Expand All @@ -14,7 +13,7 @@
MODEL0 = torch.nn.MultiheadAttention(256, 8, dropout=0.1).cuda()
MODEL1 = nestedtensor.nn.MultiheadAttention(256, 8, dropout=0.1).cuda()

def gen_t_loop_segmentation():
def gen_t_loop_mha():
tensors = [torch.rand(1, i, 256).cuda() for i in RAND_INTS]

def t_loop():
Expand All @@ -23,7 +22,7 @@ def t_loop():
return t_loop


def gen_nt_segmentation():
def gen_nt_mha():
nt0 = nestedtensor.nested_tensor(
[torch.rand(i, 256).cuda() for i in RAND_INTS])

Expand All @@ -33,5 +32,5 @@ def nt():


if __name__ == "__main__":
print(utils.benchmark_fn(gen_nt_segmentation()))
print(utils.benchmark_fn(gen_t_loop_segmentation()))
print(utils.benchmark_fn(gen_nt_mha()))
print(utils.benchmark_fn(gen_t_loop_mha()))
1 change: 1 addition & 0 deletions benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def benchmark_fn(fn, run_time = 5.0, use_cprofile=False, warmup=1.0, cuda=False)
if use_cprofile:
pr.enable()
fn()
# import sys; sys.exit(1)
if cuda:
torch.cuda.synchronize()
if use_cprofile:
Expand Down
16 changes: 10 additions & 6 deletions nestedtensor/csrc/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,18 @@ Tensor NestedTensor_binary(const Tensor& self, const Tensor& other) {
return map_nested_tensor(
[&self](Tensor other) { return func(self, other); }, other);
}
if (is_packed(self) && (other.dim() == 0 || (other.dim() == 1 && other.numel() == 1))) {
if (is_packed(self)) {
auto self_structure = get_nested_tensor_structure(self);
auto self_impl = get_nested_tensor_impl(self);
if (other.dim() == 0 || (other.dim() == 1 && other.numel() == 1)) {
#ifdef TRACEPACKED
std::cout << "calling packed binary " << typeid(func).name() << std::endl;
std::cout << "calling packed binary NT x T 0-dim / 1-dim 1-numel"
<< typeid(func).name() << std::endl;
#endif
auto self_structure = get_nested_tensor_structure(self);
return wrap_tensor_node(torch::nested_tensor::impl::build_structure(
func((*self_structure.buffer()), other),
get_nested_tensor_impl(self)->nested_size()));
return wrap_tensor_node(torch::nested_tensor::impl::build_structure(
func((*self_structure.buffer()), other),
get_nested_tensor_impl(self)->nested_size()));
}
}
return map_nested_tensor(
[&other](Tensor self) { return func(self, other); }, self);
Expand Down
2 changes: 1 addition & 1 deletion nestedtensor/csrc/autograd_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct NestedTensorFunction_batch_norm
cudnn_enabled)
.squeeze(0);
},
autograd_input);
autograd_input).contiguous();
ctx->saved_data["0"] = weight;
ctx->saved_data["1"] = bias;
ctx->saved_data["2"] = autograd_output;
Expand Down
2 changes: 1 addition & 1 deletion nestedtensor/csrc/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ Tensor NestedTensor_layer_norm(
[normalized_shape, &weight, &bias, eps](const at::Tensor t) {
return at::layer_norm(t, normalized_shape, weight, bias, eps, true);
},
input);
input).contiguous();
}

Tensor NestedTensor_all(const Tensor& self) {
Expand Down
3 changes: 3 additions & 0 deletions nestedtensor/nn/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def multi_head_attention_forward(query, # type: Nested
assert isinstance(value, nestedtensor.NestedTensor)
assert torch.is_tensor(out_proj_weight)
assert torch.is_tensor(out_proj_bias)
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()

# TODO: Explicitly unsupported flags
assert not use_separate_proj_weight
Expand Down
4 changes: 2 additions & 2 deletions nestedtensor/version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__version__ = '0.0.1.dev20208254+8c381e2'
git_version = '8c381e21692b923eb2ba58b84c1fe5955ae207ad'
__version__ = '0.0.1.dev20208255+faee8a1'
git_version = 'faee8a1a2578f7ecb80098d2cb792ea7c22e61ab'
from nestedtensor import _C
if hasattr(_C, 'CUDA_VERSION'):
cuda = _C.CUDA_VERSION