Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion backends/cuda/cuda_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from executorch.backends.cuda.cuda_backend import CudaBackend # usort: skip
from executorch.exir._warnings import experimental
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.passes.propagate_device_pass import TARGET_DEVICE_COMPILE_SPEC_KEY


@final
Expand All @@ -19,7 +20,34 @@
class CudaPartitioner(AotiPartitioner):
"""
CUDA partitioner driven by AOTInductor backend.

This partitioner adds a target_device compile spec to enable device info
propagation. The PropagateDevicePass will read this spec and mark delegate
output tensors with CUDA device type, which flows through to serialization.
"""

def __init__(self, compile_spec: List[CompileSpec]) -> None:
def __init__(
self,
compile_spec: List[CompileSpec],
) -> None:
"""
Initialize the CUDA partitioner.

Args:
compile_spec: List of compile specs for the backend. To specify a
target CUDA device, include a CompileSpec with key
"target_device" (e.g., value "cuda:1"). If not
provided, defaults to "cuda:0".
"""
# Add target_device compile spec for device propagation if not already present
has_target_device = any(
spec.key == TARGET_DEVICE_COMPILE_SPEC_KEY for spec in compile_spec
)
if not has_target_device:
compile_spec = list(compile_spec) + [
CompileSpec(
TARGET_DEVICE_COMPILE_SPEC_KEY,
b"cuda:0",
)
]
super().__init__(CudaBackend.__name__, compile_spec)
20 changes: 20 additions & 0 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,26 @@ class ET_EXPERIMENTAL CudaBackend final
n_outputs,
args.size())

// Verify device info on all memory-planned, ET-driven IO tensors.
// All input and output tensors should have device_type = CUDA, which
// is set during serialization by PropagateDevicePass based on the
// target_device compile spec from CudaPartitioner.
//
// Note: At this stage, the tensor memory is still on CPU. The device_type
// is metadata indicating where the tensor *should* reside. The backend
// is responsible for copying data to the actual CUDA device.
for (size_t i = 0; i < n_inputs + n_outputs; i++) {
auto* tensor = &(args[i]->toTensor());
auto device_type = tensor->unsafeGetTensorImpl()->device_type();
ET_CHECK_OR_RETURN_ERROR(
device_type == executorch::runtime::etensor::DeviceType::CUDA,
InvalidArgument,
"Tensor %zu expected device_type=CUDA (1), got %d. "
"Device info may not be properly propagated from CudaPartitioner.",
i,
static_cast<int>(device_type));
}

// NOTE: ExecuTorch tensors may be on CPU or GPU due to the skip-copy
// optimization. We need to create GPU copies for CUDA kernel execution
// using SlimTensor.
Expand Down
69 changes: 69 additions & 0 deletions backends/cuda/tests/test_cuda_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,72 @@ def test_triton_kernel_mode_off(self):
edge_program_manager,
"SDPA kernel export with triton_kernel_mode=OFF failed",
)

def test_device_info_propagated_to_cuda_delegate_outputs(self):
"""
Test that device info is correctly propagated from export to serialization
for CUDA delegate outputs.

This verifies the device propagation flow:
1. CudaPartitioner adds target_device="cuda:0" CompileSpec
2. PropagateDevicePass sets TensorSpec.device = CUDA for delegate outputs
3. Emitter serializes device info into ExtraTensorInfo.device_type
4. Serialized tensors have device_type = DeviceType.CUDA

Note: At this stage, the tensor memory is still on CPU. The CUDA backend
will copy data to GPU device at runtime. Device info tagging is the first
step toward full device-aware memory allocation.
"""
from executorch.exir import schema

class AddModule(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y

module = AddModule()
module.eval()
inputs = (torch.randn(2, 3), torch.randn(2, 3))

# Export to CUDA with full pipeline
edge_program_manager = self._export_to_cuda_with_lower(module, inputs)
self.assertIsNotNone(edge_program_manager, "CUDA export failed")

# Convert to ExecutorTorch and access the serialized program
et_prog = edge_program_manager.to_executorch()
program = et_prog._emitter_output.program

# Get the execution plan and verify delegate exists
plan = program.execution_plan[0]
self.assertGreater(
len(plan.delegates),
0,
"Expected at least one delegate in the execution plan",
)

# Count tensors by device type
cpu_tensors = []
cuda_tensors = []

for value in plan.values:
if isinstance(value.val, schema.Tensor):
tensor = value.val
if (
tensor.extra_tensor_info is not None
and tensor.extra_tensor_info.device_type == schema.DeviceType.CUDA
):
cuda_tensors.append(tensor)
else:
# Either no extra_tensor_info or device_type is CPU (default)
cpu_tensors.append(tensor)

# Both input and output tensors should be on CUDA device for now.
self.assertEqual(
len(cpu_tensors),
0,
"All tensors are on CUDA device..",
)
self.assertGreater(
len(cuda_tensors),
3,
"Expected CUDA tensors for delegate outputs",
)
Loading