From 8d3b99d479f629a0cee782d9be0bd497427b93b4 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 19 Feb 2026 11:48:54 -0800 Subject: [PATCH] [ET-VK][ez][qconv] Add auto-selection to prefer im2col for q8ta_conv2d MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The q8ta_conv2d operator previously always delegated to the general (sliding window) implementation, even though the im2col implementation is 2-5x faster for non-grouped convolutions with in_channels % 4 == 0. This change adds runtime auto-selection logic that checks the groups parameter and input channel alignment, then dispatches to q8ta_conv2d_im2col when its constraints are met. On ResNet50 int8, this reduces Vulkan inference latency from 14.2ms to 6.8ms (2.1x speedup) on Samsung Galaxy S24, making it 30% faster than XNNPACK (9.7ms). Also adds performance test cases for deep-channel small-spatial scenarios (512ch 7x7, 1024→2048ch 1x1 stride-2) that stress-test the optimization. Differential Revision: [D93768637](https://our.internmc.facebook.com/intern/diff/D93768637/) [ghstack-poisoned] --- .../runtime/graph/ops/impl/Q8taConv2d.cpp | 25 ++++++++++++++++++- .../test/custom_ops/test_q8ta_conv2d.cpp | 18 ++++++++++++- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp index 33b7005a845..8273df6a07e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp @@ -417,7 +417,30 @@ void q8ta_conv2d_general( } void q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { - q8ta_conv2d_general(graph, args); + const ValueRef input = args.at(0); + const ValueRef groups_ref = args.at(13); + const ValueRef output = args.at(15); + + const int64_t groups = graph.extract_scalar(groups_ref); + const int64_t in_channels = graph.size_at(-3, input); + const int64_t in_channels_per_group = in_channels / groups; + + const int64_t H_out = graph.size_at(-2, output); + const int64_t W_out = graph.size_at(-1, output); + const int64_t spatial_out = H_out * W_out; + + // Use im2col when the channel depth is sufficient for tiled GEMM to win, or + // when the output spatial area is small enough that the im2col buffer stays + // manageable. For large spatial outputs with few channels, the im2col buffer + // becomes too large and the general shader is more efficient. + const bool use_im2col = groups == 1 && in_channels_per_group % 4 == 0 && + (in_channels_per_group >= 64 || spatial_out <= 4096); + + if (use_im2col) { + q8ta_conv2d_im2col(graph, args); + } else { + q8ta_conv2d_general(graph, args); + } } REGISTER_OPERATORS { diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp index bc95cc724f5..41ddd389aa8 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp @@ -378,7 +378,23 @@ static std::vector generate_quantized_conv2d_test_cases() { Stride(2, 2), Padding(2, 2), Dilation(1, 1), - 4}}; + 4}, + // Deep channels + small spatial (ResNet50 stage 5 bottleneck) + {OutInChannels(512, 512), + InputSize2D(7, 7), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + // Strided 1x1 shortcut (worst-case strided downsample) + {OutInChannels(2048, 1024), + InputSize2D(14, 14), + KernelSize(1, 1), + Stride(2, 2), + Padding(0, 0), + Dilation(1, 1), + 1}}; // Test with different storage types and memory layouts std::vector fp_storage_types = {utils::kTexture3D};