From 8e55dc6966d1b4148b3208c0cc891749b6c322d8 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 7 Apr 2026 22:31:43 -0700 Subject: [PATCH] [ET Device Support] Add ExecutorchBackendConfig flags for skipping H2D/D2H copies Add skip_h2d_for_method_inputs and skip_d2h_for_method_outputs config flags to ExecutorchBackendConfig. These control whether PropagateDevicePass skips inserting H2D/D2H copy ops at method I/O boundaries: - skip_h2d_for_method_inputs: user provides GPU tensor directly - skip_d2h_for_method_outputs: output stays on device for cross-method pipelines Differential Revision: [D99636778](https://our.internmc.facebook.com/intern/diff/D99636778/) [ghstack-poisoned] --- exir/capture/_config.py | 12 + exir/passes/propagate_device_pass.py | 12 + exir/program/_program.py | 5 +- exir/tests/test_propagate_device_pass.py | 419 +++++++++++++++++++++-- 4 files changed, 413 insertions(+), 35 deletions(-) diff --git a/exir/capture/_config.py b/exir/capture/_config.py index 2d6290bdd0b..4ff70095041 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -123,3 +123,15 @@ class ExecutorchBackendConfig: # vs. accelerator memory. Default False preserves the legacy behavior # where all tensors are planned into CPU memory regardless of device. enable_non_cpu_memory_planning: bool = False + + # When True, method-level input tensors that feed directly into a device + # delegate are NOT wrapped with _h2d_copy. The user must provide tensors + # already on the target device. Useful for pipelines where inputs are + # pre-staged on GPU. + skip_h2d_for_method_inputs: bool = False + + # When True, device delegate outputs that are directly method outputs + # are NOT wrapped with _d2h_copy. The method outputs stay on device. + # Useful for cross-method GPU pipelines where the next method consumes + # GPU tensors directly. + skip_d2h_for_method_outputs: bool = False diff --git a/exir/passes/propagate_device_pass.py b/exir/passes/propagate_device_pass.py index 5ed0c20b1bb..3d957a9fbb2 100644 --- a/exir/passes/propagate_device_pass.py +++ b/exir/passes/propagate_device_pass.py @@ -163,8 +163,12 @@ class PropagateDevicePass(PassBase): def __init__( self, + skip_h2d_for_method_inputs: bool = False, + skip_d2h_for_method_outputs: bool = False, ) -> None: super().__init__() + self.skip_h2d_for_method_inputs = skip_h2d_for_method_inputs + self.skip_d2h_for_method_outputs = skip_d2h_for_method_outputs def _is_placeholder(self, node: torch.fx.Node) -> bool: """Check if a node is a graph-level input (placeholder).""" @@ -191,6 +195,11 @@ def _insert_h2d_copies( if not isinstance(arg_spec, TensorSpec): continue + if self.skip_h2d_for_method_inputs and self._is_placeholder(arg): + _set_device_on_spec(arg_spec, target_device_type, device_index) + changed = True + continue + with graph_module.graph.inserting_before(node): h2d_node = graph_module.graph.call_function( torch.ops.et_copy._h2d_copy.default, @@ -241,6 +250,9 @@ def _insert_d2h_for_getitem( _set_device_on_spec(spec, source_spec.device, source_spec.device_index) + if self.skip_d2h_for_method_outputs and self._feeds_directly_to_output(node): + return True + with graph_module.graph.inserting_after(node): d2h_node = graph_module.graph.call_function( torch.ops.et_copy._d2h_copy.default, diff --git a/exir/program/_program.py b/exir/program/_program.py index 8f0b983bd04..0f3bb222686 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -849,7 +849,10 @@ def edge_to_executorch_passes( # there exists an unbacked symint operation. *config.passes, SpecPropPass(), - PropagateDevicePass(), + PropagateDevicePass( + skip_h2d_for_method_inputs=config.skip_h2d_for_method_inputs, + skip_d2h_for_method_outputs=config.skip_d2h_for_method_outputs, + ), EdgeToBackendOpsPass(), RemoveGraphAssertsPass(), ] + pre_memory_planning_passes(config, name) diff --git a/exir/tests/test_propagate_device_pass.py b/exir/tests/test_propagate_device_pass.py index 8bb2fa1ab42..696c339344b 100644 --- a/exir/tests/test_propagate_device_pass.py +++ b/exir/tests/test_propagate_device_pass.py @@ -7,7 +7,7 @@ import operator import unittest from copy import deepcopy -from typing import Dict, final, List +from typing import Dict, final, List, Optional # Import to register et_copy ops import executorch.exir.passes._device_copy_ops_registry # noqa: F401 @@ -109,18 +109,21 @@ def _lower_model_to_executorch( model: torch.nn.Module, inputs: tuple, partitioner: Partitioner, + et_config: Optional[ExecutorchBackendConfig] = None, ) -> List: """Lower model all the way through to_executorch for E2E tests.""" + if et_config is None: + et_config = ExecutorchBackendConfig(emit_stacktrace=False) ep = export(model, inputs) ep_copied = deepcopy(ep) edge_1 = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=False)) lowered_1 = edge_1.to_backend(partitioner) - et_1 = lowered_1.to_executorch(ExecutorchBackendConfig(emit_stacktrace=False)) + et_1 = lowered_1.to_executorch(deepcopy(et_config)) gm_1 = et_1.exported_program().graph_module edge_2 = to_edge_transform_and_lower(ep_copied, partitioner=[partitioner]) - et_2 = edge_2.to_executorch(ExecutorchBackendConfig(emit_stacktrace=False)) + et_2 = edge_2.to_executorch(deepcopy(et_config)) gm_2 = et_2.exported_program().graph_module return [ @@ -167,6 +170,119 @@ def _assert_specs_device( if expected_index is not None: self.assertEqual(s.device_index, expected_index) + def _assert_buffer_device( + self, + spec: TensorSpec, + program, + expected_device: DeviceType, + msg: str, + ) -> None: + """Assert the emitted program maps the spec's buffer to the expected device. + + The memory planner assigns each TensorSpec a ``mem_id`` (buffer index). + When ``enable_non_cpu_memory_planning`` is True, non-CPU buffers get an + entry in ``execution_plan[0].non_const_buffer_device``. CPU buffers have + no explicit entry (CPU is the default). + """ + plan = program.execution_plan[0] + mem_id = spec.mem_id + self.assertIsNotNone(mem_id, f"{msg}: spec.mem_id should not be None") + + if expected_device == DeviceType.CPU: + # CPU buffers have no explicit entry in non_const_buffer_device. + if plan.non_const_buffer_device is not None: + for entry in plan.non_const_buffer_device: + self.assertNotEqual( + entry.buffer_idx, + mem_id, + f"{msg}: buffer {mem_id} should be CPU but found " + f"in non_const_buffer_device as {entry.device_type.name}", + ) + else: + self.assertIsNotNone( + plan.non_const_buffer_device, + f"{msg}: non_const_buffer_device should exist for non-CPU buffers", + ) + matching = [ + e for e in plan.non_const_buffer_device if e.buffer_idx == mem_id + ] + self.assertEqual( + len(matching), + 1, + f"{msg}: expected exactly one entry for buffer {mem_id} " + f"in non_const_buffer_device, got {len(matching)}", + ) + self.assertEqual( + matching[0].device_type, + expected_device, + f"{msg}: buffer {mem_id} device type mismatch", + ) + + @staticmethod + def _collect_copy_nodes(gm): + """Classify call_function nodes into H2D, D2H, delegate, and getitem lists.""" + h2d, d2h, delegate, getitem = [], [], [], [] + for node in gm.graph.nodes: + if node.op != "call_function": + continue + if node.target == torch.ops.et_copy._h2d_copy.out: + h2d.append(node) + elif node.target == torch.ops.et_copy._d2h_copy.out: + d2h.append(node) + elif node.target == executorch_call_delegate: + delegate.append(node) + elif node.target == operator.getitem: + getitem.append(node) + return {"h2d": h2d, "d2h": d2h, "delegate": delegate, "getitem": getitem} + + @staticmethod + def _collect_placeholders_by_device(gm): + """Partition placeholder nodes by device type. Returns (cuda_list, cpu_list).""" + cuda, cpu = [], [] + for node in gm.graph.nodes: + if node.op != "placeholder": + continue + spec = node.meta.get("spec") + if isinstance(spec, TensorSpec) and spec.device == DeviceType.CUDA: + cuda.append(node) + elif isinstance(spec, TensorSpec): + cpu.append(node) + return cuda, cpu + + def _collect_delegate_getitems(self, gm): + """Return list of getitem nodes extracting from delegate calls.""" + return [n for n in gm.graph.nodes if self._is_delegate_getitem(n)] + + def _assert_nodes_device( + self, nodes, expected_device, pipeline, label, expected_index=None + ): + """Assert every node's TensorSpec has the expected device.""" + for node in nodes: + spec = node.meta.get("spec") + if isinstance(spec, TensorSpec): + self.assertEqual( + spec.device, + expected_device, + f"[{pipeline}] {label} '{node.name}' should have " + f"{expected_device.name} device spec", + ) + if expected_index is not None: + self.assertEqual(spec.device_index, expected_index) + + def _assert_nodes_buffer_device( + self, nodes, program, expected_device, pipeline, label + ): + """Assert each node's buffer is mapped to the expected device.""" + for node in nodes: + spec = node.meta.get("spec") + if isinstance(spec, TensorSpec): + self._assert_buffer_device( + spec, + program, + expected_device, + f"[{pipeline}] {label} '{node.name}' buffer", + ) + # ---- Integration tests: copy nodes after to_executorch ---- def test_h2d_d2h_nodes_inserted(self): @@ -185,22 +301,11 @@ def forward(self, a, b): model, inputs, DeviceAwarePartitioner("cuda:0") ): with self.subTest(pipeline=pipeline): - h2d_nodes = [] - d2h_nodes = [] - delegate_nodes = [] - getitem_nodes = [] - - for node in gm.graph.nodes: - if node.op != "call_function": - continue - if node.target == torch.ops.et_copy._h2d_copy.out: - h2d_nodes.append(node) - elif node.target == torch.ops.et_copy._d2h_copy.out: - d2h_nodes.append(node) - elif node.target == executorch_call_delegate: - delegate_nodes.append(node) - elif node.target == operator.getitem: - getitem_nodes.append(node) + nodes = self._collect_copy_nodes(gm) + h2d_nodes = nodes["h2d"] + d2h_nodes = nodes["d2h"] + delegate_nodes = nodes["delegate"] + getitem_nodes = nodes["getitem"] # Model has 2 inputs, 1 output → 2 H2D, 1 D2H self.assertEqual( @@ -253,16 +358,9 @@ def forward(self, a, b): model, inputs, DeviceAwarePartitioner("cuda:0") ): with self.subTest(pipeline=pipeline): - h2d_nodes = [] - d2h_nodes = [] - - for node in gm.graph.nodes: - if node.op != "call_function": - continue - if node.target == torch.ops.et_copy._h2d_copy.out: - h2d_nodes.append(node) - elif node.target == torch.ops.et_copy._d2h_copy.out: - d2h_nodes.append(node) + nodes = self._collect_copy_nodes(gm) + h2d_nodes = nodes["h2d"] + d2h_nodes = nodes["d2h"] self.assertGreater( len(h2d_nodes), @@ -337,7 +435,6 @@ def forward(self, a, b): # ---- Integration tests: device consistency after to_executorch ---- - def test_device_consistency_cuda_1(self): """Verify device tags are correct with cuda:1 after to_executorch() to verify device_index propagation through the full pipeline.""" @@ -510,10 +607,11 @@ def __init__(self, specs): # ---- End-to-end tests: verify device info survives to_executorch ---- - def _get_executorch_program(self, model, inputs, partitioner): + def _get_executorch_program(self, model, inputs, partitioner, et_config=None): """Run the full pipeline and return (emitted_program, graph_module) pairs for both export pipelines.""" - from executorch.exir.capture._config import ExecutorchBackendConfig + if et_config is None: + et_config = ExecutorchBackendConfig(emit_stacktrace=False) ep = export(model, inputs) ep_copied = deepcopy(ep) @@ -521,13 +619,13 @@ def _get_executorch_program(self, model, inputs, partitioner): # Pipeline 1: to_edge → to_backend → to_executorch edge_1 = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=False)) lowered_1 = edge_1.to_backend(partitioner) - et_1 = lowered_1.to_executorch(ExecutorchBackendConfig(emit_stacktrace=False)) + et_1 = lowered_1.to_executorch(deepcopy(et_config)) program_1 = et_1._emitter_output.program gm_1 = et_1.exported_program().graph_module # Pipeline 2: to_edge_transform_and_lower → to_executorch edge_2 = to_edge_transform_and_lower(ep_copied, partitioner=[partitioner]) - et_2 = edge_2.to_executorch(ExecutorchBackendConfig(emit_stacktrace=False)) + et_2 = edge_2.to_executorch(deepcopy(et_config)) program_2 = et_2._emitter_output.program gm_2 = et_2.exported_program().graph_module @@ -614,6 +712,259 @@ def forward(self, a, b): ): continue + # ---- Skip-copy optimization tests ---- + + def test_skip_h2d_for_method_inputs(self): + """When skip_h2d_for_method_inputs=True, placeholder inputs feeding + directly into a device delegate should NOT get _h2d_copy nodes.""" + + class Model(torch.nn.Module): + def forward(self, a, b): + return torch.add(a, b) + + model = Model() + inputs = (torch.randn(2, 2), torch.randn(2, 2)) + et_config = ExecutorchBackendConfig( + emit_stacktrace=False, + skip_h2d_for_method_inputs=True, + enable_non_cpu_memory_planning=True, + ) + + for pipeline, program, gm in self._get_executorch_program( + model, inputs, DeviceAwarePartitioner("cuda:0"), et_config + ): + with self.subTest(pipeline=pipeline): + nodes = self._collect_copy_nodes(gm) + self.assertEqual( + len(nodes["h2d"]), + 0, + f"[{pipeline}] Expected no H2D copy nodes when " + f"skip_h2d_for_method_inputs=True, got {len(nodes['h2d'])}", + ) + self.assertEqual( + len(nodes["d2h"]), + 1, + f"[{pipeline}] Expected 1 D2H copy node for the single " + f"output, got {len(nodes['d2h'])}", + ) + + # Placeholder inputs should be tagged as CUDA since H2D was + # skipped and the pass sets their spec to the target device. + cuda_ph, cpu_ph = self._collect_placeholders_by_device(gm) + self.assertEqual(len(cpu_ph), 0) + self._assert_nodes_device( + cuda_ph, + DeviceType.CUDA, + pipeline, + "Placeholder", + expected_index=0, + ) + + # Verify buffer device mapping: CUDA placeholders should + # have their memory planned on a CUDA buffer. + self._assert_nodes_buffer_device( + cuda_ph, + program, + DeviceType.CUDA, + pipeline, + "Placeholder", + ) + + def test_skip_d2h_for_method_outputs(self): + """When skip_d2h_for_method_outputs=True, delegate outputs that feed + directly to the graph output should NOT get _d2h_copy nodes.""" + + class Model(torch.nn.Module): + def forward(self, a, b): + return torch.add(a, b) + + model = Model() + inputs = (torch.randn(2, 2), torch.randn(2, 2)) + et_config = ExecutorchBackendConfig( + emit_stacktrace=False, + skip_d2h_for_method_outputs=True, + enable_non_cpu_memory_planning=True, + ) + + for pipeline, program, gm in self._get_executorch_program( + model, inputs, DeviceAwarePartitioner("cuda:0"), et_config + ): + with self.subTest(pipeline=pipeline): + nodes = self._collect_copy_nodes(gm) + self.assertEqual( + len(nodes["d2h"]), + 0, + f"[{pipeline}] Expected no D2H copy nodes when " + f"skip_d2h_for_method_outputs=True, got {len(nodes['d2h'])}", + ) + self.assertEqual( + len(nodes["h2d"]), + 2, + f"[{pipeline}] Expected 2 H2D copy nodes for the two " + f"inputs, got {len(nodes['h2d'])}", + ) + + # Delegate getitem nodes feeding to output should stay on + # CUDA since D2H was skipped. + getitems = self._collect_delegate_getitems(gm) + self._assert_nodes_device( + getitems, + DeviceType.CUDA, + pipeline, + "Delegate getitem", + ) + + # Verify buffer device mapping: CUDA getitem outputs should + # have their memory planned on a CUDA buffer. + self._assert_nodes_buffer_device( + getitems, + program, + DeviceType.CUDA, + pipeline, + "Getitem", + ) + + def test_skip_both_h2d_and_d2h(self): + """When both skip flags are True, neither H2D nor D2H copy nodes + should be inserted for a direct input->delegate->output flow.""" + + class Model(torch.nn.Module): + def forward(self, a, b): + return torch.add(a, b) + + model = Model() + inputs = (torch.randn(2, 2), torch.randn(2, 2)) + et_config = ExecutorchBackendConfig( + emit_stacktrace=False, + skip_h2d_for_method_inputs=True, + skip_d2h_for_method_outputs=True, + enable_non_cpu_memory_planning=True, + ) + + for pipeline, program, gm in self._get_executorch_program( + model, inputs, DeviceAwarePartitioner("cuda:0"), et_config + ): + with self.subTest(pipeline=pipeline): + nodes = self._collect_copy_nodes(gm) + self.assertEqual( + len(nodes["h2d"]), + 0, + f"[{pipeline}] Expected no H2D copy nodes when " + f"skip_h2d_for_method_inputs=True, got {len(nodes['h2d'])}", + ) + self.assertEqual( + len(nodes["d2h"]), + 0, + f"[{pipeline}] Expected no D2H copy nodes when " + f"skip_d2h_for_method_outputs=True, got {len(nodes['d2h'])}", + ) + + # Placeholder inputs should be tagged as CUDA since H2D + # was skipped. + cuda_ph, cpu_ph = self._collect_placeholders_by_device(gm) + self.assertEqual(len(cpu_ph), 0) + self._assert_nodes_device( + cuda_ph, + DeviceType.CUDA, + pipeline, + "Placeholder", + expected_index=0, + ) + + # Delegate getitem outputs should stay on CUDA since D2H + # was skipped. + getitems = self._collect_delegate_getitems(gm) + self._assert_nodes_device( + getitems, + DeviceType.CUDA, + pipeline, + "Delegate getitem", + ) + + # Verify buffer device mapping: both input and output + # buffers should be on CUDA. + self._assert_nodes_buffer_device( + cuda_ph, + program, + DeviceType.CUDA, + pipeline, + "Placeholder", + ) + self._assert_nodes_buffer_device( + getitems, + program, + DeviceType.CUDA, + pipeline, + "Getitem", + ) + + def test_skip_h2d_partial_with_intermediate_input(self): + """When skip_h2d_for_method_inputs=True, only placeholder inputs + skip H2D copies. An intermediate (non-placeholder) input feeding + into the delegate should still get an _h2d_copy node.""" + + class Model(torch.nn.Module): + def forward(self, a, b): + c = torch.sin(a) + return torch.add(c, b) + + model = Model() + inputs = (torch.randn(2, 2), torch.randn(2, 2)) + et_config = ExecutorchBackendConfig( + emit_stacktrace=False, + skip_h2d_for_method_inputs=True, + enable_non_cpu_memory_planning=True, + ) + + for pipeline, program, gm in self._get_executorch_program( + model, inputs, DeviceAwarePartitioner("cuda:0"), et_config + ): + with self.subTest(pipeline=pipeline): + # sin(a) is intermediate (not a placeholder), so it still + # gets an H2D copy. Placeholder b is skipped. + nodes = self._collect_copy_nodes(gm) + self.assertEqual( + len(nodes["h2d"]), + 1, + f"[{pipeline}] Expected 1 H2D copy node for the " + f"intermediate input, got {len(nodes['h2d'])}", + ) + self.assertEqual( + len(nodes["d2h"]), + 1, + f"[{pipeline}] Expected 1 D2H copy node for the single " + f"output, got {len(nodes['d2h'])}", + ) + + # Exactly 1 placeholder should be on CUDA (b, which feeds + # directly into the delegate and skips H2D). The other + # placeholder (a) feeds through sin() so it stays CPU. + cuda_ph, cpu_ph = self._collect_placeholders_by_device(gm) + self.assertEqual( + len(cuda_ph), + 1, + f"[{pipeline}] Expected exactly 1 placeholder with CUDA " + f"device spec, got {len(cuda_ph)}", + ) + + # Verify buffer device mapping: the CUDA placeholder's + # buffer should be on CUDA, the CPU placeholder's buffer + # should be on CPU. + self._assert_nodes_buffer_device( + cuda_ph, + program, + DeviceType.CUDA, + pipeline, + "CUDA placeholder", + ) + self._assert_nodes_buffer_device( + cpu_ph, + program, + DeviceType.CPU, + pipeline, + "CPU placeholder", + ) + def test_tensorspec_repr_includes_device(self): spec = TensorSpec(dtype=torch.float32, shape=torch.Size([2, 3])) repr_str = repr(spec)