-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransform.py
153 lines (137 loc) · 5.69 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright 2022 Computer Systems Department, Jozef Stefan Insitute
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import onnx
import qonnx.converters
import qonnx.custom_op.registry
import qonnx.util.cleanup
import torch
from brevitas.export import export_qonnx
from onnx.onnx_ml_pb2 import NodeProto
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.double_to_single_float import DoubleToSingleFloat
from qonnx.transformation.extract_conv_bias import ExtractBiasFromConv
from qonnx.transformation.gemm_to_matmul import GemmToMatMul
from qonnx.transformation.general import SortGraph
from qonnx.transformation.infer_data_layouts import InferDataLayouts
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit
from qonnx.transformation.remove import RemoveIdentityOps
import chisel4ml.lbir.lbir_pb2 as lbir
from chisel4ml.transforms import AddDummyBiasToConv
from chisel4ml.transforms import AddFFTrealOutputShape
from chisel4ml.transforms import AddInputOrOutputQTensorToReshape
from chisel4ml.transforms import CleanupQTensors
from chisel4ml.transforms import ExtractQuantizedBiasFromConv
from chisel4ml.transforms import InputReluQTensorToQTensor
from chisel4ml.transforms import QONNXToLBIR
from chisel4ml.transforms import QuantToQTensor
from chisel4ml.transforms import RemoveFlattenNode
from chisel4ml.transforms import UnquantizedBiasToQTensor
from chisel4ml.transforms import UnquantizedOutputToQTensor
from chisel4ml.transforms import WeightQuantToQTensor
DEFAULT_QONNX_TRANSFORMS = [
DoubleToSingleFloat(),
InferDataLayouts(),
AddFFTrealOutputShape(),
GemmToMatMul(),
FoldTransposeIntoQuantInit(),
InferShapes(),
RemoveIdentityOps(),
SortGraph(),
]
QONNX_TO_LBIR_TRANSFORMS = [
AddDummyBiasToConv(),
ExtractQuantizedBiasFromConv(),
ExtractBiasFromConv(),
WeightQuantToQTensor(),
QuantToQTensor(),
AddInputOrOutputQTensorToReshape(),
UnquantizedBiasToQTensor(),
UnquantizedOutputToQTensor(),
InputReluQTensorToQTensor(),
RemoveFlattenNode(),
QONNXToLBIR(),
CleanupQTensors(),
]
def qkeras_to_qonnx(qkeras_model):
qonnx_proto, _ = qonnx.converters.from_keras(qkeras_model)
modelwrap = qonnx.core.modelwrapper.ModelWrapper(qonnx_proto)
return modelwrap
def brevitas_to_qonnx(brevitas_model, ishape):
qonnx_proto = export_qonnx(brevitas_model, torch.randn(ishape))
modelwrap = qonnx.core.modelwrapper.ModelWrapper(qonnx_proto)
return modelwrap
def qonnx_to_lbir(
modelwrap: ModelWrapper,
name="chisel4ml_model",
custom_trans_list=[],
cleanup=True,
debug=False,
) -> lbir.Model:
"Applys transformation to a QONNX model, and returns a LBIR model."
if len(custom_trans_list) == 0:
transforms = DEFAULT_QONNX_TRANSFORMS
else:
transforms = custom_trans_list
qonnx.custom_op.registry.register_custom_domain("chisel4ml")
if cleanup:
modelwrap = qonnx.util.cleanup.cleanup_model(modelwrap)
for ind, trans in enumerate(transforms + QONNX_TO_LBIR_TRANSFORMS):
logging.info(f"Running transform {type(trans).__name__}.")
if debug:
onnx.save(
modelwrap.model,
f"DEBUG_{name}_{ind}_BEFORE_{type(trans).__name__}.onnx",
)
modelwrap = modelwrap.transform(trans)
if debug:
onnx.save(
modelwrap.model,
f"DEBUG_{name}_FINAL_.onnx",
)
lbir_model = _uwrap_qonnx_to_lbir(modelwrap, name)
return lbir_model
def _uwrap_qonnx_to_lbir(onnx_model: ModelWrapper, name: str) -> lbir.Model:
if (
onnx_model.graph.node[0].op_type == "Reshape"
and onnx_model.graph.node[-1].op_type == "Reshape"
):
# This condition typically arises from QKeras conv models that have different
# tensor memory layout, hence the reshape ops.
layers = onnx_model.graph.node[1:-1]
input_channel_first = True
else:
layers = onnx_model.graph.node
input_channel_first = False
return lbir.Model(
name=name,
layers=[_unwrap_qonnx_layer_to_lbir(lay) for lay in layers],
input_channel_first=input_channel_first,
)
def _unwrap_qonnx_layer_to_lbir(layer: NodeProto) -> lbir.LayerWrap:
if layer.op_type == "QDense":
qdense_str = onnx.helper.get_node_attr_value(layer, "qdense")
return lbir.LayerWrap(dense=lbir.DenseConfig.FromString(qdense_str))
elif layer.op_type == "QConv":
qconv_str = onnx.helper.get_node_attr_value(layer, "qconv")
return lbir.LayerWrap(conv2d=lbir.Conv2DConfig.FromString(qconv_str))
elif layer.op_type == "MaxPool2D":
maxpool_str = onnx.helper.get_node_attr_value(layer, "maxpool2d")
return lbir.LayerWrap(maxpool2d=lbir.MaxPool2DConfig.FromString(maxpool_str))
elif layer.op_type == "FFTConfig":
fft_str = onnx.helper.get_node_attr_value(layer, "fft")
return lbir.LayerWrap(fft=lbir.FFTConfig.FromString(fft_str))
elif layer.op_type == "LMFEConfig":
lmfe_str = onnx.helper.get_node_attr_value(layer, "lmfe")
return lbir.LayerWrap(lmfe=lbir.LMFEConfig.FromString(lmfe_str))
else:
raise NotImplementedError