@@ -2643,3 +2643,186 @@ def forward(self, a, b):
26432643 0 ,
26442644 "No tensor should have CUDA device when model runs entirely on CPU" ,
26452645 )
2646+
2647+ def test_emit_non_const_buffer_device_populated_for_device_tensors (self ) -> None :
2648+ """Verify that non_const_buffer_device is emitted into ExecutionPlan when
2649+ device-aware memory planning is enabled and non-CPU tensors are present."""
2650+ from executorch .exir .backend .canonical_partitioners .pattern_op_partitioner import (
2651+ generate_pattern_op_partitions ,
2652+ )
2653+ from executorch .exir .backend .compile_spec_schema import CompileSpec
2654+ from executorch .exir .backend .partitioner import (
2655+ DelegationSpec ,
2656+ Partitioner ,
2657+ PartitionResult ,
2658+ )
2659+ from executorch .exir .backend .test .backend_with_compiler_demo import (
2660+ BackendWithCompilerDemo ,
2661+ )
2662+ from executorch .exir .passes .propagate_device_pass import (
2663+ TARGET_DEVICE_COMPILE_SPEC_KEY ,
2664+ )
2665+ from torch .fx .passes .operator_support import any_chain , OperatorSupportBase
2666+
2667+ class AddSupport (OperatorSupportBase ):
2668+ def is_node_supported (self , submodules , node : torch .fx .Node ) -> bool :
2669+ return node .op == "call_function" and node .target in [
2670+ exir_ops .edge .aten .add .Tensor ,
2671+ ]
2672+
2673+ class DevicePartitioner (Partitioner ):
2674+ def __init__ (self ):
2675+ super ().__init__ ()
2676+ self .delegation_spec = DelegationSpec (
2677+ BackendWithCompilerDemo .__name__ ,
2678+ [
2679+ CompileSpec ("max_value" , bytes ([4 ])),
2680+ CompileSpec (TARGET_DEVICE_COMPILE_SPEC_KEY , b"cuda:0" ),
2681+ ],
2682+ )
2683+
2684+ def partition (self , exported_program ) -> PartitionResult :
2685+ partition_tags = {}
2686+ partition_list = generate_pattern_op_partitions (
2687+ exported_program .graph_module ,
2688+ op_support = any_chain (AddSupport ()),
2689+ )
2690+ for partition in partition_list :
2691+ for node in partition .nodes :
2692+ tag = f"tag{ partition .id } "
2693+ node .meta ["delegation_tag" ] = tag
2694+ partition_tags [tag ] = self .delegation_spec
2695+ return PartitionResult (
2696+ tagged_exported_program = exported_program ,
2697+ partition_tags = partition_tags ,
2698+ )
2699+
2700+ class Model (torch .nn .Module ):
2701+ def forward (self , a , b ):
2702+ return torch .add (a , b )
2703+
2704+ model = Model ()
2705+ inputs = (torch .randn (2 , 2 ), torch .randn (2 , 2 ))
2706+
2707+ edge = to_edge (
2708+ export (model , inputs ),
2709+ compile_config = EdgeCompileConfig (_check_ir_validity = False ),
2710+ )
2711+ lowered = edge .to_backend (DevicePartitioner ())
2712+ et_prog = lowered .to_executorch (
2713+ config = ExecutorchBackendConfig (enable_non_cpu_memory_planning = True ),
2714+ )
2715+ program = et_prog ._emitter_output .program
2716+
2717+ plan = program .execution_plan [0 ]
2718+ self .assertIsNotNone (
2719+ plan .non_const_buffer_device ,
2720+ "non_const_buffer_device should be set when device tensors are present "
2721+ "and enable_non_cpu_memory_planning is True" ,
2722+ )
2723+ self .assertGreater (len (plan .non_const_buffer_device ), 0 )
2724+ for entry in plan .non_const_buffer_device :
2725+ self .assertEqual (entry .device_type , schema .DeviceType .CUDA )
2726+ self .assertEqual (entry .device_index , 0 )
2727+
2728+ def test_emit_non_const_buffer_device_none_for_cpu_only (self ) -> None :
2729+ """When all tensors are on CPU, non_const_buffer_device should be None
2730+ even with enable_non_cpu_memory_planning=True."""
2731+
2732+ class Model (torch .nn .Module ):
2733+ def forward (self , a , b ):
2734+ return torch .add (a , b )
2735+
2736+ model = Model ()
2737+ inputs = (torch .randn (2 , 2 ), torch .randn (2 , 2 ))
2738+
2739+ edge = to_edge (
2740+ export (model , inputs ),
2741+ compile_config = EdgeCompileConfig (_check_ir_validity = False ),
2742+ )
2743+ et_prog = edge .to_executorch (
2744+ config = ExecutorchBackendConfig (enable_non_cpu_memory_planning = True ),
2745+ )
2746+ program = et_prog ._emitter_output .program
2747+
2748+ plan = program .execution_plan [0 ]
2749+ self .assertIsNone (
2750+ plan .non_const_buffer_device ,
2751+ "non_const_buffer_device should be None for CPU-only programs" ,
2752+ )
2753+
2754+ def test_emit_non_const_buffer_device_none_when_flag_disabled (self ) -> None :
2755+ """Even with device tensors, non_const_buffer_device should be None when
2756+ enable_non_cpu_memory_planning is False (default)."""
2757+ from executorch .exir .backend .canonical_partitioners .pattern_op_partitioner import (
2758+ generate_pattern_op_partitions ,
2759+ )
2760+ from executorch .exir .backend .compile_spec_schema import CompileSpec
2761+ from executorch .exir .backend .partitioner import (
2762+ DelegationSpec ,
2763+ Partitioner ,
2764+ PartitionResult ,
2765+ )
2766+ from executorch .exir .backend .test .backend_with_compiler_demo import (
2767+ BackendWithCompilerDemo ,
2768+ )
2769+ from executorch .exir .passes .propagate_device_pass import (
2770+ TARGET_DEVICE_COMPILE_SPEC_KEY ,
2771+ )
2772+ from torch .fx .passes .operator_support import any_chain , OperatorSupportBase
2773+
2774+ class AddSupport (OperatorSupportBase ):
2775+ def is_node_supported (self , submodules , node : torch .fx .Node ) -> bool :
2776+ return node .op == "call_function" and node .target in [
2777+ exir_ops .edge .aten .add .Tensor ,
2778+ ]
2779+
2780+ class DevicePartitioner (Partitioner ):
2781+ def __init__ (self ):
2782+ super ().__init__ ()
2783+ self .delegation_spec = DelegationSpec (
2784+ BackendWithCompilerDemo .__name__ ,
2785+ [
2786+ CompileSpec ("max_value" , bytes ([4 ])),
2787+ CompileSpec (TARGET_DEVICE_COMPILE_SPEC_KEY , b"cuda:0" ),
2788+ ],
2789+ )
2790+
2791+ def partition (self , exported_program ) -> PartitionResult :
2792+ partition_tags = {}
2793+ partition_list = generate_pattern_op_partitions (
2794+ exported_program .graph_module ,
2795+ op_support = any_chain (AddSupport ()),
2796+ )
2797+ for partition in partition_list :
2798+ for node in partition .nodes :
2799+ tag = f"tag{ partition .id } "
2800+ node .meta ["delegation_tag" ] = tag
2801+ partition_tags [tag ] = self .delegation_spec
2802+ return PartitionResult (
2803+ tagged_exported_program = exported_program ,
2804+ partition_tags = partition_tags ,
2805+ )
2806+
2807+ class Model (torch .nn .Module ):
2808+ def forward (self , a , b ):
2809+ return torch .add (a , b )
2810+
2811+ model = Model ()
2812+ inputs = (torch .randn (2 , 2 ), torch .randn (2 , 2 ))
2813+
2814+ edge = to_edge (
2815+ export (model , inputs ),
2816+ compile_config = EdgeCompileConfig (_check_ir_validity = False ),
2817+ )
2818+ lowered = edge .to_backend (DevicePartitioner ())
2819+ # Default: enable_non_cpu_memory_planning=False
2820+ et_prog = lowered .to_executorch ()
2821+ program = et_prog ._emitter_output .program
2822+
2823+ plan = program .execution_plan [0 ]
2824+ self .assertIsNone (
2825+ plan .non_const_buffer_device ,
2826+ "non_const_buffer_device should be None when "
2827+ "enable_non_cpu_memory_planning is False" ,
2828+ )
0 commit comments