Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ install(
)

# CUDA backend implementation
set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp)
set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp
runtime/cuda_mutable_state.cpp
)
if(_cuda_is_msvc_toolchain)
# MSVC links aoti_cuda_backend into portable_lib without relying on C++
# symbols exported from aoti_cuda_shims.dll.
Expand Down Expand Up @@ -236,3 +238,13 @@ install(
EXPORT ExecuTorchTargets
DESTINATION lib
)

if(BUILD_TESTING)
include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake)

et_cxx_test(
test_cuda_mutable_state SOURCES runtime/test/test_cuda_mutable_state.cpp
EXTRA_LIBS aoti_cuda_backend
)
target_compile_definitions(test_cuda_mutable_state PRIVATE CUDA_AVAILABLE=1)
endif()
27 changes: 27 additions & 0 deletions backends/cuda/runtime/TARGETS
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest")
load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils")
load("//tools/build/buck:nvcc_flags.bzl", "get_nvcc_arch_args")
Comment thread
mergennachin marked this conversation as resolved.

oncall("executorch")
Expand Down Expand Up @@ -105,9 +107,11 @@ runtime.cxx_library(
name = "cuda_backend",
srcs = [
"cuda_backend.cpp",
"cuda_mutable_state.cpp",
],
headers = [
"cuda_delegate_handle.h",
"cuda_mutable_state.h",
],
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
link_whole = True,
Expand Down Expand Up @@ -135,3 +139,26 @@ runtime.cxx_library(
("cuda", None, "cuda-lazy"),
],
)

cpp_unittest(
name = "test_cuda_mutable_state",
srcs = [
"test/test_cuda_mutable_state.cpp",
],
deps = [
":cuda_backend",
"//executorch/backends/aoti:aoti_common_slim",
"//executorch/backends/aoti/slim/core:slimtensor",
"//executorch/backends/aoti/slim/factory:from_blob",
"//executorch/runtime/core:core",
"//executorch/runtime/platform:platform",
],
external_deps = [
("cuda", None, "cuda-lazy"),
],
preprocessor_flags = ["-DCUDA_AVAILABLE=1"],
keep_gpu_sections = True,
remote_execution = re_test_utils.remote_execution(
platform = "gpu-remote-execution",
),
)
Comment thread
mergennachin marked this conversation as resolved.
13 changes: 13 additions & 0 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include <executorch/backends/aoti/utils.h>
#include <executorch/backends/cuda/runtime/cuda_allocator.h>
#include <executorch/backends/cuda/runtime/cuda_delegate_handle.h>
#include <executorch/backends/cuda/runtime/cuda_mutable_state.h>
#include <executorch/backends/cuda/runtime/platform/platform.h>
#include <executorch/backends/cuda/runtime/shims/memory.h>
#include <executorch/backends/cuda/runtime/utils.h>
Expand Down Expand Up @@ -436,6 +437,10 @@ class ET_EXPERIMENTAL CudaBackend final
kCudaGraphWarmupSteps);
}

// Record whether this AOTI build exposes the constant-management symbols
// needed for per-session mutable-buffer rebinding (CUDA V2 multi-session).
mutable_state_note_handle(handle);

return (DelegateHandle*)handle; // Return the handle post-processing
}

Expand Down Expand Up @@ -539,6 +544,12 @@ class ET_EXPERIMENTAL CudaBackend final
}
}

// CUDA V2 multi-session: if a logical session is active on this thread,
// rebind this container's mutable constants (KV/conv/recurrent) to the
// session's own GPU buffers before running. No-op for
// single-session/legacy.
ET_CHECK_OK_OR_RETURN_ERROR(mutable_state_rebind_for_execute(handle));

// ---------------------------------------------------------------
// CUDA graph REPLAY path — skip all tensor setup and just replay
// ---------------------------------------------------------------
Expand Down Expand Up @@ -826,6 +837,8 @@ class ET_EXPERIMENTAL CudaBackend final
}
cuda::CudaDelegateHandle* handle = (cuda::CudaDelegateHandle*)handle_;

mutable_state_forget_handle(handle);

// The CUDA stream is managed by shared_ptr in the handle.
// It will be automatically destroyed when the last handle using it
// is destroyed. Just reset our reference.
Expand Down
Loading
Loading