diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_nv_cm2.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_nv_cm2.glsl new file mode 100644 index 00000000000..90adc695ad7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_nv_cm2.glsl @@ -0,0 +1,183 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Quantized int8 linear shader using GL_NV_cooperative_matrix2 extension. + * + * Uses float16 cooperative matrices with dequantization during load. + * + * Computes: output = dequant(input) @ dequant(weight)^T + bias + * where dequant(input) = (input - zero_point) * scale + * and dequant(weight) = weight * weight_scale[channel] + */ + +#version 450 core + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_cooperative_matrix : enable +#extension GL_NV_cooperative_matrix2 : enable +#extension GL_EXT_buffer_reference : enable + +${define_required_extensions("buffer", DTYPE)} + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} +#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} + +// Block sizes for cooperative matrix - 16x16x16 +#define BM 16 +#define BN 16 +#define BK 16 + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_output", DTYPE, "buffer", is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_input", "int8", "buffer", is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_weight", "int8", "buffer", is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=True)} + +${layout_declare_spec_const(C, "int", "apply_bias", "0")} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; +}; + +// Workgroup size: 32 threads (1 subgroup for Subgroup scope) +layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; + +// Matrix types using float16 with float16 accumulator +#define MAT_TYPE float16_t +#define ACC_TYPE float16_t + +// Input block size: 16 int8 values packed as ivec4 (4 ints × 4 int8 per int) +#define INPUT_BLOCK_K 16 + +// Buffer reference type for input dequantization decode functor +// Matches the packed structure from q8ta_quantize.glsl: ivec4 with 4 int8 per int +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeInputBuf { + ivec4 packed; // 4 ints, each containing 4 packed int8 values = 16 int8 total +}; + +// Input decode: unpack int8 from ivec4, dequantize as (val - zp) * scale +MAT_TYPE decodeInputFunc(const in decodeInputBuf bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { + uint idx = coordInBlock[1]; + int packed_int = bl.packed[idx >> 2]; + int8_t val = int8_t((packed_int >> (int(idx & 3) * 8)) & 0xFF); + return MAT_TYPE((float(val) - float(input_zp)) * input_scale); +} + +// Weight decode: dequantize as val * per-channel scale +layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeWeightBuf { + int8_t v; +}; + +MAT_TYPE decodeWeightFunc(const in decodeWeightBuf bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { + uint out_channel = blockCoords[0] + coordInBlock[0]; + return MAT_TYPE(float(bl.v) * float(t_weight_scales[out_channel])); +} + +void main() { + // Get dimensions + const uint M = uint(output_sizes.y); // batch/rows + const uint N = uint(output_sizes.x); // output features + const uint K = uint(input_sizes.x); // input features + + // Each workgroup handles one BM x BN tile + const uint ir = gl_WorkGroupID.x; // row tile index + const uint ic = gl_WorkGroupID.y; // column tile index + + // Early exit if out of bounds + if (ir * BM >= M || ic * BN >= N) { + return; + } + + // Create tensor layouts with clamping for boundary handling + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutA = + createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutB = + createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = + createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + + // Create transpose view for loading weights + tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); + + // Set dimensions and strides for input (A matrix) + // Block size is 16 int8 values packed as ivec4, so stride is in blocks (K / INPUT_BLOCK_K) + tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, M, K); + tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, K / INPUT_BLOCK_K, 1); + tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, INPUT_BLOCK_K); + + tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, N, K); + tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, K, 1); + + tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, M, N); + tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, N, 1); + + // Initialize accumulator with bias (broadcast across rows) or zeros + coopmat sum; + + if (apply_bias == 1) { + // Bias layout: stride 0 in row dim broadcasts same bias values across rows + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBias = + createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutBias = setTensorLayoutDimensionNV(tensorLayoutBias, BM, N); + tensorLayoutBias = setTensorLayoutStrideNV(tensorLayoutBias, 0, 1); + + // Load as T first then convert to ACC_TYPE (buffer stores T, not ACC_TYPE) + coopmat bias_tmp; + coopMatLoadTensorNV(bias_tmp, t_bias, 0, + sliceTensorLayoutNV(tensorLayoutBias, 0, BM, ic * BN, BN)); + sum = coopmat(bias_tmp); + } else { + sum = coopmat(ACC_TYPE(0.0f)); + } + + // Loop over K dimension + const uint k_iters = (K + BK - 1) / BK; + + [[dont_unroll]] + for (uint block_k = 0, i = 0; i < k_iters; block_k += BK, ++i) { + // Cooperative matrices for A and B + coopmat mat_a; + coopmat mat_b; + + // Load A tile with input dequantization + coopMatLoadTensorNV(mat_a, t_input, 0, + sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK), decodeInputFunc); + + // Load B tile with transpose and weight dequantization + coopMatLoadTensorNV(mat_b, t_weight, 0, + sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeWeightFunc); + + // Multiply and accumulate + sum = coopMatMulAdd(mat_a, mat_b, sum); + } + + // Convert accumulator to output type + coopmat result = + coopmat(sum); + + // Store result + coopMatStoreTensorNV(result, t_output, 0, + sliceTensorLayoutNV(tensorLayoutD, ir * BM, BM, ic * BN, BN)); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_nv_cm2.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_nv_cm2.yaml new file mode 100644 index 00000000000..c7aa9c88e6d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_nv_cm2.yaml @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Quantized int8 matrix multiplication shader using GL_NV_cooperative_matrix2 +# extension for optimized performance on NVIDIA GPUs with tensor cores. +# +# This shader performs int8 activation x int8 weight linear layer with +# dequantization to produce floating-point output. +# +# Computes: output = dequantize(input @ weight^T) + bias + +linear_q8ta_q8csw_nv_cm2: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: buffer + PACKED_INT8_INPUT_STORAGE: buffer + WEIGHT_STORAGE: buffer + # Tile sizes for cooperative matrix operations + TILE_ROWS: 16 + TILE_COLS: 16 + # Use Vulkan 1.3 and SPIR-V 1.6 for GL_NV_cooperative_matrix2 support + VK_VERSION: "1.3" + SPV_VERSION: "1.6" + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + shader_variants: + - NAME: linear_q8ta_q8csw_nv_cm2_buffer_buffer_buffer_float + DTYPE: float + - NAME: linear_q8ta_q8csw_nv_cm2_buffer_buffer_buffer_half + DTYPE: half diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearNvCoopMat.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearNvCoopMat.cpp new file mode 100644 index 00000000000..a84a864a682 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearNvCoopMat.cpp @@ -0,0 +1,342 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include + +namespace vkcompute { + +// +// Shader dispatch utilities +// + +void resize_linear_q8ta_q8csw_nv_cm2_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + + const ValueRef output = args.at(0).refs.at(0); + const ValueRef input = args.at(1).refs.at(0); + const ValueRef weight_data = extra_args.at(0); + + std::vector input_sizes = graph->sizes_of(input); + std::vector weight_sizes = graph->sizes_of(weight_data); + + // input: [M, K], weight: [N, K] -> output: [M, N] + const int64_t M = utils::val_at(-2, input_sizes); + const int64_t N = utils::val_at(-2, weight_sizes); + + std::vector new_out_sizes(input_sizes.size()); + if (input_sizes.size() == 2) { + new_out_sizes.at(0) = M; + new_out_sizes.at(1) = N; + } else { + new_out_sizes.at(0) = input_sizes.at(0); + new_out_sizes.at(1) = M; + new_out_sizes.at(2) = N; + } + + graph->virtual_resize(output, new_out_sizes); +} + +utils::uvec3 linear_q8ta_q8csw_nv_cm2_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef output = args.at(0).refs.at(0); + + std::vector out_sizes = graph->sizes_of(output); + // Width dimension (N = out_features) + const uint32_t N = utils::val_at(-1, out_sizes); + // Height dimension (M = batch size) + const uint32_t M = utils::val_at(-2, out_sizes); + + // NV cooperative matrix 2 shader uses BM=16 x BN=16 tiles for int8 + const uint32_t BM = 16; + const uint32_t BN = 16; + + const uint32_t blocks_m = utils::div_up(M, BM); + const uint32_t blocks_n = utils::div_up(N, BN); + + // Each workgroup (32 threads = 1 subgroup) processes one BM x BN tile + return {blocks_m*32, blocks_n, 1}; +} + +utils::uvec3 linear_q8ta_q8csw_nv_cm2_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)global_workgroup_size; + (void)args; + (void)resize_args; + + // NV cooperative matrix 2 with Subgroup scope uses 32 threads (1 subgroup) + return {32, 1, 1}; +} + +// +// Prepacking +// + +ValueRef prepack_int8_linear_weight_nv_cm2( + ComputeGraph& graph, + const QuantizationConfig& weight_quant_config, + const ValueRef weight_data) { + VK_CHECK_COND(weight_quant_config.nbits == 8); + + std::vector weight_sizes = graph.sizes_of(weight_data); + const int64_t ndim = graph.dim_of(weight_data); + + // Weight tensor has shape [N, K] (out_features, in_features) + const int64_t K = weight_sizes.at(ndim - 1); + const int64_t N = weight_sizes.at(ndim - 2); + + // Calculate output sizes for prepacked weight + // Output layout: [K, N4 * 4] where N4 = ceil(N / 4) + const int64_t N4 = utils::div_up(N, int64_t(4)); + + utils::StorageType storage_type = utils::kBuffer; + + std::vector packed_weight_sizes = {K, N4 * 4}; + + ValueRef packed_weight = graph.add_tensor( + packed_weight_sizes, vkapi::kInt, storage_type, utils::kWidthPacked); + + // Store original sizes for the shader + utils::ivec2 orig_sizes = { + utils::safe_downcast(K), utils::safe_downcast(N)}; + + utils::uvec3 global_wg_size = { + utils::safe_downcast(N4), + utils::safe_downcast(K), + 1u}; + + std::string kernel_name = "pack_q8_linear_weight"; + add_storage_type_suffix(kernel_name, storage_type); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + // Inputs and Outputs + weight_data, + packed_weight, + // UBOs + {}, + // Specialization Constants + {}, + // Push Constants + {graph.sizes_pc_of(packed_weight), + PushConstantDataInfo(&orig_sizes, sizeof(utils::ivec2))})); + + return packed_weight; +} + +// +// Linear Dispatch +// + +void add_linear_q8ta_q8csw_nv_cm2_node( + ComputeGraph& graph, + const ValueRef packed_int_input, + const ValueRef weight_data, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef bias_data, + const ValueRef packed_bias, + const ValueRef input_scale, + const ValueRef input_zero_point, + const ValueRef output) { + std::string kernel_name = "linear_q8ta_q8csw_nv_cm2"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(output)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_int_input)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(output), graph.sizes_ubo(packed_int_input)}; + + // Extract input_scale and input_zp for push constants + float in_scale = graph.extract_scalar(input_scale); + int32_t in_zp = graph.extract_scalar(input_zero_point); + + struct PushConstants { + float input_scale; + int32_t input_zp; + } push_data = {in_scale, in_zp}; + + std::vector push_constants = { + PushConstantDataInfo(&push_data, sizeof(push_data)), + }; + + int32_t apply_bias = graph.val_is_not_none(bias_data) ? 1 : 0; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + linear_q8ta_q8csw_nv_cm2_global_wg_size, + linear_q8ta_q8csw_nv_cm2_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {{packed_int_input, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {apply_bias}, + // Resize args + {weight_data}, + // Resizing Logic + resize_linear_q8ta_q8csw_nv_cm2_node)); +} + +// +// High-level operator implementation +// + +void linear_q8ta_q8csw_nv_cm2_impl( + ComputeGraph& graph, + const ValueRef input_tensor, + const ValueRef input_scale, + const ValueRef input_zero_point, + const ValueRef weight_data, + const ValueRef weight_sums, + const ValueRef weight_scales, + const ValueRef bias, + const ValueRef output) { + // Check that VK_NV_cooperative_matrix2 extension is available + VK_CHECK_COND( + graph.context()->adapter_ptr()->supports_nv_cooperative_matrix2(), + "linear_q8ta_q8csw_nv_cm2 requires VK_NV_cooperative_matrix2 extension " + "which is not available on this device."); + + std::vector input_sizes = graph.sizes_of(input_tensor); + VK_CHECK_COND( + input_sizes.size() == 2 || input_sizes.size() == 3, + "Input must be 2D or 3D tensor"); + + // Prepack weight data - just upload to GPU buffer without transformation + // Weight is in [N, K] format from input + const ValueRef packed_weight = prepack_standard( + graph, weight_data, utils::kBuffer, utils::kWidthPacked); + + // Prepack weight_sums + const ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums, utils::kBuffer, utils::kWidthPacked); + + // Prepack weight_scales + const ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales, utils::kBuffer, utils::kWidthPacked); + + // Prepack bias + // Create a dummy tensor to fill the binding slot of the bias tensor if it is + // not provided. This helps simplify dispatch logic and makes it so that + // fewer shader variants need to be generated. + TmpTensor dummy_bias( + &graph, {}, graph.dtype_of(output), utils::kBuffer, utils::kWidthPacked); + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_not_none(bias)) { + packed_bias = prepack_standard( + graph, bias, utils::kBuffer, utils::kWidthPacked); + } + + // Check if input is float type and quantize to int8 if needed + ValueRef quantized_input = input_tensor; + vkapi::ScalarType input_dtype = graph.dtype_of(input_tensor); + if (input_dtype == vkapi::kFloat || input_dtype == vkapi::kHalf) { + // Create quantized int8 output tensor + TmpTensor quantized_tensor( + &graph, + input_sizes, + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W); + + // Add quantization node to convert float input to int8 + add_q8ta_quantize_node( + graph, + input_tensor, + input_scale, + input_zero_point, + quantized_tensor.vref); + + quantized_input = quantized_tensor.vref; + } + + add_linear_q8ta_q8csw_nv_cm2_node( + graph, + quantized_input, + weight_data, + packed_weight, + packed_weight_sums, + packed_weight_scales, + bias, + packed_bias, + input_scale, + input_zero_point, + output); +} + +// +// Registered operator entry point +// + +void linear_q8ta_q8csw_nv_cm2( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef input_tensor = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zero_point = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums = args.at(idx++); + const ValueRef weight_scales = args.at(idx++); + const ValueRef bias = args.at(idx++); + const ValueRef output = args.at(idx++); + + linear_q8ta_q8csw_nv_cm2_impl( + graph, + input_tensor, + input_scale, + input_zero_point, + weight_data, + weight_sums, + weight_scales, + bias, + output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP( + et_vk.linear_q8ta_q8csw_nv_cm2.default, + linear_q8ta_q8csw_nv_cm2); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/q8csw_linear.cpp b/backends/vulkan/test/custom_ops/q8csw_linear.cpp index 4aa6f00d3f5..a2f8468bc98 100644 --- a/backends/vulkan/test/custom_ops/q8csw_linear.cpp +++ b/backends/vulkan/test/custom_ops/q8csw_linear.cpp @@ -395,7 +395,8 @@ void linear_q8ta_q8csw_reference_impl(TestCase& test_case) { // Convert accumulated integer result to float and apply scales // Final result = (int_sum - zero_point_correction) * input_scale * - // weight_scale + bias zero_point_correction = input_zero_point * + // weight_scale + bias + // zero_point_correction = input_zero_point * // sum_of_weights_for_this_output_channel int32_t zero_point_correction = input_zero_point * weight_sum; int32_t accum_adjusted = int_sum - zero_point_correction; diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index 1cf3e5b43cb..0bf7c5d390e 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -102,3 +102,4 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("test_q8ta_conv2d_dw") define_custom_op_test_binary("q8ta_q8ta_q8to_add") define_custom_op_test_binary("test_fp_linear") + define_custom_op_test_binary("test_q8csw_linear") diff --git a/backends/vulkan/test/custom_ops/test_q8csw_linear.cpp b/backends/vulkan/test/custom_ops/test_q8csw_linear.cpp new file mode 100644 index 00000000000..66f4745e93e --- /dev/null +++ b/backends/vulkan/test/custom_ops/test_q8csw_linear.cpp @@ -0,0 +1,431 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include "utils.h" + +#include + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 300; + +// Global operator selector: 0 = use existing kernel, 1 = use experimental NV CM2 kernel +static int g_operator_selector = 0; +static bool g_operator_selector_set = false; + +// Linear configuration struct +struct LinearConfig { + int64_t M; // Batch size / number of rows in input + int64_t K; // Input features / columns in input, rows in weight + int64_t N; // Output features / columns in weight + bool has_bias = false; + std::string test_case_name = "placeholder"; + std::string op_name = "linear_q8ta_q8csw"; +}; + +// Utility function to create a test case from a LinearConfig +TestCase create_test_case_from_config( + const LinearConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype) { + TestCase test_case; + + // Create a descriptive name for the test case + std::string storage_str = + (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = + config.test_case_name + "_" + storage_str + "_" + dtype_str; + test_case.set_name(test_name); + + // Set the operator name for the test case + std::string operator_name = "et_vk." + config.op_name + ".default"; + test_case.set_operator_name(operator_name); + + // Derive sizes from M, K, N + std::vector input_size = {config.M, config.K}; + std::vector weight_size = {config.N, config.K}; + + // Input tensor (float/half) - [M, K] + ValueSpec input_tensor( + input_size, + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor"); + } + + // TODO(hongbinghu): input_scale_value is not applied correctly + float input_scale_val = 1.0f; + ValueSpec input_scale(input_scale_val); + + int32_t input_zero_point_val = 0.0f; // Use 0 zero-point for per-tensor quantization + ValueSpec input_zero_point(input_zero_point_val); + + // Quantized weight tensor (int8) - [N, K] + ValueSpec quantized_weight( + weight_size, + vkapi::kChar, // int8 for quantized weights + storage_type, + utils::kWidthPacked, + DataGenType::RANDINT8); + quantized_weight.set_constant(true); + + if (debugging()) { + print_valuespec_data(quantized_weight, "weight_tensor"); + } + + // Output tensor (float/half) - [M, N] + ValueSpec output( + {config.M, config.N}, + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + + // Both existing and experimental kernels use the same input structure: + // Args: input_tensor, input_scale, input_zero_point, weight, weight_sums, weight_scales, bias, output + + // Weight quantization scales (float/half, per-channel) + ValueSpec weight_scales( + {config.N}, // Per output feature + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM); + weight_scales.set_constant(true); + + ValueSpec weight_sums( + {config.N}, // Per output features + vkapi::kInt, + storage_type, + utils::kWidthPacked, + DataGenType::RANDINT8); + weight_sums.set_constant(true); + + // Compute weight_sums data based on quantized weights + int64_t in_features = config.K; + int64_t out_features = config.N; + compute_weight_sums(weight_sums, quantized_weight, out_features, in_features); + + // Bias (optional, float/half) - [N] + ValueSpec bias( + {config.N}, // Per output feature + input_dtype, + storage_type, + utils::kWidthPacked, + config.has_bias ? DataGenType::RANDOM : DataGenType::ZEROS); + bias.set_constant(true); + if (!config.has_bias) { + bias.set_none(true); + } + + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_sums); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(bias); + test_case.add_output_spec(output); + + test_case.set_abs_tolerance(5.0f); + test_case.set_rel_tolerance(0.020f); + + return test_case; +} + +// Generate test cases for quantized linear operation +std::vector generate_quantized_linear_test_cases() { + std::vector test_cases; + + std::vector configs; + configs = { + // Bias tests (new) + {2, 32, 32, false}, + {1, 16, 16, false}, + {4, 256, 128, false}, + {4, 256, 128, true}, + // No-bias tests + {2, 32, 32}, + {4, 128, 64}, + {4, 256, 128}, + {2, 32, 32}, + {32, 64, 32}, + {32, 128, 64}, + {2, 256, 128}, + // Bias tests (larger) + {32, 64, 32, true}, + {32, 128, 64, true}, + {32, 256, 128, true}, + }; + + // Only use buffer storage for NV CM2 kernel + std::vector storage_types = {utils::kBuffer}; + + for (auto config : configs) { + std::string prefix = + (config.M < kRefDimSizeLimit && config.K < kRefDimSizeLimit && + config.N < kRefDimSizeLimit) + ? "correctness_" + : "performance_"; + std::string generated_test_case_name = prefix + std::to_string(config.M) + + "_" + std::to_string(config.K) + "_" + std::to_string(config.N); + if (!config.has_bias) { + generated_test_case_name += "_no_bias"; + } + + config.test_case_name = generated_test_case_name; + + for (const auto& storage_type : storage_types) { + // Check for int8 dot product support + if (!vkcompute::api::context() + ->adapter_ptr() + ->supports_int8_dot_product()) { + std::cout << "Skipping test: int8 dot product not supported" + << std::endl; + continue; + } + + if (g_operator_selector_set) { + // Use the operator specified by the command line + if (g_operator_selector == 1) { + config.op_name = "linear_q8ta_q8csw_nv_cm2"; + } else { + config.op_name = "linear_q8ta_q8csw"; + } + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kFloat)); + } else { + // Run both selectors + config.op_name = "linear_q8ta_q8csw"; + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kFloat)); + + config.op_name = "linear_q8ta_q8csw_nv_cm2"; + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kFloat)); + } + } + } + + return test_cases; +} + +// Reference implementation for q8ta_q8csw linear +void linear_q8ta_q8csw_reference_impl(TestCase& test_case) { + // Extract input specifications (common for both kernels) + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_sums_spec = test_case.inputs()[idx++]; + (void)weight_sums_spec; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [batch_size, in_features] + auto weight_sizes = + weight_spec.get_tensor_sizes(); // [out_features, in_features] + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = weight_sizes[0]; + + // Skip for large tensors since computation time will be extremely slow + if (batch_size > kRefDimSizeLimit || in_features > kRefDimSizeLimit || + out_features > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions (batch_size, in_features, out_features) exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers (common for both kernels) + auto& input_data = input_spec.get_float_data(); + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + // Calculate number of output elements + int64_t num_output_elements = batch_size * out_features; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + // Extract quantization parameters + const float input_scale = input_scale_spec.get_float_value(); + const int32_t input_zero_point = input_zeros_spec.get_int_value(); + + // Perform quantized linear transformation (matrix multiplication) + // Both kernels (existing and experimental) use the same reference implementation: + // integer accumulation with zero-point correction + for (int64_t b = 0; b < batch_size; ++b) { + for (int64_t out_f = 0; out_f < out_features; ++out_f) { + int32_t int_sum = 0; + int32_t weight_sum = 0; + + for (int64_t in_f = 0; in_f < in_features; ++in_f) { + int64_t input_idx = b * in_features + in_f; + int64_t weight_idx = out_f * in_features + in_f; + + // Quantize input to int8 + float quant_input_f = + std::round(input_data[input_idx] / input_scale) + input_zero_point; + quant_input_f = std::min(std::max(quant_input_f, -128.0f), 127.0f); + int8_t quantized_input = static_cast(quant_input_f); + + int8_t quantized_weight = weight_data[weight_idx]; + + int_sum += static_cast(quantized_input) * + static_cast(quantized_weight); + weight_sum += static_cast(quantized_weight); + } + + // Apply zero-point correction and scales + int32_t zero_point_correction = input_zero_point * weight_sum; + int32_t accum_adjusted = int_sum - zero_point_correction; + float result = accum_adjusted * input_scale * weight_scales_data[out_f]; + + // Add bias and store result + if (!bias_spec.is_none()) { + result += bias_data[out_f]; + } + int64_t output_idx = b * out_features + out_f; + ref_data[output_idx] = result; + } + } +} + +void reference_impl(TestCase& test_case) { + linear_q8ta_q8csw_reference_impl(test_case); +} + +int64_t quantized_linear_flop_calculator(const TestCase& test_case) { + int input_idx = 0; + int weight_idx = 3; + + // Get input and weight dimensions + const auto& input_sizes = test_case.inputs()[input_idx].get_tensor_sizes(); + const auto& weight_sizes = test_case.inputs()[weight_idx].get_tensor_sizes(); + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = weight_sizes[0]; + + // Calculate FLOPs for quantized linear operation + // Each output element requires: + // - in_features multiply-accumulate operations + // - Additional operations for quantization/dequantization + int64_t output_elements = batch_size * out_features; + int64_t ops_per_output = in_features; + + // Add quantization overhead (approximate) + // - Dequantize input: 1 op per input element used + // - Dequantize weight: 1 op per weight element used + // - Add bias: 1 op per output element + int64_t quantization_ops = ops_per_output + 1; // Simplified estimate + + int64_t flop = output_elements * (ops_per_output + quantization_ops); + + return flop; +} + +void print_usage(const char* program_name) { + std::cout << "Usage: " << program_name << " [options]" << std::endl; + std::cout << "Options:" << std::endl; + std::cout << " --operator_selector <0|1> Select operator implementation:" + << std::endl; + std::cout << " 0 = existing kernel (linear_q8ta_q8csw)" + << std::endl; + std::cout << " 1 = NV CM2 kernel (linear_q8ta_q8csw_nv_cm2)" + << std::endl; + std::cout << " --help Show this help message" + << std::endl; +} + +int main(int argc, char* argv[]) { + // Parse command line arguments + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "--operator_selector" && i + 1 < argc) { + g_operator_selector = std::stoi(argv[++i]); + g_operator_selector_set = true; + if (g_operator_selector != 0 && g_operator_selector != 1) { + std::cerr << "Error: operator_selector must be 0 or 1" << std::endl; + return 1; + } + } else if (arg == "--help" || arg == "-h") { + print_usage(argv[0]); + return 0; + } + } + + set_debugging(false); + set_print_output(false); + set_print_latencies(true); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Quantized Linear Operation Test Framework" << std::endl; + if (g_operator_selector_set) { + std::cout << "Operator selector: " << g_operator_selector; + if (g_operator_selector == 0) { + std::cout << " (existing kernel: linear_q8ta_q8csw)" << std::endl; + } else { + std::cout << " (NV CM2 kernel: linear_q8ta_q8csw_nv_cm2)" + << std::endl; + } + } else { + std::cout << "Operator selector: not set, running both kernels" << std::endl; + } + print_separator(); + + // Check for NV CM2 support if using experimental kernel + if (g_operator_selector == 1) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->supports_nv_cooperative_matrix2()) { + std::cerr + << "Error: Experimental NV CM2 kernel requires VK_NV_cooperative_matrix2 extension" + << std::endl; + std::cerr << "This extension is not supported on this device." + << std::endl; + return 1; + } + std::cout << "VK_NV_cooperative_matrix2 extension is supported." + << std::endl; + } + + ReferenceComputeFunc ref_fn = reference_impl; + + auto results = execute_test_cases( + generate_quantized_linear_test_cases, + quantized_linear_flop_calculator, + "QuantizedLinearNvCoopMat", + 1, + 3, + ref_fn); + + return 0; +}