Skip to content

Commit b0201ac

Browse files
authored
[slimtensor] Add common_shims_slim with basic property getters
Differential Revision: D90126254 Pull Request resolved: #16454
1 parent b909c79 commit b0201ac

6 files changed

Lines changed: 633 additions & 1 deletion

File tree

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/aoti/common_shims_slim.h>
10+
11+
namespace executorch {
12+
namespace backends {
13+
namespace aoti {
14+
15+
extern "C" {
16+
17+
// ============================================================
18+
// Basic Property Getters - Implementations
19+
// ============================================================
20+
21+
AOTITorchError aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr) {
22+
if (tensor == nullptr || ret_data_ptr == nullptr) {
23+
return Error::InvalidArgument;
24+
}
25+
*ret_data_ptr = tensor->data_ptr();
26+
return Error::Ok;
27+
}
28+
29+
AOTITorchError aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes) {
30+
if (tensor == nullptr || ret_sizes == nullptr) {
31+
return Error::InvalidArgument;
32+
}
33+
*ret_sizes = const_cast<int64_t*>(tensor->sizes().data());
34+
return Error::Ok;
35+
}
36+
37+
AOTITorchError aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides) {
38+
if (tensor == nullptr || ret_strides == nullptr) {
39+
return Error::InvalidArgument;
40+
}
41+
*ret_strides = const_cast<int64_t*>(tensor->strides().data());
42+
return Error::Ok;
43+
}
44+
45+
AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype) {
46+
if (tensor == nullptr || ret_dtype == nullptr) {
47+
return Error::InvalidArgument;
48+
}
49+
*ret_dtype = static_cast<int32_t>(tensor->dtype());
50+
return Error::Ok;
51+
}
52+
53+
AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim) {
54+
if (tensor == nullptr || ret_dim == nullptr) {
55+
return Error::InvalidArgument;
56+
}
57+
*ret_dim = static_cast<int64_t>(tensor->dim());
58+
return Error::Ok;
59+
}
60+
61+
int32_t aoti_torch_layout_strided() {
62+
// Slimtensor only support strided layout, the return value will always be 0,
63+
// a.k.a at::Layout::Strided;
64+
return 0;
65+
}
66+
67+
} // extern "C"
68+
69+
} // namespace aoti
70+
} // namespace backends
71+
} // namespace executorch

backends/aoti/common_shims_slim.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/aoti/export.h>
12+
#include <executorch/backends/aoti/slim/core/SlimTensor.h>
13+
#include <executorch/runtime/core/error.h>
14+
#include <cstdint>
15+
16+
namespace executorch {
17+
namespace backends {
18+
namespace aoti {
19+
20+
extern "C" {
21+
22+
// Common using declarations for ExecuTorch types
23+
using executorch::runtime::Error;
24+
25+
// Tensor type definition using SlimTensor
26+
using Tensor = executorch::backends::aoti::slim::SlimTensor;
27+
28+
// Common AOTI type aliases
29+
using AOTIRuntimeError = Error;
30+
using AOTITorchError = Error;
31+
32+
// ============================================================
33+
// Basic Property Getters - Declarations
34+
// ============================================================
35+
36+
AOTI_SHIM_EXPORT AOTITorchError
37+
aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr);
38+
39+
AOTI_SHIM_EXPORT AOTITorchError
40+
aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes);
41+
42+
AOTI_SHIM_EXPORT AOTITorchError
43+
aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides);
44+
45+
AOTI_SHIM_EXPORT AOTITorchError
46+
aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype);
47+
48+
AOTI_SHIM_EXPORT AOTITorchError
49+
aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim);
50+
51+
AOTI_SHIM_EXPORT int32_t aoti_torch_layout_strided();
52+
53+
} // extern "C"
54+
55+
} // namespace aoti
56+
} // namespace backends
57+
} // namespace executorch

backends/aoti/targets.bzl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,21 @@ def define_common_targets():
8686
":delegate_handle",
8787
],
8888
)
89+
90+
# SlimTensor-based common shims library
91+
# Uses SlimTensor for all tensor operations
92+
runtime.cxx_library(
93+
name = "common_shims_slim",
94+
srcs = [
95+
"common_shims_slim.cpp",
96+
],
97+
headers = [
98+
"common_shims_slim.h",
99+
"export.h",
100+
],
101+
visibility = ["@EXECUTORCH_CLIENTS"],
102+
exported_deps = [
103+
"//executorch/runtime/core:core",
104+
"//executorch/backends/aoti/slim/core:slimtensor",
105+
],
106+
)

backends/aoti/tests/TARGETS

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest")
2+
load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils")
23

34
oncall("executorch")
45

@@ -20,3 +21,27 @@ cpp_unittest(
2021
"//executorch/extension/tensor:tensor",
2122
],
2223
)
24+
25+
cpp_unittest(
26+
name = "test_common_shims_slim",
27+
srcs = [
28+
"test_common_shims_slim.cpp",
29+
],
30+
deps = [
31+
"//executorch/backends/aoti:common_shims_slim",
32+
"//executorch/backends/aoti/slim/core:slimtensor",
33+
"//executorch/backends/aoti/slim/factory:empty",
34+
"//executorch/runtime/core:core",
35+
"//executorch/runtime/platform:platform",
36+
],
37+
external_deps = [
38+
("cuda", None, "cuda-lazy"),
39+
],
40+
preprocessor_flags = [
41+
"-DCUDA_AVAILABLE=1",
42+
],
43+
keep_gpu_sections = True,
44+
remote_execution = re_test_utils.remote_execution(
45+
platform = "gpu-remote-execution",
46+
),
47+
)

0 commit comments

Comments
 (0)