Skip to content

Commit 0b31c7f

Browse files
committed
[ET Device Support] Propagate device metadata from partitioner result onto TensorSpecs
Pull Request resolved: #18078 Annotate the delegate's input and output tensors as specific device type The overall pipeline is: a. Partitioner use `compile_spec` to determine which device the partitoned blob is runing on b. after lowered partitioned graph to backend, the new-introed propagate_device_pass will annotate the input and output tensors of delegate blob as target device with correct device index. ghstack-source-id: 363318415 @exported-using-ghexport Differential Revision: [D95842511](https://our.internmc.facebook.com/intern/diff/D95842511/)
1 parent 19bbeac commit 0b31c7f

8 files changed

Lines changed: 691 additions & 0 deletions

File tree

exir/passes/BUCK

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,3 +439,17 @@ fbcode_target(_kind = runtime.python_library,
439439
"//caffe2:torch",
440440
],
441441
)
442+
443+
fbcode_target(_kind = runtime.python_library,
444+
name = "propagate_device_pass",
445+
srcs = [
446+
"propagate_device_pass.py",
447+
],
448+
deps = [
449+
"//caffe2:torch",
450+
"//executorch/exir:delegate",
451+
"//executorch/exir:lowered_backend_module",
452+
"//executorch/exir:schema",
453+
"//executorch/exir:tensor",
454+
],
455+
)
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import logging
10+
from typing import Optional
11+
12+
import executorch.exir.schema as schema
13+
14+
import torch
15+
from executorch.exir.delegate import executorch_call_delegate
16+
from executorch.exir.lowered_backend_module import LoweredBackendModule
17+
from executorch.exir.tensor import TensorSpec
18+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
19+
20+
logger: logging.Logger = logging.getLogger(__name__)
21+
22+
# CompileSpec key convention for specifying the target device.
23+
# Partitioners that target a specific device should include a CompileSpec entry
24+
# with this key and a value encoding the device string (e.g., b"cuda:0").
25+
TARGET_DEVICE_COMPILE_SPEC_KEY = "target_device"
26+
27+
28+
def _parse_device_spec_value(value: bytes) -> tuple[schema.DeviceType, int]:
29+
"""
30+
Parse a target_device CompileSpec value (e.g., b"cuda:0") into
31+
(DeviceType, device_index).
32+
33+
The type portion is matched case-insensitively against schema.DeviceType
34+
member names (e.g., "cpu", "cuda"). Raises ValueError for unknown types.
35+
"""
36+
device_str = value.decode("utf-8").strip().lower()
37+
if ":" in device_str:
38+
type_str, index_str = device_str.split(":", 1)
39+
device_index = int(index_str)
40+
else:
41+
type_str = device_str
42+
device_index = 0
43+
device_type = next(
44+
(dt for dt in schema.DeviceType if dt.name.lower() == type_str),
45+
None,
46+
)
47+
if device_type is None:
48+
valid = ", ".join(dt.name for dt in schema.DeviceType)
49+
raise ValueError(f"Unknown device type '{type_str}'. Valid types: {valid}")
50+
return device_type, device_index
51+
52+
53+
def _get_lowered_module(
54+
graph_module: torch.fx.GraphModule,
55+
delegate_call_node: torch.fx.Node,
56+
) -> Optional[LoweredBackendModule]:
57+
"""
58+
Given an executorch_call_delegate node, retrieve the associated
59+
LoweredBackendModule from the graph module.
60+
The first argument to executorch_call_delegate is a get_attr node
61+
whose target names the LoweredBackendModule attribute.
62+
"""
63+
if len(delegate_call_node.args) < 1:
64+
return None
65+
lowered_node = delegate_call_node.args[0]
66+
if not isinstance(lowered_node, torch.fx.Node) or lowered_node.op != "get_attr":
67+
return None
68+
lowered_module = getattr(graph_module, lowered_node.target, None)
69+
if isinstance(lowered_module, LoweredBackendModule):
70+
return lowered_module
71+
return None
72+
73+
74+
def _get_target_device_from_compile_specs(
75+
lowered_module: LoweredBackendModule,
76+
) -> Optional[tuple[schema.DeviceType, int]]:
77+
"""
78+
Look for a CompileSpec with key TARGET_DEVICE_COMPILE_SPEC_KEY and return
79+
the corresponding (DeviceType, device_index), or None if not found.
80+
"""
81+
for spec in lowered_module.compile_specs:
82+
if spec.key == TARGET_DEVICE_COMPILE_SPEC_KEY:
83+
return _parse_device_spec_value(spec.value)
84+
return None
85+
86+
87+
def _set_device_on_spec(
88+
spec: TensorSpec,
89+
device_type: schema.DeviceType,
90+
device_index: int = 0,
91+
) -> None:
92+
"""Set the device attribute on a TensorSpec."""
93+
spec.device = device_type
94+
spec.device_index = device_index
95+
96+
97+
def _tag_specs_with_device(
98+
specs: object,
99+
device_type: schema.DeviceType,
100+
device_index: int = 0,
101+
) -> bool:
102+
"""Apply device annotation to a TensorSpec or a collection of TensorSpecs.
103+
104+
Args:
105+
specs: A TensorSpec, a tuple/list of TensorSpecs, or None.
106+
device_type: The target device type to set.
107+
device_index: The device index (e.g., 0 for cuda:0, 1 for cuda:1).
108+
109+
Returns:
110+
True if any spec was modified, False otherwise.
111+
"""
112+
if specs is None:
113+
return False
114+
if isinstance(specs, TensorSpec):
115+
_set_device_on_spec(specs, device_type, device_index)
116+
return True
117+
if isinstance(specs, (tuple, list)):
118+
changed = False
119+
for s in specs:
120+
if isinstance(s, TensorSpec):
121+
_set_device_on_spec(s, device_type, device_index)
122+
changed = True
123+
return changed
124+
return False
125+
126+
127+
class PropagateDevicePass(PassBase):
128+
"""
129+
After to_backend, walk the graph and set device metadata on TensorSpecs
130+
based on partitioner-assigned delegation info.
131+
132+
Rules:
133+
1. Delegated nodes: Input and output tensors of a delegate call are marked
134+
with the target device derived from the delegate's CompileSpec
135+
(key="target_device").
136+
2. Non-delegated nodes: Remain on CPU (default).
137+
3. Getitem nodes that extract from a delegate call inherit the device from
138+
the delegate call's output spec at the corresponding index.
139+
"""
140+
141+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
142+
changed = False
143+
for node in graph_module.graph.nodes:
144+
if node.op == "call_function" and node.target == executorch_call_delegate:
145+
lowered_module = _get_lowered_module(graph_module, node)
146+
if lowered_module is None:
147+
continue
148+
149+
result = _get_target_device_from_compile_specs(lowered_module)
150+
if result is None:
151+
continue
152+
153+
target_device_type, device_index = result
154+
155+
# Tag delegate input tensors.
156+
# args[0] is the get_attr node for the lowered module; skip it.
157+
for arg in node.args[1:]:
158+
if isinstance(arg, torch.fx.Node):
159+
changed |= _tag_specs_with_device(
160+
arg.meta.get("spec"),
161+
target_device_type,
162+
device_index,
163+
)
164+
165+
# Tag delegate output tensors.
166+
changed |= _tag_specs_with_device(
167+
node.meta.get("spec"),
168+
target_device_type,
169+
device_index,
170+
)
171+
172+
logger.debug(
173+
"PropagateDevicePass: set device=%s on delegate node %s "
174+
"(backend=%s)",
175+
target_device_type,
176+
node.name,
177+
lowered_module.backend_id,
178+
)
179+
180+
# Second pass: propagate device through getitem nodes that extract
181+
# individual outputs from a delegate call.
182+
for node in graph_module.graph.nodes:
183+
if node.op == "call_function" and node.target.__name__ == "getitem":
184+
source_node = node.args[0]
185+
if (
186+
isinstance(source_node, torch.fx.Node)
187+
and source_node.op == "call_function"
188+
and source_node.target == executorch_call_delegate
189+
):
190+
spec = node.meta.get("spec")
191+
source_specs = source_node.meta.get("spec")
192+
idx = node.args[1]
193+
if (
194+
spec is not None
195+
and isinstance(spec, TensorSpec)
196+
and source_specs is not None
197+
and isinstance(source_specs, (tuple, list))
198+
and isinstance(idx, int)
199+
and idx < len(source_specs)
200+
):
201+
source_spec = source_specs[idx]
202+
if isinstance(source_spec, TensorSpec):
203+
_set_device_on_spec(
204+
spec,
205+
source_spec.device,
206+
source_spec.device_index,
207+
)
208+
changed = True
209+
210+
return PassResult(graph_module, changed)

exir/passes/replace_view_copy_with_view_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None:
110110
"mem_offset",
111111
"dtype", # property
112112
"extra_tensor_info", # property
113+
"device",
114+
"device_index",
113115
]
114116

115117
# Make sure _self_fields and _base_fields are disjoint

exir/program/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ fbcode_target(_kind = runtime.python_library,
4040
"//executorch/exir/passes:insert_write_back_for_buffers_pass",
4141
"//executorch/exir/passes:lib",
4242
"//executorch/exir/passes:normalize_view_copy_base_pass",
43+
"//executorch/exir/passes:propagate_device_pass",
4344
"//executorch/exir/passes:remove_graph_asserts_pass",
4445
"//executorch/exir/passes:remove_mixed_type_operators",
4546
"//executorch/exir/passes:replace_aten_with_edge_pass",

exir/program/_program.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from executorch.exir.passes.normalize_view_copy_base_pass import (
6060
NormalizeViewCopyBasePass,
6161
)
62+
from executorch.exir.passes.propagate_device_pass import PropagateDevicePass
6263
from executorch.exir.passes.quant_fusion_pass import quant_fusion_and_const_prop_pass
6364
from executorch.exir.passes.reinplace import reinplace_pass
6465
from executorch.exir.passes.remove_graph_asserts_pass import (
@@ -848,6 +849,7 @@ def edge_to_executorch_passes(
848849
# there exists an unbacked symint operation.
849850
*config.passes,
850851
SpecPropPass(),
852+
PropagateDevicePass(),
851853
EdgeToBackendOpsPass(),
852854
RemoveGraphAssertsPass(),
853855
] + pre_memory_planning_passes(config, name)

exir/tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ def __init__(
172172
self.init_mem_planning_fields()
173173
self.shape_dynamism: TensorShapeDynamism = determine_tensor_dynanism(self.shape)
174174
self.extra_tensor_info = extra_tensor_info
175+
# device type will be only updated during PropagateDevicePass.
176+
self.device: schema.DeviceType = schema.DeviceType.CPU
177+
self.device_index: int = 0
175178

176179
@property
177180
def allocated_memory(self) -> int:
@@ -254,6 +257,7 @@ def __repr__(self) -> str:
254257
+ f", is_sparse={self.is_sparse}"
255258
+ f", shape_dynamism={self.shape_dynamism}"
256259
+ f", const={self.const}, requires_grad={self.requires_grad}"
260+
+ f", device={self.device.name}:{self.device_index}"
257261
+ ")"
258262
)
259263

exir/tests/TARGETS

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,3 +484,23 @@ python_unittest(
484484
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer",
485485
],
486486
)
487+
488+
python_unittest(
489+
name = "propagate_device_pass",
490+
srcs = [
491+
"test_propagate_device_pass.py",
492+
],
493+
deps = [
494+
"//caffe2:torch",
495+
"//executorch/exir:lib",
496+
"//executorch/exir:schema",
497+
"//executorch/exir:tensor",
498+
"//executorch/exir/backend:backend_api",
499+
"//executorch/exir/backend:compile_spec_schema",
500+
"//executorch/exir/backend:partitioner",
501+
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
502+
"//executorch/exir/backend/test:backend_with_compiler_demo",
503+
"//executorch/exir/dialects:lib",
504+
"//executorch/exir/passes:propagate_device_pass",
505+
],
506+
)

0 commit comments

Comments
 (0)