diff --git a/exir/passes/BUCK b/exir/passes/BUCK index 954f1cfdb4f..4647388b388 100644 --- a/exir/passes/BUCK +++ b/exir/passes/BUCK @@ -381,6 +381,14 @@ fbcode_target(_kind = runtime.python_library, ], ) +fbcode_target(_kind = runtime.python_library, + name = "device_copy_ops_registry", + srcs = ["_device_copy_ops_registry.py"], + deps = [ + "//caffe2:torch", + ], +) + fbcode_target(_kind = runtime.python_library, name = "memory_format_ops_pass", srcs = [ diff --git a/exir/passes/_device_copy_ops_registry.py b/exir/passes/_device_copy_ops_registry.py new file mode 100644 index 00000000000..a62b88d4234 --- /dev/null +++ b/exir/passes/_device_copy_ops_registry.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Registry for device copy ops used to insert explicit H2D (host-to-device) +and D2H (device-to-host) data transfer operations at delegate boundaries. + +These ops are inserted by PropagateDevicePass when enable_non_cpu_memory_planning +is True, making the graph functional by explicitly transferring data between +CPU and device memory. + +Follows the same registration pattern as dim_order_ops_registry.py. +""" + +import torch +from torch.library import impl, Library + +lib = Library("et_copy", "DEF") + +# _h2d_copy: copies a CPU tensor to device memory. +# At tracing time, this is a clone (both on CPU). At runtime, the out tensor +# is memory-planned on device, and the kernel calls +# DeviceAllocator::copy_host_to_device. +lib.define("_h2d_copy(Tensor self) -> Tensor") +lib.define("_h2d_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)") + +# _d2h_copy: copies a device tensor to CPU memory. +# At tracing time, this is a clone (both on CPU). At runtime, the self tensor +# has device memory, and the kernel calls DeviceAllocator::copy_device_to_host. +lib.define("_d2h_copy(Tensor self) -> Tensor") +lib.define("_d2h_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)") + + +@impl(lib, "_h2d_copy", "CompositeImplicitAutograd") +def _h2d_copy_impl(self: torch.Tensor) -> torch.Tensor: + # During tracing, both tensors are on CPU. Just clone to represent the transfer. + return self.clone() + + +@impl(lib, "_h2d_copy.out", "CompositeImplicitAutograd") +def _h2d_copy_out_impl(self: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor: + out.copy_(self) + return out + + +@impl(lib, "_d2h_copy", "CompositeImplicitAutograd") +def _d2h_copy_impl(self: torch.Tensor) -> torch.Tensor: + # During tracing, both tensors are on CPU. Just clone to represent the transfer. + return self.clone() + + +@impl(lib, "_d2h_copy.out", "CompositeImplicitAutograd") +def _d2h_copy_out_impl(self: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor: + out.copy_(self) + return out diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 322f72c870a..21493a69644 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -504,3 +504,14 @@ python_unittest( "//executorch/exir/passes:propagate_device_pass", ], ) + +python_unittest( + name = "device_copy_ops", + srcs = [ + "test_device_copy_ops.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir/passes:device_copy_ops_registry", + ], +) diff --git a/exir/tests/test_device_copy_ops.py b/exir/tests/test_device_copy_ops.py new file mode 100644 index 00000000000..805159d9d81 --- /dev/null +++ b/exir/tests/test_device_copy_ops.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +# Import the registry to register the ops +import executorch.exir.passes._device_copy_ops_registry # noqa: F401 + +import torch + + +class DeviceCopyOpsRegistryTest(unittest.TestCase): + """Tests that et_copy._h2d_copy and et_copy._d2h_copy ops are correctly + registered and produce expected outputs during tracing (CPU-only).""" + + def test_h2d_copy_functional(self): + """_h2d_copy should return a clone of the input tensor.""" + x = torch.randn(2, 3) + result = torch.ops.et_copy._h2d_copy(x) + self.assertEqual(result.shape, x.shape) + self.assertEqual(result.dtype, x.dtype) + self.assertTrue(torch.equal(result, x)) + # Should be a new tensor, not the same object + self.assertFalse(result.data_ptr() == x.data_ptr()) + + def test_d2h_copy_functional(self): + """_d2h_copy should return a clone of the input tensor.""" + x = torch.randn(4, 5) + result = torch.ops.et_copy._d2h_copy(x) + self.assertEqual(result.shape, x.shape) + self.assertEqual(result.dtype, x.dtype) + self.assertTrue(torch.equal(result, x)) + self.assertFalse(result.data_ptr() == x.data_ptr()) + + def test_h2d_copy_out_variant(self): + """_h2d_copy.out should copy data into the provided out tensor.""" + x = torch.randn(3, 3) + out = torch.empty(3, 3) + result = torch.ops.et_copy._h2d_copy.out(x, out=out) + self.assertTrue(result is out) + self.assertTrue(torch.equal(out, x)) + + def test_d2h_copy_out_variant(self): + """_d2h_copy.out should copy data into the provided out tensor.""" + x = torch.randn(2, 4) + out = torch.empty(2, 4) + result = torch.ops.et_copy._d2h_copy.out(x, out=out) + self.assertTrue(result is out) + self.assertTrue(torch.equal(out, x)) + + def test_h2d_copy_preserves_dtype(self): + """_h2d_copy should work with various dtypes.""" + for dtype in [torch.float32, torch.float16, torch.int32, torch.int64]: + x = torch.ones(2, 2, dtype=dtype) + result = torch.ops.et_copy._h2d_copy(x) + self.assertEqual(result.dtype, dtype) + self.assertTrue(torch.equal(result, x)) + + def test_h2d_copy_scalar_tensor(self): + """_h2d_copy should handle 0-dim tensors.""" + x = torch.tensor(3.14) + result = torch.ops.et_copy._h2d_copy(x) + self.assertEqual(result.shape, torch.Size([])) + self.assertTrue(torch.equal(result, x)) + + def test_d2h_copy_empty_tensor(self): + """_d2h_copy should handle empty tensors.""" + x = torch.empty(0, 3) + result = torch.ops.et_copy._d2h_copy(x) + self.assertEqual(result.shape, torch.Size([0, 3]))