diff --git a/backends/vulkan/patterns/quantized_convolution.py b/backends/vulkan/patterns/quantized_convolution.py index 93140e15341..9a6fb69bf87 100644 --- a/backends/vulkan/patterns/quantized_convolution.py +++ b/backends/vulkan/patterns/quantized_convolution.py @@ -226,6 +226,16 @@ def make_q8ta_conv2d_custom_op( sum_per_output_channel = ( weight_tensor.sum(dim=1).to(torch.int32).contiguous() ) + # Pad weight sums to align OC to multiple of 4, matching the alignment + # applied to weight, weight_scales, and bias above. Without this, the + # GPU shader would read out-of-bounds when OC is not a multiple of 4. + oc = sum_per_output_channel.shape[0] + if oc % 4 != 0: + num_padding = 4 - (oc % 4) + sum_per_output_channel = torch.nn.functional.pad( + sum_per_output_channel, (0, num_padding) + ).contiguous() + sums_name = qweight_tensor_name + "_sums" # Sanitize the name sums_name = sums_name.replace(".", "_") diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp index 51095c649b6..6ce6671ec84 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp @@ -210,6 +210,28 @@ static std::vector generate_quantized_conv2d_pw_test_cases() { } std::vector configs = { + // OC < 4 cases to test edge cases with partial output channel blocks + {OutInChannels(1, 16), + InputSize2D(8, 8), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + {OutInChannels(2, 16), + InputSize2D(8, 8), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + {OutInChannels(3, 16), + InputSize2D(8, 8), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, // Pointwise convolutions: kernel size 1x1 {OutInChannels(32, 3), InputSize2D(64, 64), diff --git a/backends/vulkan/test/custom_ops/utils.cpp b/backends/vulkan/test/custom_ops/utils.cpp index b23c288a58f..2a50e7b5ec1 100644 --- a/backends/vulkan/test/custom_ops/utils.cpp +++ b/backends/vulkan/test/custom_ops/utils.cpp @@ -2064,7 +2064,11 @@ void compute_weight_sums( auto& weight_sums_data = weight_sums.get_int32_data(); auto& quantized_weight_data = quantized_weight.get_int8_data(); - weight_sums_data.resize(out_features); + // Don't resize down - the buffer may be pre-allocated with aligned size. + // Only resize up if needed. + if (weight_sums_data.size() < static_cast(out_features)) { + weight_sums_data.resize(out_features); + } // For each output feature, compute the sum of quantized weights for (int64_t out_f = 0; out_f < out_features; ++out_f) {