Skip to content

Commit f183191

Browse files
committed
Fix (nn/TruncAvgPool): Remove any quant tensor manual manipulation.
1 parent 13ca170 commit f183191

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

src/brevitas/nn/quant_avg_pool.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def _avg_scaling(self):
5858
else:
5959
return self.kernel_size * self.kernel_size
6060

61+
# TODO: Replace with functional call
6162
def forward(self, input: Union[Tensor, QuantTensor]):
6263
x = self.unpack_input(input)
6364

@@ -71,8 +72,6 @@ def forward(self, input: Union[Tensor, QuantTensor]):
7172
if not isinstance(x, QuantTensor):
7273
x = self.cache_class.quant_tensor.set(value=x)
7374
y = AvgPool2d.forward(self, x)
74-
rescaled_value = y.value * self._avg_scaling
75-
y = y.set(value=rescaled_value)
7675
y = self.trunc_quant(y)
7776
else:
7877
y = AvgPool2d.forward(self, _unpack_quant_tensor(x))
@@ -123,6 +122,7 @@ def compute_kernel_size_stride(input_shape, output_shape):
123122
stride_list.append(stride)
124123
return kernel_size_list, stride_list
125124

125+
# TODO: Replace with functional call
126126
def forward(self, input: Union[Tensor, QuantTensor]):
127127
x = self.unpack_input(input)
128128

@@ -139,10 +139,6 @@ def forward(self, input: Union[Tensor, QuantTensor]):
139139
if not isinstance(x, QuantTensor):
140140
x = self.cache_class.quant_tensor.set(value=x)
141141
y = AdaptiveAvgPool2d.forward(self, x)
142-
k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:])
143-
reduce_size = reduce(mul, k_size, 1)
144-
rescaled_value = y.value * reduce_size # remove avg scaling
145-
y = y.set(value=rescaled_value)
146142
y = self.trunc_quant(y)
147143
else:
148144
y = AdaptiveAvgPool2d.forward(self, _unpack_quant_tensor(x))

0 commit comments

Comments
 (0)