Fix depthwise conv detection for groups==in_channels==1#17590
Fix depthwise conv detection for groups==in_channels==1#17590JakeStevens wants to merge 2 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17590
Note: Links to docs will display an error until the docs builds have been completed. ❌ 5 New Failures, 1 Unrelated FailureAs of commit 7c3bba6 with merge base d6e8ad1 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@JakeStevens has exported this pull request. If you are a Meta employee, you can view the originating Diff in D93869048. |
This PR needs a
|
Summary: The generic C++ kernel's depthwise NHWC entry points (`quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out` and the uint8 variant) were incorrectly delegating to `quantized_conv2d_nhwc`, which assumes a regular weight layout of [OC, KH, KW, IC]. Depthwise NHWC weights use a fundamentally different layout: [*kernel_size, OC] (i.e., [KH, KW, OC] for 2D, [K, OC] for 1D). This caused incorrect memory access patterns and dimension-out-of-range errors for 1D depthwise convolutions. This diff adds a dedicated `quantized_conv2d_nhwc_depthwise` function that correctly handles the depthwise weight layout, including conv1d support. It also adds a regression test (`test_quantized_conv1d_depthwise_nhwc_out`) that exercises the full NHWC depthwise pipeline at opt_level=4, which triggers both `ReplaceConvWithChannelLastConvPass` and `CompileTimeTypeDispatchPass`. Reviewed By: mcremon-meta Differential Revision: D93620973
Summary: GitHub PR 17528 broke `test_silero_vad_16k_quantized_opt3` by removing the `len(weight_shape) == 4` guard from the depthwise detection logic. This caused regular 1D convolutions with `in_channels == groups == 1` (e.g. Silero VAD's learned STFT conv) to be misclassified as depthwise. This diff fixes the bug and centralizes the depthwise check into a single `is_depthwise_conv(groups, in_channels)` utility in `utils.py` that enforces `groups > 1 and groups == in_channels`. All 7 depthwise detection sites across 4 files now use this shared function: - **utils.py**: Added `is_depthwise_conv()` — the single source of truth. - **ops_registrations.py**: 4 meta functions updated (quantized_conv2d_nhwc, per_tensor, asym8s, asym8u). - **ref_implementations.py**: Updated `quantized_conv2d_nhwc_per_tensor`. - **replace_ops.py**: Updated `ReplaceConvWithChannelLastConvPass`. - **type_dispatch.py**: Updated `CompileTimeTypeDispatchPass` (also fixed a pre-existing bug where `groups > 1` was missing entirely). - **test_ref_implementations.py**: Fixed the test harness to squeeze the IC=1 dim for depthwise weights when converting to channels_last format, matching the actual `replace_ops.py` pipeline behavior. - **test_replace_ops_passes.py**: Added regression test for the `in_channels==groups==1` case and a positive test for 1D depthwise. Reviewed By: mcremon-meta Differential Revision: D93869048
b10d817 to
7c3bba6
Compare
Summary:
GitHub PR 17528 broke
test_silero_vad_16k_quantized_opt3by removing thelen(weight_shape) == 4guard from the depthwise detection logic. Thiscaused regular 1D convolutions with
in_channels == groups == 1(e.g.Silero VAD's learned STFT conv) to be misclassified as depthwise.
This diff fixes the bug and centralizes the depthwise check into a single
is_depthwise_conv(groups, in_channels)utility inutils.pythatenforces
groups > 1 and groups == in_channels. All 7 depthwisedetection sites across 4 files now use this shared function:
is_depthwise_conv()— the single source of truth.per_tensor, asym8s, asym8u).
quantized_conv2d_nhwc_per_tensor.ReplaceConvWithChannelLastConvPass.CompileTimeTypeDispatchPass(also fixed apre-existing bug where
groups > 1was missing entirely).IC=1 dim for depthwise weights when converting to channels_last format,
matching the actual
replace_ops.pypipeline behavior.in_channels==groups==1case and a positive test for 1D depthwise.Differential Revision: D93869048