Skip to content
25 changes: 24 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,30 @@ void q8ta_conv2d_general(
}

void q8ta_conv2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
q8ta_conv2d_general(graph, args);
const ValueRef input = args.at(0);
const ValueRef groups_ref = args.at(13);
const ValueRef output = args.at(15);

const int64_t groups = graph.extract_scalar<int64_t>(groups_ref);
const int64_t in_channels = graph.size_at<int64_t>(-3, input);
const int64_t in_channels_per_group = in_channels / groups;

const int64_t H_out = graph.size_at<int64_t>(-2, output);
const int64_t W_out = graph.size_at<int64_t>(-1, output);
const int64_t spatial_out = H_out * W_out;

// Use im2col when the channel depth is sufficient for tiled GEMM to win, or
// when the output spatial area is small enough that the im2col buffer stays
// manageable. For large spatial outputs with few channels, the im2col buffer
// becomes too large and the general shader is more efficient.
const bool use_im2col = groups == 1 && in_channels_per_group % 4 == 0 &&
(in_channels_per_group >= 64 || spatial_out <= 4096);

if (use_im2col) {
q8ta_conv2d_im2col(graph, args);
} else {
q8ta_conv2d_general(graph, args);
}
}

REGISTER_OPERATORS {
Expand Down
18 changes: 17 additions & 1 deletion backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,23 @@ static std::vector<TestCase> generate_quantized_conv2d_test_cases() {
Stride(2, 2),
Padding(2, 2),
Dilation(1, 1),
4}};
4},
// Deep channels + small spatial (ResNet50 stage 5 bottleneck)
{OutInChannels(512, 512),
InputSize2D(7, 7),
KernelSize(3, 3),
Stride(1, 1),
Padding(1, 1),
Dilation(1, 1),
1},
// Strided 1x1 shortcut (worst-case strided downsample)
{OutInChannels(2048, 1024),
InputSize2D(14, 14),
KernelSize(1, 1),
Stride(2, 2),
Padding(0, 0),
Dilation(1, 1),
1}};

// Test with different storage types and memory layouts
std::vector<utils::StorageType> fp_storage_types = {utils::kTexture3D};
Expand Down
Loading