Skip to content

Commit a807e51

Browse files
committed
[ET Device Support] Emitter reads non_const_buffer_device from graph meta
Enable serialzing non_const_buffer_device into into PTE file. Differential Revision: [D97850707](https://our.internmc.facebook.com/intern/diff/D97850707/) ghstack-source-id: 357060887 Pull Request resolved: #18472
1 parent b2022bc commit a807e51

2 files changed

Lines changed: 188 additions & 0 deletions

File tree

exir/emit/_emitter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2073,4 +2073,9 @@ def plan(self) -> ExecutionPlan:
20732073
self.module.meta["non_const_buffer_sizes"],
20742074
),
20752075
container_meta_type=self.container_meta_type,
2076+
# non_const_buffer_device is set by apply_algo in memory_planning.py
2077+
# when device tensors are present. None for CPU-only programs.
2078+
non_const_buffer_device=self.module.meta.get(
2079+
"non_const_buffer_device", None
2080+
),
20762081
)

exir/emit/test/test_emit.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2643,3 +2643,186 @@ def forward(self, a, b):
26432643
0,
26442644
"No tensor should have CUDA device when model runs entirely on CPU",
26452645
)
2646+
2647+
def test_emit_non_const_buffer_device_populated_for_device_tensors(self) -> None:
2648+
"""Verify that non_const_buffer_device is emitted into ExecutionPlan when
2649+
device-aware memory planning is enabled and non-CPU tensors are present."""
2650+
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
2651+
generate_pattern_op_partitions,
2652+
)
2653+
from executorch.exir.backend.compile_spec_schema import CompileSpec
2654+
from executorch.exir.backend.partitioner import (
2655+
DelegationSpec,
2656+
Partitioner,
2657+
PartitionResult,
2658+
)
2659+
from executorch.exir.backend.test.backend_with_compiler_demo import (
2660+
BackendWithCompilerDemo,
2661+
)
2662+
from executorch.exir.passes.propagate_device_pass import (
2663+
TARGET_DEVICE_COMPILE_SPEC_KEY,
2664+
)
2665+
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
2666+
2667+
class AddSupport(OperatorSupportBase):
2668+
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
2669+
return node.op == "call_function" and node.target in [
2670+
exir_ops.edge.aten.add.Tensor,
2671+
]
2672+
2673+
class DevicePartitioner(Partitioner):
2674+
def __init__(self):
2675+
super().__init__()
2676+
self.delegation_spec = DelegationSpec(
2677+
BackendWithCompilerDemo.__name__,
2678+
[
2679+
CompileSpec("max_value", bytes([4])),
2680+
CompileSpec(TARGET_DEVICE_COMPILE_SPEC_KEY, b"cuda:0"),
2681+
],
2682+
)
2683+
2684+
def partition(self, exported_program) -> PartitionResult:
2685+
partition_tags = {}
2686+
partition_list = generate_pattern_op_partitions(
2687+
exported_program.graph_module,
2688+
op_support=any_chain(AddSupport()),
2689+
)
2690+
for partition in partition_list:
2691+
for node in partition.nodes:
2692+
tag = f"tag{partition.id}"
2693+
node.meta["delegation_tag"] = tag
2694+
partition_tags[tag] = self.delegation_spec
2695+
return PartitionResult(
2696+
tagged_exported_program=exported_program,
2697+
partition_tags=partition_tags,
2698+
)
2699+
2700+
class Model(torch.nn.Module):
2701+
def forward(self, a, b):
2702+
return torch.add(a, b)
2703+
2704+
model = Model()
2705+
inputs = (torch.randn(2, 2), torch.randn(2, 2))
2706+
2707+
edge = to_edge(
2708+
export(model, inputs),
2709+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
2710+
)
2711+
lowered = edge.to_backend(DevicePartitioner())
2712+
et_prog = lowered.to_executorch(
2713+
config=ExecutorchBackendConfig(enable_non_cpu_memory_planning=True),
2714+
)
2715+
program = et_prog._emitter_output.program
2716+
2717+
plan = program.execution_plan[0]
2718+
self.assertIsNotNone(
2719+
plan.non_const_buffer_device,
2720+
"non_const_buffer_device should be set when device tensors are present "
2721+
"and enable_non_cpu_memory_planning is True",
2722+
)
2723+
self.assertGreater(len(plan.non_const_buffer_device), 0)
2724+
for entry in plan.non_const_buffer_device:
2725+
self.assertEqual(entry.device_type, schema.DeviceType.CUDA)
2726+
self.assertEqual(entry.device_index, 0)
2727+
2728+
def test_emit_non_const_buffer_device_none_for_cpu_only(self) -> None:
2729+
"""When all tensors are on CPU, non_const_buffer_device should be None
2730+
even with enable_non_cpu_memory_planning=True."""
2731+
2732+
class Model(torch.nn.Module):
2733+
def forward(self, a, b):
2734+
return torch.add(a, b)
2735+
2736+
model = Model()
2737+
inputs = (torch.randn(2, 2), torch.randn(2, 2))
2738+
2739+
edge = to_edge(
2740+
export(model, inputs),
2741+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
2742+
)
2743+
et_prog = edge.to_executorch(
2744+
config=ExecutorchBackendConfig(enable_non_cpu_memory_planning=True),
2745+
)
2746+
program = et_prog._emitter_output.program
2747+
2748+
plan = program.execution_plan[0]
2749+
self.assertIsNone(
2750+
plan.non_const_buffer_device,
2751+
"non_const_buffer_device should be None for CPU-only programs",
2752+
)
2753+
2754+
def test_emit_non_const_buffer_device_none_when_flag_disabled(self) -> None:
2755+
"""Even with device tensors, non_const_buffer_device should be None when
2756+
enable_non_cpu_memory_planning is False (default)."""
2757+
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
2758+
generate_pattern_op_partitions,
2759+
)
2760+
from executorch.exir.backend.compile_spec_schema import CompileSpec
2761+
from executorch.exir.backend.partitioner import (
2762+
DelegationSpec,
2763+
Partitioner,
2764+
PartitionResult,
2765+
)
2766+
from executorch.exir.backend.test.backend_with_compiler_demo import (
2767+
BackendWithCompilerDemo,
2768+
)
2769+
from executorch.exir.passes.propagate_device_pass import (
2770+
TARGET_DEVICE_COMPILE_SPEC_KEY,
2771+
)
2772+
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
2773+
2774+
class AddSupport(OperatorSupportBase):
2775+
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
2776+
return node.op == "call_function" and node.target in [
2777+
exir_ops.edge.aten.add.Tensor,
2778+
]
2779+
2780+
class DevicePartitioner(Partitioner):
2781+
def __init__(self):
2782+
super().__init__()
2783+
self.delegation_spec = DelegationSpec(
2784+
BackendWithCompilerDemo.__name__,
2785+
[
2786+
CompileSpec("max_value", bytes([4])),
2787+
CompileSpec(TARGET_DEVICE_COMPILE_SPEC_KEY, b"cuda:0"),
2788+
],
2789+
)
2790+
2791+
def partition(self, exported_program) -> PartitionResult:
2792+
partition_tags = {}
2793+
partition_list = generate_pattern_op_partitions(
2794+
exported_program.graph_module,
2795+
op_support=any_chain(AddSupport()),
2796+
)
2797+
for partition in partition_list:
2798+
for node in partition.nodes:
2799+
tag = f"tag{partition.id}"
2800+
node.meta["delegation_tag"] = tag
2801+
partition_tags[tag] = self.delegation_spec
2802+
return PartitionResult(
2803+
tagged_exported_program=exported_program,
2804+
partition_tags=partition_tags,
2805+
)
2806+
2807+
class Model(torch.nn.Module):
2808+
def forward(self, a, b):
2809+
return torch.add(a, b)
2810+
2811+
model = Model()
2812+
inputs = (torch.randn(2, 2), torch.randn(2, 2))
2813+
2814+
edge = to_edge(
2815+
export(model, inputs),
2816+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
2817+
)
2818+
lowered = edge.to_backend(DevicePartitioner())
2819+
# Default: enable_non_cpu_memory_planning=False
2820+
et_prog = lowered.to_executorch()
2821+
program = et_prog._emitter_output.program
2822+
2823+
plan = program.execution_plan[0]
2824+
self.assertIsNone(
2825+
plan.non_const_buffer_device,
2826+
"non_const_buffer_device should be None when "
2827+
"enable_non_cpu_memory_planning is False",
2828+
)

0 commit comments

Comments
 (0)