Skip to content

Commit 849b733

Browse files
[slimtensor] Add item() methods with CUDA and CPU support for scalar value extraction (#16840)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #16445 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/87/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/87/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/86/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/87/orig Differential Revision: [D90034340](https://our.internmc.facebook.com/intern/diff/D90034340/) @diff-train-skip-merge --------- Co-authored-by: gasoonjia <gasoonjia@icloud.com>
1 parent e84440b commit 849b733

2 files changed

Lines changed: 133 additions & 0 deletions

File tree

backends/aoti/slim/core/SlimTensor.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,34 @@ class SlimTensor {
504504
return *this;
505505
}
506506

507+
/**
508+
* Extract the scalar value from a tensor with exactly 1 element.
509+
* Automatically handles CUDA tensors by copying data to CPU.
510+
*
511+
* @tparam T The type to extract (must match tensor dtype).
512+
* @return The scalar value.
513+
*/
514+
template <typename T>
515+
T item() const {
516+
ET_CHECK_MSG(
517+
this->numel() == 1,
518+
"item() requires tensor to have exactly 1 element, got %zu",
519+
this->numel());
520+
521+
T result;
522+
if (this->is_cpu()) {
523+
result = *static_cast<const T*>(this->data_ptr());
524+
} else {
525+
#if defined(CUDA_AVAILABLE)
526+
DeviceTraits<c10::DeviceType::CUDA>::memcpy(
527+
&result, this->data_ptr(), sizeof(T), CPU_DEVICE, this->device());
528+
#else
529+
ET_CHECK_MSG(false, "item(): CUDA tensor but CUDA support not available");
530+
#endif
531+
}
532+
return result;
533+
}
534+
507535
private:
508536
SlimTensor _clone_impl(
509537
c10::IntArrayRef sizes,

backends/aoti/slim/core/test/test_slimtensor_basic.cpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,111 @@ TEST(SlimTensorBasicTest, CopyConstructor) {
365365
EXPECT_EQ(copy.dtype(), c10::ScalarType::Float);
366366
}
367367

368+
// =============================================================================
369+
// Item Tests (Device-Parameterized)
370+
// =============================================================================
371+
372+
// Helper to set value in storage (handles both CPU and CUDA)
373+
template <typename T>
374+
void set_storage_value(
375+
Storage& storage,
376+
const T& value,
377+
const c10::Device& dev) {
378+
if (dev.is_cpu()) {
379+
*static_cast<T*>(storage->data()) = value;
380+
} else {
381+
#if defined(CUDA_AVAILABLE)
382+
DeviceTraits<c10::DeviceType::CUDA>::memcpy(
383+
storage->data(), &value, sizeof(T), dev, CPU_DEVICE);
384+
#endif
385+
}
386+
}
387+
388+
// Template function for testing item<T>() with explicit type
389+
template <typename T>
390+
void test_item_typed(
391+
const c10::Device& dev,
392+
c10::ScalarType dtype,
393+
T input_value,
394+
T expected_value) {
395+
std::vector<int64_t> sizes = {1};
396+
std::vector<int64_t> strides = {1};
397+
Storage storage(new MaybeOwningStorage(dev, sizeof(T)));
398+
set_storage_value(storage, input_value, dev);
399+
400+
SlimTensor tensor(
401+
std::move(storage), makeArrayRef(sizes), makeArrayRef(strides), dtype);
402+
403+
T result = tensor.item<T>();
404+
if constexpr (std::is_floating_point_v<T>) {
405+
EXPECT_FLOAT_EQ(result, expected_value);
406+
} else {
407+
EXPECT_EQ(result, expected_value);
408+
}
409+
}
410+
411+
// Tests for item<T>() with explicit type
412+
TEST_P(SlimTensorBasicDeviceTest, ItemTypedFloat) {
413+
test_item_typed<float>(device(), c10::ScalarType::Float, 42.5f, 42.5f);
414+
}
415+
416+
TEST_P(SlimTensorBasicDeviceTest, ItemTypedInt) {
417+
test_item_typed<int32_t>(device(), c10::ScalarType::Int, 123, 123);
418+
}
419+
420+
TEST_P(SlimTensorBasicDeviceTest, ItemTypedLong) {
421+
test_item_typed<int64_t>(
422+
device(), c10::ScalarType::Long, 9876543210LL, 9876543210LL);
423+
}
424+
425+
TEST_P(SlimTensorBasicDeviceTest, ItemTypedShort) {
426+
test_item_typed<int16_t>(device(), c10::ScalarType::Short, 1234, 1234);
427+
}
428+
429+
TEST_P(SlimTensorBasicDeviceTest, ItemTypedChar) {
430+
test_item_typed<int8_t>(device(), c10::ScalarType::Char, -42, -42);
431+
}
432+
433+
TEST_P(SlimTensorBasicDeviceTest, ItemTypedBool) {
434+
test_item_typed<bool>(device(), c10::ScalarType::Bool, true, true);
435+
}
436+
437+
// Can't reuse test_item_typed() because we need to cast to float explictly for
438+
// comparison.
439+
TEST_P(SlimTensorBasicDeviceTest, ItemTypedBFloat16) {
440+
c10::BFloat16 input{3.14f};
441+
c10::BFloat16 expected{3.14f};
442+
std::vector<int64_t> sizes = {1};
443+
std::vector<int64_t> strides = {1};
444+
Storage storage(new MaybeOwningStorage(device(), sizeof(c10::BFloat16)));
445+
set_storage_value(storage, input, device());
446+
447+
SlimTensor tensor(
448+
std::move(storage),
449+
makeArrayRef(sizes),
450+
makeArrayRef(strides),
451+
c10::ScalarType::BFloat16);
452+
453+
c10::BFloat16 result = tensor.item<c10::BFloat16>();
454+
EXPECT_FLOAT_EQ(static_cast<float>(result), static_cast<float>(expected));
455+
}
456+
457+
// Test item() fails on non-scalar tensor (numel > 1)
458+
TEST_P(SlimTensorBasicDeviceTest, ItemFailsOnNonScalarTensor) {
459+
std::vector<int64_t> sizes = {2, 3};
460+
std::vector<int64_t> strides = {3, 1};
461+
Storage storage = make_storage(6 * sizeof(float));
462+
463+
SlimTensor tensor(
464+
std::move(storage),
465+
makeArrayRef(sizes),
466+
makeArrayRef(strides),
467+
c10::ScalarType::Float);
468+
469+
EXPECT_EQ(tensor.numel(), 6u);
470+
EXPECT_DEATH(tensor.item<float>(), "");
471+
}
472+
368473
// CPU-only test for DataPtrWithOffset (requires reading data back)
369474
TEST(SlimTensorBasicTest, DataPtrWithOffset) {
370475
std::vector<int64_t> sizes = {2, 3};

0 commit comments

Comments
 (0)