Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/cadence/aot/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ fbcode_target(_kind = runtime.python_library,
typing = True,
deps = [
"fbcode//caffe2:torch",
"fbcode//executorch/backends/cadence/aot:utils",
"fbcode//executorch/exir:scalar_type",
"fbcode//executorch/kernels/quantized:custom_ops_generated_lib",
],
Expand Down Expand Up @@ -374,6 +375,7 @@ fbcode_target(_kind = runtime.python_library,
deps = [
"//caffe2:torch",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/backends/cadence/aot:utils",
"//executorch/exir:pass_base",
],
)
Expand Down
30 changes: 9 additions & 21 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
get_conv1d_output_size,
get_conv2d_output_size,
get_im2row_output_size,
is_depthwise_conv,
)
from executorch.exir.scalar_type import ScalarType
from torch._meta_registrations import _linalg_svd_meta
Expand Down Expand Up @@ -1034,11 +1035,8 @@ def quantized_conv2d_nhwc_meta(
assert len(in_size) < 6

# Determine weight layout based on depthwise vs regular conv.
# Depthwise is defined by in_channels == groups, where in_channels
# is the last dim of the NHWC input.
in_channels = in_size[-1]
is_depthwise = in_channels == groups
if is_depthwise:
if is_depthwise_conv(groups, in_channels):
# Depthwise conv: weight is [*kernel_size, OC]
*kernel_size, out_channels = weight.shape
else:
Expand Down Expand Up @@ -1177,12 +1175,8 @@ def quantized_conv2d_nhwc_per_tensor_meta(
assert len(in_size) < 6

# Determine weight layout based on depthwise vs regular conv.
# Depthwise is defined by in_channels == groups, where in_channels
# is the last dim of the NHWC input.
in_channels = in_size[-1]
is_depthwise = in_channels == groups
if is_depthwise:
# Depthwise conv: weight is [*kernel_size, OC]
if is_depthwise_conv(groups, in_channels):
*kernel_size, out_channels = weight.shape
elif len(in_size) == 3:
# 1D conv: weight is [OC, K, IC]
Expand Down Expand Up @@ -1336,12 +1330,9 @@ def quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_meta(
assert len(in_size) > 2
assert len(in_size) < 6

# Determine weight layout based on input and weight dimensions:
# - Depthwise conv: input is 3D/4D, weight is 2/3D [K, OC]/[KH, KW, OC]
# - 1D conv: input is 3D, weight is 3D [OC, K, IC]
# - 2D regular conv: input is 4D, weight is 4D [OC, KH, KW, IC]
if len(weight.shape) == 3:
# 2D depthwise conv: weight is [KH, KW, OC]
# Determine weight layout based on depthwise vs regular conv.
in_channels = in_size[-1]
if is_depthwise_conv(groups, in_channels):
*kernel_size, out_channels = weight.shape
elif len(in_size) == 3:
# 1D conv: weight is [OC, K, IC]
Expand Down Expand Up @@ -1397,12 +1388,9 @@ def quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor_meta(
assert len(in_size) > 2
assert len(in_size) < 6

# Determine weight layout based on input and weight dimensions:
# - Depthwise conv: input is 3D/4D, weight is 3D [KH, KW, OC]
# - 1D conv: input is 3D, weight is 3D [OC, K, IC]
# - 2D regular conv: input is 4D, weight is 4D [OC, KH, KW, IC]
if len(weight.shape) == 3:
# 2D depthwise conv: weight is [KH, KW, OC]
# Determine weight layout based on depthwise vs regular conv.
in_channels = in_size[-1]
if is_depthwise_conv(groups, in_channels):
*kernel_size, out_channels = weight.shape
elif len(in_size) == 3:
# 1D conv: weight is [OC, K, IC]
Expand Down
12 changes: 4 additions & 8 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from executorch.backends.cadence.aot.utils import is_depthwise_conv
from executorch.exir.scalar_type import ScalarType
from torch.library import impl, Library

Expand Down Expand Up @@ -1104,17 +1105,12 @@ def quantized_conv2d_nhwc_per_tensor(

# Convert to NCHW format to reuse the existing implementation
in_channels = input_tensor.shape[-1]
# Depthwise weights have one fewer dimension than the input because the IC
# dimension (always 1) was squeezed out during the NCHW->NHWC conversion in
# replace_ops.py. E.g. 2D depthwise: weight is [KH, KW, OC] (3D) while
# input is [N, H, W, C] (4D). A regular conv with in_channels==groups==1
# still has 4D weights [OC, KH, KW, IC].
is_depthwise = in_channels == groups and weight.dim() < input_tensor.dim()
depthwise = is_depthwise_conv(groups, in_channels)

if len(input_tensor.shape) == 3:
# 1D conv: input is [N, L, C] -> [N, C, L]
input_tensor = input_tensor.movedim(-1, 1).contiguous()
if is_depthwise:
if depthwise:
# 1D depthwise: weight is [K, OC] -> [OC, 1, K]
weight = weight.permute(1, 0).unsqueeze(1).contiguous()
else:
Expand All @@ -1124,7 +1120,7 @@ def quantized_conv2d_nhwc_per_tensor(
else:
# 2D conv: input is [N, H, W, C] -> [N, C, H, W]
input_tensor = input_tensor.movedim(-1, -3)
if is_depthwise:
if depthwise:
# 2D depthwise: weight is [KH, KW, OC] -> [OC, 1, KH, KW]
weight = weight.permute(2, 0, 1).unsqueeze(1).contiguous()
else:
Expand Down
9 changes: 5 additions & 4 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
register_cadence_pass,
RemoveOrReplacePassInterface,
)
from executorch.backends.cadence.aot.utils import is_depthwise_conv
from executorch.backends.transforms.replace_scalar_with_tensor import (
ReplaceScalarWithTensorArgPass,
)
Expand Down Expand Up @@ -1138,19 +1139,19 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:

# Check if this is a depthwise convolution (groups == input_channels)
# and weight is 4D with shape [OC, 1, KH, KW]
groups = node.args[6]
groups = cast(int, node.args[6])
input_shape = input_node.meta["val"].shape
weight_shape = weight_node.meta["val"].shape
input_channels = input_shape[1] # NCHW format, channels at index 1
# Depthwise conv has 4D weight [OC, 1, KH, KW] where the IC dim is 1
is_depthwise = groups == input_channels and weight_shape[1] == 1
# NCHW: also verify weight IC dim == 1.
depthwise = is_depthwise_conv(groups, input_channels) and weight_shape[1] == 1
is_2d = len(input_shape) == 4
# Insert transpose operations before the node
with graph.inserting_before(node):
# Convert input from NCHW to NHWC
input_nhwc = self._change_nchw_to_nhwc(graph, input_node)
# Convert weight from NCHW to the appropriate format
if is_depthwise:
if depthwise:
# For depthwise: [OC, 1, KH, KW] -> [KH, KW, OC] for NNLib
weight_nhwc = self._change_depthwise_weight_to_hwc(
graph, weight_node, is_2d
Expand Down
15 changes: 13 additions & 2 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import torch
from executorch.backends.cadence.aot.typing_stubs import expand
from executorch.backends.cadence.aot.utils import is_depthwise_conv

from executorch.exir.scalar_type import ScalarType

Expand Down Expand Up @@ -942,12 +943,22 @@ def test_quantized_conv_per_tensor(
assert memory_format in [torch.contiguous_format, torch.channels_last]

if memory_format == torch.channels_last:
in_channels = input_tensor.shape[1] # NCHW still at this point
depthwise = is_depthwise_conv(groups, in_channels)
if input_tensor.ndim == 3:
input_tensor = input_tensor.movedim(1, -1)
weight = weight.movedim(1, -1)
if depthwise:
# [OC, 1, K] -> [K, OC] (squeeze IC, move OC to end)
weight = weight.squeeze(1).movedim(0, -1)
else:
weight = weight.movedim(1, -1)
else:
input_tensor = input_tensor.movedim(-3, -1)
weight = weight.movedim(-3, -1)
if depthwise:
# [OC, 1, KH, KW] -> [KH, KW, OC] (squeeze IC, move OC to end)
weight = weight.squeeze(1).movedim(0, -1)
else:
weight = weight.movedim(-3, -1)

convs = [
(
Expand Down
Loading
Loading