From c8b4e5d4f7adf4d1f7ce0b895947a2215e74dd3a Mon Sep 17 00:00:00 2001 From: jrstevens Date: Sun, 22 Feb 2026 06:45:38 -0800 Subject: [PATCH 1/2] Fix generic kernel depthwise NHWC conv and add tests Summary: The generic C++ kernel's depthwise NHWC entry points (`quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out` and the uint8 variant) were incorrectly delegating to `quantized_conv2d_nhwc`, which assumes a regular weight layout of [OC, KH, KW, IC]. Depthwise NHWC weights use a fundamentally different layout: [*kernel_size, OC] (i.e., [KH, KW, OC] for 2D, [K, OC] for 1D). This caused incorrect memory access patterns and dimension-out-of-range errors for 1D depthwise convolutions. This diff adds a dedicated `quantized_conv2d_nhwc_depthwise` function that correctly handles the depthwise weight layout, including conv1d support. It also adds a regression test (`test_quantized_conv1d_depthwise_nhwc_out`) that exercises the full NHWC depthwise pipeline at opt_level=4, which triggers both `ReplaceConvWithChannelLastConvPass` and `CompileTimeTypeDispatchPass`. Debug print statements from prior development are also cleaned up. Differential Revision: D93620973 --- .../generic/operators/op_quantized_conv2d.cpp | 99 ++++++++++++++++++- 1 file changed, 97 insertions(+), 2 deletions(-) diff --git a/backends/cadence/generic/operators/op_quantized_conv2d.cpp b/backends/cadence/generic/operators/op_quantized_conv2d.cpp index ca66daf776d..2cb953949ca 100644 --- a/backends/cadence/generic/operators/op_quantized_conv2d.cpp +++ b/backends/cadence/generic/operators/op_quantized_conv2d.cpp @@ -392,6 +392,101 @@ void quantized_conv2d_nchw( #undef typed_quantized_conv2d_nchw } +// Depthwise NHWC convolution. +// Weight layout is [*kernel_size, OC]: +// 2D: [KH, KW, OC] (3D tensor) +// 1D: [K, OC] (2D tensor) +// This differs from regular NHWC conv where weight is [OC, KH, KW, IC]. +void quantized_conv2d_nhwc_depthwise( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + const bool conv1d = input.dim() == 3; + + // input NHWC: [N, H, W, C] or [N, W, C] for 1D + const int n = static_cast(input.size(0)); + const int h = static_cast(conv1d ? 1 : input.size(1)); + const int w = static_cast(conv1d ? input.size(1) : input.size(2)); + const int c = static_cast(conv1d ? input.size(2) : input.size(3)); + + // Depthwise weight: [KH, KW, OC] or [K, OC] for 1D + const int kh = conv1d ? 1 : static_cast(weight.size(0)); + const int kw = conv1d ? static_cast(weight.size(0)) + : static_cast(weight.size(1)); + const int oc = conv1d ? static_cast(weight.size(1)) + : static_cast(weight.size(2)); + + // output NHWC: [N, OH, OW, OC] or [N, OW, OC] for 1D + const int oh = static_cast(conv1d ? 1 : out.size(1)); + const int ow = static_cast(conv1d ? out.size(1) : out.size(2)); + + const float inv_out_scale = 1.f / output_scale; + + // Depthwise: each output channel depends on exactly one input channel. + // ocpg = oc / groups output channels per group. + const int ocpg = oc / groups; + +#define typed_quantized_conv2d_nhwc_depthwise(ctype, dtype) \ + case ScalarType::dtype: { \ + const auto* p_in = input.const_data_ptr(); \ + const auto* p_weight = weight.const_data_ptr(); \ + const auto* p_bias = bias.const_data_ptr(); \ + auto* p_out = out.mutable_data_ptr(); \ + for (int _n = 0; _n < n; ++_n) { \ + const ctype* in_batch = p_in + _n * h * w * c; \ + ctype* out_batch = p_out + _n * oh * ow * oc; \ + for (int _oh = 0; _oh < oh; ++_oh) { \ + for (int _ow = 0; _ow < ow; ++_ow) { \ + ctype* out_pixel = out_batch + (_oh * ow + _ow) * oc; \ + for (int _g = 0; _g < groups; ++_g) { \ + int soc = _g * ocpg; \ + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { \ + float acc = p_bias[_oc]; \ + for (int _kh = 0; _kh < kh; ++_kh) { \ + for (int _kw = 0; _kw < kw; ++_kw) { \ + int ih = _oh * stride[0] + _kh * dilation[0] - padding[0]; \ + int iw = _ow * stride[1] + _kw * dilation[1] - padding[1]; \ + if (ih >= 0 && ih < h && iw >= 0 && iw < w) { \ + float lhs = \ + in_batch[ih * w * c + iw * c + _g] - in_zero_point; \ + float rhs = p_weight[_kh * kw * oc + _kw * oc + _oc] - \ + weight_zero_point; \ + acc += lhs * rhs; \ + } \ + } \ + } \ + float val = bias_scale * acc; \ + out_pixel[_oc] = quantize( \ + val, inv_out_scale, (ctype)output_zero_point); \ + } \ + } \ + } \ + } \ + } \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv2d_nhwc_depthwise); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_conv2d_nhwc_depthwise +} + void quantized_conv2d_nhwc( const Tensor& input, const Tensor& weight, @@ -928,7 +1023,7 @@ Tensor& quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out( ET_UNUSED int64_t out_multiplier, ET_UNUSED int64_t out_shift, Tensor& out) { - quantized_conv2d_nhwc( + quantized_conv2d_nhwc_depthwise( input, weight, bias, @@ -962,7 +1057,7 @@ Tensor& quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out( ET_UNUSED int64_t out_multiplier, ET_UNUSED int64_t out_shift, Tensor& out) { - quantized_conv2d_nhwc( + quantized_conv2d_nhwc_depthwise( input, weight, bias, From a6018b1a734b3ce2278bfe0c34803b8976efdb31 Mon Sep 17 00:00:00 2001 From: Jake Stevens Date: Sun, 22 Feb 2026 06:49:50 -0800 Subject: [PATCH 2/2] Fix depthwise conv detection for groups==in_channels==1 (#17590) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/17590 GitHub PR 17528 broke `test_silero_vad_16k_quantized_opt3` by removing the `len(weight_shape) == 4` guard from the depthwise detection logic. This caused regular 1D convolutions with `in_channels == groups == 1` (e.g. Silero VAD's learned STFT conv) to be misclassified as depthwise. This diff fixes the bug and centralizes the depthwise check into a single `is_depthwise_conv(groups, in_channels)` utility in `utils.py` that enforces `groups > 1 and groups == in_channels`. All 7 depthwise detection sites across 4 files now use this shared function: - **utils.py**: Added `is_depthwise_conv()` — the single source of truth. - **ops_registrations.py**: 4 meta functions updated (quantized_conv2d_nhwc, per_tensor, asym8s, asym8u). - **ref_implementations.py**: Updated `quantized_conv2d_nhwc_per_tensor`. - **replace_ops.py**: Updated `ReplaceConvWithChannelLastConvPass`. - **type_dispatch.py**: Updated `CompileTimeTypeDispatchPass` (also fixed a pre-existing bug where `groups > 1` was missing entirely). - **test_ref_implementations.py**: Fixed the test harness to squeeze the IC=1 dim for depthwise weights when converting to channels_last format, matching the actual `replace_ops.py` pipeline behavior. - **test_replace_ops_passes.py**: Added regression test for the `in_channels==groups==1` case and a positive test for 1D depthwise. Reviewed By: mcremon-meta Differential Revision: D93869048 --- backends/cadence/aot/BUCK | 2 + backends/cadence/aot/ops_registrations.py | 30 +-- backends/cadence/aot/ref_implementations.py | 12 +- backends/cadence/aot/replace_ops.py | 9 +- .../aot/tests/test_ref_implementations.py | 15 +- .../aot/tests/test_replace_ops_passes.py | 230 ++++++++++++++++++ backends/cadence/aot/type_dispatch.py | 7 +- backends/cadence/aot/utils.py | 9 + 8 files changed, 276 insertions(+), 38 deletions(-) diff --git a/backends/cadence/aot/BUCK b/backends/cadence/aot/BUCK index 4a37f945f89..98b3482c238 100644 --- a/backends/cadence/aot/BUCK +++ b/backends/cadence/aot/BUCK @@ -132,6 +132,7 @@ fbcode_target(_kind = runtime.python_library, typing = True, deps = [ "fbcode//caffe2:torch", + "fbcode//executorch/backends/cadence/aot:utils", "fbcode//executorch/exir:scalar_type", "fbcode//executorch/kernels/quantized:custom_ops_generated_lib", ], @@ -374,6 +375,7 @@ fbcode_target(_kind = runtime.python_library, deps = [ "//caffe2:torch", "//executorch/backends/cadence/aot:pass_utils", + "//executorch/backends/cadence/aot:utils", "//executorch/exir:pass_base", ], ) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index b6d92b25e13..8d9d073f2ed 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -14,6 +14,7 @@ get_conv1d_output_size, get_conv2d_output_size, get_im2row_output_size, + is_depthwise_conv, ) from executorch.exir.scalar_type import ScalarType from torch._meta_registrations import _linalg_svd_meta @@ -1034,11 +1035,8 @@ def quantized_conv2d_nhwc_meta( assert len(in_size) < 6 # Determine weight layout based on depthwise vs regular conv. - # Depthwise is defined by in_channels == groups, where in_channels - # is the last dim of the NHWC input. in_channels = in_size[-1] - is_depthwise = in_channels == groups - if is_depthwise: + if is_depthwise_conv(groups, in_channels): # Depthwise conv: weight is [*kernel_size, OC] *kernel_size, out_channels = weight.shape else: @@ -1177,12 +1175,8 @@ def quantized_conv2d_nhwc_per_tensor_meta( assert len(in_size) < 6 # Determine weight layout based on depthwise vs regular conv. - # Depthwise is defined by in_channels == groups, where in_channels - # is the last dim of the NHWC input. in_channels = in_size[-1] - is_depthwise = in_channels == groups - if is_depthwise: - # Depthwise conv: weight is [*kernel_size, OC] + if is_depthwise_conv(groups, in_channels): *kernel_size, out_channels = weight.shape elif len(in_size) == 3: # 1D conv: weight is [OC, K, IC] @@ -1336,12 +1330,9 @@ def quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_meta( assert len(in_size) > 2 assert len(in_size) < 6 - # Determine weight layout based on input and weight dimensions: - # - Depthwise conv: input is 3D/4D, weight is 2/3D [K, OC]/[KH, KW, OC] - # - 1D conv: input is 3D, weight is 3D [OC, K, IC] - # - 2D regular conv: input is 4D, weight is 4D [OC, KH, KW, IC] - if len(weight.shape) == 3: - # 2D depthwise conv: weight is [KH, KW, OC] + # Determine weight layout based on depthwise vs regular conv. + in_channels = in_size[-1] + if is_depthwise_conv(groups, in_channels): *kernel_size, out_channels = weight.shape elif len(in_size) == 3: # 1D conv: weight is [OC, K, IC] @@ -1397,12 +1388,9 @@ def quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor_meta( assert len(in_size) > 2 assert len(in_size) < 6 - # Determine weight layout based on input and weight dimensions: - # - Depthwise conv: input is 3D/4D, weight is 3D [KH, KW, OC] - # - 1D conv: input is 3D, weight is 3D [OC, K, IC] - # - 2D regular conv: input is 4D, weight is 4D [OC, KH, KW, IC] - if len(weight.shape) == 3: - # 2D depthwise conv: weight is [KH, KW, OC] + # Determine weight layout based on depthwise vs regular conv. + in_channels = in_size[-1] + if is_depthwise_conv(groups, in_channels): *kernel_size, out_channels = weight.shape elif len(in_size) == 3: # 1D conv: weight is [OC, K, IC] diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index e5b2de0efce..8f339fffe2d 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -12,6 +12,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from executorch.backends.cadence.aot.utils import is_depthwise_conv from executorch.exir.scalar_type import ScalarType from torch.library import impl, Library @@ -1104,17 +1105,12 @@ def quantized_conv2d_nhwc_per_tensor( # Convert to NCHW format to reuse the existing implementation in_channels = input_tensor.shape[-1] - # Depthwise weights have one fewer dimension than the input because the IC - # dimension (always 1) was squeezed out during the NCHW->NHWC conversion in - # replace_ops.py. E.g. 2D depthwise: weight is [KH, KW, OC] (3D) while - # input is [N, H, W, C] (4D). A regular conv with in_channels==groups==1 - # still has 4D weights [OC, KH, KW, IC]. - is_depthwise = in_channels == groups and weight.dim() < input_tensor.dim() + depthwise = is_depthwise_conv(groups, in_channels) if len(input_tensor.shape) == 3: # 1D conv: input is [N, L, C] -> [N, C, L] input_tensor = input_tensor.movedim(-1, 1).contiguous() - if is_depthwise: + if depthwise: # 1D depthwise: weight is [K, OC] -> [OC, 1, K] weight = weight.permute(1, 0).unsqueeze(1).contiguous() else: @@ -1124,7 +1120,7 @@ def quantized_conv2d_nhwc_per_tensor( else: # 2D conv: input is [N, H, W, C] -> [N, C, H, W] input_tensor = input_tensor.movedim(-1, -3) - if is_depthwise: + if depthwise: # 2D depthwise: weight is [KH, KW, OC] -> [OC, 1, KH, KW] weight = weight.permute(2, 0, 1).unsqueeze(1).contiguous() else: diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index eec0919fb10..f563a331975 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -26,6 +26,7 @@ register_cadence_pass, RemoveOrReplacePassInterface, ) +from executorch.backends.cadence.aot.utils import is_depthwise_conv from executorch.backends.transforms.replace_scalar_with_tensor import ( ReplaceScalarWithTensorArgPass, ) @@ -1138,19 +1139,19 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Check if this is a depthwise convolution (groups == input_channels) # and weight is 4D with shape [OC, 1, KH, KW] - groups = node.args[6] + groups = cast(int, node.args[6]) input_shape = input_node.meta["val"].shape weight_shape = weight_node.meta["val"].shape input_channels = input_shape[1] # NCHW format, channels at index 1 - # Depthwise conv has 4D weight [OC, 1, KH, KW] where the IC dim is 1 - is_depthwise = groups == input_channels and weight_shape[1] == 1 + # NCHW: also verify weight IC dim == 1. + depthwise = is_depthwise_conv(groups, input_channels) and weight_shape[1] == 1 is_2d = len(input_shape) == 4 # Insert transpose operations before the node with graph.inserting_before(node): # Convert input from NCHW to NHWC input_nhwc = self._change_nchw_to_nhwc(graph, input_node) # Convert weight from NCHW to the appropriate format - if is_depthwise: + if depthwise: # For depthwise: [OC, 1, KH, KW] -> [KH, KW, OC] for NNLib weight_nhwc = self._change_depthwise_weight_to_hwc( graph, weight_node, is_2d diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index e5c16fa22ea..bf9e4d39250 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -14,6 +14,7 @@ import numpy as np import torch from executorch.backends.cadence.aot.typing_stubs import expand +from executorch.backends.cadence.aot.utils import is_depthwise_conv from executorch.exir.scalar_type import ScalarType @@ -942,12 +943,22 @@ def test_quantized_conv_per_tensor( assert memory_format in [torch.contiguous_format, torch.channels_last] if memory_format == torch.channels_last: + in_channels = input_tensor.shape[1] # NCHW still at this point + depthwise = is_depthwise_conv(groups, in_channels) if input_tensor.ndim == 3: input_tensor = input_tensor.movedim(1, -1) - weight = weight.movedim(1, -1) + if depthwise: + # [OC, 1, K] -> [K, OC] (squeeze IC, move OC to end) + weight = weight.squeeze(1).movedim(0, -1) + else: + weight = weight.movedim(1, -1) else: input_tensor = input_tensor.movedim(-3, -1) - weight = weight.movedim(-3, -1) + if depthwise: + # [OC, 1, KH, KW] -> [KH, KW, OC] (squeeze IC, move OC to end) + weight = weight.squeeze(1).movedim(0, -1) + else: + weight = weight.movedim(-3, -1) convs = [ ( diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 30eb516dd5d..95d470644a0 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -2159,6 +2159,236 @@ def test_depthwise_convolution_weight_shape(self) -> None: "ReplaceConvWithChannelLastConvPass", ) + def create_1d_conv_with_single_input_channel_graph_module( + self, + ) -> Tuple[Tuple[torch.Tensor, ...], torch.fx.GraphModule]: + """Helper to create a regular 1D quantized conv with in_channels=1, groups=1. + + This configuration (groups == in_channels == 1) must NOT be classified as + depthwise. It is a regular convolution that happens to have a single input + channel, e.g. Silero VAD's learned STFT conv. + + Input shape: [N, C, L] = [1, 1, 576] (NCHW) + Weight shape: [OC, IC, K] = [258, 1, 256] + """ + in_channels = 1 + out_channels = 258 + kernel_size = 256 + x = torch.randint(0, 100, (1, in_channels, 576), dtype=torch.int32) + w = torch.randint( + 0, 100, (out_channels, in_channels, kernel_size), dtype=torch.int32 + ) + b = torch.randn(out_channels) + stride = (1, 1) + padding = (0, 0) + dilation = (1, 1) + groups = 1 + input_zero_point = 0 + w_zero_point = 0 + b_scale = 10 + out_scale = 1 + out_zero_point = 0 + out_multiplier = 5 + out_shift = 5 + args = ( + x, + w, + b, + stride, + padding, + dilation, + groups, + input_zero_point, + w_zero_point, + b_scale, + out_scale, + out_zero_point, + out_multiplier, + out_shift, + ) + placeholders = (x, w, b) + gm = single_op_builder( + placeholders=placeholders, + op=exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, + args=args, + ) + return placeholders, gm + + def test_1d_conv_single_input_channel_not_depthwise(self) -> None: + """Test that a regular 1D conv with in_channels=1 is NOT treated as depthwise. + + Regression test: when groups == in_channels == 1, the conv is regular (not + depthwise). The weight should be converted via the regular NHWC path + [OC, IC, K] -> [OC, K, IC], NOT the depthwise path [OC, 1, K] -> [K, OC]. + """ + placeholders, gm = self.create_1d_conv_with_single_input_channel_graph_module() + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor), 1 + ) + + p = ReplaceConvWithChannelLastConvPass() + gm_after_replacement = p.call(gm).graph_module + + self.assertEqual( + count_node( + gm_after_replacement, + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor, + ), + 1, + ) + + # For 1D conv, the pass uses transpose_copy.int (not permute_copy) + # because _change_nchw_to_nhwc calls _transpose_dims for 3D tensors. + # 3 transpose_copy ops: input, weight, output. NO squeeze_copy. + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.aten.transpose_copy.int), + 3, + ) + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.aten.squeeze_copy.dim), + 0, + "Regular conv with in_channels=1 must NOT have squeeze_copy (not depthwise)", + ) + + # Verify weight shape is 3D [OC, K, IC] (regular NHWC), not 2D [K, OC] (depthwise) + for node in gm_after_replacement.graph.nodes: + if node.target != exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor: + continue + weight_node = node.args[1] + weight_shape = weight_node.meta["val"].shape + self.assertEqual( + len(weight_shape), + 3, + f"Regular 1D conv weight should be 3D [OC, K, IC], got {len(weight_shape)}D", + ) + # Original weight: [258, 1, 256] (OC, IC, K) + # Expected after regular NHWC transform: [258, 256, 1] (OC, K, IC) + self.assertEqual(weight_shape[0], 258) # OC + self.assertEqual(weight_shape[1], 256) # K + self.assertEqual(weight_shape[2], 1) # IC + + validate( + gm, + gm_after_replacement, + placeholders, + "ReplaceConvWithChannelLastConvPass", + ) + + def create_1d_depthwise_convolution_graph_module( + self, + ) -> Tuple[Tuple[torch.Tensor, ...], torch.fx.GraphModule]: + """Helper to create a 1D depthwise convolution node. + + For depthwise convolution, groups == input_channels > 1. + Input shape: [N, C, L] = [1, 8, 64] (NCHW) + Weight shape: [OC, 1, K] = [8, 1, 3] + """ + in_channels = 8 + out_channels = 8 + kernel_size = 3 + x = torch.randint(0, 100, (1, in_channels, 64), dtype=torch.int32) + w = torch.randint(0, 100, (out_channels, 1, kernel_size), dtype=torch.int32) + b = torch.randn(out_channels) + stride = (1, 1) + padding = (1, 1) + dilation = (1, 1) + groups = in_channels + input_zero_point = 0 + w_zero_point = 0 + b_scale = 10 + out_scale = 1 + out_zero_point = 0 + out_multiplier = 5 + out_shift = 5 + args = ( + x, + w, + b, + stride, + padding, + dilation, + groups, + input_zero_point, + w_zero_point, + b_scale, + out_scale, + out_zero_point, + out_multiplier, + out_shift, + ) + placeholders = (x, w, b) + gm = single_op_builder( + placeholders=placeholders, + op=exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, + args=args, + ) + return placeholders, gm + + def test_1d_depthwise_convolution_weight_shape(self) -> None: + """Test that 1D depthwise conv weight is transformed to [K, OC] format. + + For 1D depthwise conv with groups == in_channels > 1, the weight should be + transformed from [OC, 1, K] to [K, OC] (2D) via permute_copy + squeeze_copy. + """ + placeholders, gm = self.create_1d_depthwise_convolution_graph_module() + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor), 1 + ) + + p = ReplaceConvWithChannelLastConvPass() + gm_after_replacement = p.call(gm).graph_module + + self.assertEqual( + count_node( + gm_after_replacement, + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor, + ), + 1, + ) + + # For 1D depthwise: + # - Input/output: transpose_copy.int (2 ops, for 3D NCHW<->NHWC) + # - Weight: permute_copy.default + squeeze_copy.dim (depthwise layout) + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.aten.transpose_copy.int), + 2, + ) + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), + 1, + ) + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.aten.squeeze_copy.dim), + 1, + ) + + for node in gm_after_replacement.graph.nodes: + if node.target != exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor: + continue + weight_node = node.args[1] + self.assertEqual( + weight_node.target, + exir_ops.edge.aten.squeeze_copy.dim, + "1D depthwise conv weight should be processed by squeeze_copy", + ) + weight_shape = weight_node.meta["val"].shape + self.assertEqual( + len(weight_shape), + 2, + f"1D depthwise weight should be 2D [K, OC], got {len(weight_shape)}D", + ) + # Original weight: [8, 1, 3] (OC, 1, K) + # Expected after depthwise transform: [3, 8] (K, OC) + self.assertEqual(weight_shape[0], 3) # K + self.assertEqual(weight_shape[1], 8) # OC + + validate( + gm, + gm_after_replacement, + placeholders, + "ReplaceConvWithChannelLastConvPass", + ) + def create_slice_graph( self, input_shape: Sequence[int], diff --git a/backends/cadence/aot/type_dispatch.py b/backends/cadence/aot/type_dispatch.py index 37f753767e9..69fd721e4e3 100644 --- a/backends/cadence/aot/type_dispatch.py +++ b/backends/cadence/aot/type_dispatch.py @@ -7,13 +7,14 @@ # pyre-strict from dataclasses import dataclass -from typing import Optional +from typing import cast, Optional import torch from executorch.backends.cadence.aot.pass_utils import ( CadencePassAttribute, register_cadence_pass, ) +from executorch.backends.cadence.aot.utils import is_depthwise_conv from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue from torch._ops import OpOverload @@ -161,13 +162,13 @@ def call_operator( exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor, ]: - groups = args[6] + groups = cast(int, args[6]) input_channels = ( args[0].to_tensor().shape[1] if op == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor else args[0].to_tensor().shape[-1] ) - is_depthwise = groups == input_channels + is_depthwise = is_depthwise_conv(groups, input_channels) # pyre-ignore[16]: None has no attribute '__iter__'. is_dilated = any(d > 1 for d in args[5]) is_1d = len(args[0].to_tensor().shape) == 3 diff --git a/backends/cadence/aot/utils.py b/backends/cadence/aot/utils.py index b711d45994b..7a87f8cea0c 100644 --- a/backends/cadence/aot/utils.py +++ b/backends/cadence/aot/utils.py @@ -49,6 +49,15 @@ class ISSRuntimeFailure(Exception): pass +def is_depthwise_conv(groups: int, in_channels: int) -> bool: + """Check whether a convolution is depthwise. + + Depthwise convolution has groups == in_channels with groups > 1. + When groups == 1, it is always a regular convolution even if in_channels == 1. + """ + return groups > 1 and groups == in_channels + + # Get the output size of a 1D convolution given the input size and parameters def get_conv1d_output_size( in_size: torch.Size,