From 6b825e99614f62d7cfae30776b6fc8f962a39ddc Mon Sep 17 00:00:00 2001 From: Jake Stevens Date: Tue, 10 Feb 2026 11:55:17 -0800 Subject: [PATCH] Add quantize fused convbn bias pass (#17348) Summary: When performing QAT with a model that has a conv layer with no bias followed by batch norm, the fusion process creates a bias. This is done *after* observers are attached so the resulting bias is kept as float. This diff adds a pass which grabs the proper qparams and applies them to the non-quantized bias. Differential Revision: D92733079 --- backends/transforms/BUCK | 23 ++ .../quantize_fused_convbn_bias_pass.py | 270 ++++++++++++++++++ backends/transforms/targets.bzl | 11 + .../test_quantize_fused_convbn_bias_pass.py | 208 ++++++++++++++ 4 files changed, 512 insertions(+) create mode 100644 backends/transforms/quantize_fused_convbn_bias_pass.py create mode 100644 backends/transforms/test/test_quantize_fused_convbn_bias_pass.py diff --git a/backends/transforms/BUCK b/backends/transforms/BUCK index f5029903c21..86039c8db76 100644 --- a/backends/transforms/BUCK +++ b/backends/transforms/BUCK @@ -1,6 +1,29 @@ load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target") +load("@fbcode_macros//build_defs:python_pytest.bzl", "python_pytest") load(":targets.bzl", "define_common_targets") oncall("executorch") fbcode_target(_kind = define_common_targets,) + +fbcode_target(_kind = python_pytest, + name = "test_quantize_fused_convbn_bias_pass", + srcs = [ + "test/test_quantize_fused_convbn_bias_pass.py", + ], + deps = [ + "//caffe2:torch", + ":quantize_fused_convbn_bias_pass", + "//executorch/backends/arm/quantizer:arm_quantizer", + "//executorch/backends/arm/test:arm_tester_lib", + "//executorch/backends/arm/test:arm_tester_serialize", + "//executorch/backends/arm/test:common", + "//executorch/backends/arm/tosa:tosa", + "//executorch/backends/nxp:quantizer", + "//executorch/backends/nxp:neutron_backend", + "//executorch/backends/xnnpack/test/tester:tester", + "//executorch/exir:lib", + "//executorch/kernels/quantized:custom_ops_generated_lib", + "fbsource//third-party/pypi/pytest:pytest", + ], +) diff --git a/backends/transforms/quantize_fused_convbn_bias_pass.py b/backends/transforms/quantize_fused_convbn_bias_pass.py new file mode 100644 index 00000000000..f5afeac92e7 --- /dev/null +++ b/backends/transforms/quantize_fused_convbn_bias_pass.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch import fx +from torch._export.utils import ( + get_buffer, + get_lifted_tensor_constant, + get_param, + is_lifted_tensor_constant, + is_param, +) +from torch._guards import detect_fake_mode +from torch.export.exported_program import InputKind, InputSpec, TensorArgument + + +def _set_param(exported_program, node_or_name, tensor): + """Set or create a parameter in an exported program. + + If node_or_name is a Node, updates the existing parameter or constant value. + If node_or_name is a string, creates a new parameter placeholder. + """ + fake_mode = detect_fake_mode( + tuple( + node.meta["val"] + for node in exported_program.graph.nodes + if node.op == "placeholder" + ) + ) + + if isinstance(node_or_name, fx.Node): + node = node_or_name + if node.name in exported_program.graph_signature.inputs_to_parameters: + name = exported_program.graph_signature.inputs_to_parameters[node.name] + exported_program.state_dict[name] = tensor + elif ( + node.name + in exported_program.graph_signature.inputs_to_lifted_tensor_constants + ): + name = exported_program.graph_signature.inputs_to_lifted_tensor_constants[ + node.name + ] + exported_program.constants[name] = tensor + else: + raise ValueError( + f"Node {node.name} is not a parameter or lifted tensor constant" + ) + node.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True) + node.meta["val"].constant = tensor + return node + + # Create a new parameter from string name + name = node_or_name + graph = exported_program.graph_module.graph + placeholders = [n for n in graph.nodes if n.op == "placeholder"] + input_name = f"arg_{name}" + with graph.inserting_before(placeholders[0]): + new_placeholder = graph.placeholder(input_name) + exported_program.graph_signature.input_specs.insert( + 0, + InputSpec( + kind=InputKind.PARAMETER, + arg=TensorArgument(name=input_name), + target=name, + persistent=None, + ), + ) + exported_program.state_dict[name] = tensor + new_placeholder.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True) + new_placeholder.meta["val"].constant = tensor + return new_placeholder + + +class QuantizeFusedConvBnBiasPass(ExportPass): + """ + BatchNorm fusion or QAT would introduce a bias that is not quantized if user + specified bias=False because it's not there yet when the quantizer runs. This pass + quantizes these biases so downstream passes can run. + + Supports both aten and edge dialect operators. + """ + + def __init__(self, exported_program, default_zero_bias=False) -> None: + super().__init__() + self.exported_program = exported_program + self.default_zero_bias = default_zero_bias + + def _is_conv_node(self, node): + """Check if node is a convolution operation.""" + return node.target in ( + exir_ops.edge.aten.convolution.default, + torch.ops.aten.convolution.default, + torch.ops.aten.conv2d.default, + ) + + def _is_edge_dialect(self, node): + """Check if node uses edge dialect operators.""" + return node.target == exir_ops.edge.aten.convolution.default + + def _get_or_create_bias_node(self, node): + """Get existing bias node or create a default zero bias if enabled.""" + input_dequant, weight_dequant, bias_node, *_ = node.args + if bias_node is None: + if self.default_zero_bias: + channel = node.meta["val"].shape[1] + bias_node = _set_param( + self.exported_program, + node.name + "_default_zero_bias", + torch.zeros(channel), + ) + args = list(node.args) + args[2] = bias_node + node.args = tuple(args) + return input_dequant, weight_dequant, bias_node + return None, None, None + return input_dequant, weight_dequant, bias_node + + def _get_bias_tensor(self, bias_node): + """Extract bias tensor from parameter or lifted constant.""" + if is_param(self.exported_program, bias_node): + return get_param(self.exported_program, bias_node) + elif is_lifted_tensor_constant(self.exported_program, bias_node): + return get_lifted_tensor_constant(self.exported_program, bias_node) + return None + + def _unwrap_unsqueeze(self, input_dequant, is_edge): + """Unwrap unsqueeze operations from input dequantize node.""" + if is_edge: + unsqueeze_targets = (exir_ops.edge.aten.unsqueeze_copy.default,) + else: + unsqueeze_targets = ( + torch.ops.aten.unsqueeze_copy.default, + torch.ops.aten.unsqueeze.default, + ) + if input_dequant.target in unsqueeze_targets: + return input_dequant.args[0] + return input_dequant + + def _create_dequant_val(self, bias_node, bias): + """Create fake tensor value for dequantized bias output.""" + bias_val = bias_node.meta.get("val") + if bias_val is not None: + return bias_val.to(torch.float32) + return torch.empty(bias.shape, dtype=torch.float32) + + def _quantize_bias_per_channel( + self, graph_module, node, bias, bias_node, bias_scale, dequant_val, is_edge + ): + """Quantize bias per-channel and insert dequantize node.""" + qbias = torch.ops.quantized_decomposed.quantize_per_channel.default( + bias, + bias_scale, + torch.zeros(bias_scale.shape, dtype=torch.int32), + 0, + -(2**31), + 2**31 - 1, + torch.int32, + ) + _set_param(self.exported_program, bias_node, qbias) + + dq_per_channel = ( + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default + if is_edge + else torch.ops.quantized_decomposed.dequantize_per_channel.default + ) + + with graph_module.graph.inserting_before(node): + bias_dequant = graph_module.graph.call_function( + dq_per_channel, + ( + bias_node, + bias_scale, + torch.zeros(bias_scale.shape, dtype=torch.int32), + 0, + -(2**31), + 2**31 - 1, + torch.int32, + ), + ) + bias_dequant.meta["val"] = dequant_val + node.replace_input_with(bias_node, bias_dequant) + + def _quantize_bias_per_tensor( + self, graph_module, node, bias, bias_node, bias_scale, dequant_val, is_edge + ): + """Quantize bias per-tensor and insert dequantize node.""" + qbias = torch.ops.quantized_decomposed.quantize_per_tensor.default( + bias, bias_scale, 0, -(2**31), 2**31 - 1, torch.int32 + ) + _set_param(self.exported_program, bias_node, qbias) + + dq_per_tensor = ( + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + if is_edge + else torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + + with graph_module.graph.inserting_before(node): + bias_dequant = graph_module.graph.call_function( + dq_per_tensor, + (bias_node, bias_scale, 0, -(2**31), 2**31 - 1, torch.int32), + ) + bias_dequant.meta["val"] = dequant_val + node.replace_input_with(bias_node, bias_dequant) + + def call(self, graph_module: fx.GraphModule) -> PassResult: + modified = False + for node in graph_module.graph.nodes: + if not self._is_conv_node(node): + continue + + is_edge = self._is_edge_dialect(node) + + input_dequant, weight_dequant, bias_node = self._get_or_create_bias_node( + node + ) + if bias_node is None: + continue + + bias = self._get_bias_tensor(bias_node) + if bias is None or bias.dtype == torch.int32: + continue + + input_dequant = self._unwrap_unsqueeze(input_dequant, is_edge) + + dq_per_tensor = ( + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + if is_edge + else torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + assert ( + input_dequant.target == dq_per_tensor + ), f"Expected dequantize_per_tensor, got {input_dequant.target}" + + dequant_val = self._create_dequant_val(bias_node, bias) + + if isinstance(weight_dequant.args[1], torch.fx.node.Node): + weight_scale = get_buffer(self.exported_program, weight_dequant.args[1]) + bias_scale = input_dequant.args[1] * weight_scale + self._quantize_bias_per_channel( + graph_module, + node, + bias, + bias_node, + bias_scale, + dequant_val, + is_edge, + ) + else: + weight_scale = weight_dequant.args[1] + bias_scale = input_dequant.args[1] * weight_scale + self._quantize_bias_per_tensor( + graph_module, + node, + bias, + bias_node, + bias_scale, + dequant_val, + is_edge, + ) + + modified = True + graph_module.recompile() + return PassResult(graph_module, modified) diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index 3cda58f6426..b5299350009 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -204,6 +204,17 @@ def define_common_targets(): ], ) + runtime.python_library( + name = "quantize_fused_convbn_bias_pass", + srcs = ["quantize_fused_convbn_bias_pass.py"], + visibility = ["PUBLIC"], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], + ) + runtime.python_test( name = "test_duplicate_dynamic_quant_chain", srcs = [ diff --git a/backends/transforms/test/test_quantize_fused_convbn_bias_pass.py b/backends/transforms/test/test_quantize_fused_convbn_bias_pass.py new file mode 100644 index 00000000000..c27eca5f915 --- /dev/null +++ b/backends/transforms/test/test_quantize_fused_convbn_bias_pass.py @@ -0,0 +1,208 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree.. + +from typing import Tuple + +import pytest +import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer +from executorch.backends.nxp.quantizer.utils import calibrate_and_quantize +from executorch.backends.transforms.quantize_fused_convbn_bias_pass import ( + QuantizeFusedConvBnBiasPass, +) +from executorch.backends.xnnpack.test.tester.tester import Quantize +from executorch.exir import to_edge +from executorch.exir.dialects._ops import ops as exir_ops +from torch import nn +from torch.export import export + + +input_t = Tuple[torch.Tensor] + + +class ConvBnNoBias(nn.Module): + """Conv2d with bias=False followed by BatchNorm. QAT fusion introduces a bias.""" + + def __init__(self, per_channel: bool = True) -> None: + super().__init__() + self.conv = nn.Conv2d(3, 16, kernel_size=3, bias=False) + self.bn = nn.BatchNorm2d(16) + + def get_inputs(self) -> input_t: + return (torch.randn(1, 3, 32, 32),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.bn(self.conv(x)) + + +class ConvBnReluNoBias(nn.Module): + """Conv2d with bias=False, BatchNorm, and ReLU.""" + + def __init__(self) -> None: + super().__init__() + self.conv = nn.Conv2d(3, 16, kernel_size=3, bias=False) + self.bn = nn.BatchNorm2d(16) + self.relu = nn.ReLU() + + def get_inputs(self) -> input_t: + return (torch.randn(1, 3, 32, 32),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.relu(self.bn(self.conv(x))) + + +class Conv1dBnNoBias(nn.Module): + """Conv1d with bias=False followed by BatchNorm.""" + + def __init__(self) -> None: + super().__init__() + self.conv = nn.Conv1d(3, 8, kernel_size=3, bias=False) + self.bn = nn.BatchNorm1d(8) + + def get_inputs(self) -> input_t: + return (torch.randn(2, 3, 16),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.bn(self.conv(x)) + + +# --- ARM (TOSA) tests --- + +arm_models = { + "conv2d_bn_no_bias_per_channel": (ConvBnNoBias(), True), + "conv2d_bn_no_bias_per_tensor": (ConvBnNoBias(), False), + "conv2d_bn_relu_no_bias_per_channel": (ConvBnReluNoBias(), True), + "conv2d_bn_relu_no_bias_per_tensor": (ConvBnReluNoBias(), False), + "conv1d_bn_no_bias_per_channel": (Conv1dBnNoBias(), True), + "conv1d_bn_no_bias_per_tensor": (Conv1dBnNoBias(), False), +} + + +@common.parametrize("test_data", arm_models) +def test_quantize_fused_convbn_bias_arm_qat(test_data) -> None: + """ + Test that QuantizeFusedConvBnBiasPass correctly quantizes the bias + introduced by BatchNorm fusion during QAT when the original conv has bias=False. + Uses the ARM TOSA quantizer. + """ + model, per_channel = test_data + pipeline = PassPipeline[input_t]( + model, + model.get_inputs(), + quantize=True, + passes_with_exported_program=[QuantizeFusedConvBnBiasPass], + ) + + quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) + pipeline.change_args( + "quantize", + Quantize( + quantizer=quantizer, + quantization_config=get_symmetric_quantization_config( + is_qat=True, is_per_channel=per_channel + ), + ), + ) + + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + +# --- NXP (Neutron) tests --- + + +def _run_nxp_qat_pass(model: nn.Module, use_edge: bool = True) -> None: + """Quantize a model with NXP's NeutronQuantizer in QAT mode, optionally convert to + edge, and verify that QuantizeFusedConvBnBiasPass quantizes the fused bias.""" + example_input = model.get_inputs() + + target_spec = NeutronTargetSpec( + target="imxrt700", neutron_converter_flavor="SDK_25_12" + ) + quantizer = NeutronQuantizer(target_spec, is_qat=True) + + exported = export(model, example_input, strict=True) + quantized_model = calibrate_and_quantize( + model=exported, + calibration_inputs=[example_input], + quantizer=quantizer, + is_qat=True, + ) + + exported_program = export(quantized_model, example_input, strict=True) + + if use_edge: + edge_program_manager = to_edge(exported_program) + exported_program = edge_program_manager.exported_program() + conv_targets = (exir_ops.edge.aten.convolution.default,) + dequant_targets = ( + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + ) + else: + conv_targets = ( + torch.ops.aten.convolution.default, + torch.ops.aten.conv2d.default, + ) + dequant_targets = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + ) + + pass_instance = QuantizeFusedConvBnBiasPass(exported_program) + result = pass_instance.call(exported_program.graph_module) + + assert result.modified + + # Every convolution bias should now flow through a dequantize node. + for node in exported_program.graph_module.graph.nodes: + if node.target not in conv_targets: + continue + bias = node.args[2] + assert bias is not None, "Bias should not be None after pass" + assert ( + bias.target in dequant_targets + ), f"Bias should be dequantized, got {bias.target}" + + +@pytest.mark.parametrize( + "model", + [ + pytest.param(ConvBnNoBias(), id="conv2d_bn_no_bias"), + pytest.param(ConvBnReluNoBias(), id="conv2d_bn_relu_no_bias"), + ], +) +def test_quantize_fused_convbn_bias_nxp_qat(model: nn.Module) -> None: + """ + Test that QuantizeFusedConvBnBiasPass correctly quantizes the bias + introduced by BatchNorm fusion during QAT when the original conv has bias=False. + Uses the NXP Neutron quantizer with edge dialect. + """ + _run_nxp_qat_pass(model, use_edge=True) + + +@pytest.mark.parametrize( + "model", + [ + pytest.param(ConvBnNoBias(), id="conv2d_bn_no_bias"), + pytest.param(ConvBnReluNoBias(), id="conv2d_bn_relu_no_bias"), + ], +) +def test_quantize_fused_convbn_bias_nxp_qat_aten(model: nn.Module) -> None: + """ + Test that QuantizeFusedConvBnBiasPass correctly quantizes the bias + on aten-dialect graphs (without edge conversion). + Uses the NXP Neutron quantizer. + """ + _run_nxp_qat_pass(model, use_edge=False)