Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions exir/capture/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,15 @@ class ExecutorchBackendConfig:
# vs. accelerator memory. Default False preserves the legacy behavior
# where all tensors are planned into CPU memory regardless of device.
enable_non_cpu_memory_planning: bool = False

# When True, method-level input tensors that feed directly into a device
# delegate are NOT wrapped with _h2d_copy. The user must provide tensors
# already on the target device. Useful for pipelines where inputs are
# pre-staged on GPU.
skip_h2d_for_method_inputs: bool = False

# When True, device delegate outputs that are directly method outputs
# are NOT wrapped with _d2h_copy. The method outputs stay on device.
# Useful for cross-method GPU pipelines where the next method consumes
# GPU tensors directly.
skip_d2h_for_method_outputs: bool = False
12 changes: 12 additions & 0 deletions exir/passes/propagate_device_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,12 @@ class PropagateDevicePass(PassBase):

def __init__(
self,
skip_h2d_for_method_inputs: bool = False,
skip_d2h_for_method_outputs: bool = False,
) -> None:
super().__init__()
self.skip_h2d_for_method_inputs = skip_h2d_for_method_inputs
self.skip_d2h_for_method_outputs = skip_d2h_for_method_outputs

def _is_placeholder(self, node: torch.fx.Node) -> bool:
"""Check if a node is a graph-level input (placeholder)."""
Expand All @@ -191,6 +195,11 @@ def _insert_h2d_copies(
if not isinstance(arg_spec, TensorSpec):
continue

if self.skip_h2d_for_method_inputs and self._is_placeholder(arg):
_set_device_on_spec(arg_spec, target_device_type, device_index)
changed = True
continue

with graph_module.graph.inserting_before(node):
h2d_node = graph_module.graph.call_function(
torch.ops.et_copy._h2d_copy.default,
Expand Down Expand Up @@ -241,6 +250,9 @@ def _insert_d2h_for_getitem(

_set_device_on_spec(spec, source_spec.device, source_spec.device_index)

if self.skip_d2h_for_method_outputs and self._feeds_directly_to_output(node):
return True

with graph_module.graph.inserting_after(node):
d2h_node = graph_module.graph.call_function(
torch.ops.et_copy._d2h_copy.default,
Expand Down
5 changes: 4 additions & 1 deletion exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,10 @@ def edge_to_executorch_passes(
# there exists an unbacked symint operation.
*config.passes,
SpecPropPass(),
PropagateDevicePass(),
PropagateDevicePass(
skip_h2d_for_method_inputs=config.skip_h2d_for_method_inputs,
skip_d2h_for_method_outputs=config.skip_d2h_for_method_outputs,
),
EdgeToBackendOpsPass(),
RemoveGraphAssertsPass(),
] + pre_memory_planning_passes(config, name)
Expand Down
Loading
Loading