From a1fde3cd967c9fde0aa057bc7e944a18b69411e7 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 19 Feb 2026 11:48:46 -0800 Subject: [PATCH] [ET-VK][q8ta] Add q8ta_linear_gemv op for batch-1 int8 linear Add a cooperative GEMV variant of q8ta_linear optimized for batch size 1. The existing q8ta_linear uses a tiled algorithm with 4H4W packed int8 layout, which is inefficient for single-row inputs because it wastes 3/4 of each ivec4 block. The new q8ta_linear_gemv uses 4W packed int8 layout (scalar int[] buffers) and a cooperative algorithm where 64 threads split the K reduction dimension with shared memory tree reduction. The shader loads one packed int32 (4 int8 values) per thread per K iteration and accumulates dot products against the weight tile using dotPacked4x8AccSatEXT. After reduction, thread 0 applies scales, zero points, bias, and quantizes the output. The pattern matcher in quantized_linear.py selects q8ta_linear_gemv when the input batch dimension is 1, falling back to q8ta_linear for larger batches. Also adds PACKED_INT8_4W (value 5) to the serialization schema to support the 4W memory layout in the export pipeline. Authored with Claude. Differential Revision: [D93768643](https://our.internmc.facebook.com/intern/diff/D93768643/) [ghstack-poisoned] --- backends/vulkan/custom_ops_lib.py | 60 ++++++ backends/vulkan/op_registry.py | 22 ++ backends/vulkan/patterns/quantized_linear.py | 10 +- .../graph/ops/glsl/q8ta_linear_gemv.glsl | 164 ++++++++++++++ .../graph/ops/glsl/q8ta_linear_gemv.yaml | 18 ++ .../runtime/graph/ops/impl/Q8taLinearGemv.cpp | 203 ++++++++++++++++++ .../runtime/graph/ops/impl/Q8taLinearGemv.h | 30 +++ .../test/custom_ops/impl/TestQ8taLinear.cpp | 111 ++++++---- .../test/custom_ops/test_q8ta_linear.cpp | 29 ++- backends/vulkan/test/test_vulkan_passes.py | 46 ++++ 10 files changed, 652 insertions(+), 41 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.cpp create mode 100644 backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.h diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index f184b4c93db..78bc87bc159 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -416,6 +416,66 @@ def q8ta_linear( lib.impl(name, q8ta_linear, "CompositeExplicitAutograd") q8ta_linear_op = getattr(getattr(torch.ops, namespace), name) +####################### +## q8ta_linear_gemv ## +####################### + + +def q8ta_linear_gemv( + x: torch.Tensor, + input_scale: float, + input_zero_point: int, + weights: torch.Tensor, + weight_sums: torch.Tensor, + weight_scales: torch.Tensor, + output_scale: float, + output_zero_point: int, + bias: Optional[torch.Tensor] = None, +): + weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32) + weights = torch.ops.quantized_decomposed.dequantize_per_channel( + weights, + weight_scales, + weight_zeros, + 0, + -127, + 127, + torch.int8, + ) + + x = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, input_scale, input_zero_point, -128, 127, x.dtype + ) + + out = torch.nn.functional.linear(x, weights) + if bias is not None: + out = out + bias + + out = torch.ops.quantized_decomposed.quantize_per_tensor( + out, output_scale, output_zero_point, -128, 127, torch.int8 + ) + + return out + + +name = "q8ta_linear_gemv" +lib.define( + f""" + {name}( + Tensor x, + float input_scale, + int input_zero_point, + Tensor weights, + Tensor weight_sums, + Tensor weight_scales, + float output_scale, + int output_zero_point, + Tensor? bias = None) -> Tensor + """ +) +lib.impl(name, q8ta_linear_gemv, "CompositeExplicitAutograd") +q8ta_linear_gemv_op = getattr(getattr(torch.ops, namespace), name) + ################### ## q8ta_conv2d_* ## ################### diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 26fe665e2f6..811f717a833 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -857,6 +857,28 @@ def register_q8ta_linear(): ) +@update_features(exir_ops.edge.et_vk.q8ta_linear_gemv.default) +def register_q8ta_linear_gemv(): + return OpFeatures( + inputs_storage=[ + utils.PACKED_INT8_4W_BUFFER, # input + utils.NO_STORAGE, # input_scale (non tensor) + utils.NO_STORAGE, # input_zero_point (non tensor) + utils.NO_STORAGE, # weight (prepacked) + utils.NO_STORAGE, # weight_sums (prepacked) + utils.NO_STORAGE, # weight_scales (prepacked) + utils.NO_STORAGE, # output_scale (non tensor) + utils.NO_STORAGE, # output_zero_point (non tensor) + utils.NO_STORAGE, # bias (prepacked) + ], + outputs_storage=[ + utils.PACKED_INT8_4W_BUFFER, + ], + supports_resize=False, + supports_prepacking=True, + ) + + # ============================================================================= # SDPA.cpp # ============================================================================= diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index 6dbb5b2fe4e..1b6c64af8e3 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -499,10 +499,18 @@ def make_q8ta_linear_custom_op( data=sum_per_output_channel, ) + # Use gemv variant when batch size is 1 + input_shape = match.fp_input_node.meta["val"].shape + batch_size = input_shape[-2] if len(input_shape) >= 2 else 1 + if batch_size == 1: + op_target = exir_ops.edge.et_vk.q8ta_linear_gemv.default + else: + op_target = exir_ops.edge.et_vk.q8ta_linear.default + with graph_module.graph.inserting_before(match.output_node): qlinear_node = graph_module.graph.create_node( "call_function", - exir_ops.edge.et_vk.q8ta_linear.default, + op_target, args=( match.quantize_input_node, match.input_scales_node, diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl new file mode 100644 index 00000000000..63ca6067734 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl @@ -0,0 +1,164 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions("buffer", DTYPE)} + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_integer_dot_product : require + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T int + +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M4 1 +#define TILE_M 1 +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +#define WGS ${WGS} + +layout(std430) buffer; + +// Scalar int arrays for 4W packed int8 input/output +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer")} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer")} +// Weight uses ivec4 (same format as q8ta_linear) +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_spec_const(C, "int", "apply_bias", "0")} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "common.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_output_tile_int8_int8_compute.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_int_weight_sums_load.glslh" +#include "linear_fp_bias_load.glslh" + +shared Int32Accum partial_accums[WGS]; + +void main() { + const int lid = int(gl_LocalInvocationID.z); + const int n4 = int(gl_GlobalInvocationID.x) * TILE_N4; + + const int n = mul_4(n4); + + const int K4 = div_up_4(input_sizes.x); + const int N4 = div_up_4(output_sizes.x); + + if (n >= output_sizes.x) { + return; + } + + Int32Accum out_accum; + initialize(out_accum); + + Int8WeightTile int8_weight_tile; + + for (int k4 = lid; k4 < K4; k4 += WGS) { + // Load one packed int32 from the 4W input buffer. Each int32 contains + // 4 int8 values at k=k4*4..k4*4+3. + const int packed_input = t_packed_int8_input[k4]; + + load_int8_weight_tile(int8_weight_tile, n4, k4, N4); + + // Accumulate dot products of the input int8x4 with each weight int8x4 + [[unroll]] for (int n = 0; n < TILE_N; ++n) { + const int tile_n4 = div_4(n); + const int n4i = mod_4(n); + out_accum.data[0][tile_n4][n4i] = dotPacked4x8AccSatEXT( + packed_input, + int8_weight_tile.data[0][tile_n4][n4i], + out_accum.data[0][tile_n4][n4i]); + } + } + + partial_accums[lid] = out_accum; + + memoryBarrierShared(); + barrier(); + + // Tree reduction to combine partial results + for (int i = WGS / 2; i > 0; i /= 2) { + if (lid < i) { + [[unroll]] for (int tile_n4 = 0; tile_n4 < TILE_N4; ++tile_n4) { + partial_accums[lid].data[0][tile_n4] += + partial_accums[lid + i].data[0][tile_n4]; + } + } + memoryBarrierShared(); + barrier(); + } + + // Only the first thread writes the result + if (lid == 0) { + out_accum = partial_accums[0]; + + FPPerOutChannelParams weight_scales_tile; + load_weight_scales_tile(weight_scales_tile, n4); + + IntPerOutChannelParams weight_sums_tile; + load_weight_sums_tile(weight_sums_tile, n4); + + FPOutTile out_tile; + initialize(out_tile); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, n4); + + accumulate_out_tile_with_int_accum( + out_tile, + out_accum, + input_scale, + input_zp, + weight_sums_tile, + weight_scales_tile, + bias_tile); + } else { + accumulate_out_tile_with_int_accum( + out_tile, + out_accum, + input_scale, + input_zp, + weight_sums_tile, + weight_scales_tile); + } + + // Quantize and write to scalar int[] buffer. Each int32 at position n4 + // contains 4 packed int8 output values for channels n4*4..n4*4+3. + [[unroll]] for (int tile_n4 = 0; tile_n4 < TILE_N4; ++tile_n4) { + if (n4 + tile_n4 < N4) { + t_packed_int8_output[n4 + tile_n4] = quantize_and_pack( + out_tile.data[0][tile_n4], output_inv_scale, output_zp); + } + } + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.yaml new file mode 100644 index 00000000000..beae1eddf3e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q8ta_linear_gemv: + parameter_names_with_default_values: + DTYPE: float + WEIGHT_STORAGE: texture2d + TILE_K4: 1 + TILE_N4: 2 + WGS: 64 + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: q8ta_linear_gemv diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.cpp new file mode 100644 index 00000000000..1d8b2ca55d5 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.cpp @@ -0,0 +1,203 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include +#include + +namespace vkcompute { + +static bool q8ta_linear_gemv_check_packed_dim_info( + const api::PackedDimInfo& info) { + return info.packed_dim == WHCN::kWidthDim && + info.packed_dim_block_size == 4 && + info.outer_packed_dim == WHCN::kHeightDim && + info.outer_packed_dim_block_size == 1; +} + +// +// Workgroup size selection +// + +utils::uvec3 q8ta_linear_gemv_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef out = args.at(0).refs.at(0); + + std::vector out_sizes = graph->sizes_of(out); + const uint32_t N = utils::val_at(-1, out_sizes); + + // Each output tile contains 8 columns (TILE_N4=2 -> 8 output channels) + const uint32_t N_per_tile = 8; + const uint32_t num_N_tiles = utils::div_up(N, N_per_tile); + + return {num_N_tiles, 1, 1}; +} + +utils::uvec3 q8ta_linear_gemv_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)global_workgroup_size; + (void)args; + (void)resize_args; + + // Cooperative algorithm: 64 threads share the K reduction + return {1, 1, 64}; +} + +// +// Dispatch node +// + +void add_q8ta_linear_gemv_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef packed_bias, + const ValueRef packed_int8_output) { + // Validate packed dim info matches 4W layout + VK_CHECK_COND(q8ta_linear_gemv_check_packed_dim_info( + graph.packed_dim_info_of(packed_int8_input))); + VK_CHECK_COND(q8ta_linear_gemv_check_packed_dim_info( + graph.packed_dim_info_of(packed_int8_output))); + + float input_scale_val = graph.extract_scalar(input_scale); + int32_t input_zp_val = graph.extract_scalar(input_zp); + + float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); + int32_t output_zp_val = graph.extract_scalar(output_zp); + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + std::string kernel_name = "q8ta_linear_gemv"; + add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(packed_int8_output), graph.sizes_ubo(packed_int8_input)}; + + std::vector push_constants = { + PushConstantDataInfo(&input_scale_val, sizeof(input_scale_val)), + PushConstantDataInfo(&input_zp_val, sizeof(input_zp_val)), + PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), + PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + q8ta_linear_gemv_global_wg_size, + q8ta_linear_gemv_local_wg_size, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, + {{packed_int8_input, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {apply_bias}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + +// +// High level operator impl +// + +void q8ta_linear_gemv(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + const int64_t K = graph.size_at(-1, packed_int8_input); + VK_CHECK_COND(K % 4 == 0); + + QuantizationConfig weight_quant_config(8, kPerChannel, {K}); + + // Prepack weight data (same format as q8ta_linear) + const ValueRef packed_weight = + prepack_quantized_linear_weight(graph, weight_quant_config, weight_data); + const ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + const ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + + // Prepack bias data + TmpTensor dummy_bias( + &graph, + {}, + graph.dtype_of(packed_weight_scales), + utils::kBuffer, + utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_not_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + add_q8ta_linear_gemv_node( + graph, + packed_int8_input, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + output_scale, + output_zp, + bias_data, + packed_bias, + packed_int8_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.q8ta_linear_gemv.default, q8ta_linear_gemv); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.h b/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.h new file mode 100644 index 00000000000..802c410c900 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace vkcompute { + +void add_q8ta_linear_gemv_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef packed_bias, + const ValueRef packed_int8_output); + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taLinear.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taLinear.cpp index e35f2509a18..0ee2dffd12a 100644 --- a/backends/vulkan/test/custom_ops/impl/TestQ8taLinear.cpp +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taLinear.cpp @@ -24,47 +24,80 @@ void test_q8ta_linear(ComputeGraph& graph, const std::vector& args) { const ValueRef output_scale = args.at(idx++); const ValueRef output_zp = args.at(idx++); const ValueRef bias_data = args.at(idx++); + const ValueRef impl_selector_str = args.at(idx++); const ValueRef fp_output = args.at(idx++); - // Create temporary packed int8 tensors for input and output - // Input uses 4H4W layout to match the linear shader's ivec4 reading pattern - // where each ivec4 contains data from 4 rows - TmpTensor packed_int8_input( - &graph, - graph.sizes_of(fp_input), - vkapi::kInt8x4, - utils::kBuffer, - utils::kPackedInt8_4H4W); - - // Output uses 4H4W layout to match the linear shader's ivec4 writing pattern - TmpTensor packed_int8_output( - &graph, - graph.sizes_of(fp_output), - vkapi::kInt8x4, - utils::kBuffer, - utils::kPackedInt8_4H4W); - - // Quantize floating point input to packed int8 - add_q8ta_quantize_node( - graph, fp_input, input_scale, input_zp, packed_int8_input); - - // Call the q8ta_linear operator - std::vector linear_args = { - packed_int8_input, - input_scale, - input_zp, - weight_data, - weight_sums_data, - weight_scales_data, - output_scale, - output_zp, - bias_data, - packed_int8_output}; - VK_GET_OP_FN("et_vk.q8ta_linear.default")(graph, linear_args); - - // Dequantize packed int8 output to floating point - add_q8ta_dequantize_node( - graph, packed_int8_output, output_scale, output_zp, fp_output); + std::string impl_selector = graph.extract_string(impl_selector_str); + + if (impl_selector == "gemv") { + // Use 4W layout for gemv variant + TmpTensor packed_int8_input( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W); + + TmpTensor packed_int8_output( + &graph, + graph.sizes_of(fp_output), + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W); + + add_q8ta_quantize_node( + graph, fp_input, input_scale, input_zp, packed_int8_input); + + std::vector linear_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + packed_int8_output}; + VK_GET_OP_FN("et_vk.q8ta_linear_gemv.default")(graph, linear_args); + + add_q8ta_dequantize_node( + graph, packed_int8_output, output_scale, output_zp, fp_output); + } else { + // Default: use 4H4W layout for tiled variant + TmpTensor packed_int8_input( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4H4W); + + TmpTensor packed_int8_output( + &graph, + graph.sizes_of(fp_output), + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4H4W); + + add_q8ta_quantize_node( + graph, fp_input, input_scale, input_zp, packed_int8_input); + + std::vector linear_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + packed_int8_output}; + VK_GET_OP_FN("et_vk.q8ta_linear.default")(graph, linear_args); + + add_q8ta_dequantize_node( + graph, packed_int8_output, output_scale, output_zp, fp_output); + } } REGISTER_OPERATORS { diff --git a/backends/vulkan/test/custom_ops/test_q8ta_linear.cpp b/backends/vulkan/test/custom_ops/test_q8ta_linear.cpp index 3c519e3afc6..6952acfc6f5 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_linear.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_linear.cpp @@ -30,12 +30,16 @@ struct LinearConfig { static TestCase create_test_case_from_config( const LinearConfig& config, - vkapi::ScalarType input_dtype) { + vkapi::ScalarType input_dtype, + const std::string& impl_selector = "") { TestCase test_case; std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; std::string test_name = config.test_case_name + "_Buffer_" + dtype_str; + if (!impl_selector.empty()) { + test_name += " [" + impl_selector + "]"; + } test_case.set_name(test_name); test_case.set_operator_name("test_etvk.test_q8ta_linear.default"); @@ -131,6 +135,11 @@ static TestCase create_test_case_from_config( test_case.add_input_spec(output_scale); test_case.add_input_spec(output_zero_point); test_case.add_input_spec(bias); + + // Add impl_selector string + ValueSpec impl_selector_spec = ValueSpec::make_string(impl_selector); + test_case.add_input_spec(impl_selector_spec); + test_case.add_output_spec(output); test_case.set_abs_tolerance(output_scale_val + 1e-4f); @@ -154,6 +163,12 @@ static std::vector generate_q8ta_linear_test_cases() { } std::vector configs = { + // Batch size 1 cases (test both tiled and gemv) + {1, 64, 32}, + {1, 128, 64}, + {1, 256, 128}, + {1, 128, 64, false}, + // Multi-batch cases {4, 64, 32}, {4, 128, 64}, {4, 256, 128}, @@ -164,6 +179,9 @@ static std::vector generate_q8ta_linear_test_cases() { {32, 128, 64, false}, {32, 256, 128, false}, // Performance cases + {1, 512, 512}, + {1, 2048, 2048}, + {1, 512, 9059}, {256, 2048, 2048}, {512, 2048, 2048}, {1024, 2048, 2048}, @@ -182,7 +200,14 @@ static std::vector generate_q8ta_linear_test_cases() { config.test_case_name = generated_test_case_name; + // Default (tiled) variant test_cases.push_back(create_test_case_from_config(config, vkapi::kFloat)); + + // For batch size 1, also test the gemv variant + if (config.M == 1) { + test_cases.push_back( + create_test_case_from_config(config, vkapi::kFloat, "gemv")); + } } return test_cases; @@ -201,6 +226,8 @@ static void q8ta_linear_reference_impl(TestCase& test_case) { const ValueSpec& output_scale_spec = test_case.inputs()[idx++]; const ValueSpec& output_zeros_spec = test_case.inputs()[idx++]; const ValueSpec& bias_spec = test_case.inputs()[idx++]; + const ValueSpec& impl_selector_spec = test_case.inputs()[idx++]; + (void)impl_selector_spec; ValueSpec& output_spec = test_case.outputs()[0]; diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index 2141d734d48..3488357d155 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -237,3 +237,49 @@ def forward(self, x): 1, "Expected at least one q8ta_linear op from output-quantized linear fusion", ) + + def test_fuse_q8ta_linear_gemv(self): + """Test that batch-1 quantized linear fuses into q8ta_linear_gemv.""" + from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, + ) + + class TwoLinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(128, 64, bias=False) + self.linear2 = torch.nn.Linear(64, 32, bias=False) + + def forward(self, x): + return self.linear2(self.linear1(x)) + + model = TwoLinearModule() + # Batch size 1 to trigger gemv variant + sample_inputs = (torch.randn(1, 128),) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=False, + ) + quantizer.set_global(operator_config) + + edge_program = quantize_and_lower_module(model, sample_inputs, quantizer) + + ep = edge_program._edge_programs["forward"] + fuse_pass = FusePatternsPass() + fuse_pass._exported_program = ep + result = fuse_pass.call(ep.graph_module) + + self.assertTrue(result.modified) + + gm = ep.graph_module + + # With batch size 1, the first linear should fuse to q8ta_linear_gemv + q8ta_linear_gemv_count = op_node_count(gm, "et_vk__q8ta_linear_gemv__default") + self.assertGreaterEqual( + q8ta_linear_gemv_count, + 1, + "Expected at least one q8ta_linear_gemv op for batch-1 linear fusion", + )