Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,18 +394,11 @@ def constrain_op_out_repset(self, op_repsets: utils.OpRepSets) -> None:
op_repsets.try_constrain_with_out_repset(out_respset)

def constrain_op_repsets(self, op_repsets: utils.OpRepSets) -> None:
# For most ops, constraining the argument repsets will also contrain the output
# repset due to OpRepSets maintaining synchronization rules.
for i in range(len(op_repsets.op_node.args)):
if utils.is_tensor_arg_node(op_repsets.op_node.args[i]):
self.constrain_op_arg_repset(i, op_repsets)

# However, some operators do not sync input and output representations and also
# define ambiguous repsets for the output tensor(s). In those cases we will need
# to execute additional logic to constrain the output repsets separately from
# the input repsets.
if not op_repsets.sync_primary_io_repr and op_repsets.sync_outs_repr:
self.constrain_op_out_repset(op_repsets)
self.constrain_op_out_repset(op_repsets)

def set_op_node_tensor_reprs(
self, graph_module: torch.fx.GraphModule, op_node: torch.fx.Node
Expand Down
130 changes: 130 additions & 0 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,136 @@ def linear_q8ta_q8csw(
lib.impl(name, linear_q8ta_q8csw, "CompositeExplicitAutograd")
qa_q8csw_linear = getattr(getattr(torch.ops, namespace), name)

##################
## q8ta_linear ##
##################


def q8ta_linear(
x: torch.Tensor,
input_scale: float,
input_zero_point: int,
weights: torch.Tensor,
weight_sums: torch.Tensor,
weight_scales: torch.Tensor,
output_scale: float,
output_zero_point: int,
bias: Optional[torch.Tensor] = None,
activation: str = "none",
):
weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32)
weights = torch.ops.quantized_decomposed.dequantize_per_channel(
weights,
weight_scales,
weight_zeros,
0,
-127,
127,
torch.int8,
)

x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, input_scale, input_zero_point, -128, 127, x.dtype
)

out = torch.nn.functional.linear(x, weights)
if bias is not None:
out = out + bias[: out.shape[-1]]

if activation == "relu":
out = torch.nn.functional.relu(out)

out = torch.ops.quantized_decomposed.quantize_per_tensor(
out, output_scale, output_zero_point, -128, 127, torch.int8
)

return out


name = "q8ta_linear"
lib.define(
f"""
{name}(
Tensor x,
float input_scale,
int input_zero_point,
Tensor weights,
Tensor weight_sums,
Tensor weight_scales,
float output_scale,
int output_zero_point,
Tensor? bias = None,
str activation = "none") -> Tensor
"""
)
lib.impl(name, q8ta_linear, "CompositeExplicitAutograd")
q8ta_linear_op = getattr(getattr(torch.ops, namespace), name)

#######################
## q8ta_linear_gemv ##
#######################


def q8ta_linear_gemv(
x: torch.Tensor,
input_scale: float,
input_zero_point: int,
weights: torch.Tensor,
weight_sums: torch.Tensor,
weight_scales: torch.Tensor,
output_scale: float,
output_zero_point: int,
bias: Optional[torch.Tensor] = None,
activation: str = "none",
):
weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32)
weights = torch.ops.quantized_decomposed.dequantize_per_channel(
weights,
weight_scales,
weight_zeros,
0,
-127,
127,
torch.int8,
)

x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, input_scale, input_zero_point, -128, 127, x.dtype
)

out = torch.nn.functional.linear(x, weights)
if bias is not None:
out = out + bias[: out.shape[-1]]

if activation == "relu":
out = torch.nn.functional.relu(out)

out = torch.ops.quantized_decomposed.quantize_per_tensor(
out, output_scale, output_zero_point, -128, 127, torch.int8
)

return out


name = "q8ta_linear_gemv"
lib.define(
f"""
{name}(
Tensor x,
float input_scale,
int input_zero_point,
Tensor weights,
Tensor weight_sums,
Tensor weight_scales,
float output_scale,
int output_zero_point,
Tensor? bias = None,
str activation = "none") -> Tensor
"""
)
lib.impl(name, q8ta_linear_gemv, "CompositeExplicitAutograd")
q8ta_linear_gemv_op = getattr(getattr(torch.ops, namespace), name)

###################
## q8ta_conv2d_* ##
###################
Expand Down
51 changes: 51 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,57 @@ def register_q8ta_conv2d_ops():
)


# =============================================================================
# Q8taLinear.cpp
# =============================================================================


@update_features(exir_ops.edge.et_vk.q8ta_linear.default)
def register_q8ta_linear():
return OpFeatures(
inputs_storage=[
utils.PACKED_INT8_4H4W_BUFFER, # input
utils.NO_STORAGE, # input_scale (non tensor)
utils.NO_STORAGE, # input_zero_point (non tensor)
utils.NO_STORAGE, # weight (prepacked)
utils.NO_STORAGE, # weight_sums (prepacked)
utils.NO_STORAGE, # weight_scales (prepacked)
utils.NO_STORAGE, # output_scale (non tensor)
utils.NO_STORAGE, # output_zero_point (non tensor)
utils.NO_STORAGE, # bias (prepacked)
utils.NO_STORAGE, # activation (non tensor)
],
outputs_storage=[
utils.PACKED_INT8_4H4W_BUFFER,
],
supports_resize=False,
supports_prepacking=True,
)


@update_features(exir_ops.edge.et_vk.q8ta_linear_gemv.default)
def register_q8ta_linear_gemv():
return OpFeatures(
inputs_storage=[
utils.PACKED_INT8_4W_BUFFER, # input
utils.NO_STORAGE, # input_scale (non tensor)
utils.NO_STORAGE, # input_zero_point (non tensor)
utils.NO_STORAGE, # weight (prepacked)
utils.NO_STORAGE, # weight_sums (prepacked)
utils.NO_STORAGE, # weight_scales (prepacked)
utils.NO_STORAGE, # output_scale (non tensor)
utils.NO_STORAGE, # output_zero_point (non tensor)
utils.NO_STORAGE, # bias (prepacked)
utils.NO_STORAGE, # activation (non tensor)
],
outputs_storage=[
utils.PACKED_INT8_4W_BUFFER,
],
supports_resize=False,
supports_prepacking=True,
)


# =============================================================================
# SDPA.cpp
# =============================================================================
Expand Down
Loading
Loading