From 4dd18548647d4b6adfdcfba97eec0b7380fd544b Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 17 Feb 2026 12:19:41 -0800 Subject: [PATCH] [ET-VK][qconv] Add apply_relu support to q8ta conv operators MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The quantized convolution pattern detector correctly identifies ReLU nodes between conv output and the output quantize node, but the pattern replacement did not pass this information to the fused q8ta operator. When the pattern replaced `dequant → conv → relu → quant` with `q8ta_conv2d`, the relu node was removed from the graph but its effect was not preserved. This silently removed all conv-relu non-linearity from int8 quantized models. Add an `apply_relu` parameter throughout the full pipeline: - Custom op schemas and reference implementations (custom_ops_lib.py) - Pattern replacement (quantized_convolution.py) - C++ dispatch logic extracts apply_relu and passes it as a spec constant (Q8taConv2d.cpp, Q8taConv2dDW.cpp, Q8taConv2dPW.cpp, Q8taConv2dIm2Col.cpp) - GLSL shaders apply conditional max(value, 0) after dequantization and before requantization (q8ta_conv2d.glsl, q8ta_conv2d_dw.glsl, q8ta_conv2d_pw.glsl) - Test operator wrappers updated with proper legacy path handling (TestQ8taConv2d.cpp) Differential Revision: [D93511632](https://our.internmc.facebook.com/intern/diff/D93511632/) [ghstack-poisoned] --- backends/vulkan/custom_ops_lib.py | 17 +- .../vulkan/patterns/quantized_convolution.py | 1 + .../runtime/graph/ops/glsl/q8ta_conv2d.glsl | 8 + .../graph/ops/glsl/q8ta_conv2d_dw.glsl | 8 + .../graph/ops/glsl/q8ta_conv2d_pw.glsl | 9 + .../runtime/graph/ops/impl/Q8taConv2d.cpp | 18 +- .../runtime/graph/ops/impl/Q8taConv2d.h | 10 + .../runtime/graph/ops/impl/Q8taConv2dDW.cpp | 9 +- .../graph/ops/impl/Q8taConv2dIm2Col.cpp | 5 + .../runtime/graph/ops/impl/Q8taConv2dPW.cpp | 9 +- .../graph/ops/impl/QuantizedConvolution.cpp | 1 + .../test/custom_ops/impl/TestQ8taConv2d.cpp | 177 +++++++++++------- .../test/custom_ops/test_q8ta_conv2d.cpp | 6 + .../test/custom_ops/test_q8ta_conv2d_dw.cpp | 4 + .../test/custom_ops/test_q8ta_conv2d_pw.cpp | 6 + 15 files changed, 217 insertions(+), 71 deletions(-) diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index f2e4482c9b9..3e77b0c0eea 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -376,6 +376,7 @@ def q8ta_conv2d( padding: list, dilation: list, groups: int, + activation: str, ): x = torch.ops.quantized_decomposed.dequantize_per_tensor( x, input_scale, input_zero_point, -128, 127, x.dtype @@ -418,6 +419,9 @@ def q8ta_conv2d( x, weights, bias, stride, padding, dilation, groups ) + if activation == "relu": + out = torch.nn.functional.relu(out) + out = torch.ops.quantized_decomposed.quantize_per_tensor( out, output_scale, output_zero_point, -128, 127, torch.int8 ) @@ -442,7 +446,8 @@ def q8ta_conv2d( SymInt[] stride, SymInt[] padding, SymInt[] dilation, - SymInt groups) -> Tensor + SymInt groups, + str activation) -> Tensor """ ) lib.impl(name, q8ta_conv2d, "CompositeExplicitAutograd") @@ -466,7 +471,8 @@ def q8ta_conv2d( SymInt[] stride, SymInt[] padding, SymInt[] dilation, - SymInt groups) -> Tensor + SymInt groups, + str activation) -> Tensor """ ) lib.impl(name, q8ta_conv2d, "CompositeExplicitAutograd") @@ -488,6 +494,7 @@ def q8ta_conv2d_dw( padding: list, dilation: list, groups: int, + activation: str, ): x = torch.ops.quantized_decomposed.dequantize_per_tensor( x, input_scale, input_zero_point, -128, 127, x.dtype @@ -514,6 +521,9 @@ def q8ta_conv2d_dw( x, weights, bias, stride, padding, dilation, groups ) + if activation == "relu": + out = torch.nn.functional.relu(out) + out = torch.ops.quantized_decomposed.quantize_per_tensor( out, output_scale, output_zero_point, -128, 127, torch.int8 ) @@ -538,7 +548,8 @@ def q8ta_conv2d_dw( SymInt[] stride, SymInt[] padding, SymInt[] dilation, - SymInt groups) -> Tensor + SymInt groups, + str activation) -> Tensor """ ) lib.impl(name, q8ta_conv2d_dw, "CompositeExplicitAutograd") diff --git a/backends/vulkan/patterns/quantized_convolution.py b/backends/vulkan/patterns/quantized_convolution.py index 9a6fb69bf87..12ebbd1a382 100644 --- a/backends/vulkan/patterns/quantized_convolution.py +++ b/backends/vulkan/patterns/quantized_convolution.py @@ -281,6 +281,7 @@ def make_q8ta_conv2d_custom_op( match.padding, match.dilation, match.groups, + "relu" if match.relu_node is not None else "none", ), ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl index 623de3a5d9a..d693acbab3f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl @@ -47,6 +47,7 @@ layout(push_constant) uniform restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "apply_bias", "1")} +${layout_declare_spec_const(C, "int", "activation_type", "0")} // Layout specialization constants ${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} @@ -220,6 +221,13 @@ void main() { } } + // Apply ReLU if enabled + if (activation_type > 0) { + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + facc[subtile_w] = max(facc[subtile_w], vec4(0.0)); + } + } + // Compute base output texel index (for subtile_w=0) const int base_outp_texel_idx = tensor4d_idx_to_texel_idx(outp, outp_tidx, outp_layout); const int out_w_stride = int(outp.strides[0][0]); diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.glsl index e6be92e7ba1..7f4d03887df 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.glsl @@ -44,6 +44,7 @@ layout(push_constant) uniform restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "apply_bias", "1")} +${layout_declare_spec_const(C, "int", "activation_type", "0")} // Layout specialization constants ${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} @@ -197,6 +198,13 @@ void main() { } } + // Apply ReLU if enabled + if (activation_type > 0) { + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + facc[subtile_w] = max(facc[subtile_w], vec4(0.0)); + } + } + // Compute base output texel index (for subtile_w=0) const int base_outp_texel_idx = tensor4d_idx_to_texel_idx(outp, outp_tidx, outp_layout); const int out_w_stride = int(outp.strides[0][0]); diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl index e0963dfcf48..ec41d933114 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl @@ -57,6 +57,7 @@ layout(push_constant) uniform restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "apply_bias", "1")} +${layout_declare_spec_const(C, "int", "activation_type", "0")} ${layout_declare_spec_const(C, "int", "conv2d_params_K4_per_group", "1")} // Layout specialization constants @@ -197,6 +198,10 @@ void main() { fma(vec4(accum_adjusted), vec4(weight_scales[n4]) * input_scale, vec4(bias[n4])); + // Apply ReLU if enabled + if (activation_type > 0) { + float_out_texel = max(float_out_texel, vec4(0.0)); + } // Requantize to int8 float_out_texel = round(float_out_texel * output_inv_scale) + output_zp; @@ -216,6 +221,10 @@ void main() { input_zp_vec * weight_sums[n4] + out_accum[m][n4]; vec4 float_out_texel = vec4(accum_adjusted) * vec4(weight_scales[n4] * input_scale); + // Apply ReLU if enabled + if (activation_type > 0) { + float_out_texel = max(float_out_texel, vec4(0.0)); + } // Requantize to int8 float_out_texel = round(float_out_texel * output_inv_scale) + output_zp; diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp index 4f047d414f8..33b7005a845 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp @@ -17,6 +17,15 @@ namespace vkcompute { +ActivationType activation_type_from_string(const std::string& activation) { + if (activation == "none") { + return ActivationType::kNone; + } else if (activation == "relu") { + return ActivationType::kRelu; + } + VK_THROW("Unknown activation type: ", activation); +} + bool q8ta_conv2d_check_packed_dim_info(const api::PackedDimInfo& info) { return info.packed_dim == WHCN::kChannelsDim && info.packed_dim_block_size == 4 && @@ -231,6 +240,7 @@ void add_q8ta_conv2d_node( const ValueRef padding, const ValueRef dilation, const ValueRef groups, + const uint32_t activation_type, const ValueRef packed_int8_output) { (void)packed_int8_input_im2col; // Not used in general shader @@ -288,9 +298,10 @@ void add_q8ta_conv2d_node( graph.buffer_meta_ubo(packed_int8_input), graph.create_params_buffer(conv_params)}; - // Build spec constants: apply_bias + layout constants + // Build spec constants: apply_bias, apply_relu + layout constants vkapi::SpecVarList spec_constants = { apply_bias, + activation_type, // Layout specialization constants graph.hashed_layout_of(packed_int8_input), graph.hashed_layout_of(packed_int8_output), @@ -341,8 +352,12 @@ void q8ta_conv2d_general( const ValueRef padding = args.at(idx++); const ValueRef dilation = args.at(idx++); const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); const ValueRef packed_int8_output = args.at(idx++); + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation))); + QuantizationConfig weight_quant_config(8, kPerChannel, {}); // Prepack weight using the conv2d weight packing for the general shader @@ -397,6 +412,7 @@ void q8ta_conv2d_general( padding, dilation, groups, + activation_type_val, packed_int8_output); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h index 9686c873c1b..2779a7445a8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h @@ -13,6 +13,13 @@ namespace vkcompute { +enum class ActivationType : uint32_t { + kNone = 0, + kRelu = 1, +}; + +ActivationType activation_type_from_string(const std::string& activation); + bool q8ta_conv2d_check_packed_dim_info(const api::PackedDimInfo& info); bool q8ta_conv2d_check_4w4c_packed_dim_info(const api::PackedDimInfo& info); @@ -58,6 +65,7 @@ void add_q8ta_conv2d_dw_node( const ValueRef padding, const ValueRef dilation, const ValueRef groups, + const uint32_t activation_type, const ValueRef packed_int8_output); void add_conv2d_dw_q8ta_q8csw_q8to_4w4c_node( @@ -97,6 +105,7 @@ void add_q8ta_conv2d_node( const ValueRef padding, const ValueRef dilation, const ValueRef groups, + const uint32_t activation_type, const ValueRef packed_int8_output); void add_q8ta_conv2d_pw_node( @@ -111,6 +120,7 @@ void add_q8ta_conv2d_pw_node( const ValueRef output_zp, const ValueRef bias_data, const ValueRef packed_bias, + const uint32_t activation_type, const ValueRef packed_int8_output); void q8ta_conv2d_im2col(ComputeGraph& graph, const std::vector& args); diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp index d12bbc0574a..e690ff435a8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp @@ -281,6 +281,7 @@ void add_q8ta_conv2d_dw_node( const ValueRef padding, const ValueRef dilation, const ValueRef groups, + const uint32_t activation_type, const ValueRef packed_int8_output) { Conv2DParams conv_params = create_conv2d_params( graph, @@ -334,9 +335,10 @@ void add_q8ta_conv2d_dw_node( graph.buffer_meta_ubo(packed_int8_input), graph.create_params_buffer(conv_params)}; - // Build spec constants: apply_bias + layout constants + // Build spec constants: apply_bias, activation_type + layout constants vkapi::SpecVarList spec_constants = { apply_bias, + activation_type, // Layout specialization constants graph.hashed_layout_of(packed_int8_input), graph.hashed_layout_of(packed_int8_output), @@ -385,8 +387,12 @@ void q8ta_conv2d_dw(ComputeGraph& graph, const std::vector& args) { const ValueRef padding = args.at(idx++); const ValueRef dilation = args.at(idx++); const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); const ValueRef packed_int8_output = args.at(idx++); + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation))); + QuantizationConfig weight_quant_config(8, kPerChannel, {}); // Prepack weight using depthwise-specific packing @@ -432,6 +438,7 @@ void q8ta_conv2d_dw(ComputeGraph& graph, const std::vector& args) { padding, dilation, groups, + activation_type_val, packed_int8_output); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp index e89ebc92aba..161b5e8fc24 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp @@ -197,6 +197,7 @@ void q8ta_conv2d_im2col( const ValueRef padding = args.at(idx++); const ValueRef dilation = args.at(idx++); const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); const ValueRef packed_int8_output = args.at(idx++); QuantizationConfig weight_quant_config(8, kPerChannel, {}); @@ -225,6 +226,9 @@ void q8ta_conv2d_im2col( prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); } + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation))); + // Calculate im2col output sizes std::vector im2col_sizes = calculate_q8ta_im2col_sizes( &graph, packed_int8_input, packed_int8_output, kernel_size, groups); @@ -265,6 +269,7 @@ void q8ta_conv2d_im2col( output_zp, bias_data, packed_bias, + activation_type_val, packed_int8_output); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp index fc883eefeef..b72f5b78f53 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp @@ -199,6 +199,7 @@ void add_q8ta_conv2d_pw_node( const ValueRef output_zp, const ValueRef bias_data, const ValueRef packed_bias, + const uint32_t activation_type, const ValueRef packed_int8_output) { // Validate packed dim info for input and output tensors // To maximize performance, the input tensor must be in 4W4C layout @@ -242,9 +243,10 @@ void add_q8ta_conv2d_pw_node( graph.buffer_meta_ubo(packed_int8_output), graph.buffer_meta_ubo(packed_int8_input)}; - // Build spec constants: apply_bias + layout constants + // Build spec constants: apply_bias, activation_type + layout constants vkapi::SpecVarList spec_constants = { apply_bias, + activation_type, K4_per_group, // Layout specialization constants graph.hashed_layout_of(packed_int8_output), @@ -296,8 +298,12 @@ void q8ta_conv2d_pw(ComputeGraph& graph, const std::vector& args) { (void)args.at(idx++); // padding (void)args.at(idx++); // dilation (void)args.at(idx++); // groups + const ValueRef activation_ref = args.at(idx++); const ValueRef packed_int8_output = args.at(idx++); + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation_ref))); + QuantizationConfig weight_quant_config(8, kPerChannel, {}); // Prepack weight using pointwise-specific packing @@ -342,6 +348,7 @@ void q8ta_conv2d_pw(ComputeGraph& graph, const std::vector& args) { output_zp, bias_data, packed_bias, + activation_type_val, packed_int8_output); } diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp index 1bfff6f1342..ebc276ee347 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp @@ -894,6 +894,7 @@ void add_conv2d_q8ta_q8csw_q8to_node( padding, dilation, groups, + static_cast(ActivationType::kNone), packed_int8_output); } } diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp index 4fed7461ce6..679ac33d11b 100644 --- a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp @@ -32,6 +32,7 @@ void test_q8ta_conv2d_dw( const ValueRef padding = args.at(idx++); const ValueRef dilation = args.at(idx++); const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); const ValueRef layout_int = args.at(idx++); const ValueRef impl_selector_str = args.at(idx++); const ValueRef fp_output = args.at(idx++); @@ -59,29 +60,43 @@ void test_q8ta_conv2d_dw( add_q8ta_quantize_node( graph, fp_input, input_scale, input_zp, packed_int8_input); - // Build args for conv operator - std::vector conv_args = { - packed_int8_input, - input_scale, - input_zp, - weight_data, - weight_sums_data, - weight_scales_data, - output_scale, - output_zp, - bias_data, - kernel_size, - stride, - padding, - dilation, - groups, - packed_int8_output}; - if (impl_selector == "legacy_4w4c") { - // Use the general quantized conv2d operator for legacy path + // Legacy path does not support activation + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output}; VK_GET_OP_FN("et_vk.conv2d_q8ta_q8csw_q8to.default")(graph, conv_args); } else { - // Use the dedicated depthwise conv2d operator + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + activation, + packed_int8_output}; VK_GET_OP_FN("et_vk.q8ta_conv2d_dw.default")(graph, conv_args); } @@ -106,6 +121,7 @@ void test_q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { const ValueRef padding = args.at(idx++); const ValueRef dilation = args.at(idx++); const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); const ValueRef layout_int = args.at(idx++); const ValueRef impl_selector_str = args.at(idx++); const ValueRef fp_output = args.at(idx++); @@ -133,36 +149,50 @@ void test_q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { add_q8ta_quantize_node( graph, fp_input, input_scale, input_zp, packed_int8_input); - // Build args for conv operator - std::vector conv_args = { - packed_int8_input, - input_scale, - input_zp, - weight_data, - weight_sums_data, - weight_scales_data, - output_scale, - output_zp, - bias_data, - kernel_size, - stride, - padding, - dilation, - groups, - packed_int8_output}; - if (impl_selector == "legacy_4w4c") { - // Use the general quantized conv2d operator for legacy path + // Legacy path does not support activation + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output}; VK_GET_OP_FN("et_vk.conv2d_q8ta_q8csw_q8to.default")(graph, conv_args); - } else if (impl_selector == "im2col") { - // Use the im2col-based conv2d operator - VK_GET_OP_FN("et_vk.q8ta_conv2d_im2col.default")(graph, conv_args); - } else if (impl_selector == "general") { - // Use the general q8ta_conv2d operator (no im2col dispatch) - VK_GET_OP_FN("et_vk.q8ta_conv2d_general.default")(graph, conv_args); } else { - // Use the new general q8ta_conv2d operator - VK_GET_OP_FN("et_vk.q8ta_conv2d.default")(graph, conv_args); + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + activation, + packed_int8_output}; + if (impl_selector == "im2col") { + VK_GET_OP_FN("et_vk.q8ta_conv2d_im2col.default")(graph, conv_args); + } else if (impl_selector == "general") { + VK_GET_OP_FN("et_vk.q8ta_conv2d_general.default")(graph, conv_args); + } else { + VK_GET_OP_FN("et_vk.q8ta_conv2d.default")(graph, conv_args); + } } // Dequantize packed int8 output to floating point @@ -188,6 +218,7 @@ void test_q8ta_conv2d_pw( const ValueRef padding = args.at(idx++); const ValueRef dilation = args.at(idx++); const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); const ValueRef layout_int = args.at(idx++); const ValueRef impl_selector_str = args.at(idx++); const ValueRef fp_output = args.at(idx++); @@ -219,27 +250,43 @@ void test_q8ta_conv2d_pw( add_q8ta_quantize_node( graph, fp_input, input_scale, input_zp, packed_int8_input); - // Build args for conv operator - std::vector conv_args = { - packed_int8_input, - input_scale, - input_zp, - weight_data, - weight_sums_data, - weight_scales_data, - output_scale, - output_zp, - bias_data, - kernel_size, - stride, - padding, - dilation, - groups, - packed_int8_output}; - if (impl_selector == "legacy_4w4c") { + // Legacy path does not support activation + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output}; VK_GET_OP_FN("et_vk.conv2d_q8ta_q8csw_q8to.default")(graph, conv_args); } else { + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + activation, + packed_int8_output}; VK_GET_OP_FN("et_vk.q8ta_conv2d_pw.default")(graph, conv_args); } diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp index 17dd7a0fc53..bc95cc724f5 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp @@ -178,6 +178,10 @@ static TestCase create_test_case_from_config( test_case.add_input_spec(dilation); test_case.add_input_spec(groups); + // Activation (none = no activation) + ValueSpec activation = ValueSpec::make_string("none"); + test_case.add_input_spec(activation); + // Add memory layout parameter for the quantized tensors ValueSpec layout_int(static_cast(int8_memory_layout)); test_case.add_input_spec(layout_int); @@ -455,6 +459,8 @@ static void conv2d_q8ta_q8csw_q8to_reference_impl(TestCase& test_case) { const ValueSpec& padding_spec = test_case.inputs()[idx++]; const ValueSpec& dilation_spec = test_case.inputs()[idx++]; const ValueSpec& groups_spec = test_case.inputs()[idx++]; + const ValueSpec& activation_spec = test_case.inputs()[idx++]; + (void)activation_spec; // Not used in reference implementation const ValueSpec& layout_spec = test_case.inputs()[idx++]; (void)layout_spec; // Not used in reference implementation const ValueSpec& impl_selector_spec = test_case.inputs()[idx++]; diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp index 7ef73d49802..0734e444d57 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp @@ -187,6 +187,10 @@ TestCase create_test_case_from_config( test_case.add_input_spec(dilation); test_case.add_input_spec(groups); + // Activation (none = no activation) + ValueSpec activation = ValueSpec::make_string("none"); + test_case.add_input_spec(activation); + // Add memory layout parameter for the quantized tensors ValueSpec layout_int(static_cast(int8_memory_layout)); test_case.add_input_spec(layout_int); diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp index 6ce6671ec84..83b9f92fb3a 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp @@ -179,6 +179,10 @@ static TestCase create_test_case_from_config( test_case.add_input_spec(dilation); test_case.add_input_spec(groups); + // Activation (none = no activation) + ValueSpec activation = ValueSpec::make_string("none"); + test_case.add_input_spec(activation); + // Add memory layout parameter for the quantized tensors ValueSpec layout_int(static_cast(int8_memory_layout)); test_case.add_input_spec(layout_int); @@ -366,6 +370,8 @@ static void conv2d_q8ta_q8csw_q8to_reference_impl(TestCase& test_case) { const ValueSpec& padding_spec = test_case.inputs()[idx++]; const ValueSpec& dilation_spec = test_case.inputs()[idx++]; const ValueSpec& groups_spec = test_case.inputs()[idx++]; + const ValueSpec& activation_spec = test_case.inputs()[idx++]; + (void)activation_spec; // Not used in reference implementation const ValueSpec& layout_spec = test_case.inputs()[idx++]; (void)layout_spec; // Not used in reference implementation const ValueSpec& impl_selector_spec = test_case.inputs()[idx++];