Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions backends/vulkan/patterns/quantized_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,16 @@ def make_q8ta_conv2d_custom_op(
sum_per_output_channel = (
weight_tensor.sum(dim=1).to(torch.int32).contiguous()
)
# Pad weight sums to align OC to multiple of 4, matching the alignment
# applied to weight, weight_scales, and bias above. Without this, the
# GPU shader would read out-of-bounds when OC is not a multiple of 4.
oc = sum_per_output_channel.shape[0]
if oc % 4 != 0:
num_padding = 4 - (oc % 4)
sum_per_output_channel = torch.nn.functional.pad(
sum_per_output_channel, (0, num_padding)
).contiguous()

sums_name = qweight_tensor_name + "_sums"
# Sanitize the name
sums_name = sums_name.replace(".", "_")
Expand Down
22 changes: 22 additions & 0 deletions backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,28 @@ static std::vector<TestCase> generate_quantized_conv2d_pw_test_cases() {
}

std::vector<Conv2dConfig> configs = {
// OC < 4 cases to test edge cases with partial output channel blocks
{OutInChannels(1, 16),
InputSize2D(8, 8),
KernelSize(1, 1),
Stride(1, 1),
Padding(0, 0),
Dilation(1, 1),
1},
{OutInChannels(2, 16),
InputSize2D(8, 8),
KernelSize(1, 1),
Stride(1, 1),
Padding(0, 0),
Dilation(1, 1),
1},
{OutInChannels(3, 16),
InputSize2D(8, 8),
KernelSize(1, 1),
Stride(1, 1),
Padding(0, 0),
Dilation(1, 1),
1},
// Pointwise convolutions: kernel size 1x1
{OutInChannels(32, 3),
InputSize2D(64, 64),
Expand Down
6 changes: 5 additions & 1 deletion backends/vulkan/test/custom_ops/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2064,7 +2064,11 @@ void compute_weight_sums(
auto& weight_sums_data = weight_sums.get_int32_data();
auto& quantized_weight_data = quantized_weight.get_int8_data();

weight_sums_data.resize(out_features);
// Don't resize down - the buffer may be pre-allocated with aligned size.
// Only resize up if needed.
if (weight_sums_data.size() < static_cast<size_t>(out_features)) {
weight_sums_data.resize(out_features);
}

// For each output feature, compute the sum of quantized weights
for (int64_t out_f = 0; out_f < out_features; ++out_f) {
Expand Down
Loading