Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[trunc/quant_avg_pool] Update Trunc and QuantAveragePool to match how Brevitas Ops work #170

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions docs/qonnx-custom-ops/trunc_op.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,20 @@ The attribute rounding_mode defines how truncated values are rounded.

#### Version

This operator is not part of the ONNX standard and is not currently versioned.
This operator is not part of the ONNX standard.
The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 2.

#### Attributes

<dl>
<dt><tt>rounding_mode</tt> : string (default is "FLOOR")</dt>
<dd>Defines how rounding should be applied during truncation. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".</dd>
<dt><tt>signed</tt> : int (default is 1)</dt>
<dd>Defines if the quantization includes a signed bit. E.g. at 8b unsigned=[0, 255] vs signed=[-128, 127].</dd>
<dt><tt>narrow</tt> : int (default is 0)</dt>
<dd>Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].</dd>
<dt><tt>output_scale</tt> : float32 (default is -1.0)</dt>
<dd>The scale factor of the output as a scalar. The output scale must represent a shift W.R.T. the input scale (i.e., scale) and therefore must be the input scale multiplied by a power-of-2. If output_scale is less-than-or-equal to 0, it is calculated as 2 ** (in_bitwidth - out_bitwidth) to approximately match the behaviour of qonnx.custom_ops.general opset version 1.</dd>
</dl>

#### Inputs
Expand Down Expand Up @@ -91,26 +98,32 @@ from __future__ import unicode_literals

import numpy as np

def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode):
# Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR
def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode):

# Scaling
y = inp_tensor / scale
y = y + zeropt
# Rounding
y = np.round(y)
# Truncate
trunc_bit_width = input_bit_width - output_bit_width
trunc_scale = 2.0 ** trunc_bit_width
# Rescale
trunc_scale = 2 ** np.round(
np.log2(output_scale / scale)
) # Trunc scale should be a power-of-two - ensure that is the case
y = y / trunc_scale

# To int
# Clamping
min_int_val = min_int(signed, narrow, output_bit_width)
max_int_val = max_int(signed, narrow, output_bit_width)
y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y)
y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y)
# To int (truncate)
rounding_fx = resolve_rounding_mode(rounding_mode)
y = rounding_fx(y)

# Rescale
y = y - zeropt
y = y * scale
output_zeropt = zeropt / trunc_scale # Rescale zero-point
y = y - output_zeropt
y = y * output_scale

return y

Expand Down
46 changes: 33 additions & 13 deletions src/qonnx/custom_op/general/trunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,38 +31,45 @@

from qonnx.core.datatype import DataType
from qonnx.custom_op.base import CustomOp
from qonnx.custom_op.general.quant import resolve_rounding_mode
from qonnx.custom_op.general.quant import max_int, min_int, resolve_rounding_mode


def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode):
def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode):
# Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR

# Scaling
y = inp_tensor / scale
y = y + zeropt
# Rounding
y = np.round(y)
# Truncate
trunc_bit_width = input_bit_width - output_bit_width
trunc_scale = 2.0**trunc_bit_width
# Rescale
trunc_scale = 2 ** np.round(
np.log2(output_scale / scale)
) # Trunc scale should be a power-of-two - ensure that is the case
y = y / trunc_scale

# To int
# Clamping
min_int_val = min_int(signed, narrow, output_bit_width)
max_int_val = max_int(signed, narrow, output_bit_width)
y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y)
y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y)
# To int (truncate)
rounding_fx = resolve_rounding_mode(rounding_mode)
y = rounding_fx(y)

# Rescale
y = y - zeropt
y = y * scale
output_zeropt = zeropt / trunc_scale # Rescale zero-point
y = y - output_zeropt
y = y * output_scale

return y


class Trunc(CustomOp):
"""Generic truncation operation for QONNX. Takes four inputs:
- input tensor to truncate
- the scale
- the zero-point
"""Generic truncation operation for QONNX. Takes four inputs:
- input tensor to truncate
- the scale
- the zero-point
- the truncation bit-width

The output is a tensor of the same shape as the input tensor, with truncated
Expand All @@ -73,6 +80,13 @@ def get_nodeattr_types(self):
return {
# The rounding mode, which is used for the trunc function
"rounding_mode": ("s", True, "FLOOR"),
"narrow": ("i", False, 0, {0, 1}),
"signed": ("i", False, 1, {0, 1}),
"output_scale": (
"f",
False,
-1.0,
), # Invalid scale signifies that it needs to be computed from input/output bit_width
}

def make_shape_compatible_op(self, model):
Expand All @@ -93,8 +107,14 @@ def execute_node(self, context, graph):
output_bit_width = context[node.input[4]]
# save attributes
rounding_mode = self.get_nodeattr("rounding_mode")
narrow = self.get_nodeattr("narrow")
signed = self.get_nodeattr("signed")
output_scale = self.get_nodeattr("output_scale")
output_scale = 2 ** (input_bit_width - output_bit_width) if output_scale <= 0.0 else output_scale
# calculate output
ret = trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode)
ret = trunc(
inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode
)
# set context according to output name
context[node.output[0]] = ret

Expand Down