From dad98696a28b3c2781f13c53b5c463e58f1df891 Mon Sep 17 00:00:00 2001
From: jvreca <jure.vreca@ijs.si>
Date: Fri, 22 Mar 2024 15:23:37 +0100
Subject: [PATCH 01/13] Added test, docs/, and updated resolve_rounding_mode
 function to return new rounding modes.

---
 docs/qonnx-custom-ops/quant_op.md    | 19 ++++++++++++++++++-
 src/qonnx/custom_op/general/quant.py | 16 +++++++++++++++-
 tests/custom_op/test_runding_mode.py | 20 ++++++++++++++++++++
 3 files changed, 53 insertions(+), 2 deletions(-)
 create mode 100644 tests/custom_op/test_runding_mode.py

diff --git a/docs/qonnx-custom-ops/quant_op.md b/docs/qonnx-custom-ops/quant_op.md
index 02d115fb..953fdca7 100644
--- a/docs/qonnx-custom-ops/quant_op.md
+++ b/docs/qonnx-custom-ops/quant_op.md
@@ -21,7 +21,7 @@ This operator is not part of the ONNX standard and is not currently versioned.
 <dt><tt>narrow</tt> : int (default is 0)</dt>
 <dd>Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].</dd>
 <dt><tt>rounding_mode</tt> : string (default is "ROUND")</dt>
-<dd>Defines how rounding should be applied during quantization. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".</dd>
+<dd>Defines how rounding should be applied during quantization. Avaiable options are ROUND, CEIL, FLOOR, UP, DOWN, HALF_UP, HALF_DOWN. The rounding modes are described in the table bellow. The names of rounding modes can be upper case or lower case.</dd>
 </dl>
 
 #### Inputs
@@ -46,6 +46,23 @@ This operator is not part of the ONNX standard and is not currently versioned.
 </dl>
 
 
+#### Rounding modes
+<details>
+<summary>rounding modes</summary>
+| **Number \ ROUNDING_MODE** 	| ROUND=HALF_EVEN 	| CEIL 	| FLOOR 	| UP 	| DOWN 	| HALF_UP 	| HALF_DOWN 	|
+|----------------------------	|-----------------	|------	|-------	|----	|------	|---------	|-----------	|
+| 5.5                        	| 6               	| 6    	| 5     	| 6  	| 5    	| 6       	| 5         	|
+| 2.5                        	| 2               	| 3    	| 2     	| 3  	| 2    	| 3       	| 2         	|
+| 1.6                        	| 2               	| 2    	| 1     	| 2  	| 1    	| 2       	| 2         	|
+| 1.1                        	| 1               	| 2    	| 1     	| 2  	| 1    	| 1       	| 1         	|
+| 1.0                        	| 1               	| 1    	| 1     	| 1  	| 1    	| 1       	| 1         	|
+| -1.0                       	| -1              	| -1   	| -1    	| -1 	| -1   	| -1      	| -1        	|
+| -1.1                       	| -1              	| -1   	| -2    	| -2 	| -1   	| -1      	| -1        	|
+| -1.6                       	| -2              	| -1   	| -2    	| -2 	| -1   	| -2      	| -2        	|
+| -2.5                       	| -2              	| -2   	| -3    	| -3 	| -2   	| -3      	| -2        	|
+| -5.5                       	| -6              	| -5   	| -6    	| -6 	| -5   	| -6      	| -5        	|
+</details>
+
 #### Examples
 <details>
 <summary>Quant</summary>
diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py
index f552e7a8..15afd048 100644
--- a/src/qonnx/custom_op/general/quant.py
+++ b/src/qonnx/custom_op/general/quant.py
@@ -135,12 +135,26 @@ def resolve_rounding_mode(mode_string):
     """Resolve the rounding mode string of Quant and Trunc ops
     to the corresponding numpy functions."""
     normalized_mode_string = mode_string.upper()
-    if normalized_mode_string == "ROUND":
+    if normalized_mode_string == "ROUND" or normalized_mode_string == "HALF_TO_EVEN":
         return np.round
     elif normalized_mode_string == "CEIL":
         return np.ceil
     elif normalized_mode_string == "FLOOR":
         return np.floor
+    elif normalized_mode_string == "UP":
+        def round_up(x):
+            return np.sign(x) * np.ceil(np.abs(x))
+        return round_up
+    elif normalized_mode_string == "DOWN":
+        return np.fix
+    elif normalized_mode_string == "HALF_UP":
+        def round_half_up(x):
+            return np.sign(x) * np.floor(np.abs(x) + 0.5)
+        return round_half_up
+    elif normalized_mode_string == "HALF_DOWN":
+        def round_half_down(x):
+            return np.sign(x) * np.ceil(np.abs(x) - 0.5)
+        return round_half_down
     else:
         raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}")
 
diff --git a/tests/custom_op/test_runding_mode.py b/tests/custom_op/test_runding_mode.py
new file mode 100644
index 00000000..54a81f0e
--- /dev/null
+++ b/tests/custom_op/test_runding_mode.py
@@ -0,0 +1,20 @@
+import pytest
+
+import numpy as np
+
+from qonnx.custom_op.general.quant import resolve_rounding_mode
+
+@pytest.mark.parametrize("rmode,exp", [
+        ("ROUND", np.array([6, 2, 2, 1, 1, -1, -1, -2, -2, -6])),
+        ("CEIL", np.array([6, 3, 2, 2, 1, -1, -1, -1, -2, - 5])),
+        ("FLOOR", np.array([5, 2, 1, 1, 1, -1, -2, -2, -3, -6])),
+        ("UP", np.array([6, 3, 2, 2, 1, -1, -2, -2, -3, -6])),
+        ("DOWN", np.array([5, 2, 1, 1, 1, -1, -1, -1, -2, -5])),
+        ("HALF_UP", np.array([6, 3, 2, 1, 1, -1, -1, -2, -3, -6])),
+        ("HALF_DOWN", np.array([5, 2, 2, 1, 1, -1, -1, -2, -2, -5]))
+    ]
+)
+def test_rounding_modes(rmode, exp):
+    test_array = np.array([5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5])
+    rounding_fn = resolve_rounding_mode(rmode)
+    assert np.array_equal(rounding_fn(test_array), exp)

From 47a88e4fb4d0bc69059297bbea39e69650f95d1f Mon Sep 17 00:00:00 2001
From: jvreca <jure.vreca@ijs.si>
Date: Fri, 22 Mar 2024 15:42:21 +0100
Subject: [PATCH 02/13] Fix table visualization.

---
 docs/qonnx-custom-ops/quant_op.md | 24 ++++++++++++------------
 1 file changed, 12 insertions(+), 12 deletions(-)

diff --git a/docs/qonnx-custom-ops/quant_op.md b/docs/qonnx-custom-ops/quant_op.md
index 953fdca7..b9e11c79 100644
--- a/docs/qonnx-custom-ops/quant_op.md
+++ b/docs/qonnx-custom-ops/quant_op.md
@@ -49,18 +49,18 @@ This operator is not part of the ONNX standard and is not currently versioned.
 #### Rounding modes
 <details>
 <summary>rounding modes</summary>
-| **Number \ ROUNDING_MODE** 	| ROUND=HALF_EVEN 	| CEIL 	| FLOOR 	| UP 	| DOWN 	| HALF_UP 	| HALF_DOWN 	|
-|----------------------------	|-----------------	|------	|-------	|----	|------	|---------	|-----------	|
-| 5.5                        	| 6               	| 6    	| 5     	| 6  	| 5    	| 6       	| 5         	|
-| 2.5                        	| 2               	| 3    	| 2     	| 3  	| 2    	| 3       	| 2         	|
-| 1.6                        	| 2               	| 2    	| 1     	| 2  	| 1    	| 2       	| 2         	|
-| 1.1                        	| 1               	| 2    	| 1     	| 2  	| 1    	| 1       	| 1         	|
-| 1.0                        	| 1               	| 1    	| 1     	| 1  	| 1    	| 1       	| 1         	|
-| -1.0                       	| -1              	| -1   	| -1    	| -1 	| -1   	| -1      	| -1        	|
-| -1.1                       	| -1              	| -1   	| -2    	| -2 	| -1   	| -1      	| -1        	|
-| -1.6                       	| -2              	| -1   	| -2    	| -2 	| -1   	| -2      	| -2        	|
-| -2.5                       	| -2              	| -2   	| -3    	| -3 	| -2   	| -3      	| -2        	|
-| -5.5                       	| -6              	| -5   	| -6    	| -6 	| -5   	| -6      	| -5        	|
+| **Number \ ROUNDING_MODE** | ROUND=HALF_EVEN | CEIL | FLOOR | UP | DOWN | HALF_UP | HALF_DOWN |
+|----------------------------|-----------------|------|-------|----|------|---------|-----------|
+| 5.5                        | 6               | 6    | 5     | 6  | 5    | 6       | 5         |
+| 2.5                        | 2               | 3    | 2     | 3  | 2    | 3       | 2         |
+| 1.6                        | 2               | 2    | 1     | 2  | 1    | 2       | 2         |
+| 1.1                        | 1               | 2    | 1     | 2  | 1    | 1       | 1         |
+| 1.0                        | 1               | 1    | 1     | 1  | 1    | 1       | 1         |
+| -1.0                       | -1              | -1   | -1    | -1 | -1   | -1      | -1        |
+| -1.1                       | -1              | -1   | -2    | -2 | -1   | -1      | -1        |
+| -1.6                       | -2              | -1   | -2    | -2 | -1   | -2      | -2        |
+| -2.5                       | -2              | -2   | -3    | -3 | -2   | -3      | -2        |
+| -5.5                       | -6              | -5   | -6    | -6 | -5   | -6      | -5        |
 </details>
 
 #### Examples

From e2c15045d9ccf5f4c8162c1555c366c337616fee Mon Sep 17 00:00:00 2001
From: jvreca <jure.vreca@ijs.si>
Date: Fri, 22 Mar 2024 15:43:56 +0100
Subject: [PATCH 03/13] Fix table visualization again.

---
 docs/qonnx-custom-ops/quant_op.md | 1 +
 1 file changed, 1 insertion(+)

diff --git a/docs/qonnx-custom-ops/quant_op.md b/docs/qonnx-custom-ops/quant_op.md
index b9e11c79..68029406 100644
--- a/docs/qonnx-custom-ops/quant_op.md
+++ b/docs/qonnx-custom-ops/quant_op.md
@@ -49,6 +49,7 @@ This operator is not part of the ONNX standard and is not currently versioned.
 #### Rounding modes
 <details>
 <summary>rounding modes</summary>
+
 | **Number \ ROUNDING_MODE** | ROUND=HALF_EVEN | CEIL | FLOOR | UP | DOWN | HALF_UP | HALF_DOWN |
 |----------------------------|-----------------|------|-------|----|------|---------|-----------|
 | 5.5                        | 6               | 6    | 5     | 6  | 5    | 6       | 5         |

From baa0df36b48b69a515604a3f5f4c04a00bd0712f Mon Sep 17 00:00:00 2001
From: jvreca <jure.vreca@ijs.si>
Date: Fri, 9 Aug 2024 14:06:07 +0200
Subject: [PATCH 04/13] Fixed converter to allow alpha/scale to be a tensor

Fixed rounding_mode specifier in convert_quantized_bits
---
 src/qonnx/converters/qkeras/onnx.py       | 14 ++++++++++++--
 src/qonnx/converters/qkeras/quantizers.py | 14 +++++++++-----
 src/qonnx/custom_op/general/quant.py      |  2 +-
 3 files changed, 22 insertions(+), 8 deletions(-)

diff --git a/src/qonnx/converters/qkeras/onnx.py b/src/qonnx/converters/qkeras/onnx.py
index 1f34d653..3bb1fed9 100644
--- a/src/qonnx/converters/qkeras/onnx.py
+++ b/src/qonnx/converters/qkeras/onnx.py
@@ -55,7 +55,7 @@ def qlayer_handler(ctx, node, name, args):
     quantizers = all_quantizers[keras_name]
     if quantizers.get("kernel_quantizer"):
         weights = node.inputs[1].get_tensor_value(as_list=True)
-        quant_params = get_quant_params(weights, quantizers["kernel_quantizer"])
+        quant_params = get_quant_params(weights, quantizers["kernel_initializer"]['config']['quantizer'])
         attr = quant_params["attributes"]
         input_nodes = [node.input[1]]
         for key in quant_params["inputs"].keys():
@@ -63,9 +63,19 @@ def qlayer_handler(ctx, node, name, args):
             np_val = np.asarray(quant_params["inputs"][key])
             ctx.make_const(name, np_val)
             input_nodes.append(name)
-        ctx.insert_new_node_on_input(
+        quant_node = ctx.insert_new_node_on_input(
             node, "Quant", input_nodes, name=node.name + "_kernel_quantizer", **attr, domain="qonnx"
         )
+        scale_node = ctx.make_const(
+            name = node.name + "_kernel_scale",
+            np_val = quant_params['inputs']['scale'].astype(np.float32)
+        )
+        ctx.insert_new_node_on_output(
+            op_type = "Mul", 
+            output_name = quant_node.output[0],
+            name = node.name + "_kernel_requantizer",
+            inputs = [quant_node.output[0], scale_node.name]
+        )
 
     if quantizers.get("bias_quantizer") and len(node.input) == 3:
         bias = node.inputs[2].get_tensor_value(as_list=True)
diff --git a/src/qonnx/converters/qkeras/quantizers.py b/src/qonnx/converters/qkeras/quantizers.py
index 983cc997..c6a00a00 100644
--- a/src/qonnx/converters/qkeras/quantizers.py
+++ b/src/qonnx/converters/qkeras/quantizers.py
@@ -1,9 +1,9 @@
 import qkeras
 import six
-
+import numpy as np
 
 def get_quant_params(tensor, qkeras_quantizer):
-    if isinstance(qkeras_quantizer, str):
+    if isinstance(qkeras_quantizer, (str, dict)):
         qkeras_quantizer = qkeras.get_quantizer(qkeras_quantizer)
 
     return handler_map[qkeras_quantizer.__class__.__name__](tensor, qkeras_quantizer)
@@ -34,11 +34,15 @@ def convert_quantized_bits(tensor, quantizer):
     signed = int(config["keep_negative"])
     narrow = int(config["symmetric"])
     qscale = _get_quantizer_scale(tensor, quantizer)
-    assert qscale == 1, "Non-unity alpha is not yet supported"
-    scale = 1.0 / 2 ** (int(config["bits"]) - int(config["integer"] + signed))
+    if not isinstance(qscale, np.ndarray):
+        qscale = np.array(qscale) 
+    scale = qscale / 2 ** (int(config["bits"]) - int(config["integer"] + signed))
     zero_point = 0
     bit_width = int(config["bits"])
-    rounding_mode = "ROUND"
+    if config['alpha'] == "auto_po2":
+        rounding_mode = "ROUND_UP"
+    else:
+        rounding_mode = "HALF_EVEN"
 
     settings = {
         "attributes": {"signed": signed, "narrow": narrow, "rounding_mode": rounding_mode},
diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py
index 15afd048..5af3f9f3 100644
--- a/src/qonnx/custom_op/general/quant.py
+++ b/src/qonnx/custom_op/general/quant.py
@@ -135,7 +135,7 @@ def resolve_rounding_mode(mode_string):
     """Resolve the rounding mode string of Quant and Trunc ops
     to the corresponding numpy functions."""
     normalized_mode_string = mode_string.upper()
-    if normalized_mode_string == "ROUND" or normalized_mode_string == "HALF_TO_EVEN":
+    if normalized_mode_string == "ROUND" or normalized_mode_string == "HALF_EVEN":
         return np.round
     elif normalized_mode_string == "CEIL":
         return np.ceil

From de9f73173705cb757daaa80a58044ed8f09a376d Mon Sep 17 00:00:00 2001
From: jvreca <jure.vreca@ijs.si>
Date: Fri, 9 Aug 2024 15:08:07 +0200
Subject: [PATCH 05/13] Added a check to see if tensor is representable by the
 quantization parameters.

---
 src/qonnx/converters/qkeras/onnx.py | 19 ++++++++++++++++++-
 1 file changed, 18 insertions(+), 1 deletion(-)

diff --git a/src/qonnx/converters/qkeras/onnx.py b/src/qonnx/converters/qkeras/onnx.py
index 3bb1fed9..cadb10c4 100644
--- a/src/qonnx/converters/qkeras/onnx.py
+++ b/src/qonnx/converters/qkeras/onnx.py
@@ -4,7 +4,7 @@
 from tf2onnx.onnx_opset.nn import BiasAdd, ConvOp
 
 from .quantizers import get_quant_params
-
+from qonnx.custom_op.general.quant import quant
 
 def get_qkeras_onnx_handlers(all_quantizers):
     """Returns the handlers for each kind of layer
@@ -58,6 +58,23 @@ def qlayer_handler(ctx, node, name, args):
         quant_params = get_quant_params(weights, quantizers["kernel_initializer"]['config']['quantizer'])
         attr = quant_params["attributes"]
         input_nodes = [node.input[1]]
+        qweights = quant(inp_tensor=np.array(weights), 
+                         scale=np.array(quant_params['inputs']['scale']),
+                         zeropt=np.array(quant_params['inputs']['zero_point']),
+                         bitwidth=np.array(quant_params['inputs']['bit_width']),
+                         signed=quant_params['attributes']['signed'],
+                         narrow=quant_params['attributes']['narrow'],
+                         rounding_mode=quant_params['attributes']['rounding_mode']
+                    )
+        assert np.array_equal(weights, qweights), f"""Weights of tensor {node.name} are not representable with the given quantization settings.
+                                                      The original weight tensor is: {np.array(weights)} and the quantized tensor is: {qweights}; 
+                                                      scale: {np.array(quant_params['inputs']['scale'])}, 
+                                                      zeropt: {np.array(quant_params['inputs']['zero_point'])}, 
+                                                      bitwidth: {np.array(quant_params['inputs']['bit_width'])},
+                                                      signed: {quant_params['attributes']['signed']},
+                                                      narrow: {quant_params['attributes']['narrow']},
+                                                      rounding_mode: {quant_params['attributes']['rounding_mode']}
+                                                      """
         for key in quant_params["inputs"].keys():
             name = f"{node.name}_kernel_quantizer_{key}"
             np_val = np.asarray(quant_params["inputs"][key])

From 72b994a923e7ea0e012de4f069af4b443ca60d83 Mon Sep 17 00:00:00 2001
From: jvreca <jure.vreca@ijs.si>
Date: Mon, 12 Aug 2024 11:17:57 +0200
Subject: [PATCH 06/13] Extra Mul node inserted only when neccessary

Commented out assertion on non-representability
---
 src/qonnx/converters/qkeras/onnx.py | 59 ++++++++++++++++-------------
 1 file changed, 32 insertions(+), 27 deletions(-)

diff --git a/src/qonnx/converters/qkeras/onnx.py b/src/qonnx/converters/qkeras/onnx.py
index cadb10c4..1444865f 100644
--- a/src/qonnx/converters/qkeras/onnx.py
+++ b/src/qonnx/converters/qkeras/onnx.py
@@ -58,23 +58,23 @@ def qlayer_handler(ctx, node, name, args):
         quant_params = get_quant_params(weights, quantizers["kernel_initializer"]['config']['quantizer'])
         attr = quant_params["attributes"]
         input_nodes = [node.input[1]]
-        qweights = quant(inp_tensor=np.array(weights), 
-                         scale=np.array(quant_params['inputs']['scale']),
-                         zeropt=np.array(quant_params['inputs']['zero_point']),
-                         bitwidth=np.array(quant_params['inputs']['bit_width']),
-                         signed=quant_params['attributes']['signed'],
-                         narrow=quant_params['attributes']['narrow'],
-                         rounding_mode=quant_params['attributes']['rounding_mode']
-                    )
-        assert np.array_equal(weights, qweights), f"""Weights of tensor {node.name} are not representable with the given quantization settings.
-                                                      The original weight tensor is: {np.array(weights)} and the quantized tensor is: {qweights}; 
-                                                      scale: {np.array(quant_params['inputs']['scale'])}, 
-                                                      zeropt: {np.array(quant_params['inputs']['zero_point'])}, 
-                                                      bitwidth: {np.array(quant_params['inputs']['bit_width'])},
-                                                      signed: {quant_params['attributes']['signed']},
-                                                      narrow: {quant_params['attributes']['narrow']},
-                                                      rounding_mode: {quant_params['attributes']['rounding_mode']}
-                                                      """
+        #qweights = quant(inp_tensor=np.array(weights), 
+        #                 scale=np.array(quant_params['inputs']['scale']),
+        #                 zeropt=np.array(quant_params['inputs']['zero_point']),
+        #                 bitwidth=np.array(quant_params['inputs']['bit_width']),
+        #                 signed=quant_params['attributes']['signed'],
+        #                 narrow=quant_params['attributes']['narrow'],
+        #                 rounding_mode=quant_params['attributes']['rounding_mode']
+        #            )
+        #assert np.array_equal(weights, qweights), f"""Weights of tensor {node.name} are not representable with the given quantization settings.
+        #                                              The original weight tensor is: {np.array(weights)} and the quantized tensor is: {qweights}; 
+        #                                              scale: {np.array(quant_params['inputs']['scale'])}, 
+        #                                              zeropt: {np.array(quant_params['inputs']['zero_point'])}, 
+        #                                              bitwidth: {np.array(quant_params['inputs']['bit_width'])},
+        #                                              signed: {quant_params['attributes']['signed']},
+        #                                              narrow: {quant_params['attributes']['narrow']},
+        #                                              rounding_mode: {quant_params['attributes']['rounding_mode']}
+        #                                              """
         for key in quant_params["inputs"].keys():
             name = f"{node.name}_kernel_quantizer_{key}"
             np_val = np.asarray(quant_params["inputs"][key])
@@ -83,16 +83,21 @@ def qlayer_handler(ctx, node, name, args):
         quant_node = ctx.insert_new_node_on_input(
             node, "Quant", input_nodes, name=node.name + "_kernel_quantizer", **attr, domain="qonnx"
         )
-        scale_node = ctx.make_const(
-            name = node.name + "_kernel_scale",
-            np_val = quant_params['inputs']['scale'].astype(np.float32)
-        )
-        ctx.insert_new_node_on_output(
-            op_type = "Mul", 
-            output_name = quant_node.output[0],
-            name = node.name + "_kernel_requantizer",
-            inputs = [quant_node.output[0], scale_node.name]
-        )
+        if quantizers["kernel_initializer"]['config']['quantizer']['class_name'] == 'quantized_bits':
+            bits = quantizers["kernel_initializer"]['config']['quantizer']['config']['bits']
+            integer = quantizers["kernel_initializer"]['config']['quantizer']['config']['integer']
+            keep_negative = quantizers["kernel_initializer"]['config']['quantizer']['config']['keep_negative']
+            if bits == integer + keep_negative:
+                scale_node = ctx.make_const(
+                    name = node.name + "_kernel_scale",
+                    np_val = quant_params['inputs']['scale'].astype(np.float32)
+                )
+                ctx.insert_new_node_on_output(
+                    op_type = "Mul", 
+                    output_name = quant_node.output[0],
+                    name = node.name + "_kernel_requantizer",
+                    inputs = [quant_node.output[0], scale_node.name]
+                )
 
     if quantizers.get("bias_quantizer") and len(node.input) == 3:
         bias = node.inputs[2].get_tensor_value(as_list=True)

From 75b40ab81d5b544533373ba8b0654911ab3ec4e4 Mon Sep 17 00:00:00 2001
From: jvreca <jure.vreca@ijs.si>
Date: Mon, 12 Aug 2024 15:05:29 +0200
Subject: [PATCH 07/13] Added parameterized test for tensor style alpha.

---
 tests/keras/test_keras_convert.py | 65 ++++++++++++++++++++++++++++++-
 1 file changed, 64 insertions(+), 1 deletion(-)

diff --git a/tests/keras/test_keras_convert.py b/tests/keras/test_keras_convert.py
index 388f39a4..46f445ef 100644
--- a/tests/keras/test_keras_convert.py
+++ b/tests/keras/test_keras_convert.py
@@ -4,6 +4,9 @@
 import onnx
 import os
 import tensorflow as tf
+tf.config.run_functions_eagerly(True)
+tf.keras.utils.set_random_seed(42)
+np.random.seed(42)
 from qkeras import QActivation, QConv2D, QDense, binary, quantized_bits, quantized_relu, ternary
 from tensorflow.keras.layers import Activation, Conv2D, Dense, Flatten, Input
 from tensorflow.keras.models import Model
@@ -323,7 +326,67 @@ def test_qkeras_qdense_4(quantizers, request):
     np.testing.assert_allclose(y_qkeras, y_qonnx, rtol=1e-4, atol=1e-4)
     os.remove(model_path)
 
-
+@pytest.mark.parametrize("bits,signed,alpha",[
+    (8, True,  [1.000, 1.000, 1.000, 1.000]),
+    (8, False, [1.000, 1.000, 1.000, 1.000]),
+    (4, True,  [1.000, 1.000, 1.000, 1.000]),
+    (4, False, [1.000, 1.000, 1.000, 1.000]),
+    (8, True,  [0.125, 0.250, 0.500, 1.000]),
+    (8, False, [0.125, 0.250, 0.500, 1.000]),
+    (5, True,  [0.250, 0.250, 0.125, 0.125]),
+    (5, False, [0.250, 0.250, 0.125, 0.125]),
+    (4, True,  [0.125, 0.250, 0.500, 1.000]),
+    (4, False, [0.125, 0.250, 0.500, 1.000]),
+    (3, True,  [0.125, 0.125, 0.250, 0.125]),
+    (3, False, [0.125, 0.125, 0.250, 0.125])
+])
+def test_qkeras_tensor_alpha(bits, signed, alpha, request):
+    random_state = np.random.RandomState(seed=42)
+    max_val = np.array(alpha) * 2**(bits-signed)
+    min_val = -(max_val + 1)
+    w1 = random_state.randint(low=min_val, high=max_val, size=(3, 4))
+    b1 = np.array([0.0, 0.0, 0.0, 0.0])
+    x = x_in = tf.keras.layers.Input(shape=3)
+    x = QActivation(
+        quantized_bits(bits=4, integer=3, keep_negative=True)
+    )(x)
+    x = QDense(
+        4,
+        kernel_quantizer=quantized_bits(
+            bits=bits, integer=(bits-signed), keep_negative=signed, alpha=alpha
+        ),
+    )(x)
+    x = QActivation(quantized_relu(bits=3, integer=3))(x)
+    model = tf.keras.Model(inputs=[x_in], outputs=[x])
+    model.compile()
+    model.layers[2].set_weights([w1, b1])
+    onnx_model, _ = from_keras(model)
+    model_path = f"model_test_qkeras_tensor_alpha_{request.node.callspec.id}.onnx"
+    onnx.save(onnx_model, model_path)
+    onnx_model = ModelWrapper(onnx_model)
+
+    data = np.array(
+        [
+            [[0.0, 0.0, 0.0]],
+            [[0.0, 1.0, 2.0]],
+            [[2.0, 1.0, 0.0]],
+            [[4.0, 4.0, 4.0]],
+            [[7.0, 7.0, 7.0]],
+            [[6.0, 0.0, 7.0]],
+            [[3.0, 3.0, 3.0]],
+            [[7.0, 0.0, 0.0]],
+            [[0.0, 7.0, 0.0]],
+            [[0.0, 0.0, 7.0]],
+        ]
+    ).astype(np.float32)
+    for x in data:
+        y_qkeras = model.predict(x, verbose=0)
+        idict = {onnx_model.graph.input[0].name: x}
+        odict = oxe.execute_onnx(onnx_model, idict, True)
+        y_qonnx = odict[onnx_model.graph.output[0].name]
+        assert np.array_equal(y_qkeras, y_qonnx)
+    os.remove(model_path)
+    
 @pytest.mark.parametrize("quantizers", kb_quantizers, ids=kb_quantizers_ids)
 def test_qkeras_qconv2d_1(quantizers, request):
     kq, bq = quantizers

From 7b5bf4a750ec0a0b66f094f917aa91aa7e677826 Mon Sep 17 00:00:00 2001
From: jvreca <jure.vreca@ijs.si>
Date: Wed, 14 Aug 2024 08:16:56 +0200
Subject: [PATCH 08/13] Updated QKeras converter to support auto_po2

---
 src/qonnx/converters/qkeras/onnx.py       | 61 ++++++++++++-----------
 src/qonnx/converters/qkeras/qlayers.py    | 19 +++++--
 src/qonnx/converters/qkeras/quantizers.py | 10 ++--
 tests/keras/test_keras_convert.py         | 53 +++++++++++++++++++-
 4 files changed, 104 insertions(+), 39 deletions(-)

diff --git a/src/qonnx/converters/qkeras/onnx.py b/src/qonnx/converters/qkeras/onnx.py
index 1444865f..1383bec9 100644
--- a/src/qonnx/converters/qkeras/onnx.py
+++ b/src/qonnx/converters/qkeras/onnx.py
@@ -53,28 +53,13 @@ def qlayer_handler(ctx, node, name, args):
     if not keras_name:
         return  # Not found in quantizers, nothing to do
     quantizers = all_quantizers[keras_name]
-    if quantizers.get("kernel_quantizer"):
+    
+    if quantizers.get("kernel_quantizer_cfg"):
         weights = node.inputs[1].get_tensor_value(as_list=True)
-        quant_params = get_quant_params(weights, quantizers["kernel_initializer"]['config']['quantizer'])
+        quant_params = get_quant_params(weights, quantizers['kernel_quantizer_cfg'])
         attr = quant_params["attributes"]
         input_nodes = [node.input[1]]
-        #qweights = quant(inp_tensor=np.array(weights), 
-        #                 scale=np.array(quant_params['inputs']['scale']),
-        #                 zeropt=np.array(quant_params['inputs']['zero_point']),
-        #                 bitwidth=np.array(quant_params['inputs']['bit_width']),
-        #                 signed=quant_params['attributes']['signed'],
-        #                 narrow=quant_params['attributes']['narrow'],
-        #                 rounding_mode=quant_params['attributes']['rounding_mode']
-        #            )
-        #assert np.array_equal(weights, qweights), f"""Weights of tensor {node.name} are not representable with the given quantization settings.
-        #                                              The original weight tensor is: {np.array(weights)} and the quantized tensor is: {qweights}; 
-        #                                              scale: {np.array(quant_params['inputs']['scale'])}, 
-        #                                              zeropt: {np.array(quant_params['inputs']['zero_point'])}, 
-        #                                              bitwidth: {np.array(quant_params['inputs']['bit_width'])},
-        #                                              signed: {quant_params['attributes']['signed']},
-        #                                              narrow: {quant_params['attributes']['narrow']},
-        #                                              rounding_mode: {quant_params['attributes']['rounding_mode']}
-        #                                              """
+
         for key in quant_params["inputs"].keys():
             name = f"{node.name}_kernel_quantizer_{key}"
             np_val = np.asarray(quant_params["inputs"][key])
@@ -83,10 +68,10 @@ def qlayer_handler(ctx, node, name, args):
         quant_node = ctx.insert_new_node_on_input(
             node, "Quant", input_nodes, name=node.name + "_kernel_quantizer", **attr, domain="qonnx"
         )
-        if quantizers["kernel_initializer"]['config']['quantizer']['class_name'] == 'quantized_bits':
-            bits = quantizers["kernel_initializer"]['config']['quantizer']['config']['bits']
-            integer = quantizers["kernel_initializer"]['config']['quantizer']['config']['integer']
-            keep_negative = quantizers["kernel_initializer"]['config']['quantizer']['config']['keep_negative']
+        if quantizers['kernel_quantizer_cfg']['class_name'] == 'quantized_bits':
+            bits = quantizers['kernel_quantizer_cfg']['config']['bits']
+            integer = quantizers['kernel_quantizer_cfg']['config']['integer']
+            keep_negative = quantizers['kernel_quantizer_cfg']['config']['keep_negative']
             if bits == integer + keep_negative:
                 scale_node = ctx.make_const(
                     name = node.name + "_kernel_scale",
@@ -99,17 +84,32 @@ def qlayer_handler(ctx, node, name, args):
                     inputs = [quant_node.output[0], scale_node.name]
                 )
 
-    if quantizers.get("bias_quantizer") and len(node.input) == 3:
-        bias = node.inputs[2].get_tensor_value(as_list=True)
-        quant_params = get_quant_params(bias, quantizers["bias_quantizer"])
+    if quantizers.get("bias_quantizer_cfg") and len(node.input) == 3:
+        bias = node.inputs[-1].get_tensor_value(as_list=True)
+        quant_params = get_quant_params(bias, quantizers['bias_quantizer_cfg'])
         attr = quant_params["attributes"]
-        input_nodes = [node.input[2]]
+        input_nodes = [node.input[-1]]
         for key in quant_params["inputs"].keys():
             name = f"{node.name}_bias_quantizer_{key}"
             np_val = np.asarray(quant_params["inputs"][key])
             ctx.make_const(name, np_val)
             input_nodes.append(name)
         ctx.insert_new_node_on_input(node, "Quant", input_nodes, name=node.name + "_bias_quantizer", **attr, domain="qonnx")
+        if quantizers['bias_quantizer_cfg']['class_name'] == 'quantized_bits':
+            bits = quantizers['bias_quantizer_cfg']['config']['bits']
+            integer = quantizers['bias_quantizer_cfg']['config']['integer']
+            keep_negative = quantizers['bias_quantizer_cfg']['config']['keep_negative']
+            if bits == integer + keep_negative:
+                scale_node = ctx.make_const(
+                    name = node.name + "_bias_scale",
+                    np_val = quant_params['inputs']['scale'].astype(np.float32)
+                )
+                ctx.insert_new_node_on_output(
+                    op_type = "Mul", 
+                    output_name = quant_node.output[0],
+                    name = node.name + "_bias_requantizer",
+                    inputs = [quant_node.output[0], scale_node.name]
+                )
 
     if quantizers.get("activation"):
         dtypes = [ctx.get_dtype(node.output[0])]
@@ -141,6 +141,9 @@ def qact_handler(ctx, node, name, args):
     quantizers = all_quantizers[keras_name]
     if quantizers.get("activation"):
         dtypes = [ctx.get_dtype(node.output[0])]
+        if "auto" in quantizers['activation']:
+            if not node.graph.get_node_by_output(node.input[0]).is_const():
+                raise AttributeError(f'Automatic quantizers (auto/auto_po2) must have a const input. Invalid topology at node: {name}.')
         quant_params = get_quant_params(None, quantizers["activation"])
         attr = quant_params["attributes"]
         input_nodes = [node.output[0]]
@@ -180,9 +183,9 @@ def bias_handler(ctx, node, name, args):
         return  # Not found in quantizers, nothing to do
     quantizers = all_quantizers[keras_name]
 
-    if quantizers.get("bias_quantizer"):
+    if quantizers.get("bias_quantizer_cfg"):
         bias = node.inputs[1].get_tensor_value(as_list=True)
-        quant_params = get_quant_params(bias, quantizers["bias_quantizer"])
+        quant_params = get_quant_params(bias, quantizers["bias_quantizer_cfg"])
         attr = quant_params["attributes"]
         input_nodes = [node.input[1]]
         for key in quant_params["inputs"].keys():
diff --git a/src/qonnx/converters/qkeras/qlayers.py b/src/qonnx/converters/qkeras/qlayers.py
index fdbca71b..0a4f907d 100644
--- a/src/qonnx/converters/qkeras/qlayers.py
+++ b/src/qonnx/converters/qkeras/qlayers.py
@@ -101,13 +101,26 @@ def _replace_activation(quant_act):
 
 def extract_qlayer(layer):
     quantizers = layer.get_quantization_config()
-
+    
     keras_config = layer.get_config()
 
-    keras_config.pop("kernel_quantizer", None)
-    keras_config.pop("bias_quantizer", None)
+    kernel_quant_cfg = keras_config.pop("kernel_quantizer", None)
+    bias_quant_cfg = keras_config.pop("bias_quantizer", None)
     keras_config.pop("kernel_range", None)
     keras_config.pop("bias_range", None)
+    
+    quantizers['kernel_quantizer_cfg'] = kernel_quant_cfg
+    quantizers['bias_quantizer_cfg'] = bias_quant_cfg
+
+    # For some reason downstream can't handle auto_po2, so we just calculate the scale value now
+    if kernel_quant_cfg['config']['alpha'] == "auto_po2":
+        layer.kernel_quantizer_internal(layer.kernel) # sets .scale (see auto_po2)
+        quantizers['kernel_quantizer_cfg']['config']['alpha'] = layer.kernel_quantizer_internal.scale.numpy().flatten().tolist()
+    if bias_quant_cfg['config']['alpha'] == "auto_po2":
+        layer.bias_quantizer_internal(layer.bias)
+        quantizers['bias_quantizer_cfg']['config']['alpha'] = layer.bias_quantizer_internal.scale.numpy().flatten().tolist()
+    quantizers.pop('kernel_quantizer', None)
+    quantizers.pop('bias_quantizer', None)
 
     # Check if activation is quantized
     if _is_keras_quantizer(keras_config["activation"]):
diff --git a/src/qonnx/converters/qkeras/quantizers.py b/src/qonnx/converters/qkeras/quantizers.py
index c6a00a00..e38cf710 100644
--- a/src/qonnx/converters/qkeras/quantizers.py
+++ b/src/qonnx/converters/qkeras/quantizers.py
@@ -1,6 +1,8 @@
 import qkeras
 import six
 import numpy as np
+import tensorflow as tf
+
 
 def get_quant_params(tensor, qkeras_quantizer):
     if isinstance(qkeras_quantizer, (str, dict)):
@@ -24,7 +26,6 @@ def _get_scale_from_alpha(tensor, quantizer):
 def _get_quantizer_scale(tensor, quantizer):
     # call the quantizer on the tensor to get its scale
     import numpy as np
-
     quantizer(np.array(tensor).astype(np.float32))
     return quantizer.scale
 
@@ -34,15 +35,12 @@ def convert_quantized_bits(tensor, quantizer):
     signed = int(config["keep_negative"])
     narrow = int(config["symmetric"])
     qscale = _get_quantizer_scale(tensor, quantizer)
-    if not isinstance(qscale, np.ndarray):
+    if not isinstance(qscale, (np.ndarray, tf.Tensor)):
         qscale = np.array(qscale) 
     scale = qscale / 2 ** (int(config["bits"]) - int(config["integer"] + signed))
     zero_point = 0
     bit_width = int(config["bits"])
-    if config['alpha'] == "auto_po2":
-        rounding_mode = "ROUND_UP"
-    else:
-        rounding_mode = "HALF_EVEN"
+    rounding_mode = "HALF_EVEN"
 
     settings = {
         "attributes": {"signed": signed, "narrow": narrow, "rounding_mode": rounding_mode},
diff --git a/tests/keras/test_keras_convert.py b/tests/keras/test_keras_convert.py
index 46f445ef..eb3e9b2e 100644
--- a/tests/keras/test_keras_convert.py
+++ b/tests/keras/test_keras_convert.py
@@ -4,7 +4,6 @@
 import onnx
 import os
 import tensorflow as tf
-tf.config.run_functions_eagerly(True)
 tf.keras.utils.set_random_seed(42)
 np.random.seed(42)
 from qkeras import QActivation, QConv2D, QDense, binary, quantized_bits, quantized_relu, ternary
@@ -66,6 +65,58 @@ def test_qkeras_qactivation(quantizer, request):
     np.testing.assert_allclose(y_qkeras, y_qonnx, rtol=1e-5, atol=1e-5)
     os.remove(model_path)
 
+@pytest.mark.parametrize("quantizer", [
+    quantized_relu(bits=4, integer=4), 
+    quantized_bits(bits=4, integer=4, keep_negative=False, alpha=1),
+    ])
+def test_qkeras_quantizers_rounding_modes(quantizer, request):  
+    x = x_in = Input((10,), name="input")
+    x = QActivation(activation=quantizer)(x)
+    model = Model(inputs=[x_in], outputs=[x])
+    model.compile()
+    
+    onnx_model, _ = from_keras(model)
+    model_path = f"model_test_qkeras_quantizers_rounding_modes_{request.node.callspec.id}.onnx"
+    onnx.save(onnx_model, model_path)
+    onnx_model = ModelWrapper(onnx_model)
+
+    x_test = np.array([[5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5]]).astype(np.float32)
+    idict = {onnx_model.graph.input[0].name: x_test}
+    odict = oxe.execute_onnx(onnx_model, idict, True)
+    y_qonnx = odict[onnx_model.graph.output[0].name]
+    y_qkeras = model.predict(x_test)
+    assert np.array_equal(y_qkeras, y_qonnx)
+    os.remove(model_path)
+
+@pytest.mark.parametrize("bias", [5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5])
+def test_qkeras_quantizers_autopo2_rounding_modes(bias, request):
+    kq = bq = quantized_bits(4, 4, 1, alpha='auto_po2')
+    # Initialize the kernel & bias to RandonUniform within the range of the quantizers
+    x = x_in = Input((10), name="input")
+    x = QDense(
+        1,
+        kernel_quantizer=kq,
+        bias_quantizer=bq,
+        kernel_initializer=tf.keras.initializers.Constant(1.0),
+        bias_initializer=tf.keras.initializers.Constant(bias),
+        name="dense_0",
+    )(x)
+    model = Model(inputs=[x_in], outputs=[x])
+    x_test = np.random.uniform(low=-1.0, high=1.0, size=(1, 10)).astype(dtype=np.float32)
+    _ = model.predict(x_test, verbose=0)
+
+    onnx_model, _ = from_keras(model)
+    model_path = f"model_test_qkeras_quantizers_auto_rounding_modes_{request.node.callspec.id}.onnx"
+    onnx.save(onnx_model, model_path)
+    onnx_model = ModelWrapper(onnx_model)
+
+    x_test = np.zeros(shape=(1, 10), dtype=np.float32)
+    idict = {onnx_model.graph.input[0].name: x_test}
+    odict = oxe.execute_onnx(onnx_model, idict, True)
+    y_qonnx = odict[onnx_model.graph.output[0].name]
+    y_qkeras = model.predict(x_test, verbose=0)
+    assert np.array_equal(y_qkeras, y_qonnx)
+    os.remove(model_path)
 
 # pairs of quantizers for kernel and bias
 kb_quantizers = [

From 9fd5f6a263fd7b8bb2d11cccdb899f61b4479fb4 Mon Sep 17 00:00:00 2001
From: jvreca <jure.vreca@ijs.si>
Date: Wed, 14 Aug 2024 08:54:12 +0200
Subject: [PATCH 09/13] Added check for none.

---
 src/qonnx/converters/qkeras/qlayers.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/qonnx/converters/qkeras/qlayers.py b/src/qonnx/converters/qkeras/qlayers.py
index 0a4f907d..4a62b5d7 100644
--- a/src/qonnx/converters/qkeras/qlayers.py
+++ b/src/qonnx/converters/qkeras/qlayers.py
@@ -113,10 +113,10 @@ def extract_qlayer(layer):
     quantizers['bias_quantizer_cfg'] = bias_quant_cfg
 
     # For some reason downstream can't handle auto_po2, so we just calculate the scale value now
-    if kernel_quant_cfg['config']['alpha'] == "auto_po2":
+    if kernel_quant_cfg is not None and kernel_quant_cfg['config']['alpha'] == "auto_po2":
         layer.kernel_quantizer_internal(layer.kernel) # sets .scale (see auto_po2)
         quantizers['kernel_quantizer_cfg']['config']['alpha'] = layer.kernel_quantizer_internal.scale.numpy().flatten().tolist()
-    if bias_quant_cfg['config']['alpha'] == "auto_po2":
+    if bias_quant_cfg is not None and bias_quant_cfg['config']['alpha'] == "auto_po2":
         layer.bias_quantizer_internal(layer.bias)
         quantizers['bias_quantizer_cfg']['config']['alpha'] = layer.bias_quantizer_internal.scale.numpy().flatten().tolist()
     quantizers.pop('kernel_quantizer', None)

From d8a66a7db408772f052d3386b74083d2d1c66f51 Mon Sep 17 00:00:00 2001
From: jvreca <jure.vreca@ijs.si>
Date: Wed, 14 Aug 2024 09:36:25 +0200
Subject: [PATCH 10/13] Fixed pre-commit issues.

---
 src/qonnx/converters/qkeras/onnx.py       | 52 ++++++++---------
 src/qonnx/converters/qkeras/qlayers.py    | 24 ++++----
 src/qonnx/converters/qkeras/quantizers.py |  5 +-
 src/qonnx/custom_op/general/quant.py      |  6 ++
 tests/custom_op/test_runding_mode.py      | 11 ++--
 tests/keras/test_keras_convert.py         | 71 +++++++++++++----------
 6 files changed, 95 insertions(+), 74 deletions(-)

diff --git a/src/qonnx/converters/qkeras/onnx.py b/src/qonnx/converters/qkeras/onnx.py
index 1383bec9..41b022d7 100644
--- a/src/qonnx/converters/qkeras/onnx.py
+++ b/src/qonnx/converters/qkeras/onnx.py
@@ -4,7 +4,7 @@
 from tf2onnx.onnx_opset.nn import BiasAdd, ConvOp
 
 from .quantizers import get_quant_params
-from qonnx.custom_op.general.quant import quant
+
 
 def get_qkeras_onnx_handlers(all_quantizers):
     """Returns the handlers for each kind of layer
@@ -53,10 +53,10 @@ def qlayer_handler(ctx, node, name, args):
     if not keras_name:
         return  # Not found in quantizers, nothing to do
     quantizers = all_quantizers[keras_name]
-    
+
     if quantizers.get("kernel_quantizer_cfg"):
         weights = node.inputs[1].get_tensor_value(as_list=True)
-        quant_params = get_quant_params(weights, quantizers['kernel_quantizer_cfg'])
+        quant_params = get_quant_params(weights, quantizers["kernel_quantizer_cfg"])
         attr = quant_params["attributes"]
         input_nodes = [node.input[1]]
 
@@ -68,25 +68,24 @@ def qlayer_handler(ctx, node, name, args):
         quant_node = ctx.insert_new_node_on_input(
             node, "Quant", input_nodes, name=node.name + "_kernel_quantizer", **attr, domain="qonnx"
         )
-        if quantizers['kernel_quantizer_cfg']['class_name'] == 'quantized_bits':
-            bits = quantizers['kernel_quantizer_cfg']['config']['bits']
-            integer = quantizers['kernel_quantizer_cfg']['config']['integer']
-            keep_negative = quantizers['kernel_quantizer_cfg']['config']['keep_negative']
+        if quantizers["kernel_quantizer_cfg"]["class_name"] == "quantized_bits":
+            bits = quantizers["kernel_quantizer_cfg"]["config"]["bits"]
+            integer = quantizers["kernel_quantizer_cfg"]["config"]["integer"]
+            keep_negative = quantizers["kernel_quantizer_cfg"]["config"]["keep_negative"]
             if bits == integer + keep_negative:
                 scale_node = ctx.make_const(
-                    name = node.name + "_kernel_scale",
-                    np_val = quant_params['inputs']['scale'].astype(np.float32)
+                    name=node.name + "_kernel_scale", np_val=quant_params["inputs"]["scale"].astype(np.float32)
                 )
                 ctx.insert_new_node_on_output(
-                    op_type = "Mul", 
-                    output_name = quant_node.output[0],
-                    name = node.name + "_kernel_requantizer",
-                    inputs = [quant_node.output[0], scale_node.name]
+                    op_type="Mul",
+                    output_name=quant_node.output[0],
+                    name=node.name + "_kernel_requantizer",
+                    inputs=[quant_node.output[0], scale_node.name],
                 )
 
     if quantizers.get("bias_quantizer_cfg") and len(node.input) == 3:
         bias = node.inputs[-1].get_tensor_value(as_list=True)
-        quant_params = get_quant_params(bias, quantizers['bias_quantizer_cfg'])
+        quant_params = get_quant_params(bias, quantizers["bias_quantizer_cfg"])
         attr = quant_params["attributes"]
         input_nodes = [node.input[-1]]
         for key in quant_params["inputs"].keys():
@@ -95,20 +94,19 @@ def qlayer_handler(ctx, node, name, args):
             ctx.make_const(name, np_val)
             input_nodes.append(name)
         ctx.insert_new_node_on_input(node, "Quant", input_nodes, name=node.name + "_bias_quantizer", **attr, domain="qonnx")
-        if quantizers['bias_quantizer_cfg']['class_name'] == 'quantized_bits':
-            bits = quantizers['bias_quantizer_cfg']['config']['bits']
-            integer = quantizers['bias_quantizer_cfg']['config']['integer']
-            keep_negative = quantizers['bias_quantizer_cfg']['config']['keep_negative']
+        if quantizers["bias_quantizer_cfg"]["class_name"] == "quantized_bits":
+            bits = quantizers["bias_quantizer_cfg"]["config"]["bits"]
+            integer = quantizers["bias_quantizer_cfg"]["config"]["integer"]
+            keep_negative = quantizers["bias_quantizer_cfg"]["config"]["keep_negative"]
             if bits == integer + keep_negative:
                 scale_node = ctx.make_const(
-                    name = node.name + "_bias_scale",
-                    np_val = quant_params['inputs']['scale'].astype(np.float32)
+                    name=node.name + "_bias_scale", np_val=quant_params["inputs"]["scale"].astype(np.float32)
                 )
                 ctx.insert_new_node_on_output(
-                    op_type = "Mul", 
-                    output_name = quant_node.output[0],
-                    name = node.name + "_bias_requantizer",
-                    inputs = [quant_node.output[0], scale_node.name]
+                    op_type="Mul",
+                    output_name=quant_node.output[0],
+                    name=node.name + "_bias_requantizer",
+                    inputs=[quant_node.output[0], scale_node.name],
                 )
 
     if quantizers.get("activation"):
@@ -141,9 +139,11 @@ def qact_handler(ctx, node, name, args):
     quantizers = all_quantizers[keras_name]
     if quantizers.get("activation"):
         dtypes = [ctx.get_dtype(node.output[0])]
-        if "auto" in quantizers['activation']:
+        if "auto" in quantizers["activation"]:
             if not node.graph.get_node_by_output(node.input[0]).is_const():
-                raise AttributeError(f'Automatic quantizers (auto/auto_po2) must have a const input. Invalid topology at node: {name}.')
+                raise AttributeError(
+                    f"Automatic quantizers (auto/auto_po2) must have a const input. Invalid topology at node: {name}."
+                )
         quant_params = get_quant_params(None, quantizers["activation"])
         attr = quant_params["attributes"]
         input_nodes = [node.output[0]]
diff --git a/src/qonnx/converters/qkeras/qlayers.py b/src/qonnx/converters/qkeras/qlayers.py
index 4a62b5d7..3bfc7fa7 100644
--- a/src/qonnx/converters/qkeras/qlayers.py
+++ b/src/qonnx/converters/qkeras/qlayers.py
@@ -101,26 +101,28 @@ def _replace_activation(quant_act):
 
 def extract_qlayer(layer):
     quantizers = layer.get_quantization_config()
-    
+
     keras_config = layer.get_config()
 
     kernel_quant_cfg = keras_config.pop("kernel_quantizer", None)
     bias_quant_cfg = keras_config.pop("bias_quantizer", None)
     keras_config.pop("kernel_range", None)
     keras_config.pop("bias_range", None)
-    
-    quantizers['kernel_quantizer_cfg'] = kernel_quant_cfg
-    quantizers['bias_quantizer_cfg'] = bias_quant_cfg
+
+    quantizers["kernel_quantizer_cfg"] = kernel_quant_cfg
+    quantizers["bias_quantizer_cfg"] = bias_quant_cfg
 
     # For some reason downstream can't handle auto_po2, so we just calculate the scale value now
-    if kernel_quant_cfg is not None and kernel_quant_cfg['config']['alpha'] == "auto_po2":
-        layer.kernel_quantizer_internal(layer.kernel) # sets .scale (see auto_po2)
-        quantizers['kernel_quantizer_cfg']['config']['alpha'] = layer.kernel_quantizer_internal.scale.numpy().flatten().tolist()
-    if bias_quant_cfg is not None and bias_quant_cfg['config']['alpha'] == "auto_po2":
+    if kernel_quant_cfg is not None and kernel_quant_cfg["config"]["alpha"] == "auto_po2":
+        layer.kernel_quantizer_internal(layer.kernel)  # sets .scale (see auto_po2)
+        quantizers["kernel_quantizer_cfg"]["config"]["alpha"] = (
+            layer.kernel_quantizer_internal.scale.numpy().flatten().tolist()
+        )
+    if bias_quant_cfg is not None and bias_quant_cfg["config"]["alpha"] == "auto_po2":
         layer.bias_quantizer_internal(layer.bias)
-        quantizers['bias_quantizer_cfg']['config']['alpha'] = layer.bias_quantizer_internal.scale.numpy().flatten().tolist()
-    quantizers.pop('kernel_quantizer', None)
-    quantizers.pop('bias_quantizer', None)
+        quantizers["bias_quantizer_cfg"]["config"]["alpha"] = layer.bias_quantizer_internal.scale.numpy().flatten().tolist()
+    quantizers.pop("kernel_quantizer", None)
+    quantizers.pop("bias_quantizer", None)
 
     # Check if activation is quantized
     if _is_keras_quantizer(keras_config["activation"]):
diff --git a/src/qonnx/converters/qkeras/quantizers.py b/src/qonnx/converters/qkeras/quantizers.py
index e38cf710..3d232390 100644
--- a/src/qonnx/converters/qkeras/quantizers.py
+++ b/src/qonnx/converters/qkeras/quantizers.py
@@ -1,6 +1,6 @@
+import numpy as np
 import qkeras
 import six
-import numpy as np
 import tensorflow as tf
 
 
@@ -26,6 +26,7 @@ def _get_scale_from_alpha(tensor, quantizer):
 def _get_quantizer_scale(tensor, quantizer):
     # call the quantizer on the tensor to get its scale
     import numpy as np
+
     quantizer(np.array(tensor).astype(np.float32))
     return quantizer.scale
 
@@ -36,7 +37,7 @@ def convert_quantized_bits(tensor, quantizer):
     narrow = int(config["symmetric"])
     qscale = _get_quantizer_scale(tensor, quantizer)
     if not isinstance(qscale, (np.ndarray, tf.Tensor)):
-        qscale = np.array(qscale) 
+        qscale = np.array(qscale)
     scale = qscale / 2 ** (int(config["bits"]) - int(config["integer"] + signed))
     zero_point = 0
     bit_width = int(config["bits"])
diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py
index 5af3f9f3..b0b50b9a 100644
--- a/src/qonnx/custom_op/general/quant.py
+++ b/src/qonnx/custom_op/general/quant.py
@@ -142,18 +142,24 @@ def resolve_rounding_mode(mode_string):
     elif normalized_mode_string == "FLOOR":
         return np.floor
     elif normalized_mode_string == "UP":
+
         def round_up(x):
             return np.sign(x) * np.ceil(np.abs(x))
+
         return round_up
     elif normalized_mode_string == "DOWN":
         return np.fix
     elif normalized_mode_string == "HALF_UP":
+
         def round_half_up(x):
             return np.sign(x) * np.floor(np.abs(x) + 0.5)
+
         return round_half_up
     elif normalized_mode_string == "HALF_DOWN":
+
         def round_half_down(x):
             return np.sign(x) * np.ceil(np.abs(x) - 0.5)
+
         return round_half_down
     else:
         raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}")
diff --git a/tests/custom_op/test_runding_mode.py b/tests/custom_op/test_runding_mode.py
index 54a81f0e..eb48d644 100644
--- a/tests/custom_op/test_runding_mode.py
+++ b/tests/custom_op/test_runding_mode.py
@@ -4,15 +4,18 @@
 
 from qonnx.custom_op.general.quant import resolve_rounding_mode
 
-@pytest.mark.parametrize("rmode,exp", [
+
+@pytest.mark.parametrize(
+    "rmode,exp",
+    [
         ("ROUND", np.array([6, 2, 2, 1, 1, -1, -1, -2, -2, -6])),
-        ("CEIL", np.array([6, 3, 2, 2, 1, -1, -1, -1, -2, - 5])),
+        ("CEIL", np.array([6, 3, 2, 2, 1, -1, -1, -1, -2, -5])),
         ("FLOOR", np.array([5, 2, 1, 1, 1, -1, -2, -2, -3, -6])),
         ("UP", np.array([6, 3, 2, 2, 1, -1, -2, -2, -3, -6])),
         ("DOWN", np.array([5, 2, 1, 1, 1, -1, -1, -1, -2, -5])),
         ("HALF_UP", np.array([6, 3, 2, 1, 1, -1, -1, -2, -3, -6])),
-        ("HALF_DOWN", np.array([5, 2, 2, 1, 1, -1, -1, -2, -2, -5]))
-    ]
+        ("HALF_DOWN", np.array([5, 2, 2, 1, 1, -1, -1, -2, -2, -5])),
+    ],
 )
 def test_rounding_modes(rmode, exp):
     test_array = np.array([5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5])
diff --git a/tests/keras/test_keras_convert.py b/tests/keras/test_keras_convert.py
index eb3e9b2e..964a2e3c 100644
--- a/tests/keras/test_keras_convert.py
+++ b/tests/keras/test_keras_convert.py
@@ -4,8 +4,6 @@
 import onnx
 import os
 import tensorflow as tf
-tf.keras.utils.set_random_seed(42)
-np.random.seed(42)
 from qkeras import QActivation, QConv2D, QDense, binary, quantized_bits, quantized_relu, ternary
 from tensorflow.keras.layers import Activation, Conv2D, Dense, Flatten, Input
 from tensorflow.keras.models import Model
@@ -15,6 +13,10 @@
 from qonnx.core.modelwrapper import ModelWrapper
 from qonnx.transformation.infer_shapes import InferShapes
 
+# For reproducibility
+tf.keras.utils.set_random_seed(42)
+np.random.seed(42)
+
 act_quantizers = [
     quantized_bits(8, 4, 0, alpha=1),
     quantized_bits(8, 4, 1, alpha=1),
@@ -65,16 +67,20 @@ def test_qkeras_qactivation(quantizer, request):
     np.testing.assert_allclose(y_qkeras, y_qonnx, rtol=1e-5, atol=1e-5)
     os.remove(model_path)
 
-@pytest.mark.parametrize("quantizer", [
-    quantized_relu(bits=4, integer=4), 
-    quantized_bits(bits=4, integer=4, keep_negative=False, alpha=1),
-    ])
-def test_qkeras_quantizers_rounding_modes(quantizer, request):  
+
+@pytest.mark.parametrize(
+    "quantizer",
+    [
+        quantized_relu(bits=4, integer=4),
+        quantized_bits(bits=4, integer=4, keep_negative=False, alpha=1),
+    ],
+)
+def test_qkeras_quantizers_rounding_modes(quantizer, request):
     x = x_in = Input((10,), name="input")
     x = QActivation(activation=quantizer)(x)
     model = Model(inputs=[x_in], outputs=[x])
     model.compile()
-    
+
     onnx_model, _ = from_keras(model)
     model_path = f"model_test_qkeras_quantizers_rounding_modes_{request.node.callspec.id}.onnx"
     onnx.save(onnx_model, model_path)
@@ -88,9 +94,10 @@ def test_qkeras_quantizers_rounding_modes(quantizer, request):
     assert np.array_equal(y_qkeras, y_qonnx)
     os.remove(model_path)
 
+
 @pytest.mark.parametrize("bias", [5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5])
 def test_qkeras_quantizers_autopo2_rounding_modes(bias, request):
-    kq = bq = quantized_bits(4, 4, 1, alpha='auto_po2')
+    kq = bq = quantized_bits(4, 4, 1, alpha="auto_po2")
     # Initialize the kernel & bias to RandonUniform within the range of the quantizers
     x = x_in = Input((10), name="input")
     x = QDense(
@@ -118,6 +125,7 @@ def test_qkeras_quantizers_autopo2_rounding_modes(bias, request):
     assert np.array_equal(y_qkeras, y_qonnx)
     os.remove(model_path)
 
+
 # pairs of quantizers for kernel and bias
 kb_quantizers = [
     (quantized_bits(8, 4, 0, alpha=1), quantized_bits(8, 4, 0, alpha=1)),
@@ -377,35 +385,35 @@ def test_qkeras_qdense_4(quantizers, request):
     np.testing.assert_allclose(y_qkeras, y_qonnx, rtol=1e-4, atol=1e-4)
     os.remove(model_path)
 
-@pytest.mark.parametrize("bits,signed,alpha",[
-    (8, True,  [1.000, 1.000, 1.000, 1.000]),
-    (8, False, [1.000, 1.000, 1.000, 1.000]),
-    (4, True,  [1.000, 1.000, 1.000, 1.000]),
-    (4, False, [1.000, 1.000, 1.000, 1.000]),
-    (8, True,  [0.125, 0.250, 0.500, 1.000]),
-    (8, False, [0.125, 0.250, 0.500, 1.000]),
-    (5, True,  [0.250, 0.250, 0.125, 0.125]),
-    (5, False, [0.250, 0.250, 0.125, 0.125]),
-    (4, True,  [0.125, 0.250, 0.500, 1.000]),
-    (4, False, [0.125, 0.250, 0.500, 1.000]),
-    (3, True,  [0.125, 0.125, 0.250, 0.125]),
-    (3, False, [0.125, 0.125, 0.250, 0.125])
-])
+
+@pytest.mark.parametrize(
+    "bits,signed,alpha",
+    [
+        (8, True, [1.000, 1.000, 1.000, 1.000]),
+        (8, False, [1.000, 1.000, 1.000, 1.000]),
+        (4, True, [1.000, 1.000, 1.000, 1.000]),
+        (4, False, [1.000, 1.000, 1.000, 1.000]),
+        (8, True, [0.125, 0.250, 0.500, 1.000]),
+        (8, False, [0.125, 0.250, 0.500, 1.000]),
+        (5, True, [0.250, 0.250, 0.125, 0.125]),
+        (5, False, [0.250, 0.250, 0.125, 0.125]),
+        (4, True, [0.125, 0.250, 0.500, 1.000]),
+        (4, False, [0.125, 0.250, 0.500, 1.000]),
+        (3, True, [0.125, 0.125, 0.250, 0.125]),
+        (3, False, [0.125, 0.125, 0.250, 0.125]),
+    ],
+)
 def test_qkeras_tensor_alpha(bits, signed, alpha, request):
     random_state = np.random.RandomState(seed=42)
-    max_val = np.array(alpha) * 2**(bits-signed)
+    max_val = np.array(alpha) * 2 ** (bits - signed)
     min_val = -(max_val + 1)
     w1 = random_state.randint(low=min_val, high=max_val, size=(3, 4))
     b1 = np.array([0.0, 0.0, 0.0, 0.0])
     x = x_in = tf.keras.layers.Input(shape=3)
-    x = QActivation(
-        quantized_bits(bits=4, integer=3, keep_negative=True)
-    )(x)
+    x = QActivation(quantized_bits(bits=4, integer=3, keep_negative=True))(x)
     x = QDense(
         4,
-        kernel_quantizer=quantized_bits(
-            bits=bits, integer=(bits-signed), keep_negative=signed, alpha=alpha
-        ),
+        kernel_quantizer=quantized_bits(bits=bits, integer=(bits - signed), keep_negative=signed, alpha=alpha),
     )(x)
     x = QActivation(quantized_relu(bits=3, integer=3))(x)
     model = tf.keras.Model(inputs=[x_in], outputs=[x])
@@ -437,7 +445,8 @@ def test_qkeras_tensor_alpha(bits, signed, alpha, request):
         y_qonnx = odict[onnx_model.graph.output[0].name]
         assert np.array_equal(y_qkeras, y_qonnx)
     os.remove(model_path)
-    
+
+
 @pytest.mark.parametrize("quantizers", kb_quantizers, ids=kb_quantizers_ids)
 def test_qkeras_qconv2d_1(quantizers, request):
     kq, bq = quantizers

From cd453206cf3ed1c90026924b24ae7fc2af6e785e Mon Sep 17 00:00:00 2001
From: jvreca <jure.vreca@ijs.si>
Date: Wed, 14 Aug 2024 10:12:38 +0200
Subject: [PATCH 11/13] Added check if tensor is repsentable with quant
 setting.s

---
 src/qonnx/converters/qkeras/onnx.py | 25 ++++++++++++++++++++++++-
 1 file changed, 24 insertions(+), 1 deletion(-)

diff --git a/src/qonnx/converters/qkeras/onnx.py b/src/qonnx/converters/qkeras/onnx.py
index 41b022d7..bbccde8d 100644
--- a/src/qonnx/converters/qkeras/onnx.py
+++ b/src/qonnx/converters/qkeras/onnx.py
@@ -1,10 +1,15 @@
+import logging
 import numpy as np
 from tf2onnx.late_rewriters import channel_order_rewriters
 from tf2onnx.onnx_opset.math import DirectOp, MatMul
 from tf2onnx.onnx_opset.nn import BiasAdd, ConvOp
 
+from qonnx.custom_op.general.quant import quant
+
 from .quantizers import get_quant_params
 
+logger = logging.getLogger(__name__)
+
 
 def get_qkeras_onnx_handlers(all_quantizers):
     """Returns the handlers for each kind of layer
@@ -47,6 +52,23 @@ def _extract_node_name(onnx_node, keras_quantizers):
     return None
 
 
+def check_tensor_is_representable(tensor, quant_params, node):
+    "Gives a Warning iftensor is not representable with the providede quantization settings"
+    qtensor = quant(
+        inp_tensor=np.array(tensor),
+        scale=np.array(quant_params["inputs"]["scale"]),
+        zeropt=np.array(quant_params["inputs"]["zero_point"]),
+        bitwidth=np.array(quant_params["inputs"]["bit_width"]),
+        signed=quant_params["attributes"]["signed"],
+        narrow=quant_params["attributes"]["narrow"],
+        rounding_mode=quant_params["attributes"]["rounding_mode"],
+    )
+    if not np.array_equal(tensor, qtensor):
+        logger.warn(
+            f"Tensor of node: {node.name} is not representable with the provided quantization settings: {quant_params}"
+        )
+
+
 def qlayer_handler(ctx, node, name, args):
     all_quantizers = args[0]
     keras_name = _extract_node_name(node, all_quantizers)
@@ -57,9 +79,9 @@ def qlayer_handler(ctx, node, name, args):
     if quantizers.get("kernel_quantizer_cfg"):
         weights = node.inputs[1].get_tensor_value(as_list=True)
         quant_params = get_quant_params(weights, quantizers["kernel_quantizer_cfg"])
+        check_tensor_is_representable(weights, quant_params, node)
         attr = quant_params["attributes"]
         input_nodes = [node.input[1]]
-
         for key in quant_params["inputs"].keys():
             name = f"{node.name}_kernel_quantizer_{key}"
             np_val = np.asarray(quant_params["inputs"][key])
@@ -86,6 +108,7 @@ def qlayer_handler(ctx, node, name, args):
     if quantizers.get("bias_quantizer_cfg") and len(node.input) == 3:
         bias = node.inputs[-1].get_tensor_value(as_list=True)
         quant_params = get_quant_params(bias, quantizers["bias_quantizer_cfg"])
+        check_tensor_is_representable(bias, quant_params, node)
         attr = quant_params["attributes"]
         input_nodes = [node.input[-1]]
         for key in quant_params["inputs"].keys():

From 7e98eb3287de5c36e463084fc5724a449751c064 Mon Sep 17 00:00:00 2001
From: jvreca <jure.vreca@ijs.si>
Date: Mon, 26 Aug 2024 09:16:02 +0200
Subject: [PATCH 12/13] reformated with pre-commit hooks.

---
 src/qonnx/custom_op/general/quant.py |  6 ++++++
 tests/custom_op/test_runding_mode.py | 11 +++++++----
 2 files changed, 13 insertions(+), 4 deletions(-)

diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py
index 15afd048..5cdc1294 100644
--- a/src/qonnx/custom_op/general/quant.py
+++ b/src/qonnx/custom_op/general/quant.py
@@ -142,18 +142,24 @@ def resolve_rounding_mode(mode_string):
     elif normalized_mode_string == "FLOOR":
         return np.floor
     elif normalized_mode_string == "UP":
+
         def round_up(x):
             return np.sign(x) * np.ceil(np.abs(x))
+
         return round_up
     elif normalized_mode_string == "DOWN":
         return np.fix
     elif normalized_mode_string == "HALF_UP":
+
         def round_half_up(x):
             return np.sign(x) * np.floor(np.abs(x) + 0.5)
+
         return round_half_up
     elif normalized_mode_string == "HALF_DOWN":
+
         def round_half_down(x):
             return np.sign(x) * np.ceil(np.abs(x) - 0.5)
+
         return round_half_down
     else:
         raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}")
diff --git a/tests/custom_op/test_runding_mode.py b/tests/custom_op/test_runding_mode.py
index 54a81f0e..eb48d644 100644
--- a/tests/custom_op/test_runding_mode.py
+++ b/tests/custom_op/test_runding_mode.py
@@ -4,15 +4,18 @@
 
 from qonnx.custom_op.general.quant import resolve_rounding_mode
 
-@pytest.mark.parametrize("rmode,exp", [
+
+@pytest.mark.parametrize(
+    "rmode,exp",
+    [
         ("ROUND", np.array([6, 2, 2, 1, 1, -1, -1, -2, -2, -6])),
-        ("CEIL", np.array([6, 3, 2, 2, 1, -1, -1, -1, -2, - 5])),
+        ("CEIL", np.array([6, 3, 2, 2, 1, -1, -1, -1, -2, -5])),
         ("FLOOR", np.array([5, 2, 1, 1, 1, -1, -2, -2, -3, -6])),
         ("UP", np.array([6, 3, 2, 2, 1, -1, -2, -2, -3, -6])),
         ("DOWN", np.array([5, 2, 1, 1, 1, -1, -1, -1, -2, -5])),
         ("HALF_UP", np.array([6, 3, 2, 1, 1, -1, -1, -2, -3, -6])),
-        ("HALF_DOWN", np.array([5, 2, 2, 1, 1, -1, -1, -2, -2, -5]))
-    ]
+        ("HALF_DOWN", np.array([5, 2, 2, 1, 1, -1, -1, -2, -2, -5])),
+    ],
 )
 def test_rounding_modes(rmode, exp):
     test_array = np.array([5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5])

From fafc4d7f5496fed4129bc820a2815c99cd0105b0 Mon Sep 17 00:00:00 2001
From: jvreca <jure.vreca@ijs.si>
Date: Mon, 26 Aug 2024 09:18:49 +0200
Subject: [PATCH 13/13] Removed the _TO_ to make it consitant with others

---
 src/qonnx/custom_op/general/quant.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py
index 5cdc1294..b0b50b9a 100644
--- a/src/qonnx/custom_op/general/quant.py
+++ b/src/qonnx/custom_op/general/quant.py
@@ -135,7 +135,7 @@ def resolve_rounding_mode(mode_string):
     """Resolve the rounding mode string of Quant and Trunc ops
     to the corresponding numpy functions."""
     normalized_mode_string = mode_string.upper()
-    if normalized_mode_string == "ROUND" or normalized_mode_string == "HALF_TO_EVEN":
+    if normalized_mode_string == "ROUND" or normalized_mode_string == "HALF_EVEN":
         return np.round
     elif normalized_mode_string == "CEIL":
         return np.ceil