Skip to content

Commit 5abfd97

Browse files
jaybdubalsrgv
andauthored
Add layer names (NVIDIA-AI-IOT#432)
* Auto-generate custom layer names * fixed layer name count key * updated changelog for adding layer names Co-authored-by: Alex Sergeev <[email protected]>
1 parent f087c03 commit 5abfd97

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44

55
### Added
66

7+
- Added names for TensorRT layers
78
- Replaced Tensor.ndim references with len(tensor.shape) to support older pytorch versions
89
- Added reduced precision documentation page

torch2trt/torch2trt.py

+36-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from copy import copy
44
import numpy as np
55
import io
6+
from collections import defaultdict
67

78
from .calibration import (
89
TensorBatchDataset,
@@ -326,10 +327,43 @@ def default_input_names(num_inputs):
326327

327328
def default_output_names(num_outputs):
328329
return ["output_%d" % i for i in range(num_outputs)]
329-
330+
331+
332+
class LayerNamingNetworkWrapper(object):
333+
def __init__(self, ctx, network):
334+
self._ctx = ctx
335+
self._network = network
336+
self._layer_counts = defaultdict(lambda: 0)
337+
338+
def _set_layer_name(self, layer):
339+
def arg_str(arg):
340+
if isinstance(arg, torch.Tensor):
341+
return "tensor(shape=%s, dtype=%s)" % (str(list(arg.shape)), str(arg.dtype))
342+
return str(arg)
343+
344+
self._layer_counts[layer.type.name] += 1
345+
args = [arg_str(arg) for arg in self._ctx.method_args]
346+
kwargs = ["%s=%s" % (key, arg_str(arg)) for key, arg in self._ctx.method_kwargs.items()]
347+
layer.name = "[%s #%d] %s(%s)" % (layer.type.name, self._layer_counts[layer.type.name],
348+
self._ctx.method_str, ", ".join(args + kwargs))
349+
350+
def __getattr__(self, name):
351+
attr = getattr(self._network, name)
352+
if callable(attr):
353+
def wrapper(*args, **kwargs):
354+
ret = attr(*args, **kwargs)
355+
if isinstance(ret, trt.ILayer):
356+
self._set_layer_name(ret)
357+
return ret
358+
359+
return wrapper
360+
else:
361+
return attr
362+
363+
330364
class ConversionContext(object):
331365
def __init__(self, network, converters=CONVERTERS):
332-
self.network = network
366+
self.network = LayerNamingNetworkWrapper(self, network)
333367
self.lock = False
334368
self.method_args = None
335369
self.method_kwargs = None

0 commit comments

Comments
 (0)