From f3f627a0fa3abb6c94efd0b5b5d34bf7b35837ff Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 17 Feb 2026 12:19:47 -0800 Subject: [PATCH] [ET-VK] Add fused q8ta_relu unary operator for int8x4 tensors This adds a fused quantized unary operator (ReLU) that operates directly on int8x4 packed buffer tensors, avoiding the overhead of separate dequantize-relu-requantize dispatches. The implementation follows the same pattern as q8ta_binary: a single GLSL compute shader dequantizes int8x4 blocks to float, applies the unary operation, and requantizes back to int8x4 in one dispatch. The shader uses the OPERATOR macro for parameterization so additional unary ops can be added as YAML variants without new shader code. Components added: - GLSL shader (q8ta_unary.glsl) and YAML config with relu variant - C++ operator implementation (Q8taUnary.cpp/h) registering et_vk.q8ta_relu.default - Export graph fusion pattern (quantized_unary.py) that detects dequant->relu->quant sequences and replaces them with the fused op - Custom op definition (q8ta_relu in custom_ops_lib.py) for the export pipeline - Test harness (TestQ8taUnary.cpp, test_q8ta_unary.cpp) with reference implementation and coverage across multiple shapes and quantized layouts This diff was authored with Claude. Differential Revision: [D93511629](https://our.internmc.facebook.com/intern/diff/D93511629/) [ghstack-poisoned] --- backends/vulkan/custom_ops_lib.py | 35 ++ backends/vulkan/op_registry.py | 14 +- backends/vulkan/patterns/BUCK | 1 + backends/vulkan/patterns/__init__.py | 2 + backends/vulkan/patterns/quantized_unary.py | 121 +++++++ .../runtime/graph/ops/glsl/q8ta_unary.glsl | 82 +++++ .../runtime/graph/ops/glsl/q8ta_unary.yaml | 12 + .../runtime/graph/ops/impl/Q8taUnary.cpp | 124 +++++++ .../vulkan/runtime/graph/ops/impl/Q8taUnary.h | 29 ++ .../test/custom_ops/impl/TestQ8taUnary.cpp | 70 ++++ .../test/custom_ops/test_q8ta_unary.cpp | 311 ++++++++++++++++++ 11 files changed, 800 insertions(+), 1 deletion(-) create mode 100644 backends/vulkan/patterns/quantized_unary.py create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/Q8taUnary.cpp create mode 100644 backends/vulkan/runtime/graph/ops/impl/Q8taUnary.h create mode 100644 backends/vulkan/test/custom_ops/impl/TestQ8taUnary.cpp create mode 100644 backends/vulkan/test/custom_ops/test_q8ta_unary.cpp diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 3e77b0c0eea..e371338e904 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -616,6 +616,41 @@ def q8ta_add_impl( lib.impl(name, q8ta_add_impl, "CompositeExplicitAutograd") q8ta_add_op = getattr(getattr(torch.ops, namespace), name) +######################## +## q8ta_relu ## +######################## + + +def q8ta_relu_impl( + input: torch.Tensor, + input_scale: float, + input_zero_point: int, + output_scale: float, + output_zero_point: int, +): + # Dequantize input to float + dequant = torch.ops.quantized_decomposed.dequantize_per_tensor( + input, input_scale, input_zero_point, -128, 127, input.dtype + ) + + # Apply ReLU + result = torch.nn.functional.relu(dequant) + + # Quantize the result back to int8 + quantized_result = torch.ops.quantized_decomposed.quantize_per_tensor( + result, output_scale, output_zero_point, -128, 127, torch.int8 + ) + + return quantized_result + + +name = "q8ta_relu" +lib.define( + f"{name}(Tensor input, float input_scale, int input_zero_point, float output_scale, int output_zero_point) -> Tensor" +) +lib.impl(name, q8ta_relu_impl, "CompositeExplicitAutograd") +q8ta_relu_op = getattr(getattr(torch.ops, namespace), name) + ############################# ## select_as_symint ## ############################# diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 55a92335bc7..721297dea37 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -514,7 +514,19 @@ def register_q8ta_add(): # ============================================================================= -# Reduce.cpp +# Q8taUnary.cpp +# ============================================================================= + + +@update_features(exir_ops.edge.et_vk.q8ta_relu.default) +def register_q8ta_relu(): + return OpFeatures( + inputs_storage=utils.PACKED_INT8_BUFFER, + supports_resize=True, + ) + + +# ============================================================================= # ============================================================================= diff --git a/backends/vulkan/patterns/BUCK b/backends/vulkan/patterns/BUCK index a7153b30967..711000f74ca 100644 --- a/backends/vulkan/patterns/BUCK +++ b/backends/vulkan/patterns/BUCK @@ -13,6 +13,7 @@ fbcode_target(_kind = runtime.python_library, "quantized_linear.py", "quantized_convolution.py", "quantized_binary.py", + "quantized_unary.py", "sdpa.py", "select_as_symint.py", ], diff --git a/backends/vulkan/patterns/__init__.py b/backends/vulkan/patterns/__init__.py index 9b875def944..050680b024d 100644 --- a/backends/vulkan/patterns/__init__.py +++ b/backends/vulkan/patterns/__init__.py @@ -12,6 +12,8 @@ import executorch.backends.vulkan.patterns.quantized_linear # noqa +import executorch.backends.vulkan.patterns.quantized_unary # noqa + import executorch.backends.vulkan.patterns.rope # noqa import executorch.backends.vulkan.patterns.sdpa # noqa diff --git a/backends/vulkan/patterns/quantized_unary.py b/backends/vulkan/patterns/quantized_unary.py new file mode 100644 index 00000000000..28dc84b7997 --- /dev/null +++ b/backends/vulkan/patterns/quantized_unary.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import executorch.backends.vulkan.utils as utils + +import torch + +from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, + register_pattern_detector, + register_pattern_replacement, +) + +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops + + +class QuantizedUnaryMatch(PatternMatch): + def __init__(self, unary_node: torch.fx.Node) -> None: + self.anchor_node = unary_node + self.match_found = False + self.all_nodes = [self.anchor_node] + + # The unary op takes a single input which must be a dequantize node + if len(unary_node.args) < 1: + return + + input_node = unary_node.args[0] + assert isinstance(input_node, torch.fx.Node) + + if not utils.is_dequant_node(input_node): + return + + self.dequantize_input_node = input_node + + # Extract quantization parameters for the input + self.quantize_input_node = self.dequantize_input_node.args[0] + self.input_scales_node = self.dequantize_input_node.args[1] + self.input_zeros_node = self.dequantize_input_node.args[2] + + self.all_nodes.append(self.dequantize_input_node) + + # The unary op output must have exactly one user: a quantize node + self.output_node = self.anchor_node + + if len(self.output_node.users) != 1: + return + + cur_node = list(self.output_node.users)[0] + + if not utils.is_quant_node(cur_node): + return + + self.quantize_output_node = cur_node + self.output_scales_node = self.quantize_output_node.args[1] + self.output_zeros_node = self.quantize_output_node.args[2] + + self.all_nodes.append(self.quantize_output_node) + + self.match_found = True + + +# Unary operation anchor nodes that we support +unary_anchor_nodes = { + exir_ops.edge.aten.relu.default, +} + + +@register_pattern_detector("quantized_unary") +def find_quantized_unary_patterns( + node: torch.fx.Node, +) -> Optional[QuantizedUnaryMatch]: + if node.target not in unary_anchor_nodes: + return None + + matched_pattern = QuantizedUnaryMatch(node) + if matched_pattern.match_found: + return matched_pattern + + return None + + +## +## Pattern Replacement +## + + +@register_pattern_replacement("quantized_unary") +def make_q8ta_unary_custom_op( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: QuantizedUnaryMatch, +): + op_target = None + if match.anchor_node.target == exir_ops.edge.aten.relu.default: + op_target = exir_ops.edge.et_vk.q8ta_relu.default + else: + raise NotImplementedError( + f"Unsupported unary operation: {match.anchor_node.target}" + ) + + with graph_module.graph.inserting_before(match.output_node): + qunary_node = graph_module.graph.create_node( + "call_function", + op_target, + args=( + match.quantize_input_node, + match.input_scales_node, + match.input_zeros_node, + match.output_scales_node, + match.output_zeros_node, + ), + ) + + qunary_node.meta["val"] = match.output_node.meta["val"] + match.quantize_output_node.replace_all_uses_with(qunary_node) diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.glsl new file mode 100644 index 00000000000..e97d6d47877 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.glsl @@ -0,0 +1,82 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_active_storage_type("buffer")} + +#define op(X) ${OPERATOR} + +layout(std430) buffer; + +#include "indexing.glslh" +#include "common.glslh" +#include "block_indexing.glslh" +#include "block_int8x4_load.glslh" +#include "block_int8x4_store.glslh" + +// Output buffer: packed int8x4 values +${layout_declare_tensor(B, "w", "t_out", "int", "buffer")} +// Input buffer: packed int8x4 values +${layout_declare_tensor(B, "r", "t_in", "int", "buffer")} + +// Metadata for output and input tensors +${layout_declare_ubo(B, "BufferMetadata", "out_meta")} +${layout_declare_ubo(B, "BufferMetadata", "in_meta")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "out_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "in_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "block_config", "0")} + +// Generate loading functions for input buffer +define_load_int8x4_buffer_fns(t_in) + +// Generate storing functions for output buffer +define_store_int8x4_buffer_fns(t_out) + +void main() { + // Buffer storage: use linear dispatch + const uint contig_block_idx = gl_GlobalInvocationID.x; + TensorIndex4D tidx = contiguous_block_idx_to_tensor4d_idx_with_block_config( + out_meta, contig_block_idx, block_config); + + if (out_of_bounds(tidx, out_meta)) { + return; + } + + const int block_outer_dim = get_block_outer_dim(block_config); + + // Load int8x4 block from input + ivec4 in_block = load_int8x4_block_from_t_in( + in_meta, tidx, in_layout, block_outer_dim); + + ivec4 out_block; + + for (int row = 0; row < 4; row++) { + vec4 in_texel = unpack_and_dequantize( + in_block[row], input_scale, input_zp); + + vec4 out_texel = op(in_texel); + out_block[row] = quantize_and_pack(out_texel, output_inv_scale, output_zp); + } + + // Store to output buffer + store_int8x4_block_to_t_out( + out_meta, tidx, out_layout, block_outer_dim, out_block); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.yaml new file mode 100644 index 00000000000..257f6a44205 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.yaml @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q8ta_unary: + parameter_names_with_default_values: + OPERATOR: X + shader_variants: + - NAME: q8ta_relu_buffer + OPERATOR: max(X, vec4(0.0)) diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.cpp new file mode 100644 index 00000000000..f8b606f3dfa --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.cpp @@ -0,0 +1,124 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace vkcompute { + +void resize_q8ta_unary_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef self = args.at(1).refs.at(0); + graph->virtual_resize(out, graph->sizes_of(self)); +} + +// +// Dispatch nodes +// + +void add_q8ta_unary_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef packed_int8_output, + const std::string& op_name) { + const api::PackedDimInfo& output_info = + graph.packed_dim_info_of(packed_int8_output); + const api::PackedDimInfo& input_info = + graph.packed_dim_info_of(packed_int8_input); + + VK_CHECK_COND(input_info.packed_dim == output_info.packed_dim); + VK_CHECK_COND( + input_info.packed_dim_block_size == output_info.packed_dim_block_size); + + float input_scale_val = graph.extract_scalar(input_scale); + int32_t input_zp_val = graph.extract_scalar(input_zp); + + float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); + int32_t output_zp_val = graph.extract_scalar(output_zp); + + std::string kernel_name = "q8ta_" + op_name; + add_storage_type_suffix( + kernel_name, graph.storage_type_of(packed_int8_output)); + + vkapi::ParamsBindList param_buffers; + param_buffers.append(graph.buffer_meta_ubo(packed_int8_output)); + param_buffers.append(graph.buffer_meta_ubo(packed_int8_input)); + + std::vector push_constants = { + PushConstantDataInfo(&input_scale_val, sizeof(input_scale_val)), + PushConstantDataInfo(&input_zp_val, sizeof(input_zp_val)), + PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), + PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + }; + + const BlockConfig block_config = + create_block_config_for_tensor(graph, packed_int8_output); + + const ValueRef block_config_ref = + static_cast(block_config.as_packed_int()); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_linear_global_wg_with_block_config, + pick_square_local_wg_with_block_config, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, {packed_int8_input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {graph.hashed_layout_of(packed_int8_output), + graph.hashed_layout_of(packed_int8_input), + block_config.as_packed_int()}, + // Resize args + {block_config_ref}, + // Resizing Logic + resize_q8ta_unary_node)); +} + +// +// High level operator impl +// + +void q8ta_relu(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + add_q8ta_unary_node( + graph, + packed_int8_input, + input_scale, + input_zp, + output_scale, + output_zp, + packed_int8_output, + "relu"); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.q8ta_relu.default, q8ta_relu); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.h b/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.h new file mode 100644 index 00000000000..2b68fa53c22 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace vkcompute { + +// +// Unary operations for int8x4 tensors +// + +void add_q8ta_unary_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef packed_int8_output, + const std::string& op_name); + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taUnary.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taUnary.cpp new file mode 100644 index 00000000000..6212216686f --- /dev/null +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taUnary.cpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace vkcompute { + +void q8ta_unary_test(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef quant_layout_int = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + int32_t layout_value = graph.extract_scalar(quant_layout_int); + utils::GPUMemoryLayout quant_layout = + static_cast(layout_value); + + // Create temporary tensor for quantized input + TmpTensor packed_int8_input( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + utils::kBuffer, + quant_layout); + + // Create temporary tensor for quantized output + TmpTensor packed_int8_output( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + utils::kBuffer, + quant_layout); + + // Quantize: FP -> int8x4 + add_q8ta_quantize_node( + graph, fp_input, input_scale, input_zp, packed_int8_input); + + // Unary op: int8x4 -> int8x4 + add_q8ta_unary_node( + graph, + packed_int8_input, + input_scale, + input_zp, + output_scale, + output_zp, + packed_int8_output, + "relu"); + + // Dequantize: int8x4 -> FP + add_q8ta_dequantize_node( + graph, packed_int8_output, output_scale, output_zp, fp_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(test_etvk.q8ta_unary_test.default, q8ta_unary_test); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/test_q8ta_unary.cpp b/backends/vulkan/test/custom_ops/test_q8ta_unary.cpp new file mode 100644 index 00000000000..bc184c6c182 --- /dev/null +++ b/backends/vulkan/test/custom_ops/test_q8ta_unary.cpp @@ -0,0 +1,311 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include "utils.h" + +#include + +// #define DEBUG_MODE + +using namespace executorch::vulkan::prototyping; +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 512; + +struct Q8taUnaryConfig { + std::vector shape; + std::string test_case_name = "placeholder"; + std::string op_name = "q8ta_unary_test"; +}; + +TestCase create_test_case_from_config( + const Q8taUnaryConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype, + utils::GPUMemoryLayout fp_memory_layout, + utils::GPUMemoryLayout quant_layout) { + TestCase test_case; + + std::string shape_str = shape_string(config.shape); + std::string test_name = config.test_case_name + " I=" + shape_str + " " + + repr_str(storage_type, fp_memory_layout) + "->" + + repr_str(utils::kBuffer, quant_layout); + test_case.set_name(test_name); + + std::string operator_name = "test_etvk." + config.op_name + ".default"; + test_case.set_operator_name(operator_name); + + // Input tensor (float) + ValueSpec input_tensor( + config.shape, + input_dtype, + storage_type, + fp_memory_layout, + DataGenType::RANDOM); + + float scale_val = 0.007112; + ValueSpec input_scale(scale_val); + + int32_t zero_point_val = 0; + ValueSpec input_zero_point(zero_point_val); + + // For relu, output scale and zero point can differ from input + float output_scale_val = 0.007112; + ValueSpec output_scale(output_scale_val); + + int32_t output_zp_val = 0; + ValueSpec output_zero_point(output_zp_val); + + int32_t layout_int = static_cast(quant_layout); + ValueSpec layout_spec(layout_int); + + // Output tensor (float) - same shape as input + ValueSpec output_tensor( + config.shape, + input_dtype, + storage_type, + fp_memory_layout, + DataGenType::ZEROS); + + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(output_scale); + test_case.add_input_spec(output_zero_point); + test_case.add_input_spec(layout_spec); + test_case.add_output_spec(output_tensor); + + test_case.set_abs_tolerance(scale_val + 1e-4); + + test_case.set_shader_filter({ + "nchw_to", + "to_nchw", + "q8ta_quantize", + "q8ta_dequantize", + }); + + return test_case; +} + +std::vector generate_q8ta_unary_easy_cases() { + std::vector test_cases; + + Q8taUnaryConfig config = { + {1, 16, 16, 16}, + "ACCU", + }; + + std::vector fp_layouts = { + utils::kWidthPacked, + utils::kChannelsPacked, + }; + + std::vector quant_layouts = { + utils::kPackedInt8_4W, + utils::kPackedInt8_4C, + utils::kPackedInt8_4W4C, + utils::kPackedInt8_4H4W, + utils::kPackedInt8_4C1W, + }; + + std::vector storage_types = {utils::kBuffer}; + std::vector float_types = {vkapi::kFloat}; + + for (const auto& fp_layout : fp_layouts) { + for (const auto& quant_layout : quant_layouts) { + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back(create_test_case_from_config( + config, storage_type, input_dtype, fp_layout, quant_layout)); + } + } + } + } + + return test_cases; +} + +std::vector generate_q8ta_unary_test_cases() { + std::vector test_cases; + + std::vector> shapes = { + {1, 3, 16, 16}, + {1, 8, 32, 32}, + {1, 16, 24, 24}, + {1, 32, 12, 12}, + {1, 1, 64, 64}, + {1, 3, 64, 64}, + {1, 4, 16, 16}, + + {1, 8, 20, 20}, + {1, 16, 14, 14}, + {1, 8, 28, 28}, + + // Odd tensor sizes + {1, 3, 15, 15}, + {1, 13, 31, 31}, + {1, 17, 23, 23}, + + // Larger tensors + {1, 64, 128, 128}, + {1, 32, 64, 64}, + {1, 128, 56, 56}, + {1, 128, 128, 128}, + }; + + std::vector fp_layouts = { + utils::kWidthPacked, + utils::kChannelsPacked, + }; + + std::vector quant_layouts = { + utils::kPackedInt8_4W, + utils::kPackedInt8_4C, + utils::kPackedInt8_4W4C, + utils::kPackedInt8_4H4W, + utils::kPackedInt8_4C1W, + }; + + std::vector storage_types = {utils::kBuffer}; + + for (const auto& shape : shapes) { + std::string prefix = "ACCU"; + for (const auto& dim : shape) { + if (dim > kRefDimSizeLimit) { + prefix = "PERF"; + break; + } + } + + for (const auto& fp_layout : fp_layouts) { + for (const auto& quant_layout : quant_layouts) { + for (const auto& storage_type : storage_types) { + Q8taUnaryConfig config; + config.shape = shape; + config.test_case_name = prefix; + + test_cases.push_back(create_test_case_from_config( + config, storage_type, vkapi::kFloat, fp_layout, quant_layout)); + } + } + } + } + + return test_cases; +} + +// Reference implementation: quantize -> relu -> dequantize +void q8ta_unary_reference_impl(TestCase& test_case) { + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zp_spec = test_case.inputs()[idx++]; + const ValueSpec& output_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& output_zp_spec = test_case.inputs()[idx++]; + const ValueSpec& layout_spec = test_case.inputs()[idx++]; + (void)layout_spec; + + ValueSpec& output_spec = test_case.outputs()[0]; + + auto input_sizes = input_spec.get_tensor_sizes(); + + int64_t num_elements = 1; + for (const auto& dim : input_sizes) { + num_elements *= dim; + } + + for (const auto& dim : input_sizes) { + if (dim > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions exceed the allowed limit for reference " + "implementation."); + } + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + auto& input_data = input_spec.get_float_data(); + + float input_scale = input_scale_spec.get_float_value(); + int32_t input_zp = input_zp_spec.get_int_value(); + float output_scale = output_scale_spec.get_float_value(); + int32_t output_zp = output_zp_spec.get_int_value(); + int32_t quant_min = -128; + int32_t quant_max = 127; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_elements); + + for (int64_t i = 0; i < num_elements; ++i) { + float input_val = input_data[i]; + + // Quantize with input scale/zp + float quantized_float = std::round(input_val / input_scale) + input_zp; + quantized_float = std::max(quantized_float, static_cast(quant_min)); + quantized_float = std::min(quantized_float, static_cast(quant_max)); + int32_t quantized_int = static_cast(quantized_float); + + // Dequantize to float + float dequantized = (quantized_int - input_zp) * input_scale; + + // Apply ReLU + float activated = std::max(dequantized, 0.0f); + + // Requantize with output scale/zp + float requantized_float = std::round(activated / output_scale) + output_zp; + requantized_float = + std::max(requantized_float, static_cast(quant_min)); + requantized_float = + std::min(requantized_float, static_cast(quant_max)); + int32_t requantized_int = static_cast(requantized_float); + + // Dequantize back to float for comparison + ref_data[i] = (requantized_int - output_zp) * output_scale; + } +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); +#ifdef DEBUG_MODE + set_print_latencies(false); +#else + set_print_latencies(false); +#endif + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Q8TA Unary (ReLU) Operation Prototyping Framework" << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = q8ta_unary_reference_impl; + + auto results = execute_test_cases( +#ifdef DEBUG_MODE + generate_q8ta_unary_easy_cases, +#else + generate_q8ta_unary_test_cases, +#endif + "Q8taUnary", +#ifdef DEBUG_MODE + 0, + 1, +#else + 3, + 10, +#endif + ref_fn); + + return 0; +}