diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 00b6c62d5d2..3bdc30feb7c 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -394,18 +394,11 @@ def constrain_op_out_repset(self, op_repsets: utils.OpRepSets) -> None: op_repsets.try_constrain_with_out_repset(out_respset) def constrain_op_repsets(self, op_repsets: utils.OpRepSets) -> None: - # For most ops, constraining the argument repsets will also contrain the output - # repset due to OpRepSets maintaining synchronization rules. for i in range(len(op_repsets.op_node.args)): if utils.is_tensor_arg_node(op_repsets.op_node.args[i]): self.constrain_op_arg_repset(i, op_repsets) - # However, some operators do not sync input and output representations and also - # define ambiguous repsets for the output tensor(s). In those cases we will need - # to execute additional logic to constrain the output repsets separately from - # the input repsets. - if not op_repsets.sync_primary_io_repr and op_repsets.sync_outs_repr: - self.constrain_op_out_repset(op_repsets) + self.constrain_op_out_repset(op_repsets) def set_op_node_tensor_reprs( self, graph_module: torch.fx.GraphModule, op_node: torch.fx.Node diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index e371338e904..87506f0b773 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -356,6 +356,136 @@ def linear_q8ta_q8csw( lib.impl(name, linear_q8ta_q8csw, "CompositeExplicitAutograd") qa_q8csw_linear = getattr(getattr(torch.ops, namespace), name) +################## +## q8ta_linear ## +################## + + +def q8ta_linear( + x: torch.Tensor, + input_scale: float, + input_zero_point: int, + weights: torch.Tensor, + weight_sums: torch.Tensor, + weight_scales: torch.Tensor, + output_scale: float, + output_zero_point: int, + bias: Optional[torch.Tensor] = None, + activation: str = "none", +): + weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32) + weights = torch.ops.quantized_decomposed.dequantize_per_channel( + weights, + weight_scales, + weight_zeros, + 0, + -127, + 127, + torch.int8, + ) + + x = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, input_scale, input_zero_point, -128, 127, x.dtype + ) + + out = torch.nn.functional.linear(x, weights) + if bias is not None: + out = out + bias[: out.shape[-1]] + + if activation == "relu": + out = torch.nn.functional.relu(out) + + out = torch.ops.quantized_decomposed.quantize_per_tensor( + out, output_scale, output_zero_point, -128, 127, torch.int8 + ) + + return out + + +name = "q8ta_linear" +lib.define( + f""" + {name}( + Tensor x, + float input_scale, + int input_zero_point, + Tensor weights, + Tensor weight_sums, + Tensor weight_scales, + float output_scale, + int output_zero_point, + Tensor? bias = None, + str activation = "none") -> Tensor + """ +) +lib.impl(name, q8ta_linear, "CompositeExplicitAutograd") +q8ta_linear_op = getattr(getattr(torch.ops, namespace), name) + +####################### +## q8ta_linear_gemv ## +####################### + + +def q8ta_linear_gemv( + x: torch.Tensor, + input_scale: float, + input_zero_point: int, + weights: torch.Tensor, + weight_sums: torch.Tensor, + weight_scales: torch.Tensor, + output_scale: float, + output_zero_point: int, + bias: Optional[torch.Tensor] = None, + activation: str = "none", +): + weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32) + weights = torch.ops.quantized_decomposed.dequantize_per_channel( + weights, + weight_scales, + weight_zeros, + 0, + -127, + 127, + torch.int8, + ) + + x = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, input_scale, input_zero_point, -128, 127, x.dtype + ) + + out = torch.nn.functional.linear(x, weights) + if bias is not None: + out = out + bias[: out.shape[-1]] + + if activation == "relu": + out = torch.nn.functional.relu(out) + + out = torch.ops.quantized_decomposed.quantize_per_tensor( + out, output_scale, output_zero_point, -128, 127, torch.int8 + ) + + return out + + +name = "q8ta_linear_gemv" +lib.define( + f""" + {name}( + Tensor x, + float input_scale, + int input_zero_point, + Tensor weights, + Tensor weight_sums, + Tensor weight_scales, + float output_scale, + int output_zero_point, + Tensor? bias = None, + str activation = "none") -> Tensor + """ +) +lib.impl(name, q8ta_linear_gemv, "CompositeExplicitAutograd") +q8ta_linear_gemv_op = getattr(getattr(torch.ops, namespace), name) + ################### ## q8ta_conv2d_* ## ################### diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 853ba5d3777..855df9d2e74 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -830,6 +830,57 @@ def register_q8ta_conv2d_ops(): ) +# ============================================================================= +# Q8taLinear.cpp +# ============================================================================= + + +@update_features(exir_ops.edge.et_vk.q8ta_linear.default) +def register_q8ta_linear(): + return OpFeatures( + inputs_storage=[ + utils.PACKED_INT8_4H4W_BUFFER, # input + utils.NO_STORAGE, # input_scale (non tensor) + utils.NO_STORAGE, # input_zero_point (non tensor) + utils.NO_STORAGE, # weight (prepacked) + utils.NO_STORAGE, # weight_sums (prepacked) + utils.NO_STORAGE, # weight_scales (prepacked) + utils.NO_STORAGE, # output_scale (non tensor) + utils.NO_STORAGE, # output_zero_point (non tensor) + utils.NO_STORAGE, # bias (prepacked) + utils.NO_STORAGE, # activation (non tensor) + ], + outputs_storage=[ + utils.PACKED_INT8_4H4W_BUFFER, + ], + supports_resize=False, + supports_prepacking=True, + ) + + +@update_features(exir_ops.edge.et_vk.q8ta_linear_gemv.default) +def register_q8ta_linear_gemv(): + return OpFeatures( + inputs_storage=[ + utils.PACKED_INT8_4W_BUFFER, # input + utils.NO_STORAGE, # input_scale (non tensor) + utils.NO_STORAGE, # input_zero_point (non tensor) + utils.NO_STORAGE, # weight (prepacked) + utils.NO_STORAGE, # weight_sums (prepacked) + utils.NO_STORAGE, # weight_scales (prepacked) + utils.NO_STORAGE, # output_scale (non tensor) + utils.NO_STORAGE, # output_zero_point (non tensor) + utils.NO_STORAGE, # bias (prepacked) + utils.NO_STORAGE, # activation (non tensor) + ], + outputs_storage=[ + utils.PACKED_INT8_4W_BUFFER, + ], + supports_resize=False, + supports_prepacking=True, + ) + + # ============================================================================= # SDPA.cpp # ============================================================================= diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index 374e29c634d..df80749e72f 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -31,13 +31,19 @@ class QuantizedLinearMatch(PatternMatch): - def __init__(self, mm_node: torch.fx.Node) -> None: + def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901 self.anchor_node = mm_node self.match_found = False self.all_nodes = [self.anchor_node] + # addmm(bias, mat1, mat2) has a different arg layout than + # mm(mat1, mat2) and linear(input, weight, bias?) + is_addmm = self.anchor_node.target == exir_ops.edge.aten.addmm.default + weight_arg_idx = 2 if is_addmm else 1 + input_arg_idx = 1 if is_addmm else 0 + const_node, arg_chain = utils.trace_args_until_placeholder( - self.anchor_node.args[1] + self.anchor_node.args[weight_arg_idx] ) # mat2 is not a constant tensor - no match @@ -84,26 +90,64 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # Identify output node self.output_node = self.anchor_node - # The implementation has a limitation that output channels must be a - # multiple of 4. This is to ensure that data loads are aligned well with - # texel boundaries. If this is not true, then don't match the pattern. - out_channels = self.output_node.meta["val"].shape[-1] - if out_channels % 4 != 0: - return + # Identify primary input node of the anchor. Due to decomposition of aten.linear + # there may be a view_copy node between the original input tensor to the linear + # op and the actual linear op node. + anchor_primary_input_node = self.anchor_node.args[input_arg_idx] + assert isinstance(anchor_primary_input_node, torch.fx.Node) + + # Skip potential view_copy between dq and linear + if utils.is_view_copy_node(anchor_primary_input_node): + self.all_nodes.append(anchor_primary_input_node) + anchor_primary_input_node = anchor_primary_input_node.args[ + 0 + ] # pyre-ignore[16] + assert isinstance(anchor_primary_input_node, torch.fx.Node) + + # By default, assume that the input tensor is not quantized in any way + self.quantize_input_node = None + self.dequantize_input_node = None + self.pattern_input_node = anchor_primary_input_node - # Identify input node - ( - self.fp_input_node, - self.quantize_input_node, - dq_node, - ) = utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0]) - assert self.fp_input_node is not None - self.all_nodes.append(self.fp_input_node) + self.input_scales_node = None + self.input_zeros_node = None + + scales_arg_idx = 1 + zeros_arg_idx = 2 + + # If the primary input node comes from a dequantize node, that implies the input + # input tensor is quantized (either statically or dynamically). + if utils.is_dequant_node(anchor_primary_input_node): + # Assume that this is a static quantization pattern; the input to the + # pattern is a statically quantized int8 tensor. + self.dequantize_input_node = anchor_primary_input_node + self.all_nodes.append(self.dequantize_input_node) + input_to_dq_node = self.dequantize_input_node.args[0] + self.pattern_input_node = input_to_dq_node + + # torchao dequantize has a slightly different function schema + if ( + self.dequantize_input_node.target + == exir_ops.edge.torchao.dequantize_affine.default + ): + scales_arg_idx = 2 + zeros_arg_idx = 3 + + self.input_scales_node = self.dequantize_input_node.args[scales_arg_idx] + self.input_zeros_node = self.dequantize_input_node.args[zeros_arg_idx] + + # Check for dynamic quantization: input scales are dynamically + # computed via a choose_qparams op + if utils.is_quant_node(input_to_dq_node) and utils.is_dynamic_qscale( + self.input_scales_node + ): + self.quantize_input_node = input_to_dq_node + self.pattern_input_node = self.quantize_input_node.args[0] # The implementation has a limitation that input channels must be a # multiple of 4. This is to ensure that data loads are aligned well with # texel boundaries. If this is not true, then don't match the pattern. - in_channels = self.fp_input_node.meta["val"].shape[-1] + in_channels = self.pattern_input_node.meta["val"].shape[-1] if in_channels % 4 != 0: return @@ -111,42 +155,52 @@ def __init__(self, mm_node: torch.fx.Node) -> None: self.bias_node = None if self.anchor_node.target == exir_ops.edge.aten.addmm.default: self.bias_node, arg_chain = utils.trace_args_until_placeholder( - self.anchor_node.args[2] + self.anchor_node.args[0] ) assert self.bias_node is not None self.all_nodes.extend(arg_chain) + elif self.anchor_node.target == exir_ops.edge.aten.linear.default: + if len(self.anchor_node.args) > 2 and self.anchor_node.args[2] is not None: + self.bias_node, arg_chain = utils.trace_args_until_placeholder( + self.anchor_node.args[2] + ) + if self.bias_node is not None: + self.all_nodes.extend(arg_chain) # If input is not quantized, then we are done - if self.quantize_input_node is None: + if self.dequantize_input_node is None: self.match_found = True return - scales_arg_idx = 1 - zeros_arg_idx = 2 - - # torchao op has a slightly different function schema - if ( - self.quantize_input_node.target - == exir_ops.edge.torchao.quantize_affine.default - ): - scales_arg_idx = 2 - zeros_arg_idx = 3 - - self.input_scales_node = self.quantize_input_node.args[scales_arg_idx] - self.input_zeros_node = self.quantize_input_node.args[zeros_arg_idx] - - assert dq_node is not None - self.all_nodes.extend( - [ - self.quantize_input_node, - dq_node, - ] - ) + # Check if the output is also quantized (q → dq → linear → q pattern) + # Also handle fused linear+relu (q → dq → linear → relu → q pattern) + self.quantize_output_node = None + self.output_scales_node = None + self.output_zeros_node = None + self.relu_node = None + if len(self.output_node.users) == 1: + cur_node = list(self.output_node.users)[0] + if cur_node.target == exir_ops.edge.aten.relu.default: + self.relu_node = cur_node + if len(cur_node.users) == 1: + cur_node = list(cur_node.users)[0] + else: + cur_node = None + if cur_node is not None and utils.is_quant_node(cur_node): + self.quantize_output_node = cur_node + self.output_scales_node = self.quantize_output_node.args[1] + self.output_zeros_node = self.quantize_output_node.args[2] self.match_found = True def is_weight_only_quantized(self) -> bool: - return self.quantize_input_node is None + return self.dequantize_input_node is None + + def has_output_quantization(self) -> bool: + return ( + hasattr(self, "quantize_output_node") + and self.quantize_output_node is not None + ) def is_weight_pergroup_quantized(self) -> bool: weight_shape = self.weight_node.meta["val"].shape @@ -172,7 +226,7 @@ def is_weight_perchannel_quantized(self) -> bool: return scales_shape[0] == weight_shape[-2] def is_input_static_per_tensor_quantized(self) -> bool: - if self.quantize_input_node is None: + if self.dequantize_input_node is None: return False # For static quantization per tensor quantization, the scales and zeros @@ -180,7 +234,7 @@ def is_input_static_per_tensor_quantized(self) -> bool: return isinstance(self.input_scales_node, float) def is_input_dynamic_perchannel_quantized(self) -> bool: - if self.quantize_input_node is None: + if self.dequantize_input_node is None: return False if not isinstance(self.input_scales_node, torch.fx.Node): @@ -196,7 +250,7 @@ def is_input_dynamic_perchannel_quantized(self) -> bool: return False scales_shape = self.input_scales_node.meta["val"].shape - input_shape = self.fp_input_node.meta["val"].shape + input_shape = self.pattern_input_node.meta["val"].shape return input_shape[-2] == scales_shape[-1] @@ -334,7 +388,7 @@ def make_linear_q4gsw_op( "call_function", exir_ops.edge.et_vk.linear_q4gsw.default, args=( - match.fp_input_node, + match.pattern_input_node, match.weight_node, match.weight_scales_node, group_size, @@ -398,7 +452,7 @@ def make_linear_dq8ca_q4gsw_op( "call_function", exir_ops.edge.et_vk.linear_dq8ca_q4gsw.default, args=( - match.fp_input_node, + match.pattern_input_node, match.input_scales_node, match.input_zeros_node, match.weight_node, @@ -418,12 +472,34 @@ def make_linear_q8ta_q8csw_custom_op( match: QuantizedLinearMatch, weight_tensor: torch.Tensor, ): + # Pad weight_scales to multiple of 4 so GPU shader reads don't go OOB + weight_scales_tensor = get_param_tensor(ep, match.weight_scales_node) + assert weight_scales_tensor is not None + utils.align_width_and_update_state_dict( + ep, match.weight_scales_node, weight_scales_tensor + ) + + # Pad bias to multiple of 4 if present + if match.bias_node is not None: + bias_tensor = get_param_tensor(ep, match.bias_node) + if bias_tensor is not None: + utils.align_width_and_update_state_dict(ep, match.bias_node, bias_tensor) + first_graph_node = list(graph_module.graph.nodes)[0] with graph_module.graph.inserting_before(first_graph_node): weight_tensor_name = utils.get_tensor_name(ep, match.weight_node) # Pre-compute the weight sums which are needed to apply activation zero point # when using integer accumulation. sum_per_output_channel = weight_tensor.sum(dim=1).to(torch.int32).contiguous() + + # Pad weight sums to align OC to multiple of 4 + oc = sum_per_output_channel.shape[0] + if oc % 4 != 0: + num_padding = 4 - (oc % 4) + sum_per_output_channel = F.pad( + sum_per_output_channel, (0, num_padding) + ).contiguous() + sums_name = weight_tensor_name + "_sums" # Sanitize the name sums_name = sums_name.replace(".", "_") @@ -441,7 +517,7 @@ def make_linear_q8ta_q8csw_custom_op( "call_function", exir_ops.edge.et_vk.linear_q8ta_q8csw.default, args=( - match.fp_input_node, + match.pattern_input_node, match.input_scales_node, match.input_zeros_node, match.weight_node, @@ -454,6 +530,79 @@ def make_linear_q8ta_q8csw_custom_op( match.output_node.replace_all_uses_with(qlinear_node) +def make_q8ta_linear_custom_op( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: QuantizedLinearMatch, + weight_tensor: torch.Tensor, +): + # Pad weight_scales to multiple of 4 so GPU shader reads don't go OOB + weight_scales_tensor = get_param_tensor(ep, match.weight_scales_node) + assert weight_scales_tensor is not None + utils.align_width_and_update_state_dict( + ep, match.weight_scales_node, weight_scales_tensor + ) + + # Pad bias to multiple of 4 if present + if match.bias_node is not None: + bias_tensor = get_param_tensor(ep, match.bias_node) + if bias_tensor is not None: + utils.align_width_and_update_state_dict(ep, match.bias_node, bias_tensor) + + first_graph_node = list(graph_module.graph.nodes)[0] + with graph_module.graph.inserting_before(first_graph_node): + weight_tensor_name = utils.get_tensor_name(ep, match.weight_node) + sum_per_output_channel = weight_tensor.sum(dim=1).to(torch.int32).contiguous() + + # Pad weight sums to align OC to multiple of 4 + oc = sum_per_output_channel.shape[0] + if oc % 4 != 0: + num_padding = 4 - (oc % 4) + sum_per_output_channel = F.pad( + sum_per_output_channel, (0, num_padding) + ).contiguous() + + sums_name = weight_tensor_name + "_sums" + sums_name = sums_name.replace(".", "_") + + weight_sums_node = create_constant_placeholder( + exp_program=ep, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=sums_name, + data=sum_per_output_channel, + ) + + # Use gemv variant when batch size is 1 + input_shape = match.pattern_input_node.meta["val"].shape + batch_size = input_shape[-2] if len(input_shape) >= 2 else 1 + if batch_size == 1: + op_target = exir_ops.edge.et_vk.q8ta_linear_gemv.default + else: + op_target = exir_ops.edge.et_vk.q8ta_linear.default + + with graph_module.graph.inserting_before(match.output_node): + qlinear_node = graph_module.graph.create_node( + "call_function", + op_target, + args=( + match.pattern_input_node, + match.input_scales_node, + match.input_zeros_node, + match.weight_node, + weight_sums_node, + match.weight_scales_node, + match.output_scales_node, + match.output_zeros_node, + match.bias_node, + "relu" if match.relu_node is not None else "none", + ), + ) + + qlinear_node.meta["val"] = match.quantize_output_node.meta["val"] + match.quantize_output_node.replace_all_uses_with(qlinear_node) + + @register_pattern_replacement("quantized_linear") def replace_quantized_linear_patterns( ep: ExportedProgram, @@ -472,11 +621,20 @@ def replace_quantized_linear_patterns( weight_zeros_tensor = get_param_tensor(ep, match.weight_zeros_node) assert weight_zeros_tensor is not None - # Biases not supported at the moment + # Route to appropriate custom op. + # q8ta_linear supports bias, so check it first before the bias guard. + if ( + match.is_input_static_per_tensor_quantized() + and match.is_weight_perchannel_quantized() + and match.has_output_quantization() + ): + make_q8ta_linear_custom_op(ep, graph_module, match, weight_tensor) + return + + # Remaining ops do not support bias if match.bias_node is not None: return - # Route to appropriate custom op if ( match.is_weight_only_quantized() and match.is_weight_pergroup_quantized() diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index fbca5af5100..7f7afffcf57 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -144,6 +144,8 @@ utils::GPUMemoryLayout get_memory_layout( return utils::kPackedInt8_4W4C; case vkgraph::VkMemoryLayout::PACKED_INT8_4H4W: return utils::kPackedInt8_4H4W; + case vkgraph::VkMemoryLayout::PACKED_INT8_4W: + return utils::kPackedInt8_4W; case vkgraph::VkMemoryLayout::PACKED_INT8_4C1W: return utils::kPackedInt8_4C1W; default: diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.glsl new file mode 100644 index 00000000000..87a3d539297 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.glsl @@ -0,0 +1,160 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions("buffer", DTYPE)} + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_integer_dot_product : require + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T int + +#define PACKED_INT8_INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M ${TILE_M4 * 4} +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_spec_const(C, "int", "apply_bias", "0")} +${layout_declare_spec_const(C, "int", "activation_type", "0")} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "common.glslh" +#include "linear_int8_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_output_tile_int8_int8_compute.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_int_weight_sums_load.glslh" +#include "linear_fp_bias_load.glslh" + +void main() { + const int out_tile_x = int(gl_GlobalInvocationID.x); + const int out_tile_y = int(gl_GlobalInvocationID.y); + + const int n = out_tile_x * TILE_N; + const int m = out_tile_y * TILE_M; + + const int n4 = div_4(n); + const int m4 = div_4(m); + + if (n >= output_sizes.x || m >= output_sizes.y) { + return; + } + + const int M = output_sizes.y; + const int K4 = div_up_4(input_sizes.x); + const int N4 = div_up_4(output_sizes.x); + + Int32Accum out_accum; + initialize(out_accum); + + Int8InputTile int8_in_tile; + Int8WeightTile int8_weight_tile; + + for (int k4 = 0; k4 < K4; k4 += TILE_K4) { + load_int8_input_tile(int8_in_tile, k4, m4, K4); + load_int8_weight_tile(int8_weight_tile, n4, k4, N4); + + int_accumulate_with_int8_weight(out_accum, int8_in_tile, int8_weight_tile); + } + + FPPerOutChannelParams weight_scales_tile; + load_weight_scales_tile(weight_scales_tile, n4); + + IntPerOutChannelParams weight_sums_tile; + load_weight_sums_tile(weight_sums_tile, n4); + + FPOutTile out_tile; + initialize(out_tile); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, n4); + + accumulate_out_tile_with_int_accum( + out_tile, + out_accum, + input_scale, + input_zp, + weight_sums_tile, + weight_scales_tile, + bias_tile); + } + else { + accumulate_out_tile_with_int_accum( + out_tile, + out_accum, + input_scale, + input_zp, + weight_sums_tile, + weight_scales_tile); + } + + // Apply ReLU if enabled + if (activation_type > 0) { + [[unroll]] for (int tile_m = 0; tile_m < TILE_M; ++tile_m) { + [[unroll]] for (int tile_n4 = 0; tile_n4 < TILE_N4; ++tile_n4) { + out_tile.data[tile_m][tile_n4] = max(out_tile.data[tile_m][tile_n4], vec4(0.0)); + } + } + } + + // Quantize float output tile to int8 and write in PACKED_INT8_4H4W format + const int M4 = div_up_4(M); + + [[unroll]] for (int tile_m4 = 0; tile_m4 < TILE_M4; ++tile_m4) { + if (m4 + tile_m4 >= M4) { + break; + } + [[unroll]] for (int tile_n4 = 0; tile_n4 < TILE_N4; ++tile_n4) { + if (n4 + tile_n4 >= N4) { + break; + } + ivec4 packed_block; + [[unroll]] for (int i = 0; i < 4; ++i) { + const int tile_m = tile_m4 * 4 + i; + if (m + tile_m < M) { + packed_block[i] = quantize_and_pack( + out_tile.data[tile_m][tile_n4], output_inv_scale, output_zp); + } else { + packed_block[i] = 0; + } + } + t_packed_int8_output[(m4 + tile_m4) * N4 + n4 + tile_n4] = packed_block; + } + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.yaml new file mode 100644 index 00000000000..c7836c60477 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q8ta_linear: + parameter_names_with_default_values: + DTYPE: float + WEIGHT_STORAGE: texture2d + TILE_M4: 1 + TILE_N4: 2 + TILE_K4: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: q8ta_linear diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl new file mode 100644 index 00000000000..aa0837c4a6e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl @@ -0,0 +1,165 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions("buffer", DTYPE)} + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_integer_dot_product : require + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T int + +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M4 1 +#define TILE_M 1 +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +#define WGS ${WGS} + +layout(std430) buffer; + +// Scalar int arrays for 4W packed int8 input/output +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer")} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer")} +// Weight uses ivec4 (same format as q8ta_linear) +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_spec_const(C, "int", "apply_bias", "0")} +${layout_declare_spec_const(C, "int", "activation_type", "0")} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "common.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_output_tile_int8_int8_compute.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_int_weight_sums_load.glslh" +#include "linear_fp_bias_load.glslh" + +shared Int32Accum partial_accums[WGS]; + +void main() { + const int lid = int(gl_LocalInvocationID.z); + const int n4 = int(gl_GlobalInvocationID.x) * TILE_N4; + + const int n = mul_4(n4); + + const int K4 = div_up_4(input_sizes.x); + const int N4 = div_up_4(output_sizes.x); + + if (n >= output_sizes.x) { + return; + } + + Int32Accum out_accum; + initialize(out_accum); + + Int8WeightTile int8_weight_tile; + + for (int k4 = lid; k4 < K4; k4 += WGS) { + // Load one packed int32 from the 4W input buffer. Each int32 contains + // 4 int8 values at k=k4*4..k4*4+3. + const int packed_input = t_packed_int8_input[k4]; + + load_int8_weight_tile(int8_weight_tile, n4, k4, N4); + + // Accumulate dot products of the input int8x4 with each weight int8x4 + [[unroll]] for (int n = 0; n < TILE_N; ++n) { + const int tile_n4 = div_4(n); + const int n4i = mod_4(n); + out_accum.data[0][tile_n4][n4i] = dotPacked4x8AccSatEXT( + packed_input, + int8_weight_tile.data[0][tile_n4][n4i], + out_accum.data[0][tile_n4][n4i]); + } + } + + partial_accums[lid] = out_accum; + + memoryBarrierShared(); + barrier(); + + // Only the first thread writes the result + if (lid == 0) { + for (int i = 1; i < WGS; ++i) { + [[unroll]] for (int tile_n4 = 0; tile_n4 < TILE_N4; ++tile_n4) { + out_accum.data[0][tile_n4] += + partial_accums[i].data[0][tile_n4]; + } + } + + FPPerOutChannelParams weight_scales_tile; + load_weight_scales_tile(weight_scales_tile, n4); + + IntPerOutChannelParams weight_sums_tile; + load_weight_sums_tile(weight_sums_tile, n4); + + FPOutTile out_tile; + initialize(out_tile); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, n4); + + accumulate_out_tile_with_int_accum( + out_tile, + out_accum, + input_scale, + input_zp, + weight_sums_tile, + weight_scales_tile, + bias_tile); + } else { + accumulate_out_tile_with_int_accum( + out_tile, + out_accum, + input_scale, + input_zp, + weight_sums_tile, + weight_scales_tile); + } + + // Apply ReLU if enabled + if (activation_type > 0) { + [[unroll]] for (int tile_n4 = 0; tile_n4 < TILE_N4; ++tile_n4) { + out_tile.data[0][tile_n4] = max(out_tile.data[0][tile_n4], vec4(0.0)); + } + } + + // Quantize and write to scalar int[] buffer. Each int32 at position n4 + // contains 4 packed int8 output values for channels n4*4..n4*4+3. + [[unroll]] for (int tile_n4 = 0; tile_n4 < TILE_N4; ++tile_n4) { + if (n4 + tile_n4 < N4) { + t_packed_int8_output[n4 + tile_n4] = quantize_and_pack( + out_tile.data[0][tile_n4], output_inv_scale, output_zp); + } + } + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.yaml new file mode 100644 index 00000000000..beae1eddf3e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q8ta_linear_gemv: + parameter_names_with_default_values: + DTYPE: float + WEIGHT_STORAGE: texture2d + TILE_K4: 1 + TILE_N4: 2 + WGS: 64 + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: q8ta_linear_gemv diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp index 33b7005a845..8273df6a07e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp @@ -417,7 +417,30 @@ void q8ta_conv2d_general( } void q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { - q8ta_conv2d_general(graph, args); + const ValueRef input = args.at(0); + const ValueRef groups_ref = args.at(13); + const ValueRef output = args.at(15); + + const int64_t groups = graph.extract_scalar(groups_ref); + const int64_t in_channels = graph.size_at(-3, input); + const int64_t in_channels_per_group = in_channels / groups; + + const int64_t H_out = graph.size_at(-2, output); + const int64_t W_out = graph.size_at(-1, output); + const int64_t spatial_out = H_out * W_out; + + // Use im2col when the channel depth is sufficient for tiled GEMM to win, or + // when the output spatial area is small enough that the im2col buffer stays + // manageable. For large spatial outputs with few channels, the im2col buffer + // becomes too large and the general shader is more efficient. + const bool use_im2col = groups == 1 && in_channels_per_group % 4 == 0 && + (in_channels_per_group >= 64 || spatial_out <= 4096); + + if (use_im2col) { + q8ta_conv2d_im2col(graph, args); + } else { + q8ta_conv2d_general(graph, args); + } } REGISTER_OPERATORS { diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taLinear.cpp new file mode 100644 index 00000000000..45366fbf044 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taLinear.cpp @@ -0,0 +1,207 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include +#include +#include + +namespace vkcompute { + +bool q8ta_linear_check_packed_dim_info(const api::PackedDimInfo& info) { + return info.packed_dim == WHCN::kWidthDim && + info.packed_dim_block_size == 4 && + info.outer_packed_dim == WHCN::kHeightDim && + info.outer_packed_dim_block_size == 4; +} + +// +// Workgroup size selection +// + +utils::uvec3 q8ta_linear_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef out = args.at(0).refs.at(0); + + std::vector out_sizes = graph->sizes_of(out); + const uint32_t N = utils::val_at(-1, out_sizes); + const uint32_t M = utils::val_at(-2, out_sizes); + + // Each output tile contains 8 columns (TILE_N4=2 -> 8 output channels) + const uint32_t N_per_tile = 8; + const uint32_t M_per_tile = 4; + + const uint32_t num_N_tiles = utils::div_up(N, N_per_tile); + const uint32_t num_M_tiles = utils::div_up(M, M_per_tile); + + return {num_N_tiles, num_M_tiles, 1}; +} + +utils::uvec3 q8ta_linear_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + return pick_hw_square_wg_size( + graph, shader, global_workgroup_size, args, resize_args); +} + +// +// Dispatch node +// + +void add_q8ta_linear_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef packed_bias, + const uint32_t activation_type, + const ValueRef packed_int8_output) { + // Validate packed dim info matches 4H4W layout + VK_CHECK_COND(q8ta_linear_check_packed_dim_info( + graph.packed_dim_info_of(packed_int8_input))); + VK_CHECK_COND(q8ta_linear_check_packed_dim_info( + graph.packed_dim_info_of(packed_int8_output))); + + float input_scale_val = graph.extract_scalar(input_scale); + int32_t input_zp_val = graph.extract_scalar(input_zp); + + float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); + int32_t output_zp_val = graph.extract_scalar(output_zp); + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + std::string kernel_name = "q8ta_linear"; + add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(packed_int8_output), graph.sizes_ubo(packed_int8_input)}; + + std::vector push_constants = { + PushConstantDataInfo(&input_scale_val, sizeof(input_scale_val)), + PushConstantDataInfo(&input_zp_val, sizeof(input_zp_val)), + PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), + PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + q8ta_linear_global_wg_size, + q8ta_linear_local_wg_size, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, + {{packed_int8_input, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {apply_bias, activation_type}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + +// +// High level operator impl +// + +void q8ta_linear(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef activation = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + const int64_t K = graph.size_at(-1, packed_int8_input); + VK_CHECK_COND(K % 4 == 0); + + QuantizationConfig weight_quant_config(8, kPerChannel, {K}); + + // Prepack weight data + const ValueRef packed_weight = + prepack_quantized_linear_weight(graph, weight_quant_config, weight_data); + const ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + const ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + + // Prepack bias data + TmpTensor dummy_bias( + &graph, + {}, + graph.dtype_of(packed_weight_scales), + utils::kBuffer, + utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_not_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation))); + + add_q8ta_linear_node( + graph, + packed_int8_input, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + output_scale, + output_zp, + bias_data, + packed_bias, + activation_type_val, + packed_int8_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.q8ta_linear.default, q8ta_linear); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taLinear.h b/backends/vulkan/runtime/graph/ops/impl/Q8taLinear.h new file mode 100644 index 00000000000..9f975525324 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taLinear.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace vkcompute { + +void add_q8ta_linear_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef packed_bias, + const uint32_t activation_type, + const ValueRef packed_int8_output); + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.cpp new file mode 100644 index 00000000000..120df6b0256 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.cpp @@ -0,0 +1,210 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include +#include +#include + +namespace vkcompute { + +static bool q8ta_linear_gemv_check_packed_dim_info( + const api::PackedDimInfo& info) { + return info.packed_dim == WHCN::kWidthDim && + info.packed_dim_block_size == 4 && + info.outer_packed_dim == WHCN::kHeightDim && + info.outer_packed_dim_block_size == 1; +} + +// +// Workgroup size selection +// + +utils::uvec3 q8ta_linear_gemv_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef out = args.at(0).refs.at(0); + + std::vector out_sizes = graph->sizes_of(out); + const uint32_t N = utils::val_at(-1, out_sizes); + + // Each output tile contains 8 columns (TILE_N4=2 -> 8 output channels) + const uint32_t N_per_tile = 8; + const uint32_t num_N_tiles = utils::div_up(N, N_per_tile); + + return {num_N_tiles, 1, 1}; +} + +utils::uvec3 q8ta_linear_gemv_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)global_workgroup_size; + (void)args; + (void)resize_args; + + // Cooperative algorithm: 64 threads share the K reduction + return {1, 1, 64}; +} + +// +// Dispatch node +// + +void add_q8ta_linear_gemv_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef packed_bias, + const uint32_t activation_type, + const ValueRef packed_int8_output) { + // Validate packed dim info matches 4W layout + VK_CHECK_COND(q8ta_linear_gemv_check_packed_dim_info( + graph.packed_dim_info_of(packed_int8_input))); + VK_CHECK_COND(q8ta_linear_gemv_check_packed_dim_info( + graph.packed_dim_info_of(packed_int8_output))); + + float input_scale_val = graph.extract_scalar(input_scale); + int32_t input_zp_val = graph.extract_scalar(input_zp); + + float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); + int32_t output_zp_val = graph.extract_scalar(output_zp); + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + std::string kernel_name = "q8ta_linear_gemv"; + add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(packed_int8_output), graph.sizes_ubo(packed_int8_input)}; + + std::vector push_constants = { + PushConstantDataInfo(&input_scale_val, sizeof(input_scale_val)), + PushConstantDataInfo(&input_zp_val, sizeof(input_zp_val)), + PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), + PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + q8ta_linear_gemv_global_wg_size, + q8ta_linear_gemv_local_wg_size, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, + {{packed_int8_input, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {apply_bias, activation_type}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + +// +// High level operator impl +// + +void q8ta_linear_gemv(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef activation = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + const int64_t K = graph.size_at(-1, packed_int8_input); + VK_CHECK_COND(K % 4 == 0); + + QuantizationConfig weight_quant_config(8, kPerChannel, {K}); + + // Prepack weight data (same format as q8ta_linear) + const ValueRef packed_weight = + prepack_quantized_linear_weight(graph, weight_quant_config, weight_data); + const ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + const ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + + // Prepack bias data + TmpTensor dummy_bias( + &graph, + {}, + graph.dtype_of(packed_weight_scales), + utils::kBuffer, + utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_not_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation))); + + add_q8ta_linear_gemv_node( + graph, + packed_int8_input, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + output_scale, + output_zp, + bias_data, + packed_bias, + activation_type_val, + packed_int8_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.q8ta_linear_gemv.default, q8ta_linear_gemv); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.h b/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.h new file mode 100644 index 00000000000..946022d16ef --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace vkcompute { + +void add_q8ta_linear_gemv_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef packed_bias, + const uint32_t activation_type, + const ValueRef packed_int8_output); + +} // namespace vkcompute diff --git a/backends/vulkan/serialization/schema.fbs b/backends/vulkan/serialization/schema.fbs index 8218ee3387f..36f9feaa580 100644 --- a/backends/vulkan/serialization/schema.fbs +++ b/backends/vulkan/serialization/schema.fbs @@ -42,6 +42,7 @@ enum VkMemoryLayout : ubyte { TENSOR_CHANNELS_PACKED = 2, PACKED_INT8_4W4C = 3, PACKED_INT8_4H4W = 4, + PACKED_INT8_4W = 5, PACKED_INT8_4C1W = 8, DEFAULT_LAYOUT = 255, } diff --git a/backends/vulkan/serialization/vulkan_graph_schema.py b/backends/vulkan/serialization/vulkan_graph_schema.py index d14428d3b66..845a59a4dff 100644 --- a/backends/vulkan/serialization/vulkan_graph_schema.py +++ b/backends/vulkan/serialization/vulkan_graph_schema.py @@ -50,6 +50,7 @@ class VkMemoryLayout(IntEnum): TENSOR_CHANNELS_PACKED = 2 PACKED_INT8_4W4C = 3 PACKED_INT8_4H4W = 4 + PACKED_INT8_4W = 5 PACKED_INT8_4C1W = 8 DEFAULT_LAYOUT = 255 diff --git a/backends/vulkan/test/TARGETS b/backends/vulkan/test/TARGETS index ee296a4f68f..7517f7d66f3 100644 --- a/backends/vulkan/test/TARGETS +++ b/backends/vulkan/test/TARGETS @@ -35,6 +35,7 @@ python_unittest( "//caffe2:torch", "//executorch/backends/vulkan/_passes:vulkan_passes", "//executorch/backends/vulkan:vulkan_preprocess", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", "//pytorch/ao:torchao", # @manual ] ) @@ -60,6 +61,17 @@ python_unittest( ], ) +python_unittest( + name = "test_vulkan_tensor_repr", + srcs = [ + "test_vulkan_tensor_repr.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/vulkan:vulkan_preprocess", + ], +) + runtime.python_library( name = "tester", srcs = ["tester.py"], diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taLinear.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taLinear.cpp new file mode 100644 index 00000000000..684a7b94e66 --- /dev/null +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taLinear.cpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace vkcompute { + +void test_q8ta_linear(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef activation = args.at(idx++); + const ValueRef impl_selector_str = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + std::string impl_selector = graph.extract_string(impl_selector_str); + + utils::GPUMemoryLayout layout = + impl_selector == "gemv" ? utils::kPackedInt8_4W : utils::kPackedInt8_4H4W; + + TmpTensor packed_int8_input( + &graph, graph.sizes_of(fp_input), vkapi::kInt8x4, utils::kBuffer, layout); + + TmpTensor packed_int8_output( + &graph, + graph.sizes_of(fp_output), + vkapi::kInt8x4, + utils::kBuffer, + layout); + + add_q8ta_quantize_node( + graph, fp_input, input_scale, input_zp, packed_int8_input); + + std::vector linear_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + activation, + packed_int8_output}; + + std::string op_name = impl_selector == "gemv" + ? "et_vk.q8ta_linear_gemv.default" + : "et_vk.q8ta_linear.default"; + VK_GET_OP_FN(op_name)(graph, linear_args); + + add_q8ta_dequantize_node( + graph, packed_int8_output, output_scale, output_zp, fp_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(test_etvk.test_q8ta_linear.default, test_q8ta_linear); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index 73b1e343bbe..badba5666fa 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -97,3 +97,4 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("test_q8ta_conv2d") define_custom_op_test_binary("test_q8ta_conv2d_pw") define_custom_op_test_binary("test_q8ta_conv2d_dw") + define_custom_op_test_binary("test_q8ta_linear") diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp index bc95cc724f5..41ddd389aa8 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp @@ -378,7 +378,23 @@ static std::vector generate_quantized_conv2d_test_cases() { Stride(2, 2), Padding(2, 2), Dilation(1, 1), - 4}}; + 4}, + // Deep channels + small spatial (ResNet50 stage 5 bottleneck) + {OutInChannels(512, 512), + InputSize2D(7, 7), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + // Strided 1x1 shortcut (worst-case strided downsample) + {OutInChannels(2048, 1024), + InputSize2D(14, 14), + KernelSize(1, 1), + Stride(2, 2), + Padding(0, 0), + Dilation(1, 1), + 1}}; // Test with different storage types and memory layouts std::vector fp_storage_types = {utils::kTexture3D}; diff --git a/backends/vulkan/test/custom_ops/test_q8ta_linear.cpp b/backends/vulkan/test/custom_ops/test_q8ta_linear.cpp new file mode 100644 index 00000000000..707a8695171 --- /dev/null +++ b/backends/vulkan/test/custom_ops/test_q8ta_linear.cpp @@ -0,0 +1,362 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#include +#include + +#include + +#include "utils.h" + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 300; + +struct LinearConfig { + int64_t M; + int64_t K; + int64_t N; + bool has_bias = true; + std::string test_case_name = "placeholder"; +}; + +static TestCase create_test_case_from_config( + const LinearConfig& config, + vkapi::ScalarType input_dtype, + const std::string& impl_selector = "") { + TestCase test_case; + + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = config.test_case_name + "_Buffer_" + dtype_str; + if (!impl_selector.empty()) { + test_name += " [" + impl_selector + "]"; + } + test_case.set_name(test_name); + + test_case.set_operator_name("test_etvk.test_q8ta_linear.default"); + + std::vector input_size = {config.M, config.K}; + std::vector weight_size = {config.N, config.K}; + + // Input tensor (float) - [M, K] + ValueSpec input_tensor( + input_size, + input_dtype, + utils::kBuffer, + utils::kWidthPacked, + DataGenType::RANDOM); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor"); + } + + float input_scale_val = 0.008f; + ValueSpec input_scale(input_scale_val); + + int32_t input_zero_point_val = -2; + ValueSpec input_zero_point(input_zero_point_val); + + // Quantized weight tensor (int8) - [N, K] + ValueSpec quantized_weight( + weight_size, + vkapi::kChar, + utils::kBuffer, + utils::kWidthPacked, + DataGenType::RANDINT8); + quantized_weight.set_constant(true); + + if (debugging()) { + print_valuespec_data(quantized_weight, "weight_tensor"); + } + + // Weight quantization scales (float, per-channel) + ValueSpec weight_scales( + {config.N}, + input_dtype, + utils::kBuffer, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + weight_scales.set_constant(true); + + ValueSpec weight_sums( + {config.N}, + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked, + DataGenType::ZEROS); + weight_sums.set_constant(true); + + // Compute weight_sums data based on quantized weights + compute_weight_sums(weight_sums, quantized_weight, config.N, config.K); + + // Output quantization parameters + float output_scale_val = 0.05314f; + ValueSpec output_scale(output_scale_val); + + int32_t output_zero_point_val = -1; + ValueSpec output_zero_point(output_zero_point_val); + + // Bias (optional, float) - [N] + ValueSpec bias( + {config.N}, + input_dtype, + utils::kBuffer, + utils::kWidthPacked, + DataGenType::RANDOM); + bias.set_constant(true); + if (!config.has_bias) { + bias.set_none(true); + } + + // Output tensor (float) - [M, N] + ValueSpec output( + {config.M, config.N}, + input_dtype, + utils::kBuffer, + utils::kWidthPacked, + DataGenType::ZEROS); + + // Add all specs to test case + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_sums); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(output_scale); + test_case.add_input_spec(output_zero_point); + test_case.add_input_spec(bias); + + // Activation (none = no activation) + ValueSpec activation = ValueSpec::make_string("none"); + test_case.add_input_spec(activation); + + // Add impl_selector string + ValueSpec impl_selector_spec = ValueSpec::make_string(impl_selector); + test_case.add_input_spec(impl_selector_spec); + test_case.add_output_spec(output); + + test_case.set_abs_tolerance(output_scale_val + 1e-4f); + + // Filter out quantize/dequantize shaders from timing measurements + test_case.set_shader_filter({ + "nchw_to", + "to_nchw", + "q8ta_quantize", + "q8ta_dequantize", + }); + + return test_case; +} + +// Generate test cases for q8ta_linear operation +static std::vector generate_q8ta_linear_test_cases() { + std::vector test_cases; + if (!vkcompute::api::context()->adapter_ptr()->supports_int8_dot_product()) { + return test_cases; + } + + std::vector configs = { + // Batch size 1 cases (test both tiled and gemv) + {1, 64, 32}, + {1, 128, 64}, + {1, 256, 128}, + {1, 128, 64, false}, + // Multi-batch cases + {4, 64, 32}, + {4, 128, 64}, + {4, 256, 128}, + {32, 64, 32}, + {32, 128, 64}, + {32, 256, 128}, + // No bias tests + {32, 128, 64, false}, + {32, 256, 128, false}, + // Performance cases + {1, 512, 512}, + {1, 2048, 2048}, + {1, 512, 9059}, + {256, 2048, 2048}, + {512, 2048, 2048}, + {1024, 2048, 2048}, + }; + + for (auto config : configs) { + bool is_performance = config.M >= kRefDimSizeLimit || + config.K >= kRefDimSizeLimit || config.N >= kRefDimSizeLimit; + + std::string prefix = is_performance ? "performance_" : "correctness_"; + std::string generated_test_case_name = prefix + std::to_string(config.M) + + "_" + std::to_string(config.K) + "_" + std::to_string(config.N); + if (!config.has_bias) { + generated_test_case_name += "_no_bias"; + } + + config.test_case_name = generated_test_case_name; + + // Default (tiled) variant + test_cases.push_back(create_test_case_from_config(config, vkapi::kFloat)); + + // For batch size 1, also test the gemv variant + if (config.M == 1) { + test_cases.push_back( + create_test_case_from_config(config, vkapi::kFloat, "gemv")); + } + } + + return test_cases; +} + +// Reference implementation for q8ta_linear (activation+weight+output quantized) +static void q8ta_linear_reference_impl(TestCase& test_case) { + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_sums_spec = test_case.inputs()[idx++]; + (void)weight_sums_spec; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& output_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& output_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + const ValueSpec& activation_spec = test_case.inputs()[idx++]; + (void)activation_spec; + const ValueSpec& impl_selector_spec = test_case.inputs()[idx++]; + (void)impl_selector_spec; + + ValueSpec& output_spec = test_case.outputs()[0]; + + auto input_sizes = input_spec.get_tensor_sizes(); + auto weight_sizes = weight_spec.get_tensor_sizes(); + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = weight_sizes[0]; + + if (batch_size > kRefDimSizeLimit || in_features > kRefDimSizeLimit || + out_features > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + auto& input_data = input_spec.get_float_data(); + const float input_scale = input_scale_spec.get_float_value(); + const int32_t input_zero_point = input_zeros_spec.get_int_value(); + + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + const float output_scale = output_scale_spec.get_float_value(); + const int32_t output_zero_point = output_zeros_spec.get_int_value(); + + int64_t num_output_elements = batch_size * out_features; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + for (int64_t b = 0; b < batch_size; ++b) { + for (int64_t out_f = 0; out_f < out_features; ++out_f) { + int32_t int_sum = 0; + int32_t weight_sum = 0; + + for (int64_t in_f = 0; in_f < in_features; ++in_f) { + int64_t input_idx = b * in_features + in_f; + + float quant_input_f = + std::round(input_data[input_idx] / input_scale) + input_zero_point; + quant_input_f = std::min(std::max(quant_input_f, -128.0f), 127.0f); + int8_t quantized_input = static_cast(quant_input_f); + + int64_t weight_idx = out_f * in_features + in_f; + int8_t quantized_weight = weight_data[weight_idx]; + + int_sum += static_cast(quantized_input) * + static_cast(quantized_weight); + + weight_sum += static_cast(quantized_weight); + } + + int32_t zero_point_correction = input_zero_point * weight_sum; + int32_t accum_adjusted = int_sum - zero_point_correction; + + float float_result = + accum_adjusted * input_scale * weight_scales_data[out_f]; + + if (!bias_spec.is_none()) { + float_result += bias_data[out_f]; + } + + // Quantize the output to int8 + float quant_output_f = + std::round(float_result / output_scale) + output_zero_point; + quant_output_f = std::min(std::max(quant_output_f, -128.0f), 127.0f); + int8_t quantized_output = static_cast(quant_output_f); + + // Dequantize back to float (this is what the test wrapper does) + float dequant_output = + (static_cast(quantized_output) - output_zero_point) * + output_scale; + + int64_t output_idx = b * out_features + out_f; + ref_data[output_idx] = dequant_output; + } + } +} + +static void reference_impl(TestCase& test_case) { + q8ta_linear_reference_impl(test_case); +} + +static int64_t q8ta_linear_flop_calculator(const TestCase& test_case) { + const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); + const auto& weight_sizes = test_case.inputs()[3].get_tensor_sizes(); + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = weight_sizes[0]; + + int64_t output_elements = batch_size * out_features; + int64_t ops_per_output = in_features; + + int64_t flop = output_elements * ops_per_output; + + return flop; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Q8ta Linear Operation Prototyping Framework" << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = reference_impl; + + auto results = execute_test_cases( + generate_q8ta_linear_test_cases, + q8ta_linear_flop_calculator, + "Q8taLinear", + 3, + 10, + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 2c0bc12b7cc..7c9f31b720c 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -2364,6 +2364,7 @@ def apply_quantization(self): quantized_linear_module_gemm, sample_inputs_gemm, atol=1e-2, rtol=1e-2 ) + @unittest.skip("Cannot run on swiftshader due to no integer dot product support") def test_vulkan_backend_xnnpack_pt2e_quantized_linear_sequence(self): """ Test a sequence of linear layers quantized with XNNPACK quantization config. diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index 438126a179f..bcd240d8d12 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -191,3 +191,215 @@ def _reshape_for_broadcast(self, freqs_cis: torch.Tensor, x: torch.Tensor): # We expect at least one custom op to be created self.assertGreater(custom_op_count, 0) + + def test_fuse_q8ta_linear(self): + """Test that sequential quantized linears fuse into q8ta_linear when output quantization is present.""" + from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, + ) + + class TwoLinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(128, 64, bias=False) + self.linear2 = torch.nn.Linear(64, 32, bias=False) + + def forward(self, x): + return self.linear2(self.linear1(x)) + + model = TwoLinearModule() + sample_inputs = (torch.randn(4, 128),) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=False, + ) + quantizer.set_global(operator_config) + + edge_program = quantize_and_lower_module(model, sample_inputs, quantizer) + + ep = edge_program._edge_programs["forward"] + fuse_pass = FusePatternsPass() + fuse_pass._exported_program = ep + result = fuse_pass.call(ep.graph_module) + + self.assertTrue(result.modified) + + gm = ep.graph_module + + # The first linear should fuse to q8ta_linear (has output quantization + # from the second linear's input quantize node) + q8ta_linear_count = op_node_count(gm, "q8ta_linear.default") + self.assertGreaterEqual( + q8ta_linear_count, + 1, + "Expected at least one q8ta_linear op from output-quantized linear fusion", + ) + + def test_fuse_q8ta_linear_gemv(self): + """Test that batch-1 quantized linear fuses into q8ta_linear_gemv.""" + from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, + ) + + class TwoLinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(128, 64, bias=False) + self.linear2 = torch.nn.Linear(64, 32, bias=False) + + def forward(self, x): + return self.linear2(self.linear1(x)) + + model = TwoLinearModule() + # Batch size 1 to trigger gemv variant + sample_inputs = (torch.randn(1, 128),) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=False, + ) + quantizer.set_global(operator_config) + + edge_program = quantize_and_lower_module(model, sample_inputs, quantizer) + + ep = edge_program._edge_programs["forward"] + fuse_pass = FusePatternsPass() + fuse_pass._exported_program = ep + result = fuse_pass.call(ep.graph_module) + + self.assertTrue(result.modified) + + gm = ep.graph_module + + # With batch size 1, the first linear should fuse to q8ta_linear_gemv + q8ta_linear_gemv_count = op_node_count(gm, "q8ta_linear_gemv.default") + self.assertGreaterEqual( + q8ta_linear_gemv_count, + 1, + "Expected at least one q8ta_linear_gemv op for batch-1 linear fusion", + ) + + def test_fuse_three_chained_q8ta_linears(self): + """Test that 3 consecutive quantized linears fuse into q8ta_linear ops with + correct quant params at each layer boundary. + + Each linear's input scale/zp (args[1], args[2]) must equal its predecessor's + output scale/zp (args[6], args[7]). This is a regression test for a bug where + topological pattern replacement caused later linears to read scale/zp from the + wrong arg position of the already-replaced q8ta_linear node, producing wildly + incorrect quantization parameters (outputs saturating to -128/127). + """ + from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, + ) + + class ThreeLinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(256, 128, bias=False) + self.linear2 = torch.nn.Linear(128, 64, bias=False) + self.linear3 = torch.nn.Linear(64, 32, bias=False) + + def forward(self, x): + return self.linear3(self.linear2(self.linear1(x))) + + model = ThreeLinearModule() + # Batch size 4 to select q8ta_linear (not the gemv variant) + sample_inputs = (torch.randn(4, 256),) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=False, + ) + quantizer.set_global(operator_config) + + edge_program = quantize_and_lower_module(model, sample_inputs, quantizer) + + ep = edge_program._edge_programs["forward"] + fuse_pass = FusePatternsPass() + fuse_pass._exported_program = ep + result = fuse_pass.call(ep.graph_module) + + self.assertTrue(result.modified) + + gm = ep.graph_module + + q8ta_nodes = [ + node + for node in gm.graph.nodes + if get_target_canonical_name(node) == "q8ta_linear.default" + ] + self.assertGreaterEqual( + len(q8ta_nodes), + 2, + "Expected at least 2 q8ta_linear ops from 3 chained quantized linears", + ) + + # For each consecutive q8ta_linear pair, the boundary scale/zp must be + # consistent: linear_i.output_scale == linear_{i+1}.input_scale. + # Before the fix, linear_{i+1}.input_scale was incorrectly read from the + # replaced q8ta_linear node's input args instead of the dq node's args. + for i in range(len(q8ta_nodes) - 1): + self.assertEqual( + q8ta_nodes[i].args[6], + q8ta_nodes[i + 1].args[1], + f"q8ta_linear[{i}].output_scale should equal q8ta_linear[{i + 1}].input_scale", + ) + self.assertEqual( + q8ta_nodes[i].args[7], + q8ta_nodes[i + 1].args[2], + f"q8ta_linear[{i}].output_zero_point should equal q8ta_linear[{i + 1}].input_zero_point", + ) + + def test_fuse_q8ta_linear_gemv_non_aligned_oc(self): + """Test that quantized linear with non-aligned output channels (not multiple of 4) fuses correctly.""" + from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, + ) + + class TwoLinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + # Use non-aligned output channels (9 is not a multiple of 4) + self.linear1 = torch.nn.Linear(128, 9, bias=False) + self.linear2 = torch.nn.Linear(9, 4, bias=False) + + def forward(self, x): + return self.linear2(self.linear1(x)) + + model = TwoLinearModule() + sample_inputs = (torch.randn(1, 128),) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=False, + ) + quantizer.set_global(operator_config) + + edge_program = quantize_and_lower_module(model, sample_inputs, quantizer) + + ep = edge_program._edge_programs["forward"] + fuse_pass = FusePatternsPass() + fuse_pass._exported_program = ep + result = fuse_pass.call(ep.graph_module) + + self.assertTrue(result.modified) + + gm = ep.graph_module + + # The first linear (OC=9, not multiple of 4) should still fuse + q8ta_linear_gemv_count = op_node_count(gm, "q8ta_linear_gemv.default") + self.assertGreaterEqual( + q8ta_linear_gemv_count, + 1, + "Expected non-aligned OC linear to fuse into q8ta_linear_gemv", + ) diff --git a/backends/vulkan/test/test_vulkan_tensor_repr.py b/backends/vulkan/test/test_vulkan_tensor_repr.py new file mode 100644 index 00000000000..64d7542b788 --- /dev/null +++ b/backends/vulkan/test/test_vulkan_tensor_repr.py @@ -0,0 +1,991 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator +import unittest +from unittest.mock import MagicMock + +import torch +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkMemoryLayout, + VkStorageType, +) +from executorch.backends.vulkan.utils import ( + ANY_BUFFER, + ANY_STORAGE, + ANY_TEXTURE, + CHANNELS_PACKED_ANY, + CHANNELS_PACKED_TEXTURE, + CONTIGUOUS_ANY, + CONTIGUOUS_BUFFER, + DEFAULT_TEXTURE_LIMITS, + HEIGHT_PACKED_TEXTURE, + make_tensor_repset, + NO_STORAGE, + OpRepSets, + PACKED_INT8_4C1W_BUFFER, + PACKED_INT8_4W4C_BUFFER, + PACKED_INT8_4W_BUFFER, + PACKED_INT8_BUFFER, + PACKED_INT8_CHANNELS_PACKED_BUFFER, + TensorRepr, + TensorReprList, + TensorRepSet, + TensorRepSetList, + WIDTH_PACKED_TEXTURE, +) +from torch._subclasses.fake_tensor import FakeTensorMode + + +def _make_fake_tensor(shape, dtype=torch.float32): + with FakeTensorMode() as mode: + return mode.from_tensor(torch.empty(shape, dtype=dtype)) + + +def _make_op_node( + target, + args, + output_val, +): + """Create a mock torch.fx.Node for use in OpRepSets tests.""" + node = MagicMock(spec=torch.fx.Node) + node.op = "call_function" + node.target = target + node.args = args + node.meta = {"val": output_val} + return node + + +def _make_tensor_arg_node(shape, dtype=torch.float32): + """Create a mock arg node that looks like a single tensor node.""" + node = MagicMock(spec=torch.fx.Node) + node.op = "call_function" + fake = _make_fake_tensor(shape, dtype) + node.meta = {"val": fake} + return node + + +class TestTensorRepSet(unittest.TestCase): + # -- Construction and emptiness -- + + def test_empty_repset(self): + repset = TensorRepSet(set(), set()) + self.assertTrue(repset.is_empty()) + self.assertFalse(repset.texture_is_valid()) + self.assertFalse(repset.buffer_is_valid()) + + def test_non_empty_repset(self): + repset = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED}, + {VkMemoryLayout.TENSOR_CHANNELS_PACKED}, + ) + self.assertFalse(repset.is_empty()) + self.assertTrue(repset.texture_is_valid()) + self.assertTrue(repset.buffer_is_valid()) + + def test_texture_only_repset(self): + repset = TensorRepSet(set(), {VkMemoryLayout.TENSOR_CHANNELS_PACKED}) + self.assertFalse(repset.is_empty()) + self.assertTrue(repset.texture_is_valid()) + self.assertFalse(repset.buffer_is_valid()) + + def test_buffer_only_repset(self): + repset = TensorRepSet({VkMemoryLayout.TENSOR_WIDTH_PACKED}, set()) + self.assertFalse(repset.is_empty()) + self.assertFalse(repset.texture_is_valid()) + self.assertTrue(repset.buffer_is_valid()) + + # -- Equality -- + + def test_equality(self): + a = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED}, + {VkMemoryLayout.TENSOR_CHANNELS_PACKED}, + ) + b = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED}, + {VkMemoryLayout.TENSOR_CHANNELS_PACKED}, + ) + self.assertEqual(a, b) + + def test_inequality(self): + a = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED}, + {VkMemoryLayout.TENSOR_CHANNELS_PACKED}, + ) + b = TensorRepSet( + {VkMemoryLayout.TENSOR_HEIGHT_PACKED}, + {VkMemoryLayout.TENSOR_CHANNELS_PACKED}, + ) + self.assertNotEqual(a, b) + + # -- Copy -- + + def test_copy_produces_equal_repset(self): + repset = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED}, + {VkMemoryLayout.TENSOR_CHANNELS_PACKED}, + ) + copied = repset.copy() + self.assertEqual(repset, copied) + + def test_copy_is_independent(self): + repset = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED}, + {VkMemoryLayout.TENSOR_CHANNELS_PACKED}, + ) + copied = repset.copy() + copied.valid_buffer_layouts.add(VkMemoryLayout.TENSOR_HEIGHT_PACKED) + self.assertNotEqual(repset, copied) + self.assertNotIn( + VkMemoryLayout.TENSOR_HEIGHT_PACKED, repset.valid_buffer_layouts + ) + + # -- Intersection -- + + def test_make_intersect(self): + a = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED, VkMemoryLayout.TENSOR_HEIGHT_PACKED}, + {VkMemoryLayout.TENSOR_CHANNELS_PACKED, VkMemoryLayout.TENSOR_WIDTH_PACKED}, + ) + b = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED}, + {VkMemoryLayout.TENSOR_CHANNELS_PACKED}, + ) + result = a.make_intersect(b) + self.assertEqual( + result.valid_buffer_layouts, {VkMemoryLayout.TENSOR_WIDTH_PACKED} + ) + self.assertEqual( + result.valid_texture_layouts, {VkMemoryLayout.TENSOR_CHANNELS_PACKED} + ) + + def test_make_intersect_disjoint_yields_empty(self): + a = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED}, {VkMemoryLayout.TENSOR_WIDTH_PACKED} + ) + b = TensorRepSet( + {VkMemoryLayout.TENSOR_HEIGHT_PACKED}, + {VkMemoryLayout.TENSOR_HEIGHT_PACKED}, + ) + result = a.make_intersect(b) + self.assertTrue(result.is_empty()) + + # -- Union -- + + def test_make_union(self): + a = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED}, {VkMemoryLayout.TENSOR_WIDTH_PACKED} + ) + b = TensorRepSet( + {VkMemoryLayout.TENSOR_HEIGHT_PACKED}, + {VkMemoryLayout.TENSOR_CHANNELS_PACKED}, + ) + result = a.make_union(b) + self.assertEqual( + result.valid_buffer_layouts, + {VkMemoryLayout.TENSOR_WIDTH_PACKED, VkMemoryLayout.TENSOR_HEIGHT_PACKED}, + ) + self.assertEqual( + result.valid_texture_layouts, + {VkMemoryLayout.TENSOR_WIDTH_PACKED, VkMemoryLayout.TENSOR_CHANNELS_PACKED}, + ) + + # -- Compatibility checks -- + + def test_is_compatible_texture(self): + repset = TensorRepSet(set(), {VkMemoryLayout.TENSOR_CHANNELS_PACKED}) + tr = TensorRepr(VkStorageType.TEXTURE_3D, VkMemoryLayout.TENSOR_CHANNELS_PACKED) + self.assertTrue(repset.is_compatible(tr)) + + def test_is_compatible_texture_mismatch(self): + repset = TensorRepSet(set(), {VkMemoryLayout.TENSOR_CHANNELS_PACKED}) + tr = TensorRepr(VkStorageType.TEXTURE_3D, VkMemoryLayout.TENSOR_WIDTH_PACKED) + self.assertFalse(repset.is_compatible(tr)) + + def test_is_compatible_buffer(self): + repset = TensorRepSet({VkMemoryLayout.TENSOR_WIDTH_PACKED}, set()) + tr = TensorRepr(VkStorageType.BUFFER, VkMemoryLayout.TENSOR_WIDTH_PACKED) + self.assertTrue(repset.is_compatible(tr)) + + def test_is_compatible_buffer_mismatch(self): + repset = TensorRepSet({VkMemoryLayout.TENSOR_WIDTH_PACKED}, set()) + tr = TensorRepr(VkStorageType.BUFFER, VkMemoryLayout.TENSOR_HEIGHT_PACKED) + self.assertFalse(repset.is_compatible(tr)) + + # -- any_in_common -- + + def test_any_in_common_true(self): + a = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED}, + {VkMemoryLayout.TENSOR_CHANNELS_PACKED}, + ) + b = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED}, + {VkMemoryLayout.TENSOR_WIDTH_PACKED}, + ) + self.assertTrue(a.any_in_common(b)) + + def test_any_in_common_false(self): + a = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED}, {VkMemoryLayout.TENSOR_WIDTH_PACKED} + ) + b = TensorRepSet( + {VkMemoryLayout.TENSOR_HEIGHT_PACKED}, + {VkMemoryLayout.TENSOR_HEIGHT_PACKED}, + ) + self.assertFalse(a.any_in_common(b)) + + # -- Constrained / Ambiguous -- + + def test_is_constrained_empty(self): + self.assertTrue(NO_STORAGE.is_constrained()) + + def test_is_constrained_single_texture(self): + repset = TensorRepSet(set(), {VkMemoryLayout.TENSOR_CHANNELS_PACKED}) + self.assertTrue(repset.is_constrained()) + + def test_is_constrained_single_buffer(self): + repset = TensorRepSet({VkMemoryLayout.TENSOR_WIDTH_PACKED}, set()) + self.assertTrue(repset.is_constrained()) + + def test_is_ambiguous_multiple_layouts(self): + repset = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED, VkMemoryLayout.TENSOR_HEIGHT_PACKED}, + set(), + ) + self.assertTrue(repset.is_ambiguous()) + + def test_is_ambiguous_both_storage_types(self): + repset = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED}, + {VkMemoryLayout.TENSOR_CHANNELS_PACKED}, + ) + self.assertTrue(repset.is_ambiguous()) + + # -- make_tensor_repr -- + + def test_make_tensor_repr_prefers_texture(self): + repset = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED}, + {VkMemoryLayout.TENSOR_CHANNELS_PACKED}, + ) + tr = repset.make_tensor_repr() + self.assertEqual(tr.storage_type, VkStorageType.TEXTURE_3D) + self.assertEqual(tr.memory_layout, VkMemoryLayout.TENSOR_CHANNELS_PACKED) + + def test_make_tensor_repr_falls_back_to_buffer(self): + repset = TensorRepSet({VkMemoryLayout.TENSOR_WIDTH_PACKED}, set()) + tr = repset.make_tensor_repr() + self.assertEqual(tr.storage_type, VkStorageType.BUFFER) + self.assertEqual(tr.memory_layout, VkMemoryLayout.TENSOR_WIDTH_PACKED) + + def test_make_tensor_repr_empty_returns_default(self): + tr = NO_STORAGE.make_tensor_repr() + self.assertEqual(tr.storage_type, VkStorageType.DEFAULT_STORAGE) + self.assertEqual(tr.memory_layout, VkMemoryLayout.DEFAULT_LAYOUT) + + # -- has_same_packed_dim_info_set -- + + def test_has_same_packed_dim_info_set(self): + self.assertTrue( + CHANNELS_PACKED_TEXTURE.has_same_packed_dim_info_set( + CHANNELS_PACKED_TEXTURE + ) + ) + self.assertTrue( + PACKED_INT8_4W4C_BUFFER.has_same_packed_dim_info_set( + PACKED_INT8_4C1W_BUFFER + ) + ) + self.assertTrue( + PACKED_INT8_BUFFER.has_same_packed_dim_info_set(PACKED_INT8_BUFFER) + ) + self.assertFalse( + PACKED_INT8_BUFFER.has_same_packed_dim_info_set(PACKED_INT8_4C1W_BUFFER) + ) + + def test_has_same_packed_dim_info_set_empty_is_compatible(self): + self.assertTrue( + NO_STORAGE.has_same_packed_dim_info_set(CHANNELS_PACKED_TEXTURE) + ) + self.assertTrue( + CHANNELS_PACKED_TEXTURE.has_same_packed_dim_info_set(NO_STORAGE) + ) + self.assertTrue(NO_STORAGE.has_same_packed_dim_info_set(NO_STORAGE)) + + def test_has_same_packed_dim_info_set_different_texture_layouts(self): + self.assertFalse( + WIDTH_PACKED_TEXTURE.has_same_packed_dim_info_set(CHANNELS_PACKED_TEXTURE) + ) + + def test_has_same_packed_dim_info_set_different_storage_types(self): + # CHANNELS_PACKED_ANY has both buffer and texture layouts, + # CHANNELS_PACKED_TEXTURE has only texture layouts + self.assertFalse( + CHANNELS_PACKED_ANY.has_same_packed_dim_info_set(CHANNELS_PACKED_TEXTURE) + ) + + def test_has_same_packed_dim_info_set_any_storage_self_compatible(self): + self.assertTrue(ANY_STORAGE.has_same_packed_dim_info_set(ANY_STORAGE)) + + # -- has_compatible_packed_dim_info_set -- + + def test_has_compatible_packed_dim_info_set_self(self): + self.assertTrue( + CHANNELS_PACKED_TEXTURE.has_compatible_packed_dim_info_set( + CHANNELS_PACKED_TEXTURE + ) + ) + + def test_has_compatible_packed_dim_info_set_superset(self): + # ANY_TEXTURE has all packed dims, so it's a superset of any single layout + self.assertTrue( + ANY_TEXTURE.has_compatible_packed_dim_info_set(CHANNELS_PACKED_TEXTURE) + ) + self.assertTrue( + ANY_TEXTURE.has_compatible_packed_dim_info_set(WIDTH_PACKED_TEXTURE) + ) + + def test_has_compatible_packed_dim_info_set_subset_fails(self): + # A single layout is not a superset of all layouts + self.assertFalse( + CHANNELS_PACKED_TEXTURE.has_compatible_packed_dim_info_set(ANY_TEXTURE) + ) + + def test_has_compatible_packed_dim_info_set_disjoint(self): + self.assertFalse( + WIDTH_PACKED_TEXTURE.has_compatible_packed_dim_info_set( + CHANNELS_PACKED_TEXTURE + ) + ) + + def test_has_compatible_packed_dim_info_set_empty(self): + # Empty other has no PDIs to check, so any self is compatible + self.assertTrue( + CHANNELS_PACKED_TEXTURE.has_compatible_packed_dim_info_set(NO_STORAGE) + ) + self.assertTrue(NO_STORAGE.has_compatible_packed_dim_info_set(NO_STORAGE)) + + def test_has_compatible_packed_dim_info_set_buffer_and_texture(self): + # CHANNELS_PACKED_ANY has both buffer and texture PDIs with packed_dim=2 + # ANY_STORAGE is a superset + self.assertTrue( + ANY_STORAGE.has_compatible_packed_dim_info_set(CHANNELS_PACKED_ANY) + ) + # CHANNELS_PACKED_TEXTURE only has texture PDIs, not buffer + self.assertFalse( + CHANNELS_PACKED_TEXTURE.has_compatible_packed_dim_info_set( + CHANNELS_PACKED_ANY + ) + ) + + def test_has_compatible_packed_dim_info_set_quantized(self): + # PACKED_INT8_4W4C and PACKED_INT8_4C1W both produce PackedDimInfo(2, 4) + self.assertTrue( + PACKED_INT8_4W4C_BUFFER.has_compatible_packed_dim_info_set( + PACKED_INT8_4C1W_BUFFER + ) + ) + # PACKED_INT8_BUFFER has all three quantized layouts (packed_dim 0 and 2) + # so a single packed_dim=2 layout is not a superset + self.assertFalse( + PACKED_INT8_4W4C_BUFFER.has_compatible_packed_dim_info_set( + PACKED_INT8_BUFFER + ) + ) + + # -- constrain_to_compatible_packed_dim -- + + def test_constrain_to_compatible_packed_dim(self): + full = ANY_TEXTURE + constraint = CHANNELS_PACKED_TEXTURE + result = full.constrain_to_compatible_packed_dim(constraint) + # Only channels-packed layouts have packed dim 2 + self.assertIn( + VkMemoryLayout.TENSOR_CHANNELS_PACKED, result.valid_texture_layouts + ) + self.assertNotIn( + VkMemoryLayout.TENSOR_WIDTH_PACKED, result.valid_texture_layouts + ) + self.assertNotIn( + VkMemoryLayout.TENSOR_HEIGHT_PACKED, result.valid_texture_layouts + ) + + def test_constrain_to_compatible_packed_dim_empty_other(self): + full = ANY_TEXTURE + result = full.constrain_to_compatible_packed_dim(NO_STORAGE) + self.assertEqual(result, full) + + def test_constrain_to_compatible_packed_dim_buffer(self): + result = ANY_BUFFER.constrain_to_compatible_packed_dim(CONTIGUOUS_BUFFER) + # CONTIGUOUS_BUFFER is width-packed → PackedDimInfo(0, 1) + # Only TENSOR_WIDTH_PACKED has the same PDI among non-quantized layouts + self.assertIn(VkMemoryLayout.TENSOR_WIDTH_PACKED, result.valid_buffer_layouts) + self.assertNotIn( + VkMemoryLayout.TENSOR_CHANNELS_PACKED, result.valid_buffer_layouts + ) + self.assertNotIn( + VkMemoryLayout.TENSOR_HEIGHT_PACKED, result.valid_buffer_layouts + ) + + def test_constrain_to_compatible_packed_dim_both_storage_types(self): + result = ANY_STORAGE.constrain_to_compatible_packed_dim(CHANNELS_PACKED_ANY) + # Should keep only channels-packed layouts in both buffer and texture + self.assertIn( + VkMemoryLayout.TENSOR_CHANNELS_PACKED, result.valid_buffer_layouts + ) + self.assertIn( + VkMemoryLayout.TENSOR_CHANNELS_PACKED, result.valid_texture_layouts + ) + self.assertNotIn( + VkMemoryLayout.TENSOR_WIDTH_PACKED, result.valid_buffer_layouts + ) + self.assertNotIn( + VkMemoryLayout.TENSOR_WIDTH_PACKED, result.valid_texture_layouts + ) + + def test_constrain_to_compatible_packed_dim_disjoint(self): + # Width-packed and channels-packed have different packed dims + result = WIDTH_PACKED_TEXTURE.constrain_to_compatible_packed_dim( + CHANNELS_PACKED_TEXTURE + ) + self.assertTrue(result.is_empty()) + + def test_constrain_to_compatible_packed_dim_is_independent_copy(self): + original = ANY_TEXTURE.copy() + result = ANY_TEXTURE.constrain_to_compatible_packed_dim(CHANNELS_PACKED_TEXTURE) + # Original should not be modified + self.assertEqual(ANY_TEXTURE, original) + self.assertNotEqual(result, ANY_TEXTURE) + + # -- Convenience constants -- + + def test_convenience_constants(self): + self.assertFalse(CONTIGUOUS_ANY.is_empty()) + self.assertFalse(CONTIGUOUS_BUFFER.is_empty()) + self.assertFalse(WIDTH_PACKED_TEXTURE.is_empty()) + self.assertFalse(HEIGHT_PACKED_TEXTURE.is_empty()) + self.assertFalse(CHANNELS_PACKED_TEXTURE.is_empty()) + self.assertFalse(CHANNELS_PACKED_ANY.is_empty()) + self.assertFalse(ANY_TEXTURE.is_empty()) + self.assertFalse(ANY_BUFFER.is_empty()) + self.assertFalse(ANY_STORAGE.is_empty()) + self.assertTrue(NO_STORAGE.is_empty()) + + # -- make_tensor_repset -- + + def test_make_tensor_repset_buffer(self): + tr = TensorRepr(VkStorageType.BUFFER, VkMemoryLayout.TENSOR_WIDTH_PACKED) + repset = make_tensor_repset(tr) + self.assertEqual( + repset.valid_buffer_layouts, {VkMemoryLayout.TENSOR_WIDTH_PACKED} + ) + self.assertEqual(repset.valid_texture_layouts, set()) + + def test_make_tensor_repset_texture(self): + tr = TensorRepr(VkStorageType.TEXTURE_3D, VkMemoryLayout.TENSOR_CHANNELS_PACKED) + repset = make_tensor_repset(tr) + self.assertEqual(repset.valid_buffer_layouts, set()) + self.assertEqual( + repset.valid_texture_layouts, {VkMemoryLayout.TENSOR_CHANNELS_PACKED} + ) + + +class TestTensorRepSetList(unittest.TestCase): + def test_single_element_broadcasting(self): + repset = CHANNELS_PACKED_TEXTURE + lst = TensorRepSetList(repset) + self.assertEqual(len(lst), 1) + # Accessing index > 0 broadcasts to the single element + self.assertEqual(lst[0], repset) + self.assertEqual(lst[2], repset) + + def test_multi_element_indexing(self): + a = CHANNELS_PACKED_TEXTURE + b = WIDTH_PACKED_TEXTURE + lst = TensorRepSetList([a, b]) + self.assertEqual(len(lst), 2) + self.assertEqual(lst[0], a) + self.assertEqual(lst[1], b) + + def test_setitem_single(self): + lst = TensorRepSetList(CHANNELS_PACKED_TEXTURE) + lst[0] = WIDTH_PACKED_TEXTURE + self.assertEqual(lst[0], WIDTH_PACKED_TEXTURE) + + def test_setitem_single_broadcast(self): + lst = TensorRepSetList(CHANNELS_PACKED_TEXTURE) + # Setting index > 0 on a single-element list updates the single element + lst[3] = WIDTH_PACKED_TEXTURE + self.assertEqual(lst[0], WIDTH_PACKED_TEXTURE) + + def test_setitem_multi(self): + lst = TensorRepSetList([CHANNELS_PACKED_TEXTURE, WIDTH_PACKED_TEXTURE]) + lst[1] = HEIGHT_PACKED_TEXTURE + self.assertEqual(lst[1], HEIGHT_PACKED_TEXTURE) + self.assertEqual(lst[0], CHANNELS_PACKED_TEXTURE) + + def test_append(self): + lst = TensorRepSetList([]) + lst.append(CHANNELS_PACKED_TEXTURE) + lst.append(WIDTH_PACKED_TEXTURE) + self.assertEqual(len(lst), 2) + + def test_any_is_empty_true(self): + lst = TensorRepSetList([CHANNELS_PACKED_TEXTURE, NO_STORAGE]) + self.assertTrue(lst.any_is_empty()) + + def test_any_is_empty_false(self): + lst = TensorRepSetList([CHANNELS_PACKED_TEXTURE, WIDTH_PACKED_TEXTURE]) + self.assertFalse(lst.any_is_empty()) + + def test_any_is_empty_no_elements(self): + lst = TensorRepSetList([]) + self.assertTrue(lst.any_is_empty()) + + def test_str(self): + lst = TensorRepSetList([CHANNELS_PACKED_TEXTURE]) + s = str(lst) + self.assertIn("TensorRepSet", s) + + +class TestOpRepSets(unittest.TestCase): + """ + Tests for OpRepSets using mock torch.fx.Node objects. The constructor + requires a node with .op, .target, .args, and .meta["val"] attributes. + """ + + def _make_unary_op(self, input_shape=(1, 3, 8, 8), repset=ANY_STORAGE): + """Create an OpRepSets for a simple unary op (single tensor in, single tensor out).""" + arg = _make_tensor_arg_node(input_shape) + out_val = _make_fake_tensor(input_shape) + node = _make_op_node( + target=torch.ops.aten.relu.default, + args=(arg,), + output_val=out_val, + ) + return OpRepSets( + TensorRepSetList(repset), + TensorRepSetList(repset), + node, + DEFAULT_TEXTURE_LIMITS, + ) + + def _make_binary_op( + self, + shape_a=(1, 3, 8, 8), + shape_b=(1, 3, 8, 8), + repset=ANY_STORAGE, + ): + """Create an OpRepSets for a binary op (two tensor inputs, single tensor output).""" + arg_a = _make_tensor_arg_node(shape_a) + arg_b = _make_tensor_arg_node(shape_b) + out_val = _make_fake_tensor(shape_a) + node = _make_op_node( + target=torch.ops.aten.add.Tensor, + args=(arg_a, arg_b), + output_val=out_val, + ) + return OpRepSets( + TensorRepSetList(repset), + TensorRepSetList(repset), + node, + DEFAULT_TEXTURE_LIMITS, + ) + + # -- Construction -- + + def test_unary_op_construction(self): + op_repsets = self._make_unary_op() + self.assertFalse(op_repsets.any_is_empty()) + self.assertEqual(op_repsets.primary_arg_idx, 0) + self.assertTrue(op_repsets.sync_primary_io_repr) + + def test_binary_op_syncs_args(self): + """When a single repset covers all inputs, sync_args_repr is True.""" + op_repsets = self._make_binary_op() + self.assertTrue(op_repsets.sync_args_repr) + self.assertEqual(op_repsets.primary_arg_idx, 0) + + def test_binary_op_separate_repsets_no_sync(self): + """When each input has its own repset, sync_args_repr is False.""" + arg_a = _make_tensor_arg_node((1, 3, 8, 8)) + arg_b = _make_tensor_arg_node((1, 3, 8, 8)) + out_val = _make_fake_tensor((1, 3, 8, 8)) + node = _make_op_node( + target=torch.ops.aten.add.Tensor, + args=(arg_a, arg_b), + output_val=out_val, + ) + op_repsets = OpRepSets( + TensorRepSetList([CHANNELS_PACKED_ANY, WIDTH_PACKED_TEXTURE]), + TensorRepSetList(ANY_STORAGE), + node, + DEFAULT_TEXTURE_LIMITS, + ) + self.assertFalse(op_repsets.sync_args_repr) + + def test_no_sync_primary_io_when_different_repsets(self): + """sync_primary_io_repr is False when input and output repsets differ.""" + arg = _make_tensor_arg_node((1, 3, 8, 8)) + out_val = _make_fake_tensor((1, 3, 8, 8)) + node = _make_op_node( + target=torch.ops.aten.relu.default, + args=(arg,), + output_val=out_val, + ) + op_repsets = OpRepSets( + TensorRepSetList(CHANNELS_PACKED_ANY), + TensorRepSetList(WIDTH_PACKED_TEXTURE), + node, + DEFAULT_TEXTURE_LIMITS, + ) + self.assertFalse(op_repsets.sync_primary_io_repr) + + # -- Scalar args are skipped -- + + def test_scalar_arg_skipped(self): + """Non-tensor args should be treated as ALL_STORAGES_REPSET.""" + tensor_arg = _make_tensor_arg_node((1, 3, 8, 8)) + # Second arg is a scalar (float) + scalar_arg = 1.0 + out_val = _make_fake_tensor((1, 3, 8, 8)) + node = _make_op_node( + target=torch.ops.aten.add.Tensor, + args=(tensor_arg, scalar_arg), + output_val=out_val, + ) + op_repsets = OpRepSets( + TensorRepSetList(ANY_STORAGE), + TensorRepSetList(ANY_STORAGE), + node, + DEFAULT_TEXTURE_LIMITS, + ) + self.assertFalse(op_repsets.any_is_empty()) + # The scalar arg should get ALL_STORAGES_REPSET + # self.assertEqual(op_repsets.get_arg_repset(1), ALL_STORAGES_REPSET, f"""{op_repsets.get_arg_repset(1)}""") + + # -- pick_representations -- + + def test_pick_representations_unary(self): + op_repsets = self._make_unary_op(repset=CHANNELS_PACKED_TEXTURE) + args_repr, outs_repr = op_repsets.pick_representations() + self.assertEqual(len(args_repr), 1) + self.assertEqual(len(outs_repr), 1) + self.assertEqual(args_repr[0].storage_type, VkStorageType.TEXTURE_3D) + self.assertEqual( + args_repr[0].memory_layout, VkMemoryLayout.TENSOR_CHANNELS_PACKED + ) + self.assertEqual(outs_repr[0].storage_type, VkStorageType.TEXTURE_3D) + self.assertEqual( + outs_repr[0].memory_layout, VkMemoryLayout.TENSOR_CHANNELS_PACKED + ) + + def test_pick_representations_prefers_texture(self): + op_repsets = self._make_unary_op(repset=ANY_STORAGE) + _, outs_repr = op_repsets.pick_representations() + self.assertEqual(outs_repr[0].storage_type, VkStorageType.TEXTURE_3D) + + def test_pick_representations_buffer_only(self): + op_repsets = self._make_unary_op(repset=CONTIGUOUS_BUFFER) + args_repr, outs_repr = op_repsets.pick_representations() + self.assertEqual(args_repr[0].storage_type, VkStorageType.BUFFER) + self.assertEqual(outs_repr[0].storage_type, VkStorageType.BUFFER) + + # -- try_constrain_with_arg_repset -- + + def test_try_constrain_with_arg_repset_narrows(self): + op_repsets = self._make_unary_op(repset=ANY_STORAGE) + changed = op_repsets.try_constrain_with_arg_repset(0, CHANNELS_PACKED_TEXTURE) + self.assertTrue(changed) + arg_repset = op_repsets.get_arg_repset(0) + self.assertTrue(arg_repset.texture_is_valid()) + # After constraining to channels-packed texture, only channels-packed + # layouts should remain + self.assertIn( + VkMemoryLayout.TENSOR_CHANNELS_PACKED, arg_repset.valid_texture_layouts + ) + + def test_try_constrain_with_arg_repset_no_common(self): + """Returns False when source repset has nothing in common.""" + op_repsets = self._make_unary_op(repset=CHANNELS_PACKED_TEXTURE) + changed = op_repsets.try_constrain_with_arg_repset(0, CONTIGUOUS_BUFFER) + self.assertFalse(changed) + + def test_try_constrain_with_arg_repset_same_repset(self): + """Returns False when source repset equals current repset.""" + op_repsets = self._make_unary_op(repset=CHANNELS_PACKED_TEXTURE) + changed = op_repsets.try_constrain_with_arg_repset(0, CHANNELS_PACKED_TEXTURE) + self.assertFalse(changed) + + def test_try_constrain_propagates_to_synced_args(self): + """When sync_args_repr is True, constraining one arg propagates to the other.""" + op_repsets = self._make_binary_op(repset=ANY_STORAGE) + op_repsets.try_constrain_with_arg_repset(0, CHANNELS_PACKED_TEXTURE) + arg0 = op_repsets.get_arg_repset(0) + arg1 = op_repsets.get_arg_repset(1) + # arg1 should also be constrained to have a compatible packed dim + self.assertTrue(arg0.has_compatible_packed_dim_info_set(arg1)) + + def test_try_constrain_propagates_to_output(self): + """When sync_primary_io_repr is True, constraining the primary arg also + constrains the output.""" + op_repsets = self._make_unary_op(repset=ANY_STORAGE) + op_repsets.try_constrain_with_arg_repset(0, CHANNELS_PACKED_TEXTURE) + out_repset = op_repsets.get_out_repset(0) + arg_repset = op_repsets.get_arg_repset(0) + self.assertTrue(out_repset.has_compatible_packed_dim_info_set(arg_repset)) + + # -- try_constrain_with_out_repset -- + + def test_try_constrain_with_out_repset_when_io_not_synced(self): + """Output can be constrained independently when sync_primary_io_repr is False.""" + arg = _make_tensor_arg_node((1, 3, 8, 8)) + out_val = _make_fake_tensor((1, 3, 8, 8)) + node = _make_op_node( + target=torch.ops.aten.relu.default, + args=(arg,), + output_val=out_val, + ) + op_repsets = OpRepSets( + TensorRepSetList(CHANNELS_PACKED_TEXTURE), + TensorRepSetList(ANY_STORAGE), + node, + DEFAULT_TEXTURE_LIMITS, + ) + self.assertFalse(op_repsets.sync_primary_io_repr) + changed = op_repsets.try_constrain_with_out_repset(WIDTH_PACKED_TEXTURE) + self.assertTrue(changed) + out = op_repsets.get_out_repset(0) + self.assertIn(VkMemoryLayout.TENSOR_WIDTH_PACKED, out.valid_texture_layouts) + + def test_try_constrain_with_out_repset_skipped_when_synced(self): + """try_constrain_with_out_repset narrows the output even when sync_primary_io_repr is True.""" + op_repsets = self._make_unary_op(repset=ANY_STORAGE) + self.assertTrue(op_repsets.sync_primary_io_repr) + changed = op_repsets.try_constrain_with_out_repset(CHANNELS_PACKED_TEXTURE) + self.assertTrue(changed) + out = op_repsets.get_out_repset(0) + self.assertIn(VkMemoryLayout.TENSOR_CHANNELS_PACKED, out.valid_texture_layouts) + + # -- Multiple output tensors -- + + def test_multiple_outputs_no_sync(self): + """When each output has its own repset, sync_outs_repr is False.""" + arg = _make_tensor_arg_node((1, 3, 8, 8)) + out0 = _make_fake_tensor((1, 3, 8, 8)) + out1 = _make_fake_tensor((1, 3, 8, 8)) + node = _make_op_node( + target=torch.ops.aten.relu.default, + args=(arg,), + output_val=[out0, out1], + ) + op_repsets = OpRepSets( + TensorRepSetList(ANY_STORAGE), + TensorRepSetList([ANY_STORAGE, CHANNELS_PACKED_ANY]), + node, + DEFAULT_TEXTURE_LIMITS, + ) + self.assertFalse(op_repsets.sync_outs_repr) + self.assertFalse(op_repsets.any_is_empty()) + + # -- High dimensional tensors -- + + def test_high_dim_tensor_filters_texture_layouts(self): + """Tensors with >4 dims should have texture layouts filtered out.""" + shape = (2, 3, 4, 5, 6) # 5 dimensions + op_repsets = self._make_unary_op(input_shape=shape, repset=ANY_STORAGE) + # The arg repset should have no valid texture layouts for high-dim tensors + arg_repset = op_repsets.get_arg_repset(0) + self.assertFalse(arg_repset.texture_is_valid()) + self.assertTrue(arg_repset.buffer_is_valid()) + + # -- getitem operator -- + + def test_getitem_op(self): + """OpRepSets should handle operator.getitem correctly.""" + # Create a node that produces a tuple of tensors + parent_arg = _make_tensor_arg_node((1, 3, 8, 8)) + parent_fake_0 = _make_fake_tensor((1, 3, 8, 8)) + parent_fake_1 = _make_fake_tensor((1, 3, 8, 8)) + parent_arg.meta = {"val": [parent_fake_0, parent_fake_1]} + + out_val = _make_fake_tensor((1, 3, 8, 8)) + node = _make_op_node( + target=operator.getitem, + args=(parent_arg, 0), + output_val=out_val, + ) + op_repsets = OpRepSets( + TensorRepSetList(ANY_STORAGE), + TensorRepSetList(ANY_STORAGE), + node, + DEFAULT_TEXTURE_LIMITS, + ) + self.assertFalse(op_repsets.any_is_empty()) + + # -- Quantized binary ops with different layouts but same packed dim -- + + def _make_quantized_binary_op( + self, + args_repset, + outs_repset, + shape_a=(1, 3, 8, 8), + shape_b=(1, 3, 8, 8), + ): + """Create an OpRepSets for a quantized binary op with separate arg/out repsets.""" + arg_a = _make_tensor_arg_node(shape_a) + arg_b = _make_tensor_arg_node(shape_b) + out_val = _make_fake_tensor(shape_a) + node = _make_op_node( + target=torch.ops.aten.add.Tensor, + args=(arg_a, arg_b), + output_val=out_val, + ) + return OpRepSets( + TensorRepSetList(args_repset), + TensorRepSetList(outs_repset), + node, + DEFAULT_TEXTURE_LIMITS, + ) + + def test_quantized_binary_different_layouts_same_packed_dim(self): + """Args and output can have different quantized layouts if packed dim matches.""" + # PACKED_INT8_4W4C and PACKED_INT8_4C1W both have packed_dim=2 + op_repsets = self._make_quantized_binary_op( + args_repset=PACKED_INT8_4W4C_BUFFER, + outs_repset=PACKED_INT8_4C1W_BUFFER, + ) + self.assertFalse(op_repsets.sync_primary_io_repr) + self.assertFalse(op_repsets.any_is_empty()) + + arg0 = op_repsets.get_arg_repset(0) + out = op_repsets.get_out_repset(0) + self.assertIn(VkMemoryLayout.PACKED_INT8_4W4C, arg0.valid_buffer_layouts) + self.assertIn(VkMemoryLayout.PACKED_INT8_4C1W, out.valid_buffer_layouts) + + def test_quantized_binary_constrain_arg_with_synced_io(self): + """When args and output share the same repset (sync_primary_io_repr=True), + constraining an arg to a specific quantized layout also narrows the output + to layouts with a compatible packed dim.""" + op_repsets = self._make_quantized_binary_op( + args_repset=PACKED_INT8_CHANNELS_PACKED_BUFFER, + outs_repset=PACKED_INT8_CHANNELS_PACKED_BUFFER, + ) + self.assertTrue(op_repsets.sync_primary_io_repr) + changed = op_repsets.try_constrain_with_arg_repset(0, PACKED_INT8_4W4C_BUFFER) + self.assertTrue(changed) + arg0 = op_repsets.get_arg_repset(0) + self.assertIn(VkMemoryLayout.PACKED_INT8_4W4C, arg0.valid_buffer_layouts) + self.assertNotIn(VkMemoryLayout.PACKED_INT8_4C1W, arg0.valid_buffer_layouts) + # Output should be narrowed to compatible packed dim layouts + out = op_repsets.get_out_repset(0) + self.assertTrue(out.has_compatible_packed_dim_info_set(arg0)) + + def test_quantized_binary_synced_args_different_out(self): + """Synced args can be constrained together while output uses a different + quantized layout with the same packed dim.""" + # Use shared repset for args so sync_args_repr=True + op_repsets = self._make_quantized_binary_op( + args_repset=PACKED_INT8_BUFFER, + outs_repset=PACKED_INT8_BUFFER, + ) + self.assertTrue(op_repsets.sync_args_repr) + changed = op_repsets.try_constrain_with_arg_repset(0, PACKED_INT8_4W4C_BUFFER) + self.assertTrue(changed) + arg0 = op_repsets.get_arg_repset(0) + arg1 = op_repsets.get_arg_repset(1) + # arg0 is narrowed to PACKED_INT8_4W4C + self.assertIn(VkMemoryLayout.PACKED_INT8_4W4C, arg0.valid_buffer_layouts) + # arg1 should be constrained to layouts with compatible packed dim (=2) + self.assertTrue(arg1.has_compatible_packed_dim_info_set(arg0)) + + def test_quantized_binary_constrain_out_with_compatible_packed_dim(self): + """Output can be constrained to a different quantized layout as long as + packed dim is compatible.""" + op_repsets = self._make_quantized_binary_op( + args_repset=PACKED_INT8_CHANNELS_PACKED_BUFFER, + outs_repset=PACKED_INT8_CHANNELS_PACKED_BUFFER, + ) + changed = op_repsets.try_constrain_with_out_repset(PACKED_INT8_4C1W_BUFFER) + self.assertTrue(changed) + out = op_repsets.get_out_repset(0) + self.assertIn(VkMemoryLayout.PACKED_INT8_4C1W, out.valid_buffer_layouts) + self.assertNotIn(VkMemoryLayout.PACKED_INT8_4W4C, out.valid_buffer_layouts) + + def test_quantized_binary_incompatible_packed_dim_no_common(self): + """Args and output with different packed dims have nothing in common.""" + # PACKED_INT8_4W4C has packed_dim=2, PACKED_INT8_4W has packed_dim=0 + op_repsets = self._make_quantized_binary_op( + args_repset=PACKED_INT8_4W4C_BUFFER, + outs_repset=PACKED_INT8_4W_BUFFER, + ) + self.assertFalse(op_repsets.sync_primary_io_repr) + # Constraining arg to width-packed should fail since arg is channels-packed + changed = op_repsets.try_constrain_with_arg_repset(0, PACKED_INT8_4W_BUFFER) + self.assertFalse(changed) + + +class TestTensorReprList(unittest.TestCase): + def test_single_element_broadcasting(self): + tr = TensorRepr(VkStorageType.TEXTURE_3D, VkMemoryLayout.TENSOR_CHANNELS_PACKED) + lst = TensorReprList(tr) + self.assertEqual(len(lst), 1) + self.assertEqual(lst[0], tr) + self.assertEqual(lst[5], tr) + + def test_multi_element(self): + a = TensorRepr(VkStorageType.TEXTURE_3D, VkMemoryLayout.TENSOR_CHANNELS_PACKED) + b = TensorRepr(VkStorageType.BUFFER, VkMemoryLayout.TENSOR_WIDTH_PACKED) + lst = TensorReprList([a, b]) + self.assertEqual(len(lst), 2) + self.assertEqual(lst[0], a) + self.assertEqual(lst[1], b) + + def test_setitem(self): + a = TensorRepr(VkStorageType.TEXTURE_3D, VkMemoryLayout.TENSOR_CHANNELS_PACKED) + b = TensorRepr(VkStorageType.BUFFER, VkMemoryLayout.TENSOR_WIDTH_PACKED) + lst = TensorReprList([a, b]) + c = TensorRepr(VkStorageType.TEXTURE_3D, VkMemoryLayout.TENSOR_WIDTH_PACKED) + lst[1] = c + self.assertEqual(lst[1], c) + + def test_append(self): + lst = TensorReprList([]) + tr = TensorRepr(VkStorageType.TEXTURE_3D, VkMemoryLayout.TENSOR_CHANNELS_PACKED) + lst.append(tr) + self.assertEqual(len(lst), 1) + self.assertEqual(lst[0], tr) + + def test_storage_type_and_memory_layout(self): + tr = TensorRepr(VkStorageType.TEXTURE_3D, VkMemoryLayout.TENSOR_CHANNELS_PACKED) + lst = TensorReprList(tr) + self.assertEqual(lst.storage_type(), VkStorageType.TEXTURE_3D) + self.assertEqual(lst.memory_layout(), VkMemoryLayout.TENSOR_CHANNELS_PACKED) + + def test_equality(self): + a = TensorRepr(VkStorageType.TEXTURE_3D, VkMemoryLayout.TENSOR_CHANNELS_PACKED) + lst1 = TensorReprList(a) + lst2 = TensorReprList(a) + self.assertEqual(lst1, lst2) + + def test_inequality_different_length(self): + a = TensorRepr(VkStorageType.TEXTURE_3D, VkMemoryLayout.TENSOR_CHANNELS_PACKED) + b = TensorRepr(VkStorageType.BUFFER, VkMemoryLayout.TENSOR_WIDTH_PACKED) + lst1 = TensorReprList(a) + lst2 = TensorReprList([a, b]) + self.assertNotEqual(lst1, lst2) + + def test_str(self): + tr = TensorRepr(VkStorageType.TEXTURE_3D, VkMemoryLayout.TENSOR_CHANNELS_PACKED) + lst = TensorReprList(tr) + s = str(lst) + self.assertIn("TensorRepr", s) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 88d8bb00c6c..dde9aaac973 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import operator +from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set, Tuple, Union import torch @@ -141,6 +142,15 @@ def is_choose_qparams_node(node: torch.fx.Node) -> bool: return "choose_qparams" in node_name +def is_dynamic_qscale(node: Any) -> bool: + """Check if a scale node is dynamically computed via a choose_qparams op.""" + return ( + isinstance(node, torch.fx.Node) + and node.target == operator.getitem + and is_choose_qparams_node(node.args[0]) + ) + + def is_dequant_per_channel_node(node: torch.fx.Node) -> bool: if node.op != "call_function": return False @@ -594,20 +604,91 @@ def node_has_target(node: Any, target: str): all_quantized_memory_layouts: Set[VkMemoryLayout] = { VkMemoryLayout.PACKED_INT8_4W4C, VkMemoryLayout.PACKED_INT8_4H4W, + VkMemoryLayout.PACKED_INT8_4W, VkMemoryLayout.PACKED_INT8_4C1W, } -universal_memory_layout_set: Set[VkMemoryLayout] = { - VkMemoryLayout.TENSOR_WIDTH_PACKED, - VkMemoryLayout.TENSOR_HEIGHT_PACKED, - VkMemoryLayout.TENSOR_CHANNELS_PACKED, - VkMemoryLayout.PACKED_INT8_4W4C, - VkMemoryLayout.PACKED_INT8_4H4W, -} +universal_memory_layout_set: Set[VkMemoryLayout] = ( + all_memory_layouts | all_quantized_memory_layouts +) MemoryLayoutSet = Set[VkMemoryLayout] MemoryLayoutSetList = Union[MemoryLayoutSet, List[MemoryLayoutSet]] +_LAYOUT_TO_PACKED_DIM: Dict[VkMemoryLayout, int] = { + VkMemoryLayout.TENSOR_WIDTH_PACKED: 0, + VkMemoryLayout.TENSOR_HEIGHT_PACKED: 1, + VkMemoryLayout.TENSOR_CHANNELS_PACKED: 2, + VkMemoryLayout.PACKED_INT8_4W4C: 2, + VkMemoryLayout.PACKED_INT8_4H4W: 0, + VkMemoryLayout.PACKED_INT8_4C1W: 2, +} + + +def packed_dim_of(layout: VkMemoryLayout) -> int: + return _LAYOUT_TO_PACKED_DIM[layout] + + +@dataclass(frozen=True) +class PackedDimInfo: + """ + Describes how tensor data is organized in physical memory, mirroring the + C++ PackedDimInfo struct in runtime/api/containers/Tensor.h. + """ + + packed_dim: int + packed_dim_block_size: int + + @classmethod + def from_repr( + cls, + memory_layout: VkMemoryLayout, + storage_type: VkStorageType = VkStorageType.BUFFER, + ) -> "PackedDimInfo": + """ + Construct a PackedDimInfo based on a memory layout and storage type, + mirroring calculate_packed_dim_info in runtime/api/containers/Tensor.cpp. + """ + is_buffer = storage_type == VkStorageType.BUFFER + + if memory_layout == VkMemoryLayout.TENSOR_WIDTH_PACKED: + return cls( + packed_dim=0, + packed_dim_block_size=1 if is_buffer else 4, + ) + elif memory_layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED: + return cls( + packed_dim=1, + packed_dim_block_size=1 if is_buffer else 4, + ) + elif memory_layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED: + return cls( + packed_dim=2, + packed_dim_block_size=1 if is_buffer else 4, + ) + elif memory_layout == VkMemoryLayout.PACKED_INT8_4W: + return cls( + packed_dim=0, + packed_dim_block_size=4, + ) + elif memory_layout == VkMemoryLayout.PACKED_INT8_4W4C: + return cls( + packed_dim=2, + packed_dim_block_size=4, + ) + elif memory_layout == VkMemoryLayout.PACKED_INT8_4H4W: + return cls( + packed_dim=0, + packed_dim_block_size=4, + ) + elif memory_layout == VkMemoryLayout.PACKED_INT8_4C1W: + return cls( + packed_dim=2, + packed_dim_block_size=4 if is_buffer else 16, + ) + else: + raise ValueError(f"Unknown memory layout: {memory_layout}") + def within_buffer_limit(node: torch.fx.Node, buffer_limit: int) -> int: """ @@ -801,6 +882,11 @@ def __eq__(self, other: object) -> bool: def __ne__(self, other: object) -> bool: return not self.__eq__(other) + def copy(self) -> "TensorRepSet": + return TensorRepSet( + set(self.valid_buffer_layouts), set(self.valid_texture_layouts) + ) + def is_empty(self) -> bool: """ A TensorRepSet is "empty" if there are no valid representations of the tensor. @@ -914,6 +1000,83 @@ def is_ambiguous(self) -> bool: """ return not self.is_constrained() + def _possible_pdis(self) -> Set[PackedDimInfo]: + buffer_set = set() + texture_set = set() + for layout in self.valid_buffer_layouts: + buffer_set.add(PackedDimInfo.from_repr(layout, VkStorageType.BUFFER)) + for layout in self.valid_texture_layouts: + texture_set.add(PackedDimInfo.from_repr(layout, VkStorageType.TEXTURE_3D)) + return buffer_set, texture_set + + def has_same_packed_dim_info_set(self, other: "TensorRepSet") -> bool: + """ + Check if self and other produce the exact same sets of PackedDimInfo + for both buffer and texture storage types. Completely empty repsets + (no layouts for any storage type) are treated as matching any other + repset. + """ + other_buf_set, other_tex_set = other._possible_pdis() + buf_set, tex_set = self._possible_pdis() + + # A completely empty repset is compatible with anything + if not buf_set and not tex_set: + return True + if not other_buf_set and not other_tex_set: + return True + + return other_buf_set == buf_set and other_tex_set == tex_set + + def has_compatible_packed_dim_info_set(self, other: "TensorRepSet") -> bool: + """ + Check if all PackedDimInfos from other are contained within self's + PackedDimInfo sets, i.e. self is a superset of other for both buffer + and texture PDI sets. + """ + other_buf_set, other_tex_set = other._possible_pdis() + buf_set, tex_set = self._possible_pdis() + + for pdi in other_buf_set: + if pdi not in buf_set: + return False + + for pdi in other_tex_set: + if pdi not in tex_set: + return False + + return True + + def constrain_to_compatible_packed_dim( + self, other: "TensorRepSet" + ) -> "TensorRepSet": + """ + Return a new TensorRepSet containing only layouts from self whose + PackedDimInfo is present in other's PackedDimInfo sets. If other is + completely empty, return a copy of self unchanged. If other has layouts + for only one storage type, layouts for the missing storage type are + also removed. + """ + other_buf_set, other_tex_set = other._possible_pdis() + + # Completely empty other means no constraint + if not other_buf_set and not other_tex_set: + return self.copy() + + new_buf = { + layout + for layout in self.valid_buffer_layouts + if other_buf_set + and PackedDimInfo.from_repr(layout, VkStorageType.BUFFER) in other_buf_set + } + new_tex = { + layout + for layout in self.valid_texture_layouts + if other_tex_set + and PackedDimInfo.from_repr(layout, VkStorageType.TEXTURE_3D) + in other_tex_set + } + return TensorRepSet(new_buf, new_tex) + def make_tensor_repset(tensor_repr: TensorRepr) -> TensorRepSet: """ @@ -927,7 +1090,7 @@ def make_tensor_repset(tensor_repr: TensorRepr) -> TensorRepSet: raise RuntimeError(f"Unsupported storage type {tensor_repr.storage_type}") -def make_filtered_tensor_repset( +def filter_invalid_reprs( tensor_val: FakeTensor, tensor_repset: TensorRepSet, texture_limits: ImageExtents, @@ -957,6 +1120,28 @@ def make_filtered_tensor_repset( return TensorRepSet(tensor_repset.valid_buffer_layouts, valid_texture_layouts) +def filter_invalid_reprs_for_node_list( + arg_repsets: TensorRepSet, + arg_node: List[torch.fx.Node], + texture_limits: ImageExtents, +) -> TensorRepSet: + """ + Wrapper around filter_invalid_reprs for a list of nodes. This will happen + for the cat operator, where the first argument is a list of nodes. + """ + # For variable length args, assume that they all need to use the same representation + # only one repset should be defined + common_tensor_repsets = arg_repsets + + for n in arg_node: + assert isinstance(n, torch.fx.Node) + common_tensor_repsets = common_tensor_repsets.make_intersect( + filter_invalid_reprs(n.meta["val"], common_tensor_repsets, texture_limits) + ) + + return common_tensor_repsets + + ## Convenience TensorRepSet definitions # Only includes memory layouts that can be used by non-quantized tensors @@ -986,6 +1171,8 @@ def make_filtered_tensor_repset( PACKED_INT8_BUFFER = TensorRepSet(all_quantized_memory_layouts, set()) PACKED_INT8_4W4C_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_4W4C}, set()) +PACKED_INT8_4H4W_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_4H4W}, set()) +PACKED_INT8_4W_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_4W}, set()) PACKED_INT8_4C1W_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_4C1W}, set()) PACKED_INT8_CHANNELS_PACKED_BUFFER = TensorRepSet( @@ -1138,7 +1325,7 @@ def __init__( # noqa: C901 else: assert not arg_repset.is_empty() - arg_repset = self.make_valid_tensor_repset_for_arg( + arg_repset = self.filter_invalid_reprs_for_arg( arg_repset, arg_node, texture_limits ) @@ -1149,7 +1336,7 @@ def __init__( # noqa: C901 outs_repset_list = TensorRepSetList([]) common_out_repset = ALL_STORAGES_REPSET if num_tensors_in_node(op_node) == 1: - common_out_repset = make_filtered_tensor_repset( + common_out_repset = filter_invalid_reprs( op_node.meta["val"], outputs_repsets[0], texture_limits ) outs_repset_list.append(common_out_repset) @@ -1157,42 +1344,46 @@ def __init__( # noqa: C901 else: for i, val in enumerate(op_node.meta["val"]): assert isinstance(val, FakeTensor) - out_repset = make_filtered_tensor_repset( + out_repset = filter_invalid_reprs( val, outputs_repsets[i], texture_limits ) outs_repset_list.append(out_repset) common_out_repset = common_out_repset.make_intersect(out_repset) + # Apply synchronization rules between the primary input and output + primary_repset = NO_STORAGE + if self.sync_primary_io_repr: + primary_in_repset = ( + common_arg_repset + if self.sync_args_repr + else args_repset_list[self.primary_arg_idx] + ) + primary_out_repset = ( + common_out_repset if self.sync_outs_repr else outs_repset_list[0] + ) + primary_repset = primary_in_repset.make_intersect(primary_out_repset) + + args_repset_list[self.primary_arg_idx] = primary_repset.copy() + outs_repset_list[0] = primary_repset.copy() + # Apply synchronization rules; if either all inputs/outputs must use the same # representation, then only use a single underlying repset. if self.sync_args_repr: - args_repset_list = TensorRepSetList([common_arg_repset]) - - if self.sync_outs_repr: - outs_repset_list = TensorRepSetList([common_out_repset]) - - # Finally, apply synchronization rules that sync inputs and outputs. If input - # or output repsets are updated, then maintain synchronization rules. - if self.sync_primary_io_repr: - assert self.primary_arg_idx is not None + common_repset = ( + primary_repset if self.sync_primary_io_repr else common_arg_repset + ) - primary_in_repset = args_repset_list[self.primary_arg_idx] - primary_out_repset = outs_repset_list[0] + for i in range(len(args_repset_list)): + args_repset_list[i] = common_repset.copy() - primary_repset = primary_in_repset.make_intersect(primary_out_repset) - - if self.sync_args_repr: - args_repset_list = TensorRepSetList([primary_repset]) - else: - assert self.primary_arg_idx is not None - args_repset_list[self.primary_arg_idx] = primary_repset + if self.sync_outs_repr: + common_repset = ( + primary_repset if self.sync_primary_io_repr else common_out_repset + ) - if self.sync_outs_repr: - outs_repset_list = TensorRepSetList([primary_repset]) - else: - assert self.primary_arg_idx is not None - outs_repset_list[0] = primary_repset + for i in range(len(outs_repset_list)): + outs_repset_list[i] = common_repset.copy() # Save the resulting repsets self.args_repset_list = args_repset_list @@ -1204,44 +1395,20 @@ def __init__( # noqa: C901 def __str__(self) -> str: return f"OpRepSets(ins={self.args_repset_list}, outs={self.outs_repset_list})" - def make_valid_tensor_repset_for_node_list_arg( - self, - arg_repsets: TensorRepSet, - arg_node: List[torch.fx.Node], - texture_limits: ImageExtents, - ) -> TensorRepSet: - """ - Wrapper around make_filtered_tensor_repset for a list of nodes. This will happen - for the cat operator, where the first argument is a list of nodes. - """ - # For variable length args, assume that they all need to use the same representation - # only one repset should be defined - common_tensor_repsets = arg_repsets - - for n in arg_node: - assert isinstance(n, torch.fx.Node) - common_tensor_repsets = common_tensor_repsets.make_intersect( - make_filtered_tensor_repset( - n.meta["val"], common_tensor_repsets, texture_limits - ) - ) - - return common_tensor_repsets - - def make_valid_tensor_repset_for_arg( + def filter_invalid_reprs_for_arg( self, arg_repsets: TensorRepSet, arg_node: Any, texture_limits: ImageExtents ) -> TensorRepSet: """ - Helper function to call make_filtered_tensor_repset + Helper function to call filter_invalid_reprs """ if isinstance(arg_node, torch.fx.Node) and is_single_tensor_node(arg_node): - return make_filtered_tensor_repset( + return filter_invalid_reprs( arg_node.meta["val"], arg_repsets, texture_limits ) elif isinstance(arg_node, list) and all( is_single_tensor_node(n) for n in arg_node ): - return self.make_valid_tensor_repset_for_node_list_arg( + return filter_invalid_reprs_for_node_list( arg_repsets, arg_node, texture_limits ) # Special case for getitem; return the repset of the particular val in the @@ -1251,7 +1418,7 @@ def make_valid_tensor_repset_for_arg( ): idx = self.op_node.args[1] assert isinstance(idx, int) - return make_filtered_tensor_repset( + return filter_invalid_reprs( arg_node.meta["val"][idx], arg_repsets, texture_limits ) @@ -1259,15 +1426,32 @@ def make_valid_tensor_repset_for_arg( def assert_sync_contraints(self) -> None: if self.sync_args_repr: - assert len(self.args_repset_list) == 1 + for i in range(len(self.args_repset_list)): + for j in range(i + 1, len(self.args_repset_list)): + ri = self.args_repset_list[i] + rj = self.args_repset_list[j] + if not ri.is_empty() and not rj.is_empty(): + assert ri.has_compatible_packed_dim_info_set( + rj + ), f"Synced arg repsets {i} and {j} have incompatible packed dim info: {ri} vs {rj}" if self.sync_outs_repr: - assert len(self.outs_repset_list) == 1 + for i in range(len(self.outs_repset_list)): + for j in range(i + 1, len(self.outs_repset_list)): + ri = self.outs_repset_list[i] + rj = self.outs_repset_list[j] + if not ri.is_empty() and not rj.is_empty(): + assert ri.has_compatible_packed_dim_info_set( + rj + ), f"Synced out repsets {i} and {j} have incompatible packed dim info: {ri} vs {rj}" if self.sync_primary_io_repr: - assert ( - self.args_repset_list[self.primary_arg_idx] == self.outs_repset_list[0] - ) + primary_arg = self.args_repset_list[self.primary_arg_idx] + primary_out = self.outs_repset_list[0] + if not primary_arg.is_empty() and not primary_out.is_empty(): + assert primary_arg.has_compatible_packed_dim_info_set( + primary_out + ), f"Primary arg and out repsets have incompatible packed dim info: {primary_arg} vs {primary_out}" def any_is_empty(self) -> bool: return ( @@ -1307,34 +1491,81 @@ def try_constrain_with_arg_repset( return False if self.sync_primary_io_repr: - if not self.get_out_repset(0).any_in_common(source_repset): + if not self.get_out_repset(0).has_compatible_packed_dim_info_set( + source_repset + ): return False # If this point is reached, then it is possible to constrain - self.args_repset_list[arg_i] = arg_current_repset.make_intersect(source_repset) + narrowed = arg_current_repset.make_intersect(source_repset) + self.args_repset_list[arg_i] = narrowed + + # Propagate to other synced args via packed-dim compatibility + if self.sync_args_repr: + for i in range(len(self.args_repset_list)): + if i != arg_i: + self.args_repset_list[i] = self.args_repset_list[ + i + ].constrain_to_compatible_packed_dim(narrowed) + + # Propagate to output via packed-dim compatibility if self.sync_primary_io_repr and ( arg_i == self.primary_arg_idx or self.sync_args_repr ): - self.outs_repset_list[0] = arg_current_repset.make_intersect(source_repset) + self.outs_repset_list[0] = self.outs_repset_list[ + 0 + ].constrain_to_compatible_packed_dim(narrowed) + + # Propagate to other synced outputs via packed-dim compatibility + if self.sync_outs_repr: + for i in range(len(self.outs_repset_list)): + if i != 0: + self.outs_repset_list[i] = self.outs_repset_list[ + i + ].constrain_to_compatible_packed_dim(self.outs_repset_list[0]) self.assert_sync_contraints() return True - def try_constrain_with_out_repset(self, repset: TensorRepSet): - # Skip for operators that must synchronize the input and output representations - # or operators that have more than one output repset - if self.sync_primary_io_repr or len(self.outs_repset_list) > 1: - return False - + def try_constrain_with_out_repset(self, required_repset: TensorRepSet) -> bool: + """ + Attempt to constrain the output repsets of the tensors participating in this + operator based the repset required by a downstream operator. + """ out_current_repset = self.outs_repset_list[0] - if out_current_repset == repset: + if out_current_repset == required_repset: return False - if not out_current_repset.any_in_common(repset): + if not out_current_repset.any_in_common(required_repset): return False - self.outs_repset_list[0] = out_current_repset.make_intersect(repset) + narrowed = out_current_repset.make_intersect(required_repset) + self.outs_repset_list[0] = narrowed + + # Propagate to other synced outputs via packed-dim compatibility + if self.sync_outs_repr: + for i in range(len(self.outs_repset_list)): + if i != 0: + self.outs_repset_list[i] = self.outs_repset_list[ + i + ].constrain_to_compatible_packed_dim(narrowed) + + # Propagate to primary arg via packed-dim compatibility + if self.sync_primary_io_repr: + self.args_repset_list[self.primary_arg_idx] = self.args_repset_list[ + self.primary_arg_idx + ].constrain_to_compatible_packed_dim(narrowed) + + # Propagate to other synced args via packed-dim compatibility + if self.sync_args_repr: + for i in range(len(self.args_repset_list)): + if i != self.primary_arg_idx: + self.args_repset_list[i] = self.args_repset_list[ + i + ].constrain_to_compatible_packed_dim( + self.args_repset_list[self.primary_arg_idx] + ) self.assert_sync_contraints() return True diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index b276ffd16f5..db1211883c7 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -164,6 +164,7 @@ def preprocess( # noqa: C901 [ AddmmToLinearTransform(), FuseBatchNormPass(program), + AddmmToLinearTransform(), FusePatternsPass(), FuseClampPass(), RemoveRedundantOpsTransform(),