Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Commit 5dde999

Browse files
authored
Check operator coverage (forward only) for issue 313 (#316)
1 parent c559076 commit 5dde999

File tree

6 files changed

+109
-5
lines changed

6 files changed

+109
-5
lines changed

nestedtensor/csrc/BinaryOps.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@ TORCH_LIBRARY_IMPL(aten, NestedTensor, m) {
274274
nt_impl(m, "logical_and_", NestedTensor_binary_<at::native::logical_and_>);
275275
nt_impl(m, "logical_and.out", NestedTensor_binary_out<at::logical_and_out>);
276276

277+
nt_impl(m, "logical_or", NestedTensor_binary<at::logical_or>);
278+
nt_impl(m, "logical_or_", NestedTensor_binary_<at::native::logical_or_>);
279+
nt_impl(m, "logical_or.out", NestedTensor_binary_out<at::logical_or_out>);
280+
277281
nt_impl(m, "sub.Tensor", (NestedTensor_binary<Scalar, at::sub>));
278282
nt_impl(m, "pow.Tensor_Tensor", NestedTensor_binary<at::pow>);
279283
}

nestedtensor/csrc/ReduceOps.cpp

+46
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,50 @@ Tensor NestedTensor_sum_dim(
102102
my_sum, self, dims, keepdims, dtype);
103103
}
104104

105+
std::tuple<Tensor, Tensor> NestedTensor_max_dim(
106+
const Tensor& self,
107+
int64_t dim,
108+
bool keepdims) {
109+
int64_t nested_dim = get_nested_tensor_impl(self)->nested_dim();
110+
at::Tensor output = self;
111+
if (dim >= nested_dim) {
112+
std::vector<TensorNode> result = unzip(map(
113+
[nested_dim, dim, keepdims](at::Tensor tensor) {
114+
auto tmp = at::max(tensor, dim - nested_dim, keepdims);
115+
std::vector<at::Tensor> result;
116+
result.push_back(std::get<0>(tmp));
117+
result.push_back(std::get<1>(tmp));
118+
return result;
119+
},
120+
get_nested_tensor_structure(output)));
121+
return std::make_tuple(
122+
wrap_tensor_node(std::move(result[0])),
123+
wrap_tensor_node(std::move(result[1])));
124+
}
125+
auto opt_sizes = get_opt_sizes(output);
126+
TORCH_CHECK(
127+
opt_sizes[dim],
128+
"Current shape doesn't support reduction across nested dimension. Please open a feature request https://t.ly/62F6.");
129+
auto new_nested_size = get_nested_size(output);
130+
new_nested_size = squeeze(new_nested_size, dim, keepdims);
131+
auto tmp =
132+
at::max(NestedTensor_to_tensor(output, c10::nullopt), dim, keepdims);
133+
return std::make_tuple(
134+
wrap_buffer(std::get<0>(tmp).reshape({-1}), new_nested_size),
135+
wrap_buffer(std::get<1>(tmp).reshape({-1}), new_nested_size));
136+
}
137+
138+
Tensor NestedTensor_max(const Tensor& self) {
139+
auto tensors = flatten(
140+
map([](at::Tensor tensor) { return at::max(tensor); },
141+
get_nested_tensor_structure(self)));
142+
if (tensors.size() == 0) {
143+
return at::ones({0});
144+
}
145+
auto all_tensor = at::stack(tensors);
146+
return at::max(all_tensor);
147+
}
148+
105149
Tensor NestedTensor_mean_dim(
106150
const Tensor& self,
107151
c10::ArrayRef<int64_t> dims,
@@ -336,6 +380,8 @@ TORCH_LIBRARY_IMPL(aten, NestedTensor, m) {
336380
nt_impl(m, "sum.dim_IntList", NestedTensor_sum_dim);
337381
nt_impl(m, "mean", NestedTensor_mean);
338382
nt_impl(m, "mean.dim", NestedTensor_mean_dim);
383+
nt_impl(m, "max", NestedTensor_max);
384+
nt_impl(m, "max.dim", NestedTensor_max_dim);
339385
nt_impl(m, "var", NestedTensor_var);
340386
nt_impl(m, "var.dim", NestedTensor_var_dim);
341387
nt_impl(m, "var_backward.dim", NestedTensor_var_backward_dim);

nestedtensor/version.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
__version__ = '0.0.1+2c3a468'
2-
git_version = '2c3a468de0a4a2e8c50d6dd7b41282fe98471206'
1+
__version__ = '0.0.1+6a95af1'
2+
git_version = '6a95af1cfc7efd07c7d103b3b3e5cd27148b128f'
33
from nestedtensor import _C
44
if hasattr(_C, 'CUDA_VERSION'):
55
cuda = _C.CUDA_VERSION

test/test_coverage.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import traceback
2+
import functools
3+
import pdb
4+
import sys
5+
import torch
6+
import nestedtensor
7+
import unittest
8+
import random
9+
from torch.nn import functional as F
10+
from torch import nn
11+
12+
from utils import TestCase
13+
14+
15+
def ntnt(x): return nestedtensor.nested_tensor(x, requires_grad=True)
16+
def ntnt_nograd(x): return nestedtensor.nested_tensor(x)
17+
18+
19+
# Various smoke tests to confirm coverage of an operator
20+
21+
class TestCoverage(TestCase):
22+
23+
def test_issues_313(self):
24+
# Based on https://github.com/pytorch/nestedtensor/issues/313
25+
26+
def model(x):
27+
torch.manual_seed(20)
28+
linear = nn.Linear(9, 64)
29+
norm = nn.BatchNorm1d(64)
30+
# 3 voxel with 40, 50 and 90 points respectively
31+
x = linear(x)
32+
x = norm(x.transpose(2, 1).contiguous()
33+
).transpose(2, 1).contiguous()
34+
x = F.relu(x)
35+
return torch.max(x, dim=1, keepdim=True)[0]
36+
37+
inputs = [torch.randn(i, 9) for i in [40, 50, 90]]
38+
model(ntnt(inputs))
39+
40+
inputs = [torch.randn(30, 9) for _ in range(3)]
41+
x0 = model(ntnt(inputs))
42+
x1 = model(torch.stack(inputs))
43+
self.assertEqual(torch.stack(x0.unbind()), x1)
44+
45+
46+
if __name__ == "__main__":
47+
unittest.main()

test/test_nested_tensor_reduce.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@ def _flatten_nt(nt):
3030

3131
class TestReduce(TestCase):
3232

33-
def _test_reduce_dim(self, fn, associative=True, test_keep_dim=True):
33+
def _test_reduce_dim(self, fn, associative=True, test_keep_dim=True, test_multi_dim=True):
3434
t0 = torch.arange(9).float().reshape(3, 3)
3535
t1 = torch.arange(6).float().reshape(2, 3)
3636
t2 = torch.arange(9).float().reshape(3, 3)
3737
ts = [[t0, t1], [t2, t1]]
3838
nt = ntnt(ts)
39-
if associative:
39+
if associative and test_multi_dim:
4040
t01 = fn(torch.stack([fn(t0, 0), fn(t1, 0)]), 0)
4141
t21 = fn(torch.stack([fn(t2, 0), fn(t1, 0)]), 0)
4242
t02 = fn(torch.stack([fn(t0, 0), fn(t2, 0)]), 0)
@@ -125,6 +125,13 @@ def test_sum_all(self):
125125
def test_sum_dim(self):
126126
self._test_reduce_dim(torch.sum, True)
127127

128+
def test_max_all(self):
129+
self._test_allreduce(lambda x: x.max())
130+
131+
def test_max_dim(self):
132+
self._test_reduce_dim(lambda x, dim, keepdim=False: x.max(
133+
dim, keepdim)[0], True, test_multi_dim=False)
134+
128135
def test_mean_all(self):
129136
self._test_allreduce(lambda x: x.mean())
130137

0 commit comments

Comments
 (0)