diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 3e77b0c0eea..e371338e904 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -616,6 +616,41 @@ def q8ta_add_impl( lib.impl(name, q8ta_add_impl, "CompositeExplicitAutograd") q8ta_add_op = getattr(getattr(torch.ops, namespace), name) +######################## +## q8ta_relu ## +######################## + + +def q8ta_relu_impl( + input: torch.Tensor, + input_scale: float, + input_zero_point: int, + output_scale: float, + output_zero_point: int, +): + # Dequantize input to float + dequant = torch.ops.quantized_decomposed.dequantize_per_tensor( + input, input_scale, input_zero_point, -128, 127, input.dtype + ) + + # Apply ReLU + result = torch.nn.functional.relu(dequant) + + # Quantize the result back to int8 + quantized_result = torch.ops.quantized_decomposed.quantize_per_tensor( + result, output_scale, output_zero_point, -128, 127, torch.int8 + ) + + return quantized_result + + +name = "q8ta_relu" +lib.define( + f"{name}(Tensor input, float input_scale, int input_zero_point, float output_scale, int output_zero_point) -> Tensor" +) +lib.impl(name, q8ta_relu_impl, "CompositeExplicitAutograd") +q8ta_relu_op = getattr(getattr(torch.ops, namespace), name) + ############################# ## select_as_symint ## ############################# diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 55a92335bc7..721297dea37 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -514,7 +514,19 @@ def register_q8ta_add(): # ============================================================================= -# Reduce.cpp +# Q8taUnary.cpp +# ============================================================================= + + +@update_features(exir_ops.edge.et_vk.q8ta_relu.default) +def register_q8ta_relu(): + return OpFeatures( + inputs_storage=utils.PACKED_INT8_BUFFER, + supports_resize=True, + ) + + +# ============================================================================= # ============================================================================= diff --git a/backends/vulkan/patterns/BUCK b/backends/vulkan/patterns/BUCK index a7153b30967..711000f74ca 100644 --- a/backends/vulkan/patterns/BUCK +++ b/backends/vulkan/patterns/BUCK @@ -13,6 +13,7 @@ fbcode_target(_kind = runtime.python_library, "quantized_linear.py", "quantized_convolution.py", "quantized_binary.py", + "quantized_unary.py", "sdpa.py", "select_as_symint.py", ], diff --git a/backends/vulkan/patterns/__init__.py b/backends/vulkan/patterns/__init__.py index 9b875def944..050680b024d 100644 --- a/backends/vulkan/patterns/__init__.py +++ b/backends/vulkan/patterns/__init__.py @@ -12,6 +12,8 @@ import executorch.backends.vulkan.patterns.quantized_linear # noqa +import executorch.backends.vulkan.patterns.quantized_unary # noqa + import executorch.backends.vulkan.patterns.rope # noqa import executorch.backends.vulkan.patterns.sdpa # noqa diff --git a/backends/vulkan/patterns/quantized_unary.py b/backends/vulkan/patterns/quantized_unary.py new file mode 100644 index 00000000000..28dc84b7997 --- /dev/null +++ b/backends/vulkan/patterns/quantized_unary.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import executorch.backends.vulkan.utils as utils + +import torch + +from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, + register_pattern_detector, + register_pattern_replacement, +) + +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops + + +class QuantizedUnaryMatch(PatternMatch): + def __init__(self, unary_node: torch.fx.Node) -> None: + self.anchor_node = unary_node + self.match_found = False + self.all_nodes = [self.anchor_node] + + # The unary op takes a single input which must be a dequantize node + if len(unary_node.args) < 1: + return + + input_node = unary_node.args[0] + assert isinstance(input_node, torch.fx.Node) + + if not utils.is_dequant_node(input_node): + return + + self.dequantize_input_node = input_node + + # Extract quantization parameters for the input + self.quantize_input_node = self.dequantize_input_node.args[0] + self.input_scales_node = self.dequantize_input_node.args[1] + self.input_zeros_node = self.dequantize_input_node.args[2] + + self.all_nodes.append(self.dequantize_input_node) + + # The unary op output must have exactly one user: a quantize node + self.output_node = self.anchor_node + + if len(self.output_node.users) != 1: + return + + cur_node = list(self.output_node.users)[0] + + if not utils.is_quant_node(cur_node): + return + + self.quantize_output_node = cur_node + self.output_scales_node = self.quantize_output_node.args[1] + self.output_zeros_node = self.quantize_output_node.args[2] + + self.all_nodes.append(self.quantize_output_node) + + self.match_found = True + + +# Unary operation anchor nodes that we support +unary_anchor_nodes = { + exir_ops.edge.aten.relu.default, +} + + +@register_pattern_detector("quantized_unary") +def find_quantized_unary_patterns( + node: torch.fx.Node, +) -> Optional[QuantizedUnaryMatch]: + if node.target not in unary_anchor_nodes: + return None + + matched_pattern = QuantizedUnaryMatch(node) + if matched_pattern.match_found: + return matched_pattern + + return None + + +## +## Pattern Replacement +## + + +@register_pattern_replacement("quantized_unary") +def make_q8ta_unary_custom_op( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: QuantizedUnaryMatch, +): + op_target = None + if match.anchor_node.target == exir_ops.edge.aten.relu.default: + op_target = exir_ops.edge.et_vk.q8ta_relu.default + else: + raise NotImplementedError( + f"Unsupported unary operation: {match.anchor_node.target}" + ) + + with graph_module.graph.inserting_before(match.output_node): + qunary_node = graph_module.graph.create_node( + "call_function", + op_target, + args=( + match.quantize_input_node, + match.input_scales_node, + match.input_zeros_node, + match.output_scales_node, + match.output_zeros_node, + ), + ) + + qunary_node.meta["val"] = match.output_node.meta["val"] + match.quantize_output_node.replace_all_uses_with(qunary_node) diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.glsl new file mode 100644 index 00000000000..e97d6d47877 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.glsl @@ -0,0 +1,82 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_active_storage_type("buffer")} + +#define op(X) ${OPERATOR} + +layout(std430) buffer; + +#include "indexing.glslh" +#include "common.glslh" +#include "block_indexing.glslh" +#include "block_int8x4_load.glslh" +#include "block_int8x4_store.glslh" + +// Output buffer: packed int8x4 values +${layout_declare_tensor(B, "w", "t_out", "int", "buffer")} +// Input buffer: packed int8x4 values +${layout_declare_tensor(B, "r", "t_in", "int", "buffer")} + +// Metadata for output and input tensors +${layout_declare_ubo(B, "BufferMetadata", "out_meta")} +${layout_declare_ubo(B, "BufferMetadata", "in_meta")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "out_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "in_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "block_config", "0")} + +// Generate loading functions for input buffer +define_load_int8x4_buffer_fns(t_in) + +// Generate storing functions for output buffer +define_store_int8x4_buffer_fns(t_out) + +void main() { + // Buffer storage: use linear dispatch + const uint contig_block_idx = gl_GlobalInvocationID.x; + TensorIndex4D tidx = contiguous_block_idx_to_tensor4d_idx_with_block_config( + out_meta, contig_block_idx, block_config); + + if (out_of_bounds(tidx, out_meta)) { + return; + } + + const int block_outer_dim = get_block_outer_dim(block_config); + + // Load int8x4 block from input + ivec4 in_block = load_int8x4_block_from_t_in( + in_meta, tidx, in_layout, block_outer_dim); + + ivec4 out_block; + + for (int row = 0; row < 4; row++) { + vec4 in_texel = unpack_and_dequantize( + in_block[row], input_scale, input_zp); + + vec4 out_texel = op(in_texel); + out_block[row] = quantize_and_pack(out_texel, output_inv_scale, output_zp); + } + + // Store to output buffer + store_int8x4_block_to_t_out( + out_meta, tidx, out_layout, block_outer_dim, out_block); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.yaml new file mode 100644 index 00000000000..257f6a44205 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.yaml @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q8ta_unary: + parameter_names_with_default_values: + OPERATOR: X + shader_variants: + - NAME: q8ta_relu_buffer + OPERATOR: max(X, vec4(0.0)) diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.cpp new file mode 100644 index 00000000000..f8b606f3dfa --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.cpp @@ -0,0 +1,124 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace vkcompute { + +void resize_q8ta_unary_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef self = args.at(1).refs.at(0); + graph->virtual_resize(out, graph->sizes_of(self)); +} + +// +// Dispatch nodes +// + +void add_q8ta_unary_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef packed_int8_output, + const std::string& op_name) { + const api::PackedDimInfo& output_info = + graph.packed_dim_info_of(packed_int8_output); + const api::PackedDimInfo& input_info = + graph.packed_dim_info_of(packed_int8_input); + + VK_CHECK_COND(input_info.packed_dim == output_info.packed_dim); + VK_CHECK_COND( + input_info.packed_dim_block_size == output_info.packed_dim_block_size); + + float input_scale_val = graph.extract_scalar(input_scale); + int32_t input_zp_val = graph.extract_scalar(input_zp); + + float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); + int32_t output_zp_val = graph.extract_scalar(output_zp); + + std::string kernel_name = "q8ta_" + op_name; + add_storage_type_suffix( + kernel_name, graph.storage_type_of(packed_int8_output)); + + vkapi::ParamsBindList param_buffers; + param_buffers.append(graph.buffer_meta_ubo(packed_int8_output)); + param_buffers.append(graph.buffer_meta_ubo(packed_int8_input)); + + std::vector push_constants = { + PushConstantDataInfo(&input_scale_val, sizeof(input_scale_val)), + PushConstantDataInfo(&input_zp_val, sizeof(input_zp_val)), + PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), + PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + }; + + const BlockConfig block_config = + create_block_config_for_tensor(graph, packed_int8_output); + + const ValueRef block_config_ref = + static_cast(block_config.as_packed_int()); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_linear_global_wg_with_block_config, + pick_square_local_wg_with_block_config, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, {packed_int8_input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {graph.hashed_layout_of(packed_int8_output), + graph.hashed_layout_of(packed_int8_input), + block_config.as_packed_int()}, + // Resize args + {block_config_ref}, + // Resizing Logic + resize_q8ta_unary_node)); +} + +// +// High level operator impl +// + +void q8ta_relu(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + add_q8ta_unary_node( + graph, + packed_int8_input, + input_scale, + input_zp, + output_scale, + output_zp, + packed_int8_output, + "relu"); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.q8ta_relu.default, q8ta_relu); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.h b/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.h new file mode 100644 index 00000000000..2b68fa53c22 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace vkcompute { + +// +// Unary operations for int8x4 tensors +// + +void add_q8ta_unary_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef packed_int8_output, + const std::string& op_name); + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taUnary.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taUnary.cpp new file mode 100644 index 00000000000..6212216686f --- /dev/null +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taUnary.cpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace vkcompute { + +void q8ta_unary_test(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef quant_layout_int = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + int32_t layout_value = graph.extract_scalar(quant_layout_int); + utils::GPUMemoryLayout quant_layout = + static_cast(layout_value); + + // Create temporary tensor for quantized input + TmpTensor packed_int8_input( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + utils::kBuffer, + quant_layout); + + // Create temporary tensor for quantized output + TmpTensor packed_int8_output( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + utils::kBuffer, + quant_layout); + + // Quantize: FP -> int8x4 + add_q8ta_quantize_node( + graph, fp_input, input_scale, input_zp, packed_int8_input); + + // Unary op: int8x4 -> int8x4 + add_q8ta_unary_node( + graph, + packed_int8_input, + input_scale, + input_zp, + output_scale, + output_zp, + packed_int8_output, + "relu"); + + // Dequantize: int8x4 -> FP + add_q8ta_dequantize_node( + graph, packed_int8_output, output_scale, output_zp, fp_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(test_etvk.q8ta_unary_test.default, q8ta_unary_test); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/test_q8ta_unary.cpp b/backends/vulkan/test/custom_ops/test_q8ta_unary.cpp new file mode 100644 index 00000000000..bc184c6c182 --- /dev/null +++ b/backends/vulkan/test/custom_ops/test_q8ta_unary.cpp @@ -0,0 +1,311 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include "utils.h" + +#include + +// #define DEBUG_MODE + +using namespace executorch::vulkan::prototyping; +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 512; + +struct Q8taUnaryConfig { + std::vector shape; + std::string test_case_name = "placeholder"; + std::string op_name = "q8ta_unary_test"; +}; + +TestCase create_test_case_from_config( + const Q8taUnaryConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype, + utils::GPUMemoryLayout fp_memory_layout, + utils::GPUMemoryLayout quant_layout) { + TestCase test_case; + + std::string shape_str = shape_string(config.shape); + std::string test_name = config.test_case_name + " I=" + shape_str + " " + + repr_str(storage_type, fp_memory_layout) + "->" + + repr_str(utils::kBuffer, quant_layout); + test_case.set_name(test_name); + + std::string operator_name = "test_etvk." + config.op_name + ".default"; + test_case.set_operator_name(operator_name); + + // Input tensor (float) + ValueSpec input_tensor( + config.shape, + input_dtype, + storage_type, + fp_memory_layout, + DataGenType::RANDOM); + + float scale_val = 0.007112; + ValueSpec input_scale(scale_val); + + int32_t zero_point_val = 0; + ValueSpec input_zero_point(zero_point_val); + + // For relu, output scale and zero point can differ from input + float output_scale_val = 0.007112; + ValueSpec output_scale(output_scale_val); + + int32_t output_zp_val = 0; + ValueSpec output_zero_point(output_zp_val); + + int32_t layout_int = static_cast(quant_layout); + ValueSpec layout_spec(layout_int); + + // Output tensor (float) - same shape as input + ValueSpec output_tensor( + config.shape, + input_dtype, + storage_type, + fp_memory_layout, + DataGenType::ZEROS); + + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(output_scale); + test_case.add_input_spec(output_zero_point); + test_case.add_input_spec(layout_spec); + test_case.add_output_spec(output_tensor); + + test_case.set_abs_tolerance(scale_val + 1e-4); + + test_case.set_shader_filter({ + "nchw_to", + "to_nchw", + "q8ta_quantize", + "q8ta_dequantize", + }); + + return test_case; +} + +std::vector generate_q8ta_unary_easy_cases() { + std::vector test_cases; + + Q8taUnaryConfig config = { + {1, 16, 16, 16}, + "ACCU", + }; + + std::vector fp_layouts = { + utils::kWidthPacked, + utils::kChannelsPacked, + }; + + std::vector quant_layouts = { + utils::kPackedInt8_4W, + utils::kPackedInt8_4C, + utils::kPackedInt8_4W4C, + utils::kPackedInt8_4H4W, + utils::kPackedInt8_4C1W, + }; + + std::vector storage_types = {utils::kBuffer}; + std::vector float_types = {vkapi::kFloat}; + + for (const auto& fp_layout : fp_layouts) { + for (const auto& quant_layout : quant_layouts) { + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back(create_test_case_from_config( + config, storage_type, input_dtype, fp_layout, quant_layout)); + } + } + } + } + + return test_cases; +} + +std::vector generate_q8ta_unary_test_cases() { + std::vector test_cases; + + std::vector> shapes = { + {1, 3, 16, 16}, + {1, 8, 32, 32}, + {1, 16, 24, 24}, + {1, 32, 12, 12}, + {1, 1, 64, 64}, + {1, 3, 64, 64}, + {1, 4, 16, 16}, + + {1, 8, 20, 20}, + {1, 16, 14, 14}, + {1, 8, 28, 28}, + + // Odd tensor sizes + {1, 3, 15, 15}, + {1, 13, 31, 31}, + {1, 17, 23, 23}, + + // Larger tensors + {1, 64, 128, 128}, + {1, 32, 64, 64}, + {1, 128, 56, 56}, + {1, 128, 128, 128}, + }; + + std::vector fp_layouts = { + utils::kWidthPacked, + utils::kChannelsPacked, + }; + + std::vector quant_layouts = { + utils::kPackedInt8_4W, + utils::kPackedInt8_4C, + utils::kPackedInt8_4W4C, + utils::kPackedInt8_4H4W, + utils::kPackedInt8_4C1W, + }; + + std::vector storage_types = {utils::kBuffer}; + + for (const auto& shape : shapes) { + std::string prefix = "ACCU"; + for (const auto& dim : shape) { + if (dim > kRefDimSizeLimit) { + prefix = "PERF"; + break; + } + } + + for (const auto& fp_layout : fp_layouts) { + for (const auto& quant_layout : quant_layouts) { + for (const auto& storage_type : storage_types) { + Q8taUnaryConfig config; + config.shape = shape; + config.test_case_name = prefix; + + test_cases.push_back(create_test_case_from_config( + config, storage_type, vkapi::kFloat, fp_layout, quant_layout)); + } + } + } + } + + return test_cases; +} + +// Reference implementation: quantize -> relu -> dequantize +void q8ta_unary_reference_impl(TestCase& test_case) { + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zp_spec = test_case.inputs()[idx++]; + const ValueSpec& output_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& output_zp_spec = test_case.inputs()[idx++]; + const ValueSpec& layout_spec = test_case.inputs()[idx++]; + (void)layout_spec; + + ValueSpec& output_spec = test_case.outputs()[0]; + + auto input_sizes = input_spec.get_tensor_sizes(); + + int64_t num_elements = 1; + for (const auto& dim : input_sizes) { + num_elements *= dim; + } + + for (const auto& dim : input_sizes) { + if (dim > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions exceed the allowed limit for reference " + "implementation."); + } + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + auto& input_data = input_spec.get_float_data(); + + float input_scale = input_scale_spec.get_float_value(); + int32_t input_zp = input_zp_spec.get_int_value(); + float output_scale = output_scale_spec.get_float_value(); + int32_t output_zp = output_zp_spec.get_int_value(); + int32_t quant_min = -128; + int32_t quant_max = 127; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_elements); + + for (int64_t i = 0; i < num_elements; ++i) { + float input_val = input_data[i]; + + // Quantize with input scale/zp + float quantized_float = std::round(input_val / input_scale) + input_zp; + quantized_float = std::max(quantized_float, static_cast(quant_min)); + quantized_float = std::min(quantized_float, static_cast(quant_max)); + int32_t quantized_int = static_cast(quantized_float); + + // Dequantize to float + float dequantized = (quantized_int - input_zp) * input_scale; + + // Apply ReLU + float activated = std::max(dequantized, 0.0f); + + // Requantize with output scale/zp + float requantized_float = std::round(activated / output_scale) + output_zp; + requantized_float = + std::max(requantized_float, static_cast(quant_min)); + requantized_float = + std::min(requantized_float, static_cast(quant_max)); + int32_t requantized_int = static_cast(requantized_float); + + // Dequantize back to float for comparison + ref_data[i] = (requantized_int - output_zp) * output_scale; + } +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); +#ifdef DEBUG_MODE + set_print_latencies(false); +#else + set_print_latencies(false); +#endif + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Q8TA Unary (ReLU) Operation Prototyping Framework" << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = q8ta_unary_reference_impl; + + auto results = execute_test_cases( +#ifdef DEBUG_MODE + generate_q8ta_unary_easy_cases, +#else + generate_q8ta_unary_test_cases, +#endif + "Q8taUnary", +#ifdef DEBUG_MODE + 0, + 1, +#else + 3, + 10, +#endif + ref_fn); + + return 0; +}