diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl index 60f437fbdce..be93e800436 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl @@ -46,6 +46,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "out_layout", "CONTIG_LAYOUT_INT")} ${layout_declare_spec_const(C, "int", "in_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "other_layout", "CONTIG_LAYOUT_INT")} ${layout_declare_spec_const(C, "int", "block_config", "0")} // Generate loading functions for input buffers @@ -71,7 +72,7 @@ void main() { ivec4 in_block_a = load_int8x4_block_from_t_in_a( in_a_meta, tidx, in_layout, block_outer_dim); ivec4 in_block_b = load_int8x4_block_from_t_in_b( - in_b_meta, tidx, in_layout, block_outer_dim); + in_b_meta, tidx, other_layout, block_outer_dim); ivec4 out_block; diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp index af934b9b521..05bdd9431c8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp @@ -42,6 +42,7 @@ void add_q8ta_binary_node( VK_CHECK_COND(input_a_info.packed_dim == output_info.packed_dim); VK_CHECK_COND(input_b_info.packed_dim == output_info.packed_dim); + VK_CHECK_COND( input_a_info.packed_dim_block_size == output_info.packed_dim_block_size); VK_CHECK_COND( @@ -105,6 +106,7 @@ void add_q8ta_binary_node( // Specialization Constants {graph.hashed_layout_of(packed_int8_output), graph.hashed_layout_of(packed_int8_input_a), + graph.hashed_layout_of(packed_int8_input_b), block_config.as_packed_int()}, // Resize args {block_config_ref},