diff --git a/docs/qonnx-custom-ops/trunc_op.md b/docs/qonnx-custom-ops/trunc_op.md
index 1b5f0d04..51b5e3a4 100644
--- a/docs/qonnx-custom-ops/trunc_op.md
+++ b/docs/qonnx-custom-ops/trunc_op.md
@@ -6,13 +6,20 @@ The attribute rounding_mode defines how truncated values are rounded.
#### Version
-This operator is not part of the ONNX standard and is not currently versioned.
+This operator is not part of the ONNX standard.
+The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 2.
#### Attributes
- rounding_mode : string (default is "FLOOR")
- Defines how rounding should be applied during truncation. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".
+- signed : int (default is 1)
+- Defines if the quantization includes a signed bit. E.g. at 8b unsigned=[0, 255] vs signed=[-128, 127].
+- narrow : int (default is 0)
+- Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].
+- output_scale : float32 (default is -1.0)
+- The scale factor of the output as a scalar. The output scale must represent a shift W.R.T. the input scale (i.e., scale) and therefore must be the input scale multiplied by a power-of-2. If output_scale is less-than-or-equal to 0, it is calculated as 2 ** (in_bitwidth - out_bitwidth) to approximately match the behaviour of qonnx.custom_ops.general opset version 1.
#### Inputs
@@ -91,26 +98,32 @@ from __future__ import unicode_literals
import numpy as np
-def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode):
- # Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR
+def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode):
# Scaling
y = inp_tensor / scale
y = y + zeropt
# Rounding
y = np.round(y)
- # Truncate
- trunc_bit_width = input_bit_width - output_bit_width
- trunc_scale = 2.0 ** trunc_bit_width
+ # Rescale
+ trunc_scale = 2 ** np.round(
+ np.log2(output_scale / scale)
+ ) # Trunc scale should be a power-of-two - ensure that is the case
y = y / trunc_scale
- # To int
+ # Clamping
+ min_int_val = min_int(signed, narrow, output_bit_width)
+ max_int_val = max_int(signed, narrow, output_bit_width)
+ y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y)
+ y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y)
+ # To int (truncate)
rounding_fx = resolve_rounding_mode(rounding_mode)
y = rounding_fx(y)
# Rescale
- y = y - zeropt
- y = y * scale
+ output_zeropt = zeropt / trunc_scale # Rescale zero-point
+ y = y - output_zeropt
+ y = y * output_scale
return y
diff --git a/src/qonnx/custom_op/general/trunc.py b/src/qonnx/custom_op/general/trunc.py
index 8e2eaa19..85cd1db6 100644
--- a/src/qonnx/custom_op/general/trunc.py
+++ b/src/qonnx/custom_op/general/trunc.py
@@ -31,10 +31,10 @@
from qonnx.core.datatype import DataType
from qonnx.custom_op.base import CustomOp
-from qonnx.custom_op.general.quant import resolve_rounding_mode
+from qonnx.custom_op.general.quant import max_int, min_int, resolve_rounding_mode
-def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode):
+def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode):
# Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR
# Scaling
@@ -42,27 +42,34 @@ def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding
y = y + zeropt
# Rounding
y = np.round(y)
- # Truncate
- trunc_bit_width = input_bit_width - output_bit_width
- trunc_scale = 2.0**trunc_bit_width
+ # Rescale
+ trunc_scale = 2 ** np.round(
+ np.log2(output_scale / scale)
+ ) # Trunc scale should be a power-of-two - ensure that is the case
y = y / trunc_scale
- # To int
+ # Clamping
+ min_int_val = min_int(signed, narrow, output_bit_width)
+ max_int_val = max_int(signed, narrow, output_bit_width)
+ y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y)
+ y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y)
+ # To int (truncate)
rounding_fx = resolve_rounding_mode(rounding_mode)
y = rounding_fx(y)
# Rescale
- y = y - zeropt
- y = y * scale
+ output_zeropt = zeropt / trunc_scale # Rescale zero-point
+ y = y - output_zeropt
+ y = y * output_scale
return y
class Trunc(CustomOp):
- """Generic truncation operation for QONNX. Takes four inputs:
- - input tensor to truncate
- - the scale
- - the zero-point
+ """Generic truncation operation for QONNX. Takes four inputs:
+ - input tensor to truncate
+ - the scale
+ - the zero-point
- the truncation bit-width
The output is a tensor of the same shape as the input tensor, with truncated
@@ -73,6 +80,13 @@ def get_nodeattr_types(self):
return {
# The rounding mode, which is used for the trunc function
"rounding_mode": ("s", True, "FLOOR"),
+ "narrow": ("i", False, 0, {0, 1}),
+ "signed": ("i", False, 1, {0, 1}),
+ "output_scale": (
+ "f",
+ False,
+ -1.0,
+ ), # Invalid scale signifies that it needs to be computed from input/output bit_width
}
def make_shape_compatible_op(self, model):
@@ -93,8 +107,14 @@ def execute_node(self, context, graph):
output_bit_width = context[node.input[4]]
# save attributes
rounding_mode = self.get_nodeattr("rounding_mode")
+ narrow = self.get_nodeattr("narrow")
+ signed = self.get_nodeattr("signed")
+ output_scale = self.get_nodeattr("output_scale")
+ output_scale = 2 ** (input_bit_width - output_bit_width) if output_scale <= 0.0 else output_scale
# calculate output
- ret = trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode)
+ ret = trunc(
+ inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode
+ )
# set context according to output name
context[node.output[0]] = ret