Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ def _tosa_pipeline(
RewriteMatmulPass(),
RewritePadPass(),
FuseViewCopyTransformPass(),
RemovePermutesAroundElementwiseTosaOps(),
RemovePermutesAroundElementwiseTosaOps(exported_program),
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(),
FuseCascadedTransposeOrPermuteOps(),
ConvertPermuteSingletonToViewPass(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

from executorch.backends.arm._passes.arm_pass_utils import is_param_node
from executorch.backends.arm._passes.insert_table_ops import TableOps
from executorch.backends.transforms.remove_permutes_around_elementwise_ops import (
RemovePermutesAroundElementwiseOps,
)
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops


class RemovePermutesAroundElementwiseTosaOps(RemovePermutesAroundElementwiseOps):
def __init__(self) -> None:
def __init__(self, exported_program: ExportedProgram) -> None:
super().__init__(
extra_permutable_ops={
*TableOps.unary_table_ops.keys(),
Expand All @@ -20,16 +24,19 @@ def __init__(self) -> None:
exir_ops.backend.tosa.TABLE.default,
}
)
self.exported_program = exported_program

def _is_constant(self, node: torch.fx.Node) -> bool:
# Override fragile string match check with exported program check
return super()._is_constant(node) or is_param_node(self.exported_program, node)

def permute_subgraph(self, subgraph) -> bool:
# Original function will always permute constant nodes which is wrong for table ops
# Remove constant tosa.TABLE edges before running full function
# TABLE lookup inputs are already tied to the table layout.
new_constant_edges_in = set()
for const_node, user_node in subgraph.constant_edges_in:
if user_node.target == exir_ops.backend.tosa.TABLE.default:
continue
else:
new_constant_edges_in.add((const_node, user_node))
new_constant_edges_in.add((const_node, user_node))

subgraph.constant_edges_in = new_constant_edges_in
return super().permute_subgraph(subgraph)
2 changes: 1 addition & 1 deletion backends/arm/test/misc/test_transpose_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def forward(self, x: torch.Tensor):
Model4ConvLstmLinearLayerNorm(), (torch.randn(2, 8, 32),), 3
),
"model_5_dwconv_gelu_layernorm_avgpool": TransposeCountCase(
Model5DwConvGeluLayerNormAvgPool(), (torch.randn(1, 8, 16, 16),), 4
Model5DwConvGeluLayerNormAvgPool(), (torch.randn(1, 8, 16, 16),), 2
),
"model_6_gru_linear": TransposeCountCase(
Model6GruLinear(), (torch.randn(2, 16, 8),), 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from types import SimpleNamespace
from typing import cast

import torch
from executorch.backends.arm._passes.remove_permutes_around_elementwise_tosa_ops import (
RemovePermutesAroundElementwiseTosaOps,
Expand All @@ -11,12 +14,30 @@
TosaLoweringContext,
TosaSpecification,
)
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops

TOSA_INT_SPEC = TosaSpecification.create_from_string("TOSA-1.0+INT")
TOSA_FP_SPEC = TosaSpecification.create_from_string("TOSA-1.0+FP")
PERMUTE_TARGET = exir_ops.edge.aten.permute_copy.default
RESCALE_TARGET = exir_ops.backend.tosa.RESCALE.default
TABLE_TARGET = exir_ops.backend.tosa.TABLE.default
MUL_TARGET = exir_ops.edge.aten.mul.Tensor
ADD_TARGET = exir_ops.edge.aten.add.Tensor
ERF_TARGET = exir_ops.edge.aten.erf.default
Comment on lines 22 to +27


def _fake_exported_program() -> ExportedProgram:
return cast(
ExportedProgram,
SimpleNamespace(
graph_signature=SimpleNamespace(
inputs_to_buffers={},
inputs_to_lifted_tensor_constants={},
inputs_to_parameters={},
)
),
)


def _count_nodes(graph_module: torch.fx.GraphModule, target) -> int:
Expand Down Expand Up @@ -52,8 +73,125 @@ def test_remove_permutes_around_rescale_tosa_INT() -> None:
graph_module = torch.fx.GraphModule({}, graph)

with TosaLoweringContext(TOSA_INT_SPEC):
result = RemovePermutesAroundElementwiseTosaOps().call(graph_module)
result = RemovePermutesAroundElementwiseTosaOps(_fake_exported_program()).call(
graph_module
)

assert result.modified
assert _count_nodes(result.graph_module, PERMUTE_TARGET) == 0
assert _count_nodes(result.graph_module, RESCALE_TARGET) == 1


def test_remove_permutes_around_gelu_with_folded_scalar_constants_tosa_FP() -> None:
graph = torch.fx.Graph()
x = graph.placeholder("x")
x.meta["val"] = torch.randn(1, 2, 3, 4)

scalar_constants = []
for i in range(3):
const = graph.placeholder(f"c_scalar_{i}")
const.meta["val"] = torch.randn(1, 1, 1, 1)
scalar_constants.append(const)

permute_in = graph.create_node(
"call_function",
PERMUTE_TARGET,
args=(x, [0, 2, 3, 1]),
)
permute_in.meta["val"] = torch.randn(1, 3, 4, 2)
mul_0 = graph.create_node(
"call_function",
MUL_TARGET,
args=(permute_in, scalar_constants[0]),
)
mul_0.meta["val"] = torch.randn(1, 3, 4, 2)
erf = graph.create_node("call_function", ERF_TARGET, args=(mul_0,))
erf.meta["val"] = torch.randn(1, 3, 4, 2)
add = graph.create_node(
"call_function",
ADD_TARGET,
args=(erf, scalar_constants[1]),
)
add.meta["val"] = torch.randn(1, 3, 4, 2)
mul_1 = graph.create_node(
"call_function",
MUL_TARGET,
args=(add, scalar_constants[2]),
)
mul_1.meta["val"] = torch.randn(1, 3, 4, 2)
mul_2 = graph.create_node(
"call_function",
MUL_TARGET,
args=(permute_in, mul_1),
)
mul_2.meta["val"] = torch.randn(1, 3, 4, 2)
permute_out = graph.create_node(
"call_function",
PERMUTE_TARGET,
args=(mul_2, [0, 3, 1, 2]),
)
permute_out.meta["val"] = torch.randn(1, 2, 3, 4)
graph.output(permute_out)

graph_module = torch.fx.GraphModule({}, graph)

with TosaLoweringContext(TOSA_FP_SPEC):
result = RemovePermutesAroundElementwiseTosaOps(_fake_exported_program()).call(
graph_module
)

assert result.modified
assert _count_nodes(result.graph_module, PERMUTE_TARGET) == 3
assert _count_nodes(result.graph_module, ERF_TARGET) == 1


def test_remove_permutes_skips_stale_shared_boundary_subgraph_tosa_FP() -> None:
graph = torch.fx.Graph()
x = graph.placeholder("x")
x.meta["val"] = torch.randn(1, 16, 16, 8)

channel_const = graph.placeholder("p_layer_norm_weight")
channel_const.meta["val"] = torch.randn(1, 1, 1, 8)

permute_in = graph.create_node(
"call_function",
PERMUTE_TARGET,
args=(x, [0, 3, 1, 2]),
)
permute_in.meta["val"] = torch.randn(1, 8, 16, 16)
first_mul = graph.create_node(
"call_function",
MUL_TARGET,
args=(permute_in, permute_in),
)
first_mul.meta["val"] = torch.randn(1, 8, 16, 16)
shared_permute = graph.create_node(
"call_function",
PERMUTE_TARGET,
args=(first_mul, [0, 2, 3, 1]),
)
shared_permute.meta["val"] = torch.randn(1, 16, 16, 8)
second_mul = graph.create_node(
"call_function",
MUL_TARGET,
args=(shared_permute, channel_const),
)
second_mul.meta["val"] = torch.randn(1, 16, 16, 8)
permute_out = graph.create_node(
"call_function",
PERMUTE_TARGET,
args=(second_mul, [0, 3, 1, 2]),
)
permute_out.meta["val"] = torch.randn(1, 8, 16, 16)
graph.output(permute_out)

graph_module = torch.fx.GraphModule({}, graph)

with TosaLoweringContext(TOSA_FP_SPEC):
result = RemovePermutesAroundElementwiseTosaOps(_fake_exported_program()).call(
graph_module
)

assert result.modified
assert _count_nodes(result.graph_module, PERMUTE_TARGET) == 1
assert second_mul.args[1] is channel_const
26 changes: 26 additions & 0 deletions backends/transforms/remove_permutes_around_elementwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,10 @@ def is_node_permutable(self, node: torch.fx.Node) -> bool:
return self._is_pointwise(node.target)

def permute_subgraph(self, subgraph: Subgraph) -> bool: # noqa: C901
# Ensure that the subgraph's edges have not been modified by an earlier rewrite before applying changes.
if not self._subgraph_edges_are_current(subgraph):
return False

# Validate: every view_copy node's permutation rank must match its
# input tensor rank. A mismatch can occur when a squeeze/unsqueeze
# view is reached via upstream traversal with a permutation that was
Expand Down Expand Up @@ -495,6 +499,28 @@ def permute_subgraph(self, subgraph: Subgraph) -> bool: # noqa: C901

return True

def _subgraph_edges_are_current(self, subgraph: Subgraph) -> bool:
"""Return false if an earlier rewrite invalidated this candidate."""
for inp, out in subgraph.edges_in:
if (
inp.target != exir_ops.edge.aten.permute_copy.default
or inp not in out.all_input_nodes
):
return False

for inp, out in subgraph.edges_out:
if (
out.target != exir_ops.edge.aten.permute_copy.default
or out not in inp.users
):
return False

for const_node, user_node in subgraph.constant_edges_in:
if const_node not in user_node.all_input_nodes:
return False

return True

def update_cat(self, node: torch.fx.Node, start_permute: list[int]) -> None:
dim = get_arg(node, "dim", int)
set_arg(node, "dim", start_permute[dim])
Expand Down
Loading