Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,41 @@ def q8ta_add_impl(
lib.impl(name, q8ta_add_impl, "CompositeExplicitAutograd")
q8ta_add_op = getattr(getattr(torch.ops, namespace), name)

########################
## q8ta_relu ##
########################


def q8ta_relu_impl(
input: torch.Tensor,
input_scale: float,
input_zero_point: int,
output_scale: float,
output_zero_point: int,
):
# Dequantize input to float
dequant = torch.ops.quantized_decomposed.dequantize_per_tensor(
input, input_scale, input_zero_point, -128, 127, input.dtype
)

# Apply ReLU
result = torch.nn.functional.relu(dequant)

# Quantize the result back to int8
quantized_result = torch.ops.quantized_decomposed.quantize_per_tensor(
result, output_scale, output_zero_point, -128, 127, torch.int8
)

return quantized_result


name = "q8ta_relu"
lib.define(
f"{name}(Tensor input, float input_scale, int input_zero_point, float output_scale, int output_zero_point) -> Tensor"
)
lib.impl(name, q8ta_relu_impl, "CompositeExplicitAutograd")
q8ta_relu_op = getattr(getattr(torch.ops, namespace), name)

#############################
## select_as_symint ##
#############################
Expand Down
14 changes: 13 additions & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,19 @@ def register_q8ta_add():


# =============================================================================
# Reduce.cpp
# Q8taUnary.cpp
# =============================================================================


@update_features(exir_ops.edge.et_vk.q8ta_relu.default)
def register_q8ta_relu():
return OpFeatures(
inputs_storage=utils.PACKED_INT8_BUFFER,
supports_resize=True,
)


# =============================================================================
# =============================================================================


Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/patterns/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ fbcode_target(_kind = runtime.python_library,
"quantized_linear.py",
"quantized_convolution.py",
"quantized_binary.py",
"quantized_unary.py",
"sdpa.py",
"select_as_symint.py",
],
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/patterns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import executorch.backends.vulkan.patterns.quantized_linear # noqa

import executorch.backends.vulkan.patterns.quantized_unary # noqa

import executorch.backends.vulkan.patterns.rope # noqa

import executorch.backends.vulkan.patterns.sdpa # noqa
Expand Down
121 changes: 121 additions & 0 deletions backends/vulkan/patterns/quantized_unary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import executorch.backends.vulkan.utils as utils

import torch

from executorch.backends.vulkan.patterns.pattern_registry import (
PatternMatch,
register_pattern_detector,
register_pattern_replacement,
)

from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops


class QuantizedUnaryMatch(PatternMatch):
def __init__(self, unary_node: torch.fx.Node) -> None:
self.anchor_node = unary_node
self.match_found = False
self.all_nodes = [self.anchor_node]

# The unary op takes a single input which must be a dequantize node
if len(unary_node.args) < 1:
return

input_node = unary_node.args[0]
assert isinstance(input_node, torch.fx.Node)

if not utils.is_dequant_node(input_node):
return

self.dequantize_input_node = input_node

# Extract quantization parameters for the input
self.quantize_input_node = self.dequantize_input_node.args[0]
self.input_scales_node = self.dequantize_input_node.args[1]
self.input_zeros_node = self.dequantize_input_node.args[2]

self.all_nodes.append(self.dequantize_input_node)

# The unary op output must have exactly one user: a quantize node
self.output_node = self.anchor_node

if len(self.output_node.users) != 1:
return

cur_node = list(self.output_node.users)[0]

if not utils.is_quant_node(cur_node):
return

self.quantize_output_node = cur_node
self.output_scales_node = self.quantize_output_node.args[1]
self.output_zeros_node = self.quantize_output_node.args[2]

self.all_nodes.append(self.quantize_output_node)

self.match_found = True


# Unary operation anchor nodes that we support
unary_anchor_nodes = {
exir_ops.edge.aten.relu.default,
}


@register_pattern_detector("quantized_unary")
def find_quantized_unary_patterns(
node: torch.fx.Node,
) -> Optional[QuantizedUnaryMatch]:
if node.target not in unary_anchor_nodes:
return None

matched_pattern = QuantizedUnaryMatch(node)
if matched_pattern.match_found:
return matched_pattern

return None


##
## Pattern Replacement
##


@register_pattern_replacement("quantized_unary")
def make_q8ta_unary_custom_op(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
match: QuantizedUnaryMatch,
):
op_target = None
if match.anchor_node.target == exir_ops.edge.aten.relu.default:
op_target = exir_ops.edge.et_vk.q8ta_relu.default
else:
raise NotImplementedError(
f"Unsupported unary operation: {match.anchor_node.target}"
)

with graph_module.graph.inserting_before(match.output_node):
qunary_node = graph_module.graph.create_node(
"call_function",
op_target,
args=(
match.quantize_input_node,
match.input_scales_node,
match.input_zeros_node,
match.output_scales_node,
match.output_zeros_node,
),
)

qunary_node.meta["val"] = match.output_node.meta["val"]
match.quantize_output_node.replace_all_uses_with(qunary_node)
82 changes: 82 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

${define_active_storage_type("buffer")}

#define op(X) ${OPERATOR}

layout(std430) buffer;

#include "indexing.glslh"
#include "common.glslh"
#include "block_indexing.glslh"
#include "block_int8x4_load.glslh"
#include "block_int8x4_store.glslh"

// Output buffer: packed int8x4 values
${layout_declare_tensor(B, "w", "t_out", "int", "buffer")}
// Input buffer: packed int8x4 values
${layout_declare_tensor(B, "r", "t_in", "int", "buffer")}

// Metadata for output and input tensors
${layout_declare_ubo(B, "BufferMetadata", "out_meta")}
${layout_declare_ubo(B, "BufferMetadata", "in_meta")}

layout(push_constant) uniform restrict Block {
float input_scale;
int input_zp;
float output_inv_scale;
int output_zp;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "out_layout", "CONTIG_LAYOUT_INT")}
${layout_declare_spec_const(C, "int", "in_layout", "CONTIG_LAYOUT_INT")}
${layout_declare_spec_const(C, "int", "block_config", "0")}

// Generate loading functions for input buffer
define_load_int8x4_buffer_fns(t_in)

// Generate storing functions for output buffer
define_store_int8x4_buffer_fns(t_out)

void main() {
// Buffer storage: use linear dispatch
const uint contig_block_idx = gl_GlobalInvocationID.x;
TensorIndex4D tidx = contiguous_block_idx_to_tensor4d_idx_with_block_config(
out_meta, contig_block_idx, block_config);

if (out_of_bounds(tidx, out_meta)) {
return;
}

const int block_outer_dim = get_block_outer_dim(block_config);

// Load int8x4 block from input
ivec4 in_block = load_int8x4_block_from_t_in(
in_meta, tidx, in_layout, block_outer_dim);

ivec4 out_block;

for (int row = 0; row < 4; row++) {
vec4 in_texel = unpack_and_dequantize(
in_block[row], input_scale, input_zp);

vec4 out_texel = op(in_texel);
out_block[row] = quantize_and_pack(out_texel, output_inv_scale, output_zp);
}

// Store to output buffer
store_int8x4_block_to_t_out(
out_meta, tidx, out_layout, block_outer_dim, out_block);
}
12 changes: 12 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

q8ta_unary:
parameter_names_with_default_values:
OPERATOR: X
shader_variants:
- NAME: q8ta_relu_buffer
OPERATOR: max(X, vec4(0.0))
Loading
Loading