Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions exir/passes/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,14 @@ fbcode_target(_kind = runtime.python_library,
],
)

fbcode_target(_kind = runtime.python_library,
name = "device_copy_ops_registry",
srcs = ["_device_copy_ops_registry.py"],
deps = [
"//caffe2:torch",
],
)

fbcode_target(_kind = runtime.python_library,
name = "memory_format_ops_pass",
srcs = [
Expand Down
58 changes: 58 additions & 0 deletions exir/passes/_device_copy_ops_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Registry for device copy ops used to insert explicit H2D (host-to-device)
and D2H (device-to-host) data transfer operations at delegate boundaries.

These ops are inserted by PropagateDevicePass when enable_non_cpu_memory_planning
is True, making the graph functional by explicitly transferring data between
CPU and device memory.

Follows the same registration pattern as dim_order_ops_registry.py.
"""

import torch
from torch.library import impl, Library

lib = Library("et_copy", "DEF")

# _h2d_copy: copies a CPU tensor to device memory.
# At tracing time, this is a clone (both on CPU). At runtime, the out tensor
# is memory-planned on device, and the kernel calls
# DeviceAllocator::copy_host_to_device.
lib.define("_h2d_copy(Tensor self) -> Tensor")
lib.define("_h2d_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)")

# _d2h_copy: copies a device tensor to CPU memory.
# At tracing time, this is a clone (both on CPU). At runtime, the self tensor
# has device memory, and the kernel calls DeviceAllocator::copy_device_to_host.
lib.define("_d2h_copy(Tensor self) -> Tensor")
lib.define("_d2h_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)")


@impl(lib, "_h2d_copy", "CompositeImplicitAutograd")
def _h2d_copy_impl(self: torch.Tensor) -> torch.Tensor:
# During tracing, both tensors are on CPU. Just clone to represent the transfer.
return self.clone()


@impl(lib, "_h2d_copy.out", "CompositeImplicitAutograd")
def _h2d_copy_out_impl(self: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor:
out.copy_(self)
return out


@impl(lib, "_d2h_copy", "CompositeImplicitAutograd")
def _d2h_copy_impl(self: torch.Tensor) -> torch.Tensor:
# During tracing, both tensors are on CPU. Just clone to represent the transfer.
return self.clone()


@impl(lib, "_d2h_copy.out", "CompositeImplicitAutograd")
def _d2h_copy_out_impl(self: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor:
out.copy_(self)
return out
11 changes: 11 additions & 0 deletions exir/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -504,3 +504,14 @@ python_unittest(
"//executorch/exir/passes:propagate_device_pass",
],
)

python_unittest(
name = "device_copy_ops",
srcs = [
"test_device_copy_ops.py",
],
deps = [
"//caffe2:torch",
"//executorch/exir/passes:device_copy_ops_registry",
],
)
73 changes: 73 additions & 0 deletions exir/tests/test_device_copy_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

# Import the registry to register the ops
import executorch.exir.passes._device_copy_ops_registry # noqa: F401

import torch


class DeviceCopyOpsRegistryTest(unittest.TestCase):
"""Tests that et_copy._h2d_copy and et_copy._d2h_copy ops are correctly
registered and produce expected outputs during tracing (CPU-only)."""

def test_h2d_copy_functional(self):
"""_h2d_copy should return a clone of the input tensor."""
x = torch.randn(2, 3)
result = torch.ops.et_copy._h2d_copy(x)
self.assertEqual(result.shape, x.shape)
self.assertEqual(result.dtype, x.dtype)
self.assertTrue(torch.equal(result, x))
# Should be a new tensor, not the same object
self.assertFalse(result.data_ptr() == x.data_ptr())

def test_d2h_copy_functional(self):
"""_d2h_copy should return a clone of the input tensor."""
x = torch.randn(4, 5)
result = torch.ops.et_copy._d2h_copy(x)
self.assertEqual(result.shape, x.shape)
self.assertEqual(result.dtype, x.dtype)
self.assertTrue(torch.equal(result, x))
self.assertFalse(result.data_ptr() == x.data_ptr())

def test_h2d_copy_out_variant(self):
"""_h2d_copy.out should copy data into the provided out tensor."""
x = torch.randn(3, 3)
out = torch.empty(3, 3)
result = torch.ops.et_copy._h2d_copy.out(x, out=out)
self.assertTrue(result is out)
self.assertTrue(torch.equal(out, x))

def test_d2h_copy_out_variant(self):
"""_d2h_copy.out should copy data into the provided out tensor."""
x = torch.randn(2, 4)
out = torch.empty(2, 4)
result = torch.ops.et_copy._d2h_copy.out(x, out=out)
self.assertTrue(result is out)
self.assertTrue(torch.equal(out, x))

def test_h2d_copy_preserves_dtype(self):
"""_h2d_copy should work with various dtypes."""
for dtype in [torch.float32, torch.float16, torch.int32, torch.int64]:
x = torch.ones(2, 2, dtype=dtype)
result = torch.ops.et_copy._h2d_copy(x)
self.assertEqual(result.dtype, dtype)
self.assertTrue(torch.equal(result, x))

def test_h2d_copy_scalar_tensor(self):
"""_h2d_copy should handle 0-dim tensors."""
x = torch.tensor(3.14)
result = torch.ops.et_copy._h2d_copy(x)
self.assertEqual(result.shape, torch.Size([]))
self.assertTrue(torch.equal(result, x))

def test_d2h_copy_empty_tensor(self):
"""_d2h_copy should handle empty tensors."""
x = torch.empty(0, 3)
result = torch.ops.et_copy._d2h_copy(x)
self.assertEqual(result.shape, torch.Size([0, 3]))
Loading