diff --git a/docs/qonnx-custom-ops/trunc_op.md b/docs/qonnx-custom-ops/trunc_op.md index 1b5f0d04..51b5e3a4 100644 --- a/docs/qonnx-custom-ops/trunc_op.md +++ b/docs/qonnx-custom-ops/trunc_op.md @@ -6,13 +6,20 @@ The attribute rounding_mode defines how truncated values are rounded. #### Version -This operator is not part of the ONNX standard and is not currently versioned. +This operator is not part of the ONNX standard. +The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 2. #### Attributes
rounding_mode : string (default is "FLOOR")
Defines how rounding should be applied during truncation. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".
+
signed : int (default is 1)
+
Defines if the quantization includes a signed bit. E.g. at 8b unsigned=[0, 255] vs signed=[-128, 127].
+
narrow : int (default is 0)
+
Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].
+
output_scale : float32 (default is -1.0)
+
The scale factor of the output as a scalar. The output scale must represent a shift W.R.T. the input scale (i.e., scale) and therefore must be the input scale multiplied by a power-of-2. If output_scale is less-than-or-equal to 0, it is calculated as 2 ** (in_bitwidth - out_bitwidth) to approximately match the behaviour of qonnx.custom_ops.general opset version 1.
#### Inputs @@ -91,26 +98,32 @@ from __future__ import unicode_literals import numpy as np -def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode): - # Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR +def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode): # Scaling y = inp_tensor / scale y = y + zeropt # Rounding y = np.round(y) - # Truncate - trunc_bit_width = input_bit_width - output_bit_width - trunc_scale = 2.0 ** trunc_bit_width + # Rescale + trunc_scale = 2 ** np.round( + np.log2(output_scale / scale) + ) # Trunc scale should be a power-of-two - ensure that is the case y = y / trunc_scale - # To int + # Clamping + min_int_val = min_int(signed, narrow, output_bit_width) + max_int_val = max_int(signed, narrow, output_bit_width) + y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y) + y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y) + # To int (truncate) rounding_fx = resolve_rounding_mode(rounding_mode) y = rounding_fx(y) # Rescale - y = y - zeropt - y = y * scale + output_zeropt = zeropt / trunc_scale # Rescale zero-point + y = y - output_zeropt + y = y * output_scale return y diff --git a/src/qonnx/custom_op/general/trunc.py b/src/qonnx/custom_op/general/trunc.py index 8e2eaa19..85cd1db6 100644 --- a/src/qonnx/custom_op/general/trunc.py +++ b/src/qonnx/custom_op/general/trunc.py @@ -31,10 +31,10 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.general.quant import resolve_rounding_mode +from qonnx.custom_op.general.quant import max_int, min_int, resolve_rounding_mode -def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode): +def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode): # Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR # Scaling @@ -42,27 +42,34 @@ def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding y = y + zeropt # Rounding y = np.round(y) - # Truncate - trunc_bit_width = input_bit_width - output_bit_width - trunc_scale = 2.0**trunc_bit_width + # Rescale + trunc_scale = 2 ** np.round( + np.log2(output_scale / scale) + ) # Trunc scale should be a power-of-two - ensure that is the case y = y / trunc_scale - # To int + # Clamping + min_int_val = min_int(signed, narrow, output_bit_width) + max_int_val = max_int(signed, narrow, output_bit_width) + y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y) + y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y) + # To int (truncate) rounding_fx = resolve_rounding_mode(rounding_mode) y = rounding_fx(y) # Rescale - y = y - zeropt - y = y * scale + output_zeropt = zeropt / trunc_scale # Rescale zero-point + y = y - output_zeropt + y = y * output_scale return y class Trunc(CustomOp): - """Generic truncation operation for QONNX. Takes four inputs: - - input tensor to truncate - - the scale - - the zero-point + """Generic truncation operation for QONNX. Takes four inputs: + - input tensor to truncate + - the scale + - the zero-point - the truncation bit-width The output is a tensor of the same shape as the input tensor, with truncated @@ -73,6 +80,13 @@ def get_nodeattr_types(self): return { # The rounding mode, which is used for the trunc function "rounding_mode": ("s", True, "FLOOR"), + "narrow": ("i", False, 0, {0, 1}), + "signed": ("i", False, 1, {0, 1}), + "output_scale": ( + "f", + False, + -1.0, + ), # Invalid scale signifies that it needs to be computed from input/output bit_width } def make_shape_compatible_op(self, model): @@ -93,8 +107,14 @@ def execute_node(self, context, graph): output_bit_width = context[node.input[4]] # save attributes rounding_mode = self.get_nodeattr("rounding_mode") + narrow = self.get_nodeattr("narrow") + signed = self.get_nodeattr("signed") + output_scale = self.get_nodeattr("output_scale") + output_scale = 2 ** (input_bit_width - output_bit_width) if output_scale <= 0.0 else output_scale # calculate output - ret = trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode) + ret = trunc( + inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode + ) # set context according to output name context[node.output[0]] = ret