diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp index 33b7005a845..8273df6a07e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp @@ -417,7 +417,30 @@ void q8ta_conv2d_general( } void q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { - q8ta_conv2d_general(graph, args); + const ValueRef input = args.at(0); + const ValueRef groups_ref = args.at(13); + const ValueRef output = args.at(15); + + const int64_t groups = graph.extract_scalar(groups_ref); + const int64_t in_channels = graph.size_at(-3, input); + const int64_t in_channels_per_group = in_channels / groups; + + const int64_t H_out = graph.size_at(-2, output); + const int64_t W_out = graph.size_at(-1, output); + const int64_t spatial_out = H_out * W_out; + + // Use im2col when the channel depth is sufficient for tiled GEMM to win, or + // when the output spatial area is small enough that the im2col buffer stays + // manageable. For large spatial outputs with few channels, the im2col buffer + // becomes too large and the general shader is more efficient. + const bool use_im2col = groups == 1 && in_channels_per_group % 4 == 0 && + (in_channels_per_group >= 64 || spatial_out <= 4096); + + if (use_im2col) { + q8ta_conv2d_im2col(graph, args); + } else { + q8ta_conv2d_general(graph, args); + } } REGISTER_OPERATORS { diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp index bc95cc724f5..41ddd389aa8 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp @@ -378,7 +378,23 @@ static std::vector generate_quantized_conv2d_test_cases() { Stride(2, 2), Padding(2, 2), Dilation(1, 1), - 4}}; + 4}, + // Deep channels + small spatial (ResNet50 stage 5 bottleneck) + {OutInChannels(512, 512), + InputSize2D(7, 7), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + // Strided 1x1 shortcut (worst-case strided downsample) + {OutInChannels(2048, 1024), + InputSize2D(14, 14), + KernelSize(1, 1), + Stride(2, 2), + Padding(0, 0), + Dilation(1, 1), + 1}}; // Test with different storage types and memory layouts std::vector fp_storage_types = {utils::kTexture3D};