diff --git a/backends/nxp/BUCK b/backends/nxp/BUCK index ab22cc35466..8355874db92 100644 --- a/backends/nxp/BUCK +++ b/backends/nxp/BUCK @@ -55,6 +55,7 @@ fbcode_target(_kind = runtime.python_library, deps = [ ":aten_passes", "//caffe2:torch", + "//executorch/backends/transforms:quantize_fused_convbn_bias_pass", "//pytorch/ao:torchao", # @manual ], ) diff --git a/backends/nxp/quantizer/utils.py b/backends/nxp/quantizer/utils.py index 459f31ec7da..a34610eb257 100644 --- a/backends/nxp/quantizer/utils.py +++ b/backends/nxp/quantizer/utils.py @@ -29,6 +29,10 @@ prepare_pt2e, prepare_qat_pt2e, ) +from executorch.backends.transforms.quantize_fused_convbn_bias_pass import ( + QuantizeFusedConvBnBiasAtenPass, +) + from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY, Quantizer @@ -162,14 +166,13 @@ def find_sequential_partitions_aten( def calibrate_and_quantize( - model: ExportedProgram | fx.GraphModule, + model: ExportedProgram, calibration_inputs: Iterable[tuple[torch.Tensor, ...]], quantizer: Quantizer, is_qat: bool = False, ) -> fx.GraphModule: """Quantize the provided model. - - :param model: Aten model (or it's GraphModule representation) to quantize. + :param model: Aten exported model to quantize. :param calibration_inputs: Either a tuple of calibration input tensors where each element corresponds to a model input. Or an iterator over such tuples. :param quantizer: Quantizer to use. @@ -179,8 +182,7 @@ def calibrate_and_quantize( :return: Quantized GraphModule. """ - if isinstance(model, ExportedProgram): - model = model.module() + model = model.module() if is_qat: m = prepare_qat_pt2e(model, quantizer) @@ -192,4 +194,5 @@ def calibrate_and_quantize( m(*data) m = convert_pt2e(m) + QuantizeFusedConvBnBiasAtenPass()(m) return m diff --git a/backends/nxp/tests/BUCK b/backends/nxp/tests/BUCK index 24da34357db..66ec9ba1f9b 100644 --- a/backends/nxp/tests/BUCK +++ b/backends/nxp/tests/BUCK @@ -53,3 +53,18 @@ fbcode_target(_kind = python_pytest, ":models", ] ) + +fbcode_target(_kind = python_pytest, + name = "test_batch_norm_fusion", + srcs = [ + "test_batch_norm_fusion.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/nxp:neutron_backend", + ":executorch_pipeline", + ":models", + "fbsource//third-party/pypi/pytest:pytest", + "fbsource//third-party/pypi/numpy:numpy", + ], +) diff --git a/backends/nxp/tests/executors.py b/backends/nxp/tests/executors.py index f9156b0b86e..319f372b5fa 100644 --- a/backends/nxp/tests/executors.py +++ b/backends/nxp/tests/executors.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import warnings from typing import Callable, Dict, Union @@ -36,7 +38,10 @@ try: import tensorflow.lite as tflite except ModuleNotFoundError: - import tflite_runtime.interpreter as tflite + try: + import tflite_runtime.interpreter as tflite + except ModuleNotFoundError: + tflite = None class EdgeProgramExecutor: @@ -85,7 +90,7 @@ def __init__( saved_model_name="model.tflite", delegate_path=None, num_threads=None, - op_resolver_type=tflite.experimental.OpResolverType.AUTO, + op_resolver_type=None, ): """ Construct TFLiteExecutor used to quickly run inference on TFLite model. @@ -105,6 +110,8 @@ def __init__( https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter for details. Default value is tflite.experimental.OpResolverType.AUTO. """ + if op_resolver_type is None: + op_resolver_type = tflite.experimental.OpResolverType.AUTO assert model_path is not None or model_content is not None assert model_path is None or model_content is None @@ -310,9 +317,12 @@ def convert_run_compare( tflite_input_preprocess: TFLiteIOPreprocess = TFLiteIOPreprocess(), # noqa B008 tflite_output_preprocess: TFLiteIOPreprocess = TFLiteIOPreprocess(), # noqa B008 conversion_config: ConversionConfig = ConversionConfig(), # noqa B008 - tflite_op_resolver_type=tflite.experimental.OpResolverType.AUTO, + tflite_op_resolver_type=None, ) -> (TFLiteExecutor, EdgeProgramExecutor): + if tflite_op_resolver_type is None: + tflite_op_resolver_type = tflite.experimental.OpResolverType.AUTO + if tfl_model is None: NodeFormatInference(edge_program).identify_node_formats() tfl_model, _ = EdgeProgramToIRConverter().convert_program( diff --git a/backends/nxp/tests/test_batch_norm_fusion.py b/backends/nxp/tests/test_batch_norm_fusion.py index eeb4b03d7a6..71d3620bd3e 100644 --- a/backends/nxp/tests/test_batch_norm_fusion.py +++ b/backends/nxp/tests/test_batch_norm_fusion.py @@ -25,6 +25,8 @@ from executorch.backends.nxp.tests.executors import OverrideTargetSupportCheck from torch import nn +from executorch.backends.nxp.tests.models import ConvBNModule + @pytest.fixture(autouse=True) def reseed_model_per_test_run(): @@ -231,3 +233,24 @@ def unsupported_target(*_): # Accept all input arguments and return `False`. node.op == "call_function" and "batch_norm" in node.target.__name__ for node in nodes ) +@pytest.mark.parametrize( + "conv_module", + ["conv2d"], +) +def test_biasless_convbn_fusion_qat( + conv_module, +): + if conv_module.startswith("conv1d"): + input_shape = (1, 3, 32) + elif conv_module.startswith("conv2d"): + input_shape = (1, 3, 32, 32) + else: # conv3d + input_shape = (1, 3, 32, 32, 32) + + model = ConvBNModule(conv_module, conv_bias=False, bn_affine=True) + + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=True, use_neutron_for_format_conversion=False + ).exported_program() + + assert any("lowered_module" in node.name for node in edge_program.graph.nodes) diff --git a/backends/transforms/BUCK b/backends/transforms/BUCK index f5029903c21..86039c8db76 100644 --- a/backends/transforms/BUCK +++ b/backends/transforms/BUCK @@ -1,6 +1,29 @@ load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target") +load("@fbcode_macros//build_defs:python_pytest.bzl", "python_pytest") load(":targets.bzl", "define_common_targets") oncall("executorch") fbcode_target(_kind = define_common_targets,) + +fbcode_target(_kind = python_pytest, + name = "test_quantize_fused_convbn_bias_pass", + srcs = [ + "test/test_quantize_fused_convbn_bias_pass.py", + ], + deps = [ + "//caffe2:torch", + ":quantize_fused_convbn_bias_pass", + "//executorch/backends/arm/quantizer:arm_quantizer", + "//executorch/backends/arm/test:arm_tester_lib", + "//executorch/backends/arm/test:arm_tester_serialize", + "//executorch/backends/arm/test:common", + "//executorch/backends/arm/tosa:tosa", + "//executorch/backends/nxp:quantizer", + "//executorch/backends/nxp:neutron_backend", + "//executorch/backends/xnnpack/test/tester:tester", + "//executorch/exir:lib", + "//executorch/kernels/quantized:custom_ops_generated_lib", + "fbsource//third-party/pypi/pytest:pytest", + ], +) diff --git a/backends/transforms/quantize_fused_convbn_bias_pass.py b/backends/transforms/quantize_fused_convbn_bias_pass.py new file mode 100644 index 00000000000..1f1887923a9 --- /dev/null +++ b/backends/transforms/quantize_fused_convbn_bias_pass.py @@ -0,0 +1,365 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass +from torch import fx +from torch._export.utils import ( + get_buffer, + get_lifted_tensor_constant, + get_param, + is_lifted_tensor_constant, + is_param, +) +from torch._guards import detect_fake_mode +from torch.export.exported_program import InputKind, InputSpec, TensorArgument +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +# --- ExportedProgram param helpers --- + + +def _set_param_ep(exported_program, node_or_name, tensor, insert_before=None): + """Set or create a parameter in an exported program. + + If node_or_name is a Node, updates the existing parameter or constant value. + If node_or_name is a string, creates a new parameter placeholder. + """ + fake_mode = detect_fake_mode( + tuple( + node.meta["val"] + for node in exported_program.graph.nodes + if node.op == "placeholder" + ) + ) + + if isinstance(node_or_name, fx.Node): + node = node_or_name + if node.name in exported_program.graph_signature.inputs_to_parameters: + name = exported_program.graph_signature.inputs_to_parameters[node.name] + exported_program.state_dict[name] = torch.nn.Parameter( + tensor, requires_grad=False + ) + elif ( + node.name + in exported_program.graph_signature.inputs_to_lifted_tensor_constants + ): + name = exported_program.graph_signature.inputs_to_lifted_tensor_constants[ + node.name + ] + exported_program.constants[name] = tensor + else: + raise ValueError( + f"Node {node.name} is not a parameter or lifted tensor constant" + ) + node.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True) + node.meta["val"].constant = tensor + return node + + # Create a new parameter from string name + name = node_or_name + graph = exported_program.graph_module.graph + placeholders = [n for n in graph.nodes if n.op == "placeholder"] + input_name = f"arg_{name}" + with graph.inserting_before(placeholders[0]): + new_placeholder = graph.placeholder(input_name) + exported_program.graph_signature.input_specs.insert( + 0, + InputSpec( + kind=InputKind.PARAMETER, + arg=TensorArgument(name=input_name), + target=name, + persistent=None, + ), + ) + exported_program.state_dict[name] = torch.nn.Parameter(tensor, requires_grad=False) + new_placeholder.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True) + new_placeholder.meta["val"].constant = tensor + return new_placeholder + + +def _get_bias_tensor_ep(exported_program, bias_node): + """Extract bias tensor from parameter or lifted constant in an ExportedProgram.""" + if is_param(exported_program, bias_node): + return get_param(exported_program, bias_node) + elif is_lifted_tensor_constant(exported_program, bias_node): + return get_lifted_tensor_constant(exported_program, bias_node) + return None + + +# --- GraphModule param helpers --- + + +def _get_tensor_from_node(graph_module, node): + """Get tensor from a get_attr node on a GraphModule.""" + if node is None or node.op != "get_attr": + return None + target_atoms = node.target.split(".") + attr = graph_module + for atom in target_atoms: + if not hasattr(attr, atom): + return None + attr = getattr(attr, atom) + return attr + + +def _set_param_gm(graph_module, node_or_name, tensor, insert_before=None): + """Set or create a parameter on a GraphModule using get_attr nodes. + + If node_or_name is a Node, updates the existing parameter tensor. + If node_or_name is a string, creates a new get_attr node. + """ + if isinstance(node_or_name, fx.Node): + node = node_or_name + target_atoms = node.target.split(".") + parent = graph_module + for atom in target_atoms[:-1]: + parent = getattr(parent, atom) + setattr( + parent, + target_atoms[-1], + torch.nn.Parameter(tensor, requires_grad=False), + ) + if "val" in node.meta: + fake_mode = detect_fake_mode( + tuple( + n.meta["val"] for n in graph_module.graph.nodes if "val" in n.meta + ) + ) + if fake_mode is not None: + node.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True) + else: + node.meta["val"] = tensor + return node + + # Create new get_attr node + name = node_or_name + graph_module.register_parameter( + name, torch.nn.Parameter(tensor, requires_grad=False) + ) + with graph_module.graph.inserting_before(insert_before): + new_node = graph_module.graph.get_attr(name) + fake_mode = detect_fake_mode( + tuple(n.meta["val"] for n in graph_module.graph.nodes if "val" in n.meta) + ) + if fake_mode is not None: + new_node.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True) + else: + new_node.meta["val"] = tensor + return new_node + + +# --- Shared core logic --- + + +def _quantize_fused_conv_bias( + graph_module, + conv_targets, + unsqueeze_targets, + dq_per_tensor, + dq_per_channel, + get_bias_tensor, + set_param, + get_weight_scale_tensor, + default_zero_bias=False, +): + """Core logic for quantizing biases introduced by BatchNorm fusion/QAT. + + BatchNorm fusion or QAT introduces a bias to conv layers that originally had + bias=False. Since the bias is added after the quantizer runs, it lacks proper + quantize->dequantize nodes. This function adds them. + + Args: + graph_module: The graph module to transform. + conv_targets: Tuple of conv op targets to match. + unsqueeze_targets: Tuple of unsqueeze op targets to unwrap. + dq_per_tensor: The dequantize_per_tensor op for this dialect. + dq_per_channel: The dequantize_per_channel op for this dialect. + get_bias_tensor: Callable(node) -> Optional[Tensor]. + set_param: Callable(node_or_name, tensor, insert_before=None) -> Node. + get_weight_scale_tensor: Callable(node) -> Tensor. + default_zero_bias: If True, create zero bias for conv nodes without bias. + + Returns: + True if any modifications were made. + """ + modified = False + for node in graph_module.graph.nodes: + if node.target not in conv_targets: + continue + + input_dequant, weight_dequant, bias_node, *_ = node.args + + if bias_node is None: + if default_zero_bias: + channel = node.meta["val"].shape[1] + bias_node = set_param( + node.name + "_default_zero_bias", + torch.zeros(channel), + insert_before=node, + ) + args = list(node.args) + args[2] = bias_node + node.args = tuple(args) + else: + continue + + bias = get_bias_tensor(bias_node) + if bias is None or bias.dtype == torch.int32: + continue + + if input_dequant.target in unsqueeze_targets: + input_dequant = input_dequant.args[0] + + assert ( + input_dequant.target == dq_per_tensor + ), f"Expected dequantize_per_tensor, got {input_dequant.target}" + + bias_val = bias_node.meta.get("val") + dequant_val = ( + bias_val.to(torch.float32) + if bias_val is not None + else torch.empty(bias.shape, dtype=torch.float32) + ) + + if isinstance(weight_dequant.args[1], torch.fx.node.Node): + weight_scale = get_weight_scale_tensor(weight_dequant.args[1]) + bias_scale = input_dequant.args[1] * weight_scale + + bias_zp = torch.zeros(bias_scale.shape, dtype=torch.int32) + qbias = torch.ops.quantized_decomposed.quantize_per_channel.default( + bias, + bias_scale, + bias_zp, + 0, + -(2**31), + 2**31 - 1, + torch.int32, + ) + set_param(bias_node, qbias) + + scale_node = set_param( + node.name + "_bias_scale", bias_scale, insert_before=node + ) + zp_node = set_param( + node.name + "_bias_zero_point", bias_zp, insert_before=node + ) + + with graph_module.graph.inserting_before(node): + bias_dequant = graph_module.graph.call_function( + dq_per_channel, + ( + bias_node, + scale_node, + zp_node, + 0, + -(2**31), + 2**31 - 1, + torch.int32, + ), + ) + bias_dequant.meta["val"] = dequant_val + node.replace_input_with(bias_node, bias_dequant) + else: + weight_scale = weight_dequant.args[1] + bias_scale = input_dequant.args[1] * weight_scale + + qbias = torch.ops.quantized_decomposed.quantize_per_tensor.default( + bias, bias_scale, 0, -(2**31), 2**31 - 1, torch.int32 + ) + set_param(bias_node, qbias) + + with graph_module.graph.inserting_before(node): + bias_dequant = graph_module.graph.call_function( + dq_per_tensor, + (bias_node, bias_scale, 0, -(2**31), 2**31 - 1, torch.int32), + ) + bias_dequant.meta["val"] = dequant_val + node.replace_input_with(bias_node, bias_dequant) + + modified = True + + graph_module.recompile() + return modified + + +# --- Pass classes --- + + +class QuantizeFusedConvBnBiasPass(ExportPass): + """Quantize biases introduced by BatchNorm fusion/QAT on edge dialect graphs. + + Works on ExportedPrograms after to_edge() conversion. + """ + + def __init__(self, exported_program, default_zero_bias=False) -> None: + super().__init__() + self.exported_program = exported_program + self.default_zero_bias = default_zero_bias + + def call(self, graph_module: fx.GraphModule) -> PassResult: + ep = self.exported_program + modified = _quantize_fused_conv_bias( + graph_module, + conv_targets=(exir_ops.edge.aten.convolution.default,), + unsqueeze_targets=(exir_ops.edge.aten.unsqueeze_copy.default,), + dq_per_tensor=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + dq_per_channel=exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + get_bias_tensor=lambda node: _get_bias_tensor_ep(ep, node), + set_param=lambda n, t, insert_before=None: _set_param_ep(ep, n, t), + get_weight_scale_tensor=lambda node: get_buffer(ep, node), + default_zero_bias=self.default_zero_bias, + ) + return PassResult(graph_module, modified) + + +class QuantizeFusedConvBnBiasAtenPass(PassBase): + """Quantize biases introduced by BatchNorm fusion/QAT on aten dialect graphs. + + Operates on a GraphModule. If the graph_module came from an ExportedProgram + (params are placeholder nodes), pass the exported_program so params can be + resolved. If operating on a plain GraphModule (params are get_attr nodes), + exported_program can be omitted. + """ + + def __init__(self, exported_program=None, default_zero_bias=False) -> None: + self.exported_program = exported_program + self.default_zero_bias = default_zero_bias + + def call(self, graph_module: fx.GraphModule) -> PassResult: + ep = self.exported_program + if ep is not None: + get_bias = lambda node: _get_bias_tensor_ep(ep, node) + set_param = lambda n, t, insert_before=None: _set_param_ep(ep, n, t) + get_scale = lambda node: get_buffer(ep, node) + else: + get_bias = lambda node: _get_tensor_from_node(graph_module, node) + set_param = lambda n, t, insert_before=None: _set_param_gm( + graph_module, n, t, insert_before + ) + get_scale = lambda node: _get_tensor_from_node(graph_module, node) + + modified = _quantize_fused_conv_bias( + graph_module, + conv_targets=( + torch.ops.aten.convolution.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv_transpose2d.input, + ), + unsqueeze_targets=( + torch.ops.aten.unsqueeze_copy.default, + torch.ops.aten.unsqueeze.default, + ), + dq_per_tensor=torch.ops.quantized_decomposed.dequantize_per_tensor.default, + dq_per_channel=torch.ops.quantized_decomposed.dequantize_per_channel.default, + get_bias_tensor=get_bias, + set_param=set_param, + get_weight_scale_tensor=get_scale, + default_zero_bias=self.default_zero_bias, + ) + return PassResult(graph_module, modified) diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index 3cda58f6426..b5299350009 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -204,6 +204,17 @@ def define_common_targets(): ], ) + runtime.python_library( + name = "quantize_fused_convbn_bias_pass", + srcs = ["quantize_fused_convbn_bias_pass.py"], + visibility = ["PUBLIC"], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], + ) + runtime.python_test( name = "test_duplicate_dynamic_quant_chain", srcs = [ diff --git a/backends/transforms/test/test_quantize_fused_convbn_bias_pass.py b/backends/transforms/test/test_quantize_fused_convbn_bias_pass.py new file mode 100644 index 00000000000..883b218b8bd --- /dev/null +++ b/backends/transforms/test/test_quantize_fused_convbn_bias_pass.py @@ -0,0 +1,282 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree.. + +from typing import Tuple + +import pytest +import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer +from executorch.backends.nxp.quantizer.utils import calibrate_and_quantize +from executorch.backends.transforms.quantize_fused_convbn_bias_pass import ( + QuantizeFusedConvBnBiasAtenPass, + QuantizeFusedConvBnBiasPass, +) +from executorch.backends.xnnpack.test.tester.tester import Quantize +from executorch.exir import to_edge +from executorch.exir.dialects._ops import ops as exir_ops +from torch import nn +from torch.export import export +from torchao.quantization.pt2e import move_exported_model_to_eval +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_qat_pt2e + + +input_t = Tuple[torch.Tensor] + + +class ConvBnNoBias(nn.Module): + """Conv2d with bias=False followed by BatchNorm. QAT fusion introduces a bias.""" + + def __init__(self, per_channel: bool = True) -> None: + super().__init__() + self.conv = nn.Conv2d(3, 16, kernel_size=3, bias=False) + self.bn = nn.BatchNorm2d(16) + + def get_inputs(self) -> input_t: + return (torch.randn(1, 3, 32, 32),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.bn(self.conv(x)) + + +class ConvBnReluNoBias(nn.Module): + """Conv2d with bias=False, BatchNorm, and ReLU.""" + + def __init__(self) -> None: + super().__init__() + self.conv = nn.Conv2d(3, 16, kernel_size=3, bias=False) + self.bn = nn.BatchNorm2d(16) + self.relu = nn.ReLU() + + def get_inputs(self) -> input_t: + return (torch.randn(1, 3, 32, 32),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.relu(self.bn(self.conv(x))) + + +class Conv1dBnNoBias(nn.Module): + """Conv1d with bias=False followed by BatchNorm.""" + + def __init__(self) -> None: + super().__init__() + self.conv = nn.Conv1d(3, 8, kernel_size=3, bias=False) + self.bn = nn.BatchNorm1d(8) + + def get_inputs(self) -> input_t: + return (torch.randn(2, 3, 16),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.bn(self.conv(x)) + + +# --- Shared helpers --- + + +def _qat_prepare_convert(model, per_channel): + """QAT prepare -> calibrate -> convert_pt2e, returns GraphModule with get_attr nodes.""" + quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) + quantizer.set_global( + get_symmetric_quantization_config(is_qat=True, is_per_channel=per_channel) + ) + example_input = model.get_inputs() + exported = export(model, example_input, strict=True).module() + prepared = prepare_qat_pt2e(exported, quantizer) + prepared(*example_input) + move_exported_model_to_eval(prepared) + converted = convert_pt2e(prepared) + return converted + + +def _assert_bias_dequantized(graph, conv_targets, dequant_targets): + """Assert every conv's bias flows through a dequantize node.""" + conv_count = 0 + for node in graph.nodes: + if node.target not in conv_targets: + continue + conv_count += 1 + bias = node.args[2] + assert bias is not None, "Bias should not be None after pass" + assert ( + bias.target in dequant_targets + ), f"Bias should be dequantized, got {bias.target}" + assert conv_count > 0, "Expected at least one convolution node" + + +# --- ARM (TOSA) tests --- + +arm_models = { + "conv2d_bn_no_bias_per_channel": (ConvBnNoBias(), True), + "conv2d_bn_no_bias_per_tensor": (ConvBnNoBias(), False), + "conv2d_bn_relu_no_bias_per_channel": (ConvBnReluNoBias(), True), + "conv2d_bn_relu_no_bias_per_tensor": (ConvBnReluNoBias(), False), + "conv1d_bn_no_bias_per_channel": (Conv1dBnNoBias(), True), + "conv1d_bn_no_bias_per_tensor": (Conv1dBnNoBias(), False), +} + + +@common.parametrize("test_data", arm_models) +def test_quantize_fused_convbn_bias_arm_qat(test_data) -> None: + """ + Test that QuantizeFusedConvBnBiasPass correctly quantizes the bias + introduced by BatchNorm fusion during QAT when the original conv has bias=False. + Uses the ARM TOSA quantizer. + """ + model, per_channel = test_data + pipeline = PassPipeline[input_t]( + model, + model.get_inputs(), + quantize=True, + passes_with_exported_program=[QuantizeFusedConvBnBiasPass], + ) + + quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) + pipeline.change_args( + "quantize", + Quantize( + quantizer=quantizer, + quantization_config=get_symmetric_quantization_config( + is_qat=True, is_per_channel=per_channel + ), + ), + ) + + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + +# --- NXP (Neutron) tests --- + + +def _run_nxp_qat_pass(model: nn.Module, use_edge: bool = True) -> None: + """Quantize a model with NXP's NeutronQuantizer in QAT mode, optionally convert to + edge, and verify that the fused bias is quantized via calibrate_and_quantize.""" + example_input = model.get_inputs() + + target_spec = NeutronTargetSpec( + target="imxrt700", neutron_converter_flavor="SDK_25_12" + ) + quantizer = NeutronQuantizer(target_spec, is_qat=True) + + exported = export(model, example_input, strict=True) + quantized_gm = calibrate_and_quantize( + model=exported, + calibration_inputs=[example_input], + quantizer=quantizer, + is_qat=True, + ) + + if use_edge: + re_exported = export(quantized_gm, example_input, strict=True) + edge_program_manager = to_edge(re_exported) + edge_program = edge_program_manager.exported_program() + graph = edge_program.graph_module.graph + conv_targets = (exir_ops.edge.aten.convolution.default,) + dequant_targets = ( + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + ) + else: + graph = quantized_gm.graph + conv_targets = ( + torch.ops.aten.convolution.default, + torch.ops.aten.conv2d.default, + ) + dequant_targets = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + ) + + _assert_bias_dequantized(graph, conv_targets, dequant_targets) + + +@pytest.mark.parametrize( + "model", + [ + pytest.param(ConvBnNoBias(), id="conv2d_bn_no_bias"), + pytest.param(ConvBnReluNoBias(), id="conv2d_bn_relu_no_bias"), + ], +) +def test_quantize_fused_convbn_bias_nxp_qat(model: nn.Module) -> None: + """ + Test that QuantizeFusedConvBnBiasPass correctly quantizes the bias + introduced by BatchNorm fusion during QAT when the original conv has bias=False. + Uses the NXP Neutron quantizer with edge dialect. + """ + _run_nxp_qat_pass(model, use_edge=True) + + +@pytest.mark.parametrize( + "model", + [ + pytest.param(ConvBnNoBias(), id="conv2d_bn_no_bias"), + pytest.param(ConvBnReluNoBias(), id="conv2d_bn_relu_no_bias"), + ], +) +def test_quantize_fused_convbn_bias_nxp_qat_aten(model: nn.Module) -> None: + """ + Test that QuantizeFusedConvBnBiasPass correctly quantizes the bias + on aten-dialect graphs (without edge conversion). + Uses the NXP Neutron quantizer. + """ + _run_nxp_qat_pass(model, use_edge=False) + + +# --- Direct aten pass tests (no NXP dependency) --- + +_aten_conv_targets = ( + torch.ops.aten.convolution.default, + torch.ops.aten.conv2d.default, +) +_aten_dequant_targets = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, +) + +_aten_direct_models = { + "conv2d_bn_per_channel": (ConvBnNoBias, True), + "conv2d_bn_per_tensor": (ConvBnNoBias, False), + "conv2d_bn_relu_per_channel": (ConvBnReluNoBias, True), + "conv2d_bn_relu_per_tensor": (ConvBnReluNoBias, False), +} + + +@common.parametrize("test_data", _aten_direct_models) +def test_aten_pass_direct(test_data) -> None: + """QuantizeFusedConvBnBiasAtenPass on GraphModule (get_attr nodes, no EP).""" + model_cls, per_channel = test_data + gm = _qat_prepare_convert(model_cls(), per_channel) + QuantizeFusedConvBnBiasAtenPass()(gm) + _assert_bias_dequantized(gm.graph, _aten_conv_targets, _aten_dequant_targets) + + +@common.parametrize("test_data", _aten_direct_models) +def test_aten_pass_with_exported_program(test_data) -> None: + """QuantizeFusedConvBnBiasAtenPass on graph_module from EP (placeholder nodes).""" + model_cls, per_channel = test_data + model = model_cls() + gm = _qat_prepare_convert(model, per_channel) + ep = export(gm, model.get_inputs(), strict=True) + QuantizeFusedConvBnBiasAtenPass(ep)(ep.graph_module) + _assert_bias_dequantized( + ep.graph_module.graph, _aten_conv_targets, _aten_dequant_targets + ) + + +def test_aten_pass_idempotent() -> None: + """Running the pass twice doesn't break.""" + model = ConvBnNoBias() + gm = _qat_prepare_convert(model, per_channel=True) + QuantizeFusedConvBnBiasAtenPass()(gm) + QuantizeFusedConvBnBiasAtenPass()(gm) + _assert_bias_dequantized(gm.graph, _aten_conv_targets, _aten_dequant_targets)