Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def q8ta_conv2d(
padding: list,
dilation: list,
groups: int,
activation: str,
):
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, input_scale, input_zero_point, -128, 127, x.dtype
Expand Down Expand Up @@ -418,6 +419,9 @@ def q8ta_conv2d(
x, weights, bias, stride, padding, dilation, groups
)

if activation == "relu":
out = torch.nn.functional.relu(out)

out = torch.ops.quantized_decomposed.quantize_per_tensor(
out, output_scale, output_zero_point, -128, 127, torch.int8
)
Expand All @@ -442,7 +446,8 @@ def q8ta_conv2d(
SymInt[] stride,
SymInt[] padding,
SymInt[] dilation,
SymInt groups) -> Tensor
SymInt groups,
str activation) -> Tensor
"""
)
lib.impl(name, q8ta_conv2d, "CompositeExplicitAutograd")
Expand All @@ -466,7 +471,8 @@ def q8ta_conv2d(
SymInt[] stride,
SymInt[] padding,
SymInt[] dilation,
SymInt groups) -> Tensor
SymInt groups,
str activation) -> Tensor
"""
)
lib.impl(name, q8ta_conv2d, "CompositeExplicitAutograd")
Expand All @@ -488,6 +494,7 @@ def q8ta_conv2d_dw(
padding: list,
dilation: list,
groups: int,
activation: str,
):
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, input_scale, input_zero_point, -128, 127, x.dtype
Expand All @@ -514,6 +521,9 @@ def q8ta_conv2d_dw(
x, weights, bias, stride, padding, dilation, groups
)

if activation == "relu":
out = torch.nn.functional.relu(out)

out = torch.ops.quantized_decomposed.quantize_per_tensor(
out, output_scale, output_zero_point, -128, 127, torch.int8
)
Expand All @@ -538,7 +548,8 @@ def q8ta_conv2d_dw(
SymInt[] stride,
SymInt[] padding,
SymInt[] dilation,
SymInt groups) -> Tensor
SymInt groups,
str activation) -> Tensor
"""
)
lib.impl(name, q8ta_conv2d_dw, "CompositeExplicitAutograd")
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/patterns/quantized_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def make_q8ta_conv2d_custom_op(
match.padding,
match.dilation,
match.groups,
"relu" if match.relu_node is not None else "none",
),
)

Expand Down
8 changes: 8 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ layout(push_constant) uniform restrict Block {
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "apply_bias", "1")}
${layout_declare_spec_const(C, "int", "activation_type", "0")}

// Layout specialization constants
${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")}
Expand Down Expand Up @@ -220,6 +221,13 @@ void main() {
}
}

// Apply ReLU if enabled
if (activation_type > 0) {
[[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) {
facc[subtile_w] = max(facc[subtile_w], vec4(0.0));
}
}

// Compute base output texel index (for subtile_w=0)
const int base_outp_texel_idx = tensor4d_idx_to_texel_idx(outp, outp_tidx, outp_layout);
const int out_w_stride = int(outp.strides[0][0]);
Expand Down
8 changes: 8 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ layout(push_constant) uniform restrict Block {
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "apply_bias", "1")}
${layout_declare_spec_const(C, "int", "activation_type", "0")}

// Layout specialization constants
${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")}
Expand Down Expand Up @@ -197,6 +198,13 @@ void main() {
}
}

// Apply ReLU if enabled
if (activation_type > 0) {
[[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) {
facc[subtile_w] = max(facc[subtile_w], vec4(0.0));
}
}

// Compute base output texel index (for subtile_w=0)
const int base_outp_texel_idx = tensor4d_idx_to_texel_idx(outp, outp_tidx, outp_layout);
const int out_w_stride = int(outp.strides[0][0]);
Expand Down
9 changes: 9 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ layout(push_constant) uniform restrict Block {
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "apply_bias", "1")}
${layout_declare_spec_const(C, "int", "activation_type", "0")}
${layout_declare_spec_const(C, "int", "conv2d_params_K4_per_group", "1")}

// Layout specialization constants
Expand Down Expand Up @@ -197,6 +198,10 @@ void main() {
fma(vec4(accum_adjusted),
vec4(weight_scales[n4]) * input_scale,
vec4(bias[n4]));
// Apply ReLU if enabled
if (activation_type > 0) {
float_out_texel = max(float_out_texel, vec4(0.0));
}
// Requantize to int8
float_out_texel =
round(float_out_texel * output_inv_scale) + output_zp;
Expand All @@ -216,6 +221,10 @@ void main() {
input_zp_vec * weight_sums[n4] + out_accum[m][n4];
vec4 float_out_texel =
vec4(accum_adjusted) * vec4(weight_scales[n4] * input_scale);
// Apply ReLU if enabled
if (activation_type > 0) {
float_out_texel = max(float_out_texel, vec4(0.0));
}
// Requantize to int8
float_out_texel =
round(float_out_texel * output_inv_scale) + output_zp;
Expand Down
18 changes: 17 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@

namespace vkcompute {

ActivationType activation_type_from_string(const std::string& activation) {
if (activation == "none") {
return ActivationType::kNone;
} else if (activation == "relu") {
return ActivationType::kRelu;
}
VK_THROW("Unknown activation type: ", activation);
}

bool q8ta_conv2d_check_packed_dim_info(const api::PackedDimInfo& info) {
return info.packed_dim == WHCN::kChannelsDim &&
info.packed_dim_block_size == 4 &&
Expand Down Expand Up @@ -231,6 +240,7 @@ void add_q8ta_conv2d_node(
const ValueRef padding,
const ValueRef dilation,
const ValueRef groups,
const uint32_t activation_type,
const ValueRef packed_int8_output) {
(void)packed_int8_input_im2col; // Not used in general shader

Expand Down Expand Up @@ -288,9 +298,10 @@ void add_q8ta_conv2d_node(
graph.buffer_meta_ubo(packed_int8_input),
graph.create_params_buffer(conv_params)};

// Build spec constants: apply_bias + layout constants
// Build spec constants: apply_bias, apply_relu + layout constants
vkapi::SpecVarList spec_constants = {
apply_bias,
activation_type,
// Layout specialization constants
graph.hashed_layout_of(packed_int8_input),
graph.hashed_layout_of(packed_int8_output),
Expand Down Expand Up @@ -341,8 +352,12 @@ void q8ta_conv2d_general(
const ValueRef padding = args.at(idx++);
const ValueRef dilation = args.at(idx++);
const ValueRef groups = args.at(idx++);
const ValueRef activation = args.at(idx++);
const ValueRef packed_int8_output = args.at(idx++);

uint32_t activation_type_val = static_cast<uint32_t>(
activation_type_from_string(graph.extract_string(activation)));

QuantizationConfig weight_quant_config(8, kPerChannel, {});

// Prepack weight using the conv2d weight packing for the general shader
Expand Down Expand Up @@ -397,6 +412,7 @@ void q8ta_conv2d_general(
padding,
dilation,
groups,
activation_type_val,
packed_int8_output);
}

Expand Down
10 changes: 10 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@

namespace vkcompute {

enum class ActivationType : uint32_t {
kNone = 0,
kRelu = 1,
};

ActivationType activation_type_from_string(const std::string& activation);

bool q8ta_conv2d_check_packed_dim_info(const api::PackedDimInfo& info);

bool q8ta_conv2d_check_4w4c_packed_dim_info(const api::PackedDimInfo& info);
Expand Down Expand Up @@ -58,6 +65,7 @@ void add_q8ta_conv2d_dw_node(
const ValueRef padding,
const ValueRef dilation,
const ValueRef groups,
const uint32_t activation_type,
const ValueRef packed_int8_output);

void add_conv2d_dw_q8ta_q8csw_q8to_4w4c_node(
Expand Down Expand Up @@ -97,6 +105,7 @@ void add_q8ta_conv2d_node(
const ValueRef padding,
const ValueRef dilation,
const ValueRef groups,
const uint32_t activation_type,
const ValueRef packed_int8_output);

void add_q8ta_conv2d_pw_node(
Expand All @@ -111,6 +120,7 @@ void add_q8ta_conv2d_pw_node(
const ValueRef output_zp,
const ValueRef bias_data,
const ValueRef packed_bias,
const uint32_t activation_type,
const ValueRef packed_int8_output);

void q8ta_conv2d_im2col(ComputeGraph& graph, const std::vector<ValueRef>& args);
Expand Down
9 changes: 8 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ void add_q8ta_conv2d_dw_node(
const ValueRef padding,
const ValueRef dilation,
const ValueRef groups,
const uint32_t activation_type,
const ValueRef packed_int8_output) {
Conv2DParams conv_params = create_conv2d_params(
graph,
Expand Down Expand Up @@ -334,9 +335,10 @@ void add_q8ta_conv2d_dw_node(
graph.buffer_meta_ubo(packed_int8_input),
graph.create_params_buffer(conv_params)};

// Build spec constants: apply_bias + layout constants
// Build spec constants: apply_bias, activation_type + layout constants
vkapi::SpecVarList spec_constants = {
apply_bias,
activation_type,
// Layout specialization constants
graph.hashed_layout_of(packed_int8_input),
graph.hashed_layout_of(packed_int8_output),
Expand Down Expand Up @@ -385,8 +387,12 @@ void q8ta_conv2d_dw(ComputeGraph& graph, const std::vector<ValueRef>& args) {
const ValueRef padding = args.at(idx++);
const ValueRef dilation = args.at(idx++);
const ValueRef groups = args.at(idx++);
const ValueRef activation = args.at(idx++);
const ValueRef packed_int8_output = args.at(idx++);

uint32_t activation_type_val = static_cast<uint32_t>(
activation_type_from_string(graph.extract_string(activation)));

QuantizationConfig weight_quant_config(8, kPerChannel, {});

// Prepack weight using depthwise-specific packing
Expand Down Expand Up @@ -432,6 +438,7 @@ void q8ta_conv2d_dw(ComputeGraph& graph, const std::vector<ValueRef>& args) {
padding,
dilation,
groups,
activation_type_val,
packed_int8_output);
}

Expand Down
5 changes: 5 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ void q8ta_conv2d_im2col(
const ValueRef padding = args.at(idx++);
const ValueRef dilation = args.at(idx++);
const ValueRef groups = args.at(idx++);
const ValueRef activation = args.at(idx++);
const ValueRef packed_int8_output = args.at(idx++);

QuantizationConfig weight_quant_config(8, kPerChannel, {});
Expand Down Expand Up @@ -225,6 +226,9 @@ void q8ta_conv2d_im2col(
prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked);
}

uint32_t activation_type_val = static_cast<uint32_t>(
activation_type_from_string(graph.extract_string(activation)));

// Calculate im2col output sizes
std::vector<int64_t> im2col_sizes = calculate_q8ta_im2col_sizes(
&graph, packed_int8_input, packed_int8_output, kernel_size, groups);
Expand Down Expand Up @@ -265,6 +269,7 @@ void q8ta_conv2d_im2col(
output_zp,
bias_data,
packed_bias,
activation_type_val,
packed_int8_output);
}

Expand Down
9 changes: 8 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ void add_q8ta_conv2d_pw_node(
const ValueRef output_zp,
const ValueRef bias_data,
const ValueRef packed_bias,
const uint32_t activation_type,
const ValueRef packed_int8_output) {
// Validate packed dim info for input and output tensors
// To maximize performance, the input tensor must be in 4W4C layout
Expand Down Expand Up @@ -242,9 +243,10 @@ void add_q8ta_conv2d_pw_node(
graph.buffer_meta_ubo(packed_int8_output),
graph.buffer_meta_ubo(packed_int8_input)};

// Build spec constants: apply_bias + layout constants
// Build spec constants: apply_bias, activation_type + layout constants
vkapi::SpecVarList spec_constants = {
apply_bias,
activation_type,
K4_per_group,
// Layout specialization constants
graph.hashed_layout_of(packed_int8_output),
Expand Down Expand Up @@ -296,8 +298,12 @@ void q8ta_conv2d_pw(ComputeGraph& graph, const std::vector<ValueRef>& args) {
(void)args.at(idx++); // padding
(void)args.at(idx++); // dilation
(void)args.at(idx++); // groups
const ValueRef activation_ref = args.at(idx++);
const ValueRef packed_int8_output = args.at(idx++);

uint32_t activation_type_val = static_cast<uint32_t>(
activation_type_from_string(graph.extract_string(activation_ref)));

QuantizationConfig weight_quant_config(8, kPerChannel, {});

// Prepack weight using pointwise-specific packing
Expand Down Expand Up @@ -342,6 +348,7 @@ void q8ta_conv2d_pw(ComputeGraph& graph, const std::vector<ValueRef>& args) {
output_zp,
bias_data,
packed_bias,
activation_type_val,
packed_int8_output);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,7 @@ void add_conv2d_q8ta_q8csw_q8to_node(
padding,
dilation,
groups,
static_cast<uint32_t>(ActivationType::kNone),
packed_int8_output);
}
}
Expand Down
Loading
Loading