Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions backends/aoti/slim/c10/core/Device.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct Device final {
}

/// Constructs a Device from a string description.
/// The string must be "cpu" or "cpu:0".
/// The string must be "cpu", "cpu:0", "cuda", or "cuda:N".
/* implicit */ Device(const std::string& device_string)
: Device(DeviceType::CPU) {
ET_CHECK_MSG(!device_string.empty(), "Device string must not be empty");
Expand All @@ -46,11 +46,19 @@ struct Device final {
index_ = -1;
} else if (device_string == "cpu:0" || device_string == "CPU:0") {
type_ = DeviceType::CPU;
index_ = 0;
} else if (device_string == "cuda" || device_string == "CUDA") {
type_ = DeviceType::CUDA;
index_ = 0;
} else if (
device_string.substr(0, 5) == "cuda:" ||
device_string.substr(0, 5) == "CUDA:") {
type_ = DeviceType::CUDA;
index_ = static_cast<DeviceIndex>(device_string.back() - '0');
} else {
ET_CHECK_MSG(
false,
"Invalid device string: %s. Currently only 'cpu' is supported.",
"Invalid device string: %s. Supported: 'cpu', 'cuda', 'cuda:N'.",
device_string.c_str());
}
validate();
Expand Down Expand Up @@ -92,7 +100,12 @@ struct Device final {
return type_ == DeviceType::CPU;
}

/// Returns a string representation of the device (e.g., "cpu" or "cpu:0").
/// Returns true if the device is of CUDA type.
bool is_cuda() const noexcept {
return type_ == DeviceType::CUDA;
}

/// Returns a string representation of the device (e.g., "cpu" or "cuda:0").
std::string str() const {
std::string str = DeviceTypeName(type(), /* lower_case */ true);
if (has_index()) {
Expand Down
8 changes: 6 additions & 2 deletions backends/aoti/slim/c10/core/DeviceType.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ namespace executorch::backends::aoti::slim::c10 {
/// Enum representing the type of device.
enum class DeviceType : int8_t {
CPU = 0,
COMPILE_TIME_MAX_DEVICE_TYPES = 1,
CUDA = 1,
COMPILE_TIME_MAX_DEVICE_TYPES = 2,
};

constexpr DeviceType kCPU = DeviceType::CPU;
constexpr DeviceType kCUDA = DeviceType::CUDA;

/// Maximum number of device types at compile time.
constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
Expand All @@ -36,6 +38,8 @@ inline std::string DeviceTypeName(DeviceType d, bool lower_case = false) {
switch (d) {
case DeviceType::CPU:
return lower_case ? "cpu" : "CPU";
case DeviceType::CUDA:
return lower_case ? "cuda" : "CUDA";
default:
ET_CHECK_MSG(false, "Unknown device type: %d", static_cast<int>(d));
}
Expand All @@ -45,7 +49,7 @@ inline std::string DeviceTypeName(DeviceType d, bool lower_case = false) {
/// @param d The device type to check.
/// @return true if the device type is valid, false otherwise.
inline bool isValidDeviceType(DeviceType d) {
return d == DeviceType::CPU;
return d == DeviceType::CPU || d == DeviceType::CUDA;
}

inline std::ostream& operator<<(std::ostream& stream, DeviceType type) {
Expand Down
87 changes: 72 additions & 15 deletions backends/aoti/slim/c10/core/ScalarType.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,64 @@
#include <cstdint>
#include <ostream>

#include <executorch/runtime/core/portable_type/bfloat16.h>
#include <executorch/runtime/platform/assert.h>

namespace executorch::backends::aoti::slim::c10 {

// Import BFloat16 from ExecuTorch's portable_type
using BFloat16 = ::executorch::runtime::etensor::BFloat16;

/// Enum representing the scalar type (dtype) of tensor elements.
/// Note: Enum values must match PyTorch's c10::ScalarType for compatibility.
enum class ScalarType : int8_t {
// Byte = 0,
// Char = 1,
// Short = 2,
// Int = 3,
// Long = 4,
Float = 6,
// Bool = 11,
// BFloat16 = 15,
// Byte = 0, // uint8_t - not currently needed
Char = 1, // int8_t
Short = 2, // int16_t
Int = 3, // int32_t
Long = 4, // int64_t
// Half = 5, // float16 - not currently needed
Float = 6, // float
// Double = 7, // double - not currently needed
// ComplexHalf = 8,
// ComplexFloat = 9,
// ComplexDouble = 10,
Bool = 11, // bool
// QInt8 = 12,
// QUInt8 = 13,
// QInt32 = 14,
BFloat16 = 15, // bfloat16
Undefined = -1,
NumOptions = 7,
};

/// Constant for Float scalar type.
// Type alias constants for convenience
constexpr ScalarType kChar = ScalarType::Char;
constexpr ScalarType kShort = ScalarType::Short;
constexpr ScalarType kInt = ScalarType::Int;
constexpr ScalarType kLong = ScalarType::Long;
constexpr ScalarType kFloat = ScalarType::Float;
constexpr ScalarType kBool = ScalarType::Bool;
constexpr ScalarType kBFloat16 = ScalarType::BFloat16;

/// Returns the size in bytes of a single element of the given scalar type.
/// @param t The scalar type.
/// @return The size in bytes of a single element.
inline size_t elementSize(ScalarType t) {
switch (t) {
case ScalarType::Char:
return sizeof(int8_t);
case ScalarType::Short:
return sizeof(int16_t);
case ScalarType::Int:
return sizeof(int32_t);
case ScalarType::Long:
return sizeof(int64_t);
case ScalarType::Float:
return sizeof(float);
case ScalarType::Bool:
return sizeof(bool);
case ScalarType::BFloat16:
return sizeof(BFloat16);
default:
ET_CHECK_MSG(false, "Unknown ScalarType: %d", static_cast<int>(t));
}
Expand All @@ -51,8 +80,20 @@ inline size_t elementSize(ScalarType t) {
/// @return The name of the scalar type.
inline const char* toString(ScalarType t) {
switch (t) {
case ScalarType::Char:
return "Char";
case ScalarType::Short:
return "Short";
case ScalarType::Int:
return "Int";
case ScalarType::Long:
return "Long";
case ScalarType::Float:
return "Float";
case ScalarType::Bool:
return "Bool";
case ScalarType::BFloat16:
return "BFloat16";
case ScalarType::Undefined:
return "Undefined";
default:
Expand All @@ -64,16 +105,32 @@ inline const char* toString(ScalarType t) {
/// @param t The scalar type to check.
/// @return true if the scalar type is floating point, false otherwise.
inline bool isFloatingType(ScalarType t) {
return t == ScalarType::Float;
return t == ScalarType::Float || t == ScalarType::BFloat16;
}

/// Checks if the scalar type is an integral type (including bool).
/// Checks if the scalar type is an integral type (including bool optionally).
/// @param t The scalar type to check.
/// @param includeBool Whether to consider Bool as integral.
/// @return true if the scalar type is integral, false otherwise.
inline bool isIntegralType(ScalarType t, bool /*includeBool*/) {
(void)t;
return false;
inline bool isIntegralType(ScalarType t, bool includeBool) {
switch (t) {
case ScalarType::Char:
case ScalarType::Short:
case ScalarType::Int:
case ScalarType::Long:
return true;
case ScalarType::Bool:
return includeBool;
default:
return false;
}
}

/// Checks if the scalar type is a boolean type.
/// @param t The scalar type to check.
/// @return true if the scalar type is Bool, false otherwise.
inline bool isBoolType(ScalarType t) {
return t == ScalarType::Bool;
}

inline std::ostream& operator<<(std::ostream& stream, ScalarType scalar_type) {
Expand Down
1 change: 1 addition & 0 deletions backends/aoti/slim/c10/core/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def define_common_targets():
],
visibility = ["@EXECUTORCH_CLIENTS"],
exported_deps = [
"//executorch/runtime/core/portable_type:portable_type",
"//executorch/runtime/platform:platform",
],
)
Expand Down
122 changes: 122 additions & 0 deletions backends/aoti/slim/c10/core/test/test_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,125 @@ TEST_F(DeviceTest, Hash) {
EXPECT_EQ(hasher(cpu1), hasher(cpu2));
EXPECT_NE(hasher(cpu1), hasher(cpu3));
}

// =============================================================================
// CUDA DeviceType Tests
// =============================================================================

class CUDADeviceTypeTest : public ::testing::Test {};

TEST_F(CUDADeviceTypeTest, CUDAEnumValue) {
// Verify CUDA has the correct enum value (1) to match PyTorch
EXPECT_EQ(static_cast<int>(DeviceType::CUDA), 1);
}

TEST_F(CUDADeviceTypeTest, DeviceTypeName) {
// Verify DeviceTypeName returns correct strings for CUDA
EXPECT_EQ(DeviceTypeName(DeviceType::CUDA, false), "CUDA");
EXPECT_EQ(DeviceTypeName(DeviceType::CUDA, true), "cuda");
}

TEST_F(CUDADeviceTypeTest, IsValidDeviceType) {
// Verify isValidDeviceType works correctly for CUDA
EXPECT_TRUE(isValidDeviceType(DeviceType::CUDA));
}

TEST_F(CUDADeviceTypeTest, KCUDAConstant) {
// Verify kCUDA constant
EXPECT_EQ(kCUDA, DeviceType::CUDA);
}

// =============================================================================
// CUDA Device Tests
// =============================================================================

class CUDADeviceTest : public ::testing::Test {};

TEST_F(CUDADeviceTest, ConstructFromDeviceType) {
// Construct Device from DeviceType
Device cuda_device(DeviceType::CUDA);

EXPECT_TRUE(cuda_device.is_cuda());
EXPECT_FALSE(cuda_device.is_cpu());
EXPECT_EQ(cuda_device.type(), DeviceType::CUDA);
EXPECT_EQ(cuda_device.index(), -1);
EXPECT_FALSE(cuda_device.has_index());
}

TEST_F(CUDADeviceTest, ConstructWithIndex) {
// Construct CUDA Device with explicit index
Device cuda_device(DeviceType::CUDA, 0);

EXPECT_TRUE(cuda_device.is_cuda());
EXPECT_FALSE(cuda_device.is_cpu());
EXPECT_EQ(cuda_device.type(), DeviceType::CUDA);
EXPECT_EQ(cuda_device.index(), 0);
EXPECT_TRUE(cuda_device.has_index());
}

TEST_F(CUDADeviceTest, ConstructWithNonZeroIndex) {
// Construct CUDA Device with non-zero index (multi-GPU)
Device cuda_device(DeviceType::CUDA, 3);

EXPECT_TRUE(cuda_device.is_cuda());
EXPECT_EQ(cuda_device.index(), 3);
EXPECT_TRUE(cuda_device.has_index());
}

TEST_F(CUDADeviceTest, ConstructFromString) {
// Construct CUDA Device from string
Device cuda1("cuda");
EXPECT_TRUE(cuda1.is_cuda());
EXPECT_EQ(cuda1.index(), 0);

Device cuda2("CUDA");
EXPECT_TRUE(cuda2.is_cuda());
EXPECT_EQ(cuda2.index(), 0);

Device cuda3("cuda:0");
EXPECT_TRUE(cuda3.is_cuda());
EXPECT_EQ(cuda3.index(), 0);

Device cuda4("cuda:1");
EXPECT_TRUE(cuda4.is_cuda());
EXPECT_EQ(cuda4.index(), 1);

Device cuda5("CUDA:2");
EXPECT_TRUE(cuda5.is_cuda());
EXPECT_EQ(cuda5.index(), 2);
}

TEST_F(CUDADeviceTest, Equality) {
Device cuda1(DeviceType::CUDA, 0);
Device cuda2(DeviceType::CUDA, 0);
Device cuda3(DeviceType::CUDA, 1);
Device cpu(DeviceType::CPU, 0);

EXPECT_EQ(cuda1, cuda2);
EXPECT_NE(cuda1, cuda3);
EXPECT_NE(cuda1, cpu);
}

TEST_F(CUDADeviceTest, Str) {
Device cuda1(DeviceType::CUDA);
EXPECT_EQ(cuda1.str(), "cuda");

Device cuda2(DeviceType::CUDA, 0);
EXPECT_EQ(cuda2.str(), "cuda:0");

Device cuda3(DeviceType::CUDA, 1);
EXPECT_EQ(cuda3.str(), "cuda:1");
}

TEST_F(CUDADeviceTest, Hash) {
// Verify CUDA Device can be hashed
Device cuda1(DeviceType::CUDA, 0);
Device cuda2(DeviceType::CUDA, 0);
Device cuda3(DeviceType::CUDA, 1);
Device cpu(DeviceType::CPU, 0);

std::hash<Device> hasher;
EXPECT_EQ(hasher(cuda1), hasher(cuda2));
EXPECT_NE(hasher(cuda1), hasher(cuda3));
EXPECT_NE(hasher(cuda1), hasher(cpu));
}
Loading
Loading