Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions backends/nxp/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from typing import Any, Dict, List, Tuple, Type

import torch
from executorch.backends.transforms.quantize_fused_convbn_bias_pass import (
QuantizeFusedConvBnBiasPass,
)
from torch import fx
from torch._ops import OpOverload
from torch.export import ExportedProgram
Expand Down Expand Up @@ -162,15 +165,15 @@ def find_sequential_partitions_aten(


def calibrate_and_quantize(
model: ExportedProgram | fx.GraphModule,
model: ExportedProgram,
calibration_inputs: Iterable[tuple[torch.Tensor, ...]],
quantizer: Quantizer,
is_qat: bool = False,
) -> fx.GraphModule:
"""Quantize the provided model.

:param model: Aten model (or it's GraphModule representation) to quantize.
:param calibration_inputs: Either a tuple of calibration input tensors where each element corresponds to a model
:param model: Aten exported model to quantize.
:param calibration_inputs: Either a tuple of calibration input tensors where each element corresponds to a module
input. Or an iterator over such tuples.
:param quantizer: Quantizer to use.
:param is_qat: Whether quantization is done using Quantization Aware Training (QAT) or not.
Expand All @@ -179,17 +182,21 @@ def calibrate_and_quantize(
:return: Quantized GraphModule.
"""

if isinstance(model, ExportedProgram):
model = model.module()
module = model.module()

if is_qat:
m = prepare_qat_pt2e(model, quantizer)
m = move_exported_model_to_eval(m)
module = prepare_qat_pt2e(module, quantizer)
module = move_exported_model_to_eval(module)
else:
m = prepare_pt2e(model, quantizer)
module = prepare_pt2e(module, quantizer)

for data in calibration_inputs:
m(*data)
m = convert_pt2e(m)
module(*data)
module = convert_pt2e(module)

# Without this export, conv bias is not in the graph_signature.
model = torch.export.export(module, calibration_inputs[0], strict=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you want to try to drop needing this?

bias_quant_pass = QuantizeFusedConvBnBiasPass(model)
model = bias_quant_pass(model.graph_module)

return m
return model.graph_module
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the change to use graph_module desired outcome or currently just needed for the pass as is?

54 changes: 53 additions & 1 deletion backends/nxp/tests/test_batch_norm_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import (
NeutronAtenPassManager,
)
from executorch.backends.nxp.backend.edge_program_converter import (
EdgeProgramToIRConverter,
)
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import (
AddMMConverter,
MMConverter,
Expand All @@ -22,8 +25,15 @@
neutron_target_spec,
to_quantized_edge_program,
)
from executorch.backends.nxp.tests.executors import OverrideTargetSupportCheck
from executorch.backends.nxp.tests.executors import (
convert_run_compare,
OverrideTargetSupportCheck,
ToChannelFirstPreprocess,
ToChannelLastPreprocess,
)
from executorch.backends.nxp.tests.models import ConvBNModule
from torch import nn
from torch.export import ExportedProgram


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -231,3 +241,45 @@ def unsupported_target(*_): # Accept all input arguments and return `False`.
node.op == "call_function" and "batch_norm" in node.target.__name__
for node in nodes
)


@pytest.mark.parametrize(
"conv_module",
["conv2d", "conv2d_t"],
)
def test_biasless_convbn_fusion_qat(
mocker,
conv_module,
):
if conv_module.startswith("conv1d"):
input_shape = (1, 3, 32)
elif conv_module.startswith("conv2d"):
input_shape = (1, 3, 32, 32)
else: # conv3d
input_shape = (1, 3, 32, 32, 32)

model = ConvBNModule(conv_module, conv_bias=False, bn_affine=True)

converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
edge_program = to_quantized_edge_program(
model, input_shape, use_qat=True, use_neutron_for_format_conversion=False
).exported_program()

assert any("lowered_module" in node.name for node in edge_program.graph.nodes)

# Capture generated model
tflite_flatbuffers_model, io_formats = converter_spy.spy_return

# Capture converted program
exported_program: ExportedProgram = converter_spy.call_args.args[1]

input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8)

convert_run_compare(
exported_program,
tflite_input_preprocess=ToChannelLastPreprocess(),
tfl_model=tflite_flatbuffers_model,
tflite_output_preprocess=ToChannelFirstPreprocess(),
input_data=input_data,
atol=1.0,
)
23 changes: 23 additions & 0 deletions backends/transforms/BUCK
Original file line number Diff line number Diff line change
@@ -1,6 +1,29 @@
load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target")
load("@fbcode_macros//build_defs:python_pytest.bzl", "python_pytest")
load(":targets.bzl", "define_common_targets")

oncall("executorch")

fbcode_target(_kind = define_common_targets,)

fbcode_target(_kind = python_pytest,
name = "test_quantize_fused_convbn_bias_pass",
srcs = [
"test/test_quantize_fused_convbn_bias_pass.py",
],
deps = [
"//caffe2:torch",
":quantize_fused_convbn_bias_pass",
"//executorch/backends/arm/quantizer:arm_quantizer",
"//executorch/backends/arm/test:arm_tester_lib",
"//executorch/backends/arm/test:arm_tester_serialize",
"//executorch/backends/arm/test:common",
"//executorch/backends/arm/tosa:tosa",
"//executorch/backends/nxp:quantizer",
"//executorch/backends/nxp:neutron_backend",
"//executorch/backends/xnnpack/test/tester:tester",
"//executorch/exir:lib",
"//executorch/kernels/quantized:custom_ops_generated_lib",
"fbsource//third-party/pypi/pytest:pytest",
],
)
Loading
Loading