Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 49 additions & 3 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def q8ta_conv2d(
padding: list,
dilation: list,
groups: int,
activation: str,
):
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, input_scale, input_zero_point, -128, 127, x.dtype
Expand Down Expand Up @@ -418,6 +419,9 @@ def q8ta_conv2d(
x, weights, bias, stride, padding, dilation, groups
)

if activation == "relu":
out = torch.nn.functional.relu(out)

out = torch.ops.quantized_decomposed.quantize_per_tensor(
out, output_scale, output_zero_point, -128, 127, torch.int8
)
Expand All @@ -442,7 +446,8 @@ def q8ta_conv2d(
SymInt[] stride,
SymInt[] padding,
SymInt[] dilation,
SymInt groups) -> Tensor
SymInt groups,
str activation) -> Tensor
"""
)
lib.impl(name, q8ta_conv2d, "CompositeExplicitAutograd")
Expand All @@ -466,7 +471,8 @@ def q8ta_conv2d(
SymInt[] stride,
SymInt[] padding,
SymInt[] dilation,
SymInt groups) -> Tensor
SymInt groups,
str activation) -> Tensor
"""
)
lib.impl(name, q8ta_conv2d, "CompositeExplicitAutograd")
Expand All @@ -488,6 +494,7 @@ def q8ta_conv2d_dw(
padding: list,
dilation: list,
groups: int,
activation: str,
):
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, input_scale, input_zero_point, -128, 127, x.dtype
Expand All @@ -514,6 +521,9 @@ def q8ta_conv2d_dw(
x, weights, bias, stride, padding, dilation, groups
)

if activation == "relu":
out = torch.nn.functional.relu(out)

out = torch.ops.quantized_decomposed.quantize_per_tensor(
out, output_scale, output_zero_point, -128, 127, torch.int8
)
Expand All @@ -538,7 +548,8 @@ def q8ta_conv2d_dw(
SymInt[] stride,
SymInt[] padding,
SymInt[] dilation,
SymInt groups) -> Tensor
SymInt groups,
str activation) -> Tensor
"""
)
lib.impl(name, q8ta_conv2d_dw, "CompositeExplicitAutograd")
Expand Down Expand Up @@ -605,6 +616,41 @@ def q8ta_add_impl(
lib.impl(name, q8ta_add_impl, "CompositeExplicitAutograd")
q8ta_add_op = getattr(getattr(torch.ops, namespace), name)

########################
## q8ta_relu ##
########################


def q8ta_relu_impl(
input: torch.Tensor,
input_scale: float,
input_zero_point: int,
output_scale: float,
output_zero_point: int,
):
# Dequantize input to float
dequant = torch.ops.quantized_decomposed.dequantize_per_tensor(
input, input_scale, input_zero_point, -128, 127, input.dtype
)

# Apply ReLU
result = torch.nn.functional.relu(dequant)

# Quantize the result back to int8
quantized_result = torch.ops.quantized_decomposed.quantize_per_tensor(
result, output_scale, output_zero_point, -128, 127, torch.int8
)

return quantized_result


name = "q8ta_relu"
lib.define(
f"{name}(Tensor input, float input_scale, int input_zero_point, float output_scale, int output_zero_point) -> Tensor"
)
lib.impl(name, q8ta_relu_impl, "CompositeExplicitAutograd")
q8ta_relu_op = getattr(getattr(torch.ops, namespace), name)

#############################
## select_as_symint ##
#############################
Expand Down
28 changes: 13 additions & 15 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,19 @@ def register_q8ta_add():


# =============================================================================
# Reduce.cpp
# Q8taUnary.cpp
# =============================================================================


@update_features(exir_ops.edge.et_vk.q8ta_relu.default)
def register_q8ta_relu():
return OpFeatures(
inputs_storage=utils.PACKED_INT8_BUFFER,
supports_resize=True,
)


# =============================================================================
# =============================================================================


Expand Down Expand Up @@ -1221,25 +1233,11 @@ def register_embedding():

@update_features(exir_ops.edge.aten._native_batch_norm_legit_no_training.default)
def register_native_batch_norm_legit_no_training():
def check_batch_norm_node(node: torch.fx.Node) -> bool:
x = node.args[0]
if not isinstance(x, torch.fx.Node):
return False
x_val = x.meta.get("val", None)
if x_val is None:
return False
x_shape = x_val.size()
# Only support 4-D input tensors since this is a restriction enforced by the
# operator implementation.
# TODO(ssjia): Add shape agnostic support for batch norm
return len(x_shape) == 4

return OpFeatures(
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
inputs_dtypes=utils.FP_T,
supports_prepacking=True,
supports_resize=True,
are_node_inputs_supported_fn=check_batch_norm_node,
)


Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/patterns/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ fbcode_target(_kind = runtime.python_library,
"quantized_linear.py",
"quantized_convolution.py",
"quantized_binary.py",
"quantized_unary.py",
"sdpa.py",
"select_as_symint.py",
],
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/patterns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import executorch.backends.vulkan.patterns.quantized_linear # noqa

import executorch.backends.vulkan.patterns.quantized_unary # noqa

import executorch.backends.vulkan.patterns.rope # noqa

import executorch.backends.vulkan.patterns.sdpa # noqa
Expand Down
25 changes: 22 additions & 3 deletions backends/vulkan/patterns/quantized_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,27 @@ def make_q8ta_conv2d_custom_op(
with graph_module.graph.inserting_before(first_graph_node):
qweight_tensor_name = utils.get_tensor_name(ep, match.weight_node)
# Pre-compute the weight sums which are needed to apply activation zero point
# when using integer accumulation. For the reshaped 2D weight matrix (IC_per_group * H * W, OC),
# sum over dimension 0 to get sums per output channel
sum_per_output_channel = weight_tensor.sum(dim=1).to(torch.int32).contiguous()
# when using integer accumulation. Sum all weight elements per output channel.
if is_depthwise_conv:
# weight_tensor shape is (H, W, OC); sum over spatial dims (H, W)
sum_per_output_channel = (
weight_tensor.sum(dim=(0, 1)).to(torch.int32).contiguous()
)
else:
# weight_tensor shape is (OC, H*W*IC_per_group); sum over dim 1
sum_per_output_channel = (
weight_tensor.sum(dim=1).to(torch.int32).contiguous()
)
# Pad weight sums to align OC to multiple of 4, matching the alignment
# applied to weight, weight_scales, and bias above. Without this, the
# GPU shader would read out-of-bounds when OC is not a multiple of 4.
oc = sum_per_output_channel.shape[0]
if oc % 4 != 0:
num_padding = 4 - (oc % 4)
sum_per_output_channel = torch.nn.functional.pad(
sum_per_output_channel, (0, num_padding)
).contiguous()

sums_name = qweight_tensor_name + "_sums"
# Sanitize the name
sums_name = sums_name.replace(".", "_")
Expand Down Expand Up @@ -263,6 +281,7 @@ def make_q8ta_conv2d_custom_op(
match.padding,
match.dilation,
match.groups,
"relu" if match.relu_node is not None else "none",
),
)

Expand Down
121 changes: 121 additions & 0 deletions backends/vulkan/patterns/quantized_unary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import executorch.backends.vulkan.utils as utils

import torch

from executorch.backends.vulkan.patterns.pattern_registry import (
PatternMatch,
register_pattern_detector,
register_pattern_replacement,
)

from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops


class QuantizedUnaryMatch(PatternMatch):
def __init__(self, unary_node: torch.fx.Node) -> None:
self.anchor_node = unary_node
self.match_found = False
self.all_nodes = [self.anchor_node]

# The unary op takes a single input which must be a dequantize node
if len(unary_node.args) < 1:
return

input_node = unary_node.args[0]
assert isinstance(input_node, torch.fx.Node)

if not utils.is_dequant_node(input_node):
return

self.dequantize_input_node = input_node

# Extract quantization parameters for the input
self.quantize_input_node = self.dequantize_input_node.args[0]
self.input_scales_node = self.dequantize_input_node.args[1]
self.input_zeros_node = self.dequantize_input_node.args[2]

self.all_nodes.append(self.dequantize_input_node)

# The unary op output must have exactly one user: a quantize node
self.output_node = self.anchor_node

if len(self.output_node.users) != 1:
return

cur_node = list(self.output_node.users)[0]

if not utils.is_quant_node(cur_node):
return

self.quantize_output_node = cur_node
self.output_scales_node = self.quantize_output_node.args[1]
self.output_zeros_node = self.quantize_output_node.args[2]

self.all_nodes.append(self.quantize_output_node)

self.match_found = True


# Unary operation anchor nodes that we support
unary_anchor_nodes = {
exir_ops.edge.aten.relu.default,
}


@register_pattern_detector("quantized_unary")
def find_quantized_unary_patterns(
node: torch.fx.Node,
) -> Optional[QuantizedUnaryMatch]:
if node.target not in unary_anchor_nodes:
return None

matched_pattern = QuantizedUnaryMatch(node)
if matched_pattern.match_found:
return matched_pattern

return None


##
## Pattern Replacement
##


@register_pattern_replacement("quantized_unary")
def make_q8ta_unary_custom_op(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
match: QuantizedUnaryMatch,
):
op_target = None
if match.anchor_node.target == exir_ops.edge.aten.relu.default:
op_target = exir_ops.edge.et_vk.q8ta_relu.default
else:
raise NotImplementedError(
f"Unsupported unary operation: {match.anchor_node.target}"
)

with graph_module.graph.inserting_before(match.output_node):
qunary_node = graph_module.graph.create_node(
"call_function",
op_target,
args=(
match.quantize_input_node,
match.input_scales_node,
match.input_zeros_node,
match.output_scales_node,
match.output_zeros_node,
),
)

qunary_node.meta["val"] = match.output_node.meta["val"]
match.quantize_output_node.replace_all_uses_with(qunary_node)
3 changes: 2 additions & 1 deletion backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "out_layout", "CONTIG_LAYOUT_INT")}
${layout_declare_spec_const(C, "int", "in_layout", "CONTIG_LAYOUT_INT")}
${layout_declare_spec_const(C, "int", "other_layout", "CONTIG_LAYOUT_INT")}
${layout_declare_spec_const(C, "int", "block_config", "0")}

// Generate loading functions for input buffers
Expand All @@ -71,7 +72,7 @@ void main() {
ivec4 in_block_a = load_int8x4_block_from_t_in_a(
in_a_meta, tidx, in_layout, block_outer_dim);
ivec4 in_block_b = load_int8x4_block_from_t_in_b(
in_b_meta, tidx, in_layout, block_outer_dim);
in_b_meta, tidx, other_layout, block_outer_dim);

ivec4 out_block;

Expand Down
8 changes: 8 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ layout(push_constant) uniform restrict Block {
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "apply_bias", "1")}
${layout_declare_spec_const(C, "int", "activation_type", "0")}

// Layout specialization constants
${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")}
Expand Down Expand Up @@ -220,6 +221,13 @@ void main() {
}
}

// Apply ReLU if enabled
if (activation_type > 0) {
[[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) {
facc[subtile_w] = max(facc[subtile_w], vec4(0.0));
}
}

// Compute base output texel index (for subtile_w=0)
const int base_outp_texel_idx = tensor4d_idx_to_texel_idx(outp, outp_tidx, outp_layout);
const int out_w_stride = int(outp.strides[0][0]);
Expand Down
8 changes: 8 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ layout(push_constant) uniform restrict Block {
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "apply_bias", "1")}
${layout_declare_spec_const(C, "int", "activation_type", "0")}

// Layout specialization constants
${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")}
Expand Down Expand Up @@ -197,6 +198,13 @@ void main() {
}
}

// Apply ReLU if enabled
if (activation_type > 0) {
[[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) {
facc[subtile_w] = max(facc[subtile_w], vec4(0.0));
}
}

// Compute base output texel index (for subtile_w=0)
const int base_outp_texel_idx = tensor4d_idx_to_texel_idx(outp, outp_tidx, outp_layout);
const int out_w_stride = int(outp.strides[0][0]);
Expand Down
Loading
Loading