diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index fb64b27b49e..7f891409e41 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -421,6 +421,71 @@ def q8ta_linear( lib.impl(name, q8ta_linear, "CompositeExplicitAutograd") q8ta_linear_op = getattr(getattr(torch.ops, namespace), name) +####################### +## q8ta_linear_gemv ## +####################### + + +def q8ta_linear_gemv( + x: torch.Tensor, + input_scale: float, + input_zero_point: int, + weights: torch.Tensor, + weight_sums: torch.Tensor, + weight_scales: torch.Tensor, + output_scale: float, + output_zero_point: int, + bias: Optional[torch.Tensor] = None, + activation: str = "none", +): + weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32) + weights = torch.ops.quantized_decomposed.dequantize_per_channel( + weights, + weight_scales, + weight_zeros, + 0, + -127, + 127, + torch.int8, + ) + + x = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, input_scale, input_zero_point, -128, 127, x.dtype + ) + + out = torch.nn.functional.linear(x, weights) + if bias is not None: + out = out + bias + + if activation == "relu": + out = torch.nn.functional.relu(out) + + out = torch.ops.quantized_decomposed.quantize_per_tensor( + out, output_scale, output_zero_point, -128, 127, torch.int8 + ) + + return out + + +name = "q8ta_linear_gemv" +lib.define( + f""" + {name}( + Tensor x, + float input_scale, + int input_zero_point, + Tensor weights, + Tensor weight_sums, + Tensor weight_scales, + float output_scale, + int output_zero_point, + Tensor? bias = None, + str activation = "none") -> Tensor + """ +) +lib.impl(name, q8ta_linear_gemv, "CompositeExplicitAutograd") +q8ta_linear_gemv_op = getattr(getattr(torch.ops, namespace), name) + ################### ## q8ta_conv2d_* ## ################### diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 48fac18bc56..855df9d2e74 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -858,6 +858,29 @@ def register_q8ta_linear(): ) +@update_features(exir_ops.edge.et_vk.q8ta_linear_gemv.default) +def register_q8ta_linear_gemv(): + return OpFeatures( + inputs_storage=[ + utils.PACKED_INT8_4W_BUFFER, # input + utils.NO_STORAGE, # input_scale (non tensor) + utils.NO_STORAGE, # input_zero_point (non tensor) + utils.NO_STORAGE, # weight (prepacked) + utils.NO_STORAGE, # weight_sums (prepacked) + utils.NO_STORAGE, # weight_scales (prepacked) + utils.NO_STORAGE, # output_scale (non tensor) + utils.NO_STORAGE, # output_zero_point (non tensor) + utils.NO_STORAGE, # bias (prepacked) + utils.NO_STORAGE, # activation (non tensor) + ], + outputs_storage=[ + utils.PACKED_INT8_4W_BUFFER, + ], + supports_resize=False, + supports_prepacking=True, + ) + + # ============================================================================= # SDPA.cpp # ============================================================================= diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index fefad0eaf8a..f1bcfc775bc 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -507,10 +507,18 @@ def make_q8ta_linear_custom_op( data=sum_per_output_channel, ) + # Use gemv variant when batch size is 1 + input_shape = match.fp_input_node.meta["val"].shape + batch_size = input_shape[-2] if len(input_shape) >= 2 else 1 + if batch_size == 1: + op_target = exir_ops.edge.et_vk.q8ta_linear_gemv.default + else: + op_target = exir_ops.edge.et_vk.q8ta_linear.default + with graph_module.graph.inserting_before(match.output_node): qlinear_node = graph_module.graph.create_node( "call_function", - exir_ops.edge.et_vk.q8ta_linear.default, + op_target, args=( match.quantize_input_node, match.input_scales_node, diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl new file mode 100644 index 00000000000..aa0837c4a6e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl @@ -0,0 +1,165 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions("buffer", DTYPE)} + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_integer_dot_product : require + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T int + +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M4 1 +#define TILE_M 1 +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +#define WGS ${WGS} + +layout(std430) buffer; + +// Scalar int arrays for 4W packed int8 input/output +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer")} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer")} +// Weight uses ivec4 (same format as q8ta_linear) +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_spec_const(C, "int", "apply_bias", "0")} +${layout_declare_spec_const(C, "int", "activation_type", "0")} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "common.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_output_tile_int8_int8_compute.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_int_weight_sums_load.glslh" +#include "linear_fp_bias_load.glslh" + +shared Int32Accum partial_accums[WGS]; + +void main() { + const int lid = int(gl_LocalInvocationID.z); + const int n4 = int(gl_GlobalInvocationID.x) * TILE_N4; + + const int n = mul_4(n4); + + const int K4 = div_up_4(input_sizes.x); + const int N4 = div_up_4(output_sizes.x); + + if (n >= output_sizes.x) { + return; + } + + Int32Accum out_accum; + initialize(out_accum); + + Int8WeightTile int8_weight_tile; + + for (int k4 = lid; k4 < K4; k4 += WGS) { + // Load one packed int32 from the 4W input buffer. Each int32 contains + // 4 int8 values at k=k4*4..k4*4+3. + const int packed_input = t_packed_int8_input[k4]; + + load_int8_weight_tile(int8_weight_tile, n4, k4, N4); + + // Accumulate dot products of the input int8x4 with each weight int8x4 + [[unroll]] for (int n = 0; n < TILE_N; ++n) { + const int tile_n4 = div_4(n); + const int n4i = mod_4(n); + out_accum.data[0][tile_n4][n4i] = dotPacked4x8AccSatEXT( + packed_input, + int8_weight_tile.data[0][tile_n4][n4i], + out_accum.data[0][tile_n4][n4i]); + } + } + + partial_accums[lid] = out_accum; + + memoryBarrierShared(); + barrier(); + + // Only the first thread writes the result + if (lid == 0) { + for (int i = 1; i < WGS; ++i) { + [[unroll]] for (int tile_n4 = 0; tile_n4 < TILE_N4; ++tile_n4) { + out_accum.data[0][tile_n4] += + partial_accums[i].data[0][tile_n4]; + } + } + + FPPerOutChannelParams weight_scales_tile; + load_weight_scales_tile(weight_scales_tile, n4); + + IntPerOutChannelParams weight_sums_tile; + load_weight_sums_tile(weight_sums_tile, n4); + + FPOutTile out_tile; + initialize(out_tile); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, n4); + + accumulate_out_tile_with_int_accum( + out_tile, + out_accum, + input_scale, + input_zp, + weight_sums_tile, + weight_scales_tile, + bias_tile); + } else { + accumulate_out_tile_with_int_accum( + out_tile, + out_accum, + input_scale, + input_zp, + weight_sums_tile, + weight_scales_tile); + } + + // Apply ReLU if enabled + if (activation_type > 0) { + [[unroll]] for (int tile_n4 = 0; tile_n4 < TILE_N4; ++tile_n4) { + out_tile.data[0][tile_n4] = max(out_tile.data[0][tile_n4], vec4(0.0)); + } + } + + // Quantize and write to scalar int[] buffer. Each int32 at position n4 + // contains 4 packed int8 output values for channels n4*4..n4*4+3. + [[unroll]] for (int tile_n4 = 0; tile_n4 < TILE_N4; ++tile_n4) { + if (n4 + tile_n4 < N4) { + t_packed_int8_output[n4 + tile_n4] = quantize_and_pack( + out_tile.data[0][tile_n4], output_inv_scale, output_zp); + } + } + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.yaml new file mode 100644 index 00000000000..beae1eddf3e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q8ta_linear_gemv: + parameter_names_with_default_values: + DTYPE: float + WEIGHT_STORAGE: texture2d + TILE_K4: 1 + TILE_N4: 2 + WGS: 64 + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: q8ta_linear_gemv diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.cpp new file mode 100644 index 00000000000..120df6b0256 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.cpp @@ -0,0 +1,210 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include +#include +#include + +namespace vkcompute { + +static bool q8ta_linear_gemv_check_packed_dim_info( + const api::PackedDimInfo& info) { + return info.packed_dim == WHCN::kWidthDim && + info.packed_dim_block_size == 4 && + info.outer_packed_dim == WHCN::kHeightDim && + info.outer_packed_dim_block_size == 1; +} + +// +// Workgroup size selection +// + +utils::uvec3 q8ta_linear_gemv_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef out = args.at(0).refs.at(0); + + std::vector out_sizes = graph->sizes_of(out); + const uint32_t N = utils::val_at(-1, out_sizes); + + // Each output tile contains 8 columns (TILE_N4=2 -> 8 output channels) + const uint32_t N_per_tile = 8; + const uint32_t num_N_tiles = utils::div_up(N, N_per_tile); + + return {num_N_tiles, 1, 1}; +} + +utils::uvec3 q8ta_linear_gemv_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)global_workgroup_size; + (void)args; + (void)resize_args; + + // Cooperative algorithm: 64 threads share the K reduction + return {1, 1, 64}; +} + +// +// Dispatch node +// + +void add_q8ta_linear_gemv_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef packed_bias, + const uint32_t activation_type, + const ValueRef packed_int8_output) { + // Validate packed dim info matches 4W layout + VK_CHECK_COND(q8ta_linear_gemv_check_packed_dim_info( + graph.packed_dim_info_of(packed_int8_input))); + VK_CHECK_COND(q8ta_linear_gemv_check_packed_dim_info( + graph.packed_dim_info_of(packed_int8_output))); + + float input_scale_val = graph.extract_scalar(input_scale); + int32_t input_zp_val = graph.extract_scalar(input_zp); + + float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); + int32_t output_zp_val = graph.extract_scalar(output_zp); + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + std::string kernel_name = "q8ta_linear_gemv"; + add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(packed_int8_output), graph.sizes_ubo(packed_int8_input)}; + + std::vector push_constants = { + PushConstantDataInfo(&input_scale_val, sizeof(input_scale_val)), + PushConstantDataInfo(&input_zp_val, sizeof(input_zp_val)), + PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), + PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + q8ta_linear_gemv_global_wg_size, + q8ta_linear_gemv_local_wg_size, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, + {{packed_int8_input, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {apply_bias, activation_type}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + +// +// High level operator impl +// + +void q8ta_linear_gemv(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef activation = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + const int64_t K = graph.size_at(-1, packed_int8_input); + VK_CHECK_COND(K % 4 == 0); + + QuantizationConfig weight_quant_config(8, kPerChannel, {K}); + + // Prepack weight data (same format as q8ta_linear) + const ValueRef packed_weight = + prepack_quantized_linear_weight(graph, weight_quant_config, weight_data); + const ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + const ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + + // Prepack bias data + TmpTensor dummy_bias( + &graph, + {}, + graph.dtype_of(packed_weight_scales), + utils::kBuffer, + utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_not_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation))); + + add_q8ta_linear_gemv_node( + graph, + packed_int8_input, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + output_scale, + output_zp, + bias_data, + packed_bias, + activation_type_val, + packed_int8_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.q8ta_linear_gemv.default, q8ta_linear_gemv); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.h b/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.h new file mode 100644 index 00000000000..946022d16ef --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace vkcompute { + +void add_q8ta_linear_gemv_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef packed_bias, + const uint32_t activation_type, + const ValueRef packed_int8_output); + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taLinear.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taLinear.cpp index d0803fe746b..684a7b94e66 100644 --- a/backends/vulkan/test/custom_ops/impl/TestQ8taLinear.cpp +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taLinear.cpp @@ -25,31 +25,27 @@ void test_q8ta_linear(ComputeGraph& graph, const std::vector& args) { const ValueRef output_zp = args.at(idx++); const ValueRef bias_data = args.at(idx++); const ValueRef activation = args.at(idx++); + const ValueRef impl_selector_str = args.at(idx++); const ValueRef fp_output = args.at(idx++); - // Create temporary packed int8 tensors for input and output - // Input uses 4H4W layout to match the linear shader's ivec4 reading pattern - // where each ivec4 contains data from 4 rows + std::string impl_selector = graph.extract_string(impl_selector_str); + + utils::GPUMemoryLayout layout = + impl_selector == "gemv" ? utils::kPackedInt8_4W : utils::kPackedInt8_4H4W; + TmpTensor packed_int8_input( - &graph, - graph.sizes_of(fp_input), - vkapi::kInt8x4, - utils::kBuffer, - utils::kPackedInt8_4H4W); + &graph, graph.sizes_of(fp_input), vkapi::kInt8x4, utils::kBuffer, layout); - // Output uses 4H4W layout to match the linear shader's ivec4 writing pattern TmpTensor packed_int8_output( &graph, graph.sizes_of(fp_output), vkapi::kInt8x4, utils::kBuffer, - utils::kPackedInt8_4H4W); + layout); - // Quantize floating point input to packed int8 add_q8ta_quantize_node( graph, fp_input, input_scale, input_zp, packed_int8_input); - // Call the q8ta_linear operator std::vector linear_args = { packed_int8_input, input_scale, @@ -62,9 +58,12 @@ void test_q8ta_linear(ComputeGraph& graph, const std::vector& args) { bias_data, activation, packed_int8_output}; - VK_GET_OP_FN("et_vk.q8ta_linear.default")(graph, linear_args); - // Dequantize packed int8 output to floating point + std::string op_name = impl_selector == "gemv" + ? "et_vk.q8ta_linear_gemv.default" + : "et_vk.q8ta_linear.default"; + VK_GET_OP_FN(op_name)(graph, linear_args); + add_q8ta_dequantize_node( graph, packed_int8_output, output_scale, output_zp, fp_output); } diff --git a/backends/vulkan/test/custom_ops/test_q8ta_linear.cpp b/backends/vulkan/test/custom_ops/test_q8ta_linear.cpp index faec638059c..707a8695171 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_linear.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_linear.cpp @@ -30,12 +30,16 @@ struct LinearConfig { static TestCase create_test_case_from_config( const LinearConfig& config, - vkapi::ScalarType input_dtype) { + vkapi::ScalarType input_dtype, + const std::string& impl_selector = "") { TestCase test_case; std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; std::string test_name = config.test_case_name + "_Buffer_" + dtype_str; + if (!impl_selector.empty()) { + test_name += " [" + impl_selector + "]"; + } test_case.set_name(test_name); test_case.set_operator_name("test_etvk.test_q8ta_linear.default"); @@ -136,6 +140,9 @@ static TestCase create_test_case_from_config( ValueSpec activation = ValueSpec::make_string("none"); test_case.add_input_spec(activation); + // Add impl_selector string + ValueSpec impl_selector_spec = ValueSpec::make_string(impl_selector); + test_case.add_input_spec(impl_selector_spec); test_case.add_output_spec(output); test_case.set_abs_tolerance(output_scale_val + 1e-4f); @@ -159,6 +166,12 @@ static std::vector generate_q8ta_linear_test_cases() { } std::vector configs = { + // Batch size 1 cases (test both tiled and gemv) + {1, 64, 32}, + {1, 128, 64}, + {1, 256, 128}, + {1, 128, 64, false}, + // Multi-batch cases {4, 64, 32}, {4, 128, 64}, {4, 256, 128}, @@ -169,6 +182,9 @@ static std::vector generate_q8ta_linear_test_cases() { {32, 128, 64, false}, {32, 256, 128, false}, // Performance cases + {1, 512, 512}, + {1, 2048, 2048}, + {1, 512, 9059}, {256, 2048, 2048}, {512, 2048, 2048}, {1024, 2048, 2048}, @@ -187,7 +203,14 @@ static std::vector generate_q8ta_linear_test_cases() { config.test_case_name = generated_test_case_name; + // Default (tiled) variant test_cases.push_back(create_test_case_from_config(config, vkapi::kFloat)); + + // For batch size 1, also test the gemv variant + if (config.M == 1) { + test_cases.push_back( + create_test_case_from_config(config, vkapi::kFloat, "gemv")); + } } return test_cases; @@ -206,6 +229,10 @@ static void q8ta_linear_reference_impl(TestCase& test_case) { const ValueSpec& output_scale_spec = test_case.inputs()[idx++]; const ValueSpec& output_zeros_spec = test_case.inputs()[idx++]; const ValueSpec& bias_spec = test_case.inputs()[idx++]; + const ValueSpec& activation_spec = test_case.inputs()[idx++]; + (void)activation_spec; + const ValueSpec& impl_selector_spec = test_case.inputs()[idx++]; + (void)impl_selector_spec; ValueSpec& output_spec = test_case.outputs()[0]; diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index bbab1535954..c5664de1e73 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -237,3 +237,49 @@ def forward(self, x): 1, "Expected at least one q8ta_linear op from output-quantized linear fusion", ) + + def test_fuse_q8ta_linear_gemv(self): + """Test that batch-1 quantized linear fuses into q8ta_linear_gemv.""" + from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, + ) + + class TwoLinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(128, 64, bias=False) + self.linear2 = torch.nn.Linear(64, 32, bias=False) + + def forward(self, x): + return self.linear2(self.linear1(x)) + + model = TwoLinearModule() + # Batch size 1 to trigger gemv variant + sample_inputs = (torch.randn(1, 128),) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=False, + ) + quantizer.set_global(operator_config) + + edge_program = quantize_and_lower_module(model, sample_inputs, quantizer) + + ep = edge_program._edge_programs["forward"] + fuse_pass = FusePatternsPass() + fuse_pass._exported_program = ep + result = fuse_pass.call(ep.graph_module) + + self.assertTrue(result.modified) + + gm = ep.graph_module + + # With batch size 1, the first linear should fuse to q8ta_linear_gemv + q8ta_linear_gemv_count = op_node_count(gm, "q8ta_linear_gemv.default") + self.assertGreaterEqual( + q8ta_linear_gemv_count, + 1, + "Expected at least one q8ta_linear_gemv op for batch-1 linear fusion", + )