diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 700b58f6c85..66f7a277672 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -614,7 +614,7 @@ def _tosa_pipeline( RewriteMatmulPass(), RewritePadPass(), FuseViewCopyTransformPass(), - RemovePermutesAroundElementwiseTosaOps(), + RemovePermutesAroundElementwiseTosaOps(exported_program), PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(), FuseCascadedTransposeOrPermuteOps(), ConvertPermuteSingletonToViewPass(), diff --git a/backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py b/backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py index 72688d17ef2..b241038f7a9 100644 --- a/backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py +++ b/backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py @@ -3,15 +3,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torch + +from executorch.backends.arm._passes.arm_pass_utils import is_param_node from executorch.backends.arm._passes.insert_table_ops import TableOps from executorch.backends.transforms.remove_permutes_around_elementwise_ops import ( RemovePermutesAroundElementwiseOps, ) +from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops class RemovePermutesAroundElementwiseTosaOps(RemovePermutesAroundElementwiseOps): - def __init__(self) -> None: + def __init__(self, exported_program: ExportedProgram) -> None: super().__init__( extra_permutable_ops={ *TableOps.unary_table_ops.keys(), @@ -20,16 +24,19 @@ def __init__(self) -> None: exir_ops.backend.tosa.TABLE.default, } ) + self.exported_program = exported_program + + def _is_constant(self, node: torch.fx.Node) -> bool: + # Override fragile string match check with exported program check + return super()._is_constant(node) or is_param_node(self.exported_program, node) def permute_subgraph(self, subgraph) -> bool: - # Original function will always permute constant nodes which is wrong for table ops - # Remove constant tosa.TABLE edges before running full function + # TABLE lookup inputs are already tied to the table layout. new_constant_edges_in = set() for const_node, user_node in subgraph.constant_edges_in: if user_node.target == exir_ops.backend.tosa.TABLE.default: continue - else: - new_constant_edges_in.add((const_node, user_node)) + new_constant_edges_in.add((const_node, user_node)) subgraph.constant_edges_in = new_constant_edges_in return super().permute_subgraph(subgraph) diff --git a/backends/arm/test/misc/test_transpose_counts.py b/backends/arm/test/misc/test_transpose_counts.py index 8ce032058bf..7aae0555dc6 100644 --- a/backends/arm/test/misc/test_transpose_counts.py +++ b/backends/arm/test/misc/test_transpose_counts.py @@ -453,7 +453,7 @@ def forward(self, x: torch.Tensor): Model4ConvLstmLinearLayerNorm(), (torch.randn(2, 8, 32),), 3 ), "model_5_dwconv_gelu_layernorm_avgpool": TransposeCountCase( - Model5DwConvGeluLayerNormAvgPool(), (torch.randn(1, 8, 16, 16),), 4 + Model5DwConvGeluLayerNormAvgPool(), (torch.randn(1, 8, 16, 16),), 2 ), "model_6_gru_linear": TransposeCountCase( Model6GruLinear(), (torch.randn(2, 16, 8),), 2 diff --git a/backends/arm/test/passes/test_remove_permutes_around_elementwise_tosa_ops.py b/backends/arm/test/passes/test_remove_permutes_around_elementwise_tosa_ops.py index 341d985134e..b94a70fdec9 100644 --- a/backends/arm/test/passes/test_remove_permutes_around_elementwise_tosa_ops.py +++ b/backends/arm/test/passes/test_remove_permutes_around_elementwise_tosa_ops.py @@ -3,6 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from types import SimpleNamespace +from typing import cast + import torch from executorch.backends.arm._passes.remove_permutes_around_elementwise_tosa_ops import ( RemovePermutesAroundElementwiseTosaOps, @@ -11,12 +14,30 @@ TosaLoweringContext, TosaSpecification, ) +from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops TOSA_INT_SPEC = TosaSpecification.create_from_string("TOSA-1.0+INT") +TOSA_FP_SPEC = TosaSpecification.create_from_string("TOSA-1.0+FP") PERMUTE_TARGET = exir_ops.edge.aten.permute_copy.default RESCALE_TARGET = exir_ops.backend.tosa.RESCALE.default TABLE_TARGET = exir_ops.backend.tosa.TABLE.default +MUL_TARGET = exir_ops.edge.aten.mul.Tensor +ADD_TARGET = exir_ops.edge.aten.add.Tensor +ERF_TARGET = exir_ops.edge.aten.erf.default + + +def _fake_exported_program() -> ExportedProgram: + return cast( + ExportedProgram, + SimpleNamespace( + graph_signature=SimpleNamespace( + inputs_to_buffers={}, + inputs_to_lifted_tensor_constants={}, + inputs_to_parameters={}, + ) + ), + ) def _count_nodes(graph_module: torch.fx.GraphModule, target) -> int: @@ -52,8 +73,125 @@ def test_remove_permutes_around_rescale_tosa_INT() -> None: graph_module = torch.fx.GraphModule({}, graph) with TosaLoweringContext(TOSA_INT_SPEC): - result = RemovePermutesAroundElementwiseTosaOps().call(graph_module) + result = RemovePermutesAroundElementwiseTosaOps(_fake_exported_program()).call( + graph_module + ) assert result.modified assert _count_nodes(result.graph_module, PERMUTE_TARGET) == 0 assert _count_nodes(result.graph_module, RESCALE_TARGET) == 1 + + +def test_remove_permutes_around_gelu_with_folded_scalar_constants_tosa_FP() -> None: + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.randn(1, 2, 3, 4) + + scalar_constants = [] + for i in range(3): + const = graph.placeholder(f"c_scalar_{i}") + const.meta["val"] = torch.randn(1, 1, 1, 1) + scalar_constants.append(const) + + permute_in = graph.create_node( + "call_function", + PERMUTE_TARGET, + args=(x, [0, 2, 3, 1]), + ) + permute_in.meta["val"] = torch.randn(1, 3, 4, 2) + mul_0 = graph.create_node( + "call_function", + MUL_TARGET, + args=(permute_in, scalar_constants[0]), + ) + mul_0.meta["val"] = torch.randn(1, 3, 4, 2) + erf = graph.create_node("call_function", ERF_TARGET, args=(mul_0,)) + erf.meta["val"] = torch.randn(1, 3, 4, 2) + add = graph.create_node( + "call_function", + ADD_TARGET, + args=(erf, scalar_constants[1]), + ) + add.meta["val"] = torch.randn(1, 3, 4, 2) + mul_1 = graph.create_node( + "call_function", + MUL_TARGET, + args=(add, scalar_constants[2]), + ) + mul_1.meta["val"] = torch.randn(1, 3, 4, 2) + mul_2 = graph.create_node( + "call_function", + MUL_TARGET, + args=(permute_in, mul_1), + ) + mul_2.meta["val"] = torch.randn(1, 3, 4, 2) + permute_out = graph.create_node( + "call_function", + PERMUTE_TARGET, + args=(mul_2, [0, 3, 1, 2]), + ) + permute_out.meta["val"] = torch.randn(1, 2, 3, 4) + graph.output(permute_out) + + graph_module = torch.fx.GraphModule({}, graph) + + with TosaLoweringContext(TOSA_FP_SPEC): + result = RemovePermutesAroundElementwiseTosaOps(_fake_exported_program()).call( + graph_module + ) + + assert result.modified + assert _count_nodes(result.graph_module, PERMUTE_TARGET) == 3 + assert _count_nodes(result.graph_module, ERF_TARGET) == 1 + + +def test_remove_permutes_skips_stale_shared_boundary_subgraph_tosa_FP() -> None: + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.randn(1, 16, 16, 8) + + channel_const = graph.placeholder("p_layer_norm_weight") + channel_const.meta["val"] = torch.randn(1, 1, 1, 8) + + permute_in = graph.create_node( + "call_function", + PERMUTE_TARGET, + args=(x, [0, 3, 1, 2]), + ) + permute_in.meta["val"] = torch.randn(1, 8, 16, 16) + first_mul = graph.create_node( + "call_function", + MUL_TARGET, + args=(permute_in, permute_in), + ) + first_mul.meta["val"] = torch.randn(1, 8, 16, 16) + shared_permute = graph.create_node( + "call_function", + PERMUTE_TARGET, + args=(first_mul, [0, 2, 3, 1]), + ) + shared_permute.meta["val"] = torch.randn(1, 16, 16, 8) + second_mul = graph.create_node( + "call_function", + MUL_TARGET, + args=(shared_permute, channel_const), + ) + second_mul.meta["val"] = torch.randn(1, 16, 16, 8) + permute_out = graph.create_node( + "call_function", + PERMUTE_TARGET, + args=(second_mul, [0, 3, 1, 2]), + ) + permute_out.meta["val"] = torch.randn(1, 8, 16, 16) + graph.output(permute_out) + + graph_module = torch.fx.GraphModule({}, graph) + + with TosaLoweringContext(TOSA_FP_SPEC): + result = RemovePermutesAroundElementwiseTosaOps(_fake_exported_program()).call( + graph_module + ) + + assert result.modified + assert _count_nodes(result.graph_module, PERMUTE_TARGET) == 1 + assert second_mul.args[1] is channel_const diff --git a/backends/transforms/remove_permutes_around_elementwise_ops.py b/backends/transforms/remove_permutes_around_elementwise_ops.py index b992afaeb53..373f84230cb 100644 --- a/backends/transforms/remove_permutes_around_elementwise_ops.py +++ b/backends/transforms/remove_permutes_around_elementwise_ops.py @@ -400,6 +400,10 @@ def is_node_permutable(self, node: torch.fx.Node) -> bool: return self._is_pointwise(node.target) def permute_subgraph(self, subgraph: Subgraph) -> bool: # noqa: C901 + # Ensure that the subgraph's edges have not been modified by an earlier rewrite before applying changes. + if not self._subgraph_edges_are_current(subgraph): + return False + # Validate: every view_copy node's permutation rank must match its # input tensor rank. A mismatch can occur when a squeeze/unsqueeze # view is reached via upstream traversal with a permutation that was @@ -495,6 +499,28 @@ def permute_subgraph(self, subgraph: Subgraph) -> bool: # noqa: C901 return True + def _subgraph_edges_are_current(self, subgraph: Subgraph) -> bool: + """Return false if an earlier rewrite invalidated this candidate.""" + for inp, out in subgraph.edges_in: + if ( + inp.target != exir_ops.edge.aten.permute_copy.default + or inp not in out.all_input_nodes + ): + return False + + for inp, out in subgraph.edges_out: + if ( + out.target != exir_ops.edge.aten.permute_copy.default + or out not in inp.users + ): + return False + + for const_node, user_node in subgraph.constant_edges_in: + if const_node not in user_node.all_input_nodes: + return False + + return True + def update_cat(self, node: torch.fx.Node, start_permute: list[int]) -> None: dim = get_arg(node, "dim", int) set_arg(node, "dim", start_permute[dim])