diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index b8f3c1fc8..c4659ac87 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -236,7 +236,8 @@ class QCDQCastDecoupledWeightQuantProxyHandlerMixin(QCDQCastWeightQuantProxyHand def symbolic_execution(self, x: Tensor): out, scale, zero_point, bit_width = super().symbolic_execution(x) # Return post-rounding scale and zero-point in place of pre-rounding as a placeholder - return out, scale, zero_point, scale, zero_point, bit_width + # The order of arguments must match the order in the forward method of DecoupledRescalingIntQuant + return out, scale, zero_point, bit_width, scale, zero_point class QCDQCastDecoupledWeightQuantWithInputProxyHandlerMixin(