Skip to content

Commit 7232d34

Browse files
jaybdubsandeepkumar-skb
and
sandeepkumar-skb
authored
Sandeepkumar skb groupnorm plugin (NVIDIA-AI-IOT#437)
* added plugin for GroupNorm Co-authored-by: sandeepkumar-skb <[email protected]>
1 parent 5abfd97 commit 7232d34

File tree

8 files changed

+380
-17
lines changed

8 files changed

+380
-17
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@
55
### Added
66

77
- Added names for TensorRT layers
8+
- Added GroupNorm plugin which internally uses PyTorch aten::group_norm
89
- Replaced Tensor.ndim references with len(tensor.shape) to support older pytorch versions
910
- Added reduced precision documentation page

build.py

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
PLUGINS = [
77
'interpolate',
8+
'group_norm',
89
]
910

1011
BASE_FOLDER = 'torch2trt/converters'

setup.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def trt_lib_dir():
1414
plugins_ext_module = CUDAExtension(
1515
name='plugins',
1616
sources=[
17-
'torch2trt/plugins/interpolate.cpp'
17+
'torch2trt/plugins/plugins.cpp'
1818
],
1919
include_dirs=[
2020
trt_inc_dir()
@@ -29,8 +29,7 @@ def trt_lib_dir():
2929
'cxx': ['-DUSE_DEPRECATED_INTLIST'] if torch.__version__ < "1.5" else [],
3030
'nvcc': []
3131
}
32-
)
33-
32+
)
3433
if '--plugins' in sys.argv:
3534
ext_modules.append(plugins_ext_module)
3635
sys.argv.remove('--plugins')

torch2trt/converters/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from .identity import *
3333
from .instance_norm import *
3434
from .interpolate import *
35+
from .group_norm import *
3536
from .max import *
3637
from .max_pool2d import *
3738
from .mean import *

torch2trt/converters/group_norm.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import torch.nn as nn
2+
from torch2trt.torch2trt import *
3+
from torch2trt.module_test import add_module_test
4+
5+
def has_group_norm_plugin():
6+
try:
7+
from torch2trt.plugins import GroupNormPlugin
8+
return True
9+
except:
10+
return False
11+
12+
13+
def get_group_norm_plugin(num_groups, weight, bias, eps):
14+
from torch2trt.plugins import GroupNormPlugin
15+
PLUGIN_NAME = 'group_norm'
16+
registry = trt.get_plugin_registry()
17+
creator = [c for c in registry.plugin_creator_list if c.name == PLUGIN_NAME and c.plugin_namespace == 'torch2trt'][0]
18+
torch2trt_plugin = GroupNormPlugin(num_groups=num_groups, weight=weight, bias=bias, eps=eps)
19+
return creator.deserialize_plugin(PLUGIN_NAME, torch2trt_plugin.serializeToString())
20+
21+
@tensorrt_converter('torch.nn.GroupNorm.forward', has_group_norm_plugin())
22+
def convert_group_norm_trt(ctx):
23+
module = ctx.method_args[0]
24+
input = ctx.method_args[1]
25+
num_groups = module.num_groups
26+
weight = module.weight
27+
bias = module.bias
28+
eps = module.eps
29+
input_trt = add_missing_trt_tensors(ctx.network, [input])
30+
output = ctx.method_return
31+
plugin = get_group_norm_plugin(num_groups, weight, bias, eps)
32+
33+
layer = ctx.network.add_plugin_v2(input_trt, plugin)
34+
35+
output._trt = layer.get_output(0)
36+
37+
38+
39+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 112, 112)], has_group_norm_plugin())
40+
def test_group_norm_trt_g2_fp32():
41+
return torch.nn.GroupNorm(2, 10)
42+
43+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 112, 112)], has_group_norm_plugin())
44+
def test_group_norm_trt_g2_eps_fp32():
45+
return torch.nn.GroupNorm(2, 10, eps=1e-4)
46+
47+
48+

torch2trt/plugins/group_norm.cpp

+296
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
#include <torch/extension.h>
2+
#include <torch/script.h>
3+
#include <iostream>
4+
#include <string>
5+
#include <sstream>
6+
#include <NvInfer.h>
7+
#include <ATen/ATen.h>
8+
#include <ATen/cuda/CUDAEvent.h>
9+
#include <torch/torch.h>
10+
#include <cuda_runtime_api.h>
11+
12+
using namespace nvinfer1;
13+
14+
namespace torch2trt {
15+
16+
class GroupNormPlugin : public IPluginV2 {
17+
private:
18+
// configured by class
19+
at::TensorOptions tensor_options;
20+
std::vector<int64_t> input_sizes;
21+
std::vector<int64_t> output_sizes;
22+
DataType dtype;
23+
24+
// group norm parameters, configured by user
25+
int64_t num_groups;
26+
at::Tensor weight;
27+
at::Tensor bias;
28+
double eps;
29+
30+
31+
public:
32+
33+
// create from arguments
34+
GroupNormPlugin(int64_t num_groups, at::Tensor weight, at::Tensor bias, double eps) :
35+
num_groups{num_groups}, weight{weight}, bias{bias}, eps{eps}
36+
{}
37+
38+
GroupNormPlugin(const char *data, size_t length) : GroupNormPlugin(std::string(data, length)) {}
39+
40+
GroupNormPlugin(const std::string &data){
41+
deserializeFromString(data);
42+
}
43+
44+
void deserializeFromString(const std::string &data) {
45+
std::istringstream data_stream(data);
46+
torch::serialize::InputArchive input_archive;
47+
input_archive.load_from(data_stream);
48+
{
49+
torch::IValue value;
50+
input_archive.read("num_groups", value);
51+
#ifdef USE_DEPRECATED_INTLIST
52+
num_groups = value.toIntListRef().vec();
53+
#else
54+
num_groups = value.toInt();
55+
#endif
56+
}
57+
{
58+
torch::IValue value;
59+
input_archive.read("weight", value);
60+
weight = value.toTensor();
61+
}
62+
{
63+
torch::IValue value;
64+
input_archive.read("bias", value);
65+
bias = value.toTensor();
66+
}
67+
68+
{
69+
torch::IValue value;
70+
input_archive.read("eps", value);
71+
#ifdef USE_DEPRECATED_INTLIST
72+
eps = value.toDoubleListRef().vec();
73+
#else
74+
eps = value.toDouble();
75+
#endif
76+
}
77+
{
78+
torch::IValue value;
79+
input_archive.read("dtype", value);
80+
dtype = (DataType) value.toInt();
81+
}
82+
{
83+
torch::IValue value;
84+
input_archive.read("input_sizes", value);
85+
#ifdef USE_DEPRECATED_INTLIST
86+
input_sizes = value.toIntListRef().vec();
87+
#else
88+
input_sizes = value.toIntVector();
89+
#endif
90+
}
91+
{
92+
torch::IValue value;
93+
input_archive.read("output_sizes", value);
94+
#ifdef USE_DEPRECATED_INTLIST
95+
output_sizes = value.toIntListRef().vec();
96+
#else
97+
output_sizes = value.toIntVector();
98+
#endif
99+
}
100+
}
101+
std::string serializeToString() const {
102+
torch::serialize::OutputArchive output_archive;
103+
output_archive.write("num_groups", torch::IValue(num_groups));
104+
output_archive.write("weight", torch::IValue(weight));
105+
output_archive.write("bias", torch::IValue(bias));
106+
output_archive.write("eps", torch::IValue(eps));
107+
output_archive.write("dtype", torch::IValue((int) dtype));
108+
output_archive.write("input_sizes", torch::IValue(input_sizes));
109+
output_archive.write("output_sizes", torch::IValue(output_sizes));
110+
std::ostringstream data_str;
111+
output_archive.save_to(data_str);
112+
return data_str.str();
113+
}
114+
115+
const char* getPluginType() const override {
116+
return "group_norm";
117+
};
118+
119+
const char* getPluginVersion() const override {
120+
return "1";
121+
}
122+
123+
int getNbOutputs() const override {
124+
return 1;
125+
}
126+
127+
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override {
128+
Dims dims;
129+
dims.nbDims = inputs->nbDims;
130+
131+
for (int i = 0; i < inputs->nbDims; i++) {
132+
dims.d[i] = inputs->d[i];
133+
}
134+
135+
return dims;
136+
}
137+
138+
bool supportsFormat(DataType type, PluginFormat format) const override {
139+
if (format != PluginFormat::kNCHW) {
140+
return false;
141+
}
142+
if (type == DataType::kINT32 || type == DataType::kINT8) {
143+
return false;
144+
}
145+
return true;
146+
}
147+
148+
void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims,
149+
int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) override {
150+
151+
// set data type
152+
if (type == DataType::kFLOAT) {
153+
tensor_options = tensor_options.dtype(c10::kFloat);
154+
dtype = type;
155+
} else if (type == DataType::kHALF) {
156+
tensor_options = tensor_options.dtype(c10::kHalf);
157+
dtype = type;
158+
}
159+
160+
// set input sizes
161+
input_sizes.resize(inputDims[0].nbDims);
162+
for (int i = 0; i < inputDims[0].nbDims; i++) {
163+
input_sizes[i] = inputDims[0].d[i];
164+
}
165+
166+
// set output sizes
167+
output_sizes.resize(outputDims[0].nbDims);
168+
for (int i = 0; i < outputDims[0].nbDims; i++) {
169+
output_sizes[i] = outputDims[0].d[i];
170+
}
171+
}
172+
173+
int initialize() override {
174+
// set device
175+
tensor_options = tensor_options.device(c10::kCUDA);
176+
177+
// set data type
178+
if (dtype == DataType::kFLOAT) {
179+
tensor_options = tensor_options.dtype(c10::kFloat);
180+
} else if (dtype == DataType::kHALF) {
181+
tensor_options = tensor_options.dtype(c10::kHalf);
182+
}
183+
184+
185+
weight = weight.to(tensor_options);
186+
bias = bias.to(tensor_options);
187+
188+
return 0;
189+
}
190+
191+
void terminate() override {}
192+
193+
size_t getWorkspaceSize(int maxBatchSize) const override { return 0; }
194+
195+
int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override {
196+
// get input / output dimensions
197+
std::vector<long> batch_input_sizes = input_sizes;
198+
std::vector<long> batch_output_sizes = output_sizes;
199+
batch_input_sizes.insert(batch_input_sizes.begin(), batchSize);
200+
batch_output_sizes.insert(batch_output_sizes.begin(), batchSize);
201+
202+
// create tensor wrappers
203+
at::Tensor input = at::from_blob((void*) inputs[0], batch_input_sizes, [](void*){}, tensor_options);
204+
at::Tensor output = at::from_blob(outputs[0], batch_output_sizes, [](void*){}, tensor_options);
205+
206+
// create new torch cuda stream
207+
at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool();
208+
at::cuda::CUDAStreamGuard torch_guard(torch_stream);
209+
210+
// capture current work on tensorrt cuda stream
211+
cudaEvent_t event;
212+
cudaEventCreate(&event);
213+
cudaEventRecord(event, stream);
214+
215+
// make torch cuda stream wait on tensorrt work
216+
cudaStreamWaitEvent(torch_stream.stream(), event, 0);
217+
218+
219+
220+
// enqueue work
221+
// Group_norm function from PyTorch: https://pytorch.org/cppdocs/api/function_namespaceat_1a6bc1e9504ea440c6c96ff8a8b94333f2.html#exhale-function-namespaceat-1a6bc1e9504ea440c6c96ff8a8b94333f2
222+
at::Tensor output_tmp = at::group_norm(input, num_groups, weight, bias, eps=eps);
223+
output.copy_(output_tmp);
224+
225+
// capture event on enqueued stream
226+
cudaEvent_t torch_event;
227+
cudaEventCreate(&torch_event);
228+
cudaEventRecord(torch_event, torch_stream.stream());
229+
cudaStreamWaitEvent(stream, torch_event, 0);
230+
231+
cudaEventDestroy(event);
232+
cudaEventDestroy(torch_event);
233+
234+
return 0;
235+
}
236+
237+
238+
size_t getSerializationSize() const override {
239+
return serializeToString().size();
240+
}
241+
242+
void serialize(void* buffer) const override {
243+
std::string data = serializeToString();
244+
size_t size = getSerializationSize();
245+
data.copy((char *) buffer, size);
246+
}
247+
248+
void destroy() override {}
249+
250+
IPluginV2* clone() const override {
251+
return new GroupNormPlugin(num_groups, weight, bias, eps);
252+
}
253+
254+
void setPluginNamespace(const char* pluginNamespace) override {}
255+
256+
const char *getPluginNamespace() const override {
257+
return "torch2trt";
258+
}
259+
260+
};
261+
262+
class GroupNormPluginCreator : public IPluginCreator {
263+
public:
264+
GroupNormPluginCreator() {}
265+
266+
const char *getPluginNamespace() const override {
267+
return "torch2trt";
268+
}
269+
270+
const char *getPluginName() const override {
271+
return "group_norm";
272+
}
273+
274+
const char *getPluginVersion() const override {
275+
return "1";
276+
}
277+
278+
IPluginV2 *deserializePlugin(const char *name, const void *data, size_t length) override {
279+
return new GroupNormPlugin((const char*) data, length);
280+
}
281+
282+
void setPluginNamespace(const char *N) override {}
283+
const PluginFieldCollection *getFieldNames() override { return nullptr; }
284+
285+
IPluginV2 *createPlugin(const char *name, const PluginFieldCollection *fc) override { return nullptr; }
286+
287+
};
288+
289+
290+
REGISTER_TENSORRT_PLUGIN(GroupNormPluginCreator);
291+
292+
} // namespace torch2trt
293+
294+
295+
296+

0 commit comments

Comments
 (0)