|
3 | 3 | from copy import copy
|
4 | 4 | import numpy as np
|
5 | 5 | import io
|
| 6 | +from collections import defaultdict |
6 | 7 |
|
7 | 8 | from .calibration import (
|
8 | 9 | TensorBatchDataset,
|
@@ -326,10 +327,43 @@ def default_input_names(num_inputs):
|
326 | 327 |
|
327 | 328 | def default_output_names(num_outputs):
|
328 | 329 | return ["output_%d" % i for i in range(num_outputs)]
|
329 |
| - |
| 330 | + |
| 331 | + |
| 332 | +class LayerNamingNetworkWrapper(object): |
| 333 | + def __init__(self, ctx, network): |
| 334 | + self._ctx = ctx |
| 335 | + self._network = network |
| 336 | + self._layer_counts = defaultdict(lambda: 0) |
| 337 | + |
| 338 | + def _set_layer_name(self, layer): |
| 339 | + def arg_str(arg): |
| 340 | + if isinstance(arg, torch.Tensor): |
| 341 | + return "tensor(shape=%s, dtype=%s)" % (str(list(arg.shape)), str(arg.dtype)) |
| 342 | + return str(arg) |
| 343 | + |
| 344 | + self._layer_counts[layer.type.name] += 1 |
| 345 | + args = [arg_str(arg) for arg in self._ctx.method_args] |
| 346 | + kwargs = ["%s=%s" % (key, arg_str(arg)) for key, arg in self._ctx.method_kwargs.items()] |
| 347 | + layer.name = "[%s #%d] %s(%s)" % (layer.type.name, self._layer_counts[layer.type.name], |
| 348 | + self._ctx.method_str, ", ".join(args + kwargs)) |
| 349 | + |
| 350 | + def __getattr__(self, name): |
| 351 | + attr = getattr(self._network, name) |
| 352 | + if callable(attr): |
| 353 | + def wrapper(*args, **kwargs): |
| 354 | + ret = attr(*args, **kwargs) |
| 355 | + if isinstance(ret, trt.ILayer): |
| 356 | + self._set_layer_name(ret) |
| 357 | + return ret |
| 358 | + |
| 359 | + return wrapper |
| 360 | + else: |
| 361 | + return attr |
| 362 | + |
| 363 | + |
330 | 364 | class ConversionContext(object):
|
331 | 365 | def __init__(self, network, converters=CONVERTERS):
|
332 |
| - self.network = network |
| 366 | + self.network = LayerNamingNetworkWrapper(self, network) |
333 | 367 | self.lock = False
|
334 | 368 | self.method_args = None
|
335 | 369 | self.method_kwargs = None
|
|
0 commit comments