diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index 0ce48d85e92..4668e48b91e 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -184,7 +184,9 @@ install( ) # CUDA backend implementation -set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp) +set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp + runtime/cuda_mutable_state.cpp +) if(_cuda_is_msvc_toolchain) # MSVC links aoti_cuda_backend into portable_lib without relying on C++ # symbols exported from aoti_cuda_shims.dll. @@ -236,3 +238,13 @@ install( EXPORT ExecuTorchTargets DESTINATION lib ) + +if(BUILD_TESTING) + include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) + + et_cxx_test( + test_cuda_mutable_state SOURCES runtime/test/test_cuda_mutable_state.cpp + EXTRA_LIBS aoti_cuda_backend + ) + target_compile_definitions(test_cuda_mutable_state PRIVATE CUDA_AVAILABLE=1) +endif() diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS index f62780b29c2..1cdd430a020 100644 --- a/backends/cuda/runtime/TARGETS +++ b/backends/cuda/runtime/TARGETS @@ -1,4 +1,6 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest") +load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils") load("//tools/build/buck:nvcc_flags.bzl", "get_nvcc_arch_args") oncall("executorch") @@ -105,9 +107,11 @@ runtime.cxx_library( name = "cuda_backend", srcs = [ "cuda_backend.cpp", + "cuda_mutable_state.cpp", ], headers = [ "cuda_delegate_handle.h", + "cuda_mutable_state.h", ], # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) link_whole = True, @@ -135,3 +139,26 @@ runtime.cxx_library( ("cuda", None, "cuda-lazy"), ], ) + +cpp_unittest( + name = "test_cuda_mutable_state", + srcs = [ + "test/test_cuda_mutable_state.cpp", + ], + deps = [ + ":cuda_backend", + "//executorch/backends/aoti:aoti_common_slim", + "//executorch/backends/aoti/slim/core:slimtensor", + "//executorch/backends/aoti/slim/factory:from_blob", + "//executorch/runtime/core:core", + "//executorch/runtime/platform:platform", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], + preprocessor_flags = ["-DCUDA_AVAILABLE=1"], + keep_gpu_sections = True, + remote_execution = re_test_utils.remote_execution( + platform = "gpu-remote-execution", + ), +) diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index b0a06c8e8a0..1939f09358b 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -44,6 +44,7 @@ #include #include #include +#include #include #include #include @@ -436,6 +437,10 @@ class ET_EXPERIMENTAL CudaBackend final kCudaGraphWarmupSteps); } + // Record whether this AOTI build exposes the constant-management symbols + // needed for per-session mutable-buffer rebinding (CUDA V2 multi-session). + mutable_state_note_handle(handle); + return (DelegateHandle*)handle; // Return the handle post-processing } @@ -539,6 +544,12 @@ class ET_EXPERIMENTAL CudaBackend final } } + // CUDA V2 multi-session: if a logical session is active on this thread, + // rebind this container's mutable constants (KV/conv/recurrent) to the + // session's own GPU buffers before running. No-op for + // single-session/legacy. + ET_CHECK_OK_OR_RETURN_ERROR(mutable_state_rebind_for_execute(handle)); + // --------------------------------------------------------------- // CUDA graph REPLAY path — skip all tensor setup and just replay // --------------------------------------------------------------- @@ -826,6 +837,8 @@ class ET_EXPERIMENTAL CudaBackend final } cuda::CudaDelegateHandle* handle = (cuda::CudaDelegateHandle*)handle_; + mutable_state_forget_handle(handle); + // The CUDA stream is managed by shared_ptr in the handle. // It will be automatically destroyed when the last handle using it // is destroyed. Just reset our reference. diff --git a/backends/cuda/runtime/cuda_mutable_state.cpp b/backends/cuda/runtime/cuda_mutable_state.cpp new file mode 100644 index 00000000000..31b88e1240a --- /dev/null +++ b/backends/cuda/runtime/cuda_mutable_state.cpp @@ -0,0 +1,597 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace cuda { + +namespace aoti = ::executorch::backends::aoti; +namespace slimc10 = ::executorch::backends::aoti::slim::c10; +using ::executorch::backends::aoti::slim::from_blob; +using ::executorch::backends::aoti::slim::SlimTensor; +using ::executorch::runtime::Error; +using ::executorch::runtime::Result; + +namespace { + +// Per-handle descriptor of one mutable constant (AOTI internal name differs per +// compiled method, so this is keyed per delegate handle within a context). +struct Desc { + std::string internal_name; + std::vector sizes; + std::vector strides; + slimc10::ScalarType dtype{slimc10::ScalarType::Float}; + slimc10::Device device{slimc10::DeviceType::CUDA, 0}; + size_t nbytes{0}; +}; + +// Cached user-managed pairs for a (handle, session): SlimTensors wrapping the +// session's GPU buffers (kept alive here) and the flat pairs array AOTI +// rebinds. +struct Bound { + std::vector> tensors; + std::vector pairs; +}; + +// All per-engine/model mutable state. Keyed by context id in Manager. +struct Context { + std::vector fqns; + std::unordered_set fqn_set; + + bool symbols_checked{false}; + bool symbols_available{false}; + + // FQN -> device template (the model's initial mutable contents) + sizes. + std::unordered_map template_ptr; + std::unordered_map template_nbytes; + int64_t total_bytes{0}; + + // Per-handle descriptor table + the union of discovered FQNs (for coverage). + std::unordered_map> + desc; + std::unordered_set discovered_fqns; + Error build_error{Error::Ok}; + + std::unordered_set sessions; + int next_token{0}; + // token -> (fqn -> device buffer) shared across the session's handles. + std::unordered_map> session_buf; + // (handle, token) -> cached wrappers + pairs. + std::unordered_map> bound; +}; + +struct Manager { + std::mutex mu; + std::unordered_map contexts; + std::unordered_map handle_ctx; + MutableStateContext next_ctx{1}; +}; + +Manager& mgr() { + static Manager m; + return m; +} + +// The context whose model is currently being loaded on this thread (so +// note_handle, called from CudaBackend::init, can associate handles). And the +// active (context, session) selected before execute on this thread. +thread_local MutableStateContext tl_loading_ctx = kInvalidMutableContext; +thread_local MutableStateContext tl_active_ctx = kInvalidMutableContext; +thread_local int tl_active_token = kNoMutableSession; + +bool handle_has_symbols(CudaDelegateHandle* h) { + return h->get_num_constants && h->get_constant_name && + h->get_constant_original_fqn && h->extract_constants_map && + h->update_user_managed_constant_buffer_pairs; +} + +bool validate_descriptors(const Context& c) { + bool ok = true; + std::unordered_map first_desc; + for (const auto& handle_descs : c.desc) { + for (const auto& fd : handle_descs.second) { + const std::string& fqn = fd.first; + const Desc& d = fd.second; + auto template_it = c.template_nbytes.find(fqn); + if (template_it == c.template_nbytes.end()) { + ET_LOG( + Error, + "mutable_state: descriptor '%s' has no captured template", + fqn.c_str()); + ok = false; + continue; + } + if (d.nbytes > template_it->second) { + ET_LOG( + Error, + "mutable_state: descriptor '%s' (%zu B) exceeds shared template " + "buffer (%zu B)", + fqn.c_str(), + d.nbytes, + template_it->second); + ok = false; + } + + auto inserted = first_desc.emplace(fqn, &d); + if (!inserted.second) { + const Desc& base = *inserted.first->second; + if (d.dtype != base.dtype || d.device != base.device) { + ET_LOG( + Error, + "mutable_state: descriptor '%s' has incompatible dtype/device " + "across loaded methods", + fqn.c_str()); + ok = false; + } + } + } + } + return ok; +} + +// Build the descriptor table for a handle and capture per-FQN initial +// templates. Caller holds mgr().mu. Runs before any session has rebound this +// container, so the constants still hold the model's initial mutable state. +Error build_descriptors(Context& c, CudaDelegateHandle* h) { + auto container = h->container_handle; + + size_t n = 0; + ET_CHECK_OK_OR_RETURN_ERROR( + h->get_num_constants(container, &n), + "mutable_state: get_num_constants failed"); + std::unordered_map fqn_to_internal; + for (size_t i = 0; i < n; ++i) { + const char* internal = nullptr; + const char* fqn = nullptr; + ET_CHECK_OK_OR_RETURN_ERROR( + h->get_constant_name(container, i, &internal), + "mutable_state: get_constant_name failed"); + ET_CHECK_OK_OR_RETURN_ERROR( + h->get_constant_original_fqn(container, i, &fqn), + "mutable_state: get_constant_original_fqn failed"); + // A successful call may still report an unusable (null/empty) name -- + // that's a method-scoped constant, not an error: skip it (another container + // owns it). A non-OK return code above is a real failure and falls closed. + if (internal && fqn && fqn[0] != '\0') { + fqn_to_internal[fqn] = internal; + } + } + + std::unordered_map extracted; + ET_CHECK_OK_OR_RETURN_ERROR( + h->extract_constants_map( + container, + reinterpret_cast(&extracted), + /*use_inactive=*/false), + "mutable_state: extract_constants_map failed"); + + auto& table = c.desc[h]; + for (const auto& fqn : c.fqns) { + auto it_name = fqn_to_internal.find(fqn); + auto it_t = extracted.find(fqn); + // A mutable FQN not present in this container = a method that does not use + // it (method-scoped). Skip; another container will own it. + if (it_name == fqn_to_internal.end() || it_t == extracted.end()) { + continue; + } + auto* t = reinterpret_cast(it_t->second); + Desc d; + d.internal_name = it_name->second; + d.sizes.assign(t->sizes().begin(), t->sizes().end()); + d.strides.assign(t->strides().begin(), t->strides().end()); + d.dtype = t->dtype(); + d.device = t->device(); + d.nbytes = t->nbytes(); + table.emplace(fqn, std::move(d)); + c.discovered_fqns.insert(fqn); + + if (c.template_ptr.find(fqn) == c.template_ptr.end()) { + // If a later FQN fails during this build, already captured templates are + // released by mutable_state_destroy_context(). + void* tpl = nullptr; + if (cudaMalloc(&tpl, t->nbytes()) != cudaSuccess) { + ET_LOG(Error, "mutable_state: cudaMalloc template '%s'", fqn.c_str()); + return Error::Internal; + } + if (cudaMemcpy( + tpl, t->data_ptr(), t->nbytes(), cudaMemcpyDeviceToDevice) != + cudaSuccess) { + ET_LOG(Error, "mutable_state: cudaMemcpy template '%s'", fqn.c_str()); + cudaFree(tpl); + return Error::Internal; + } + c.template_ptr[fqn] = tpl; + c.template_nbytes[fqn] = t->nbytes(); + c.total_bytes += static_cast(t->nbytes()); + } + } + return Error::Ok; +} + +// Allocate a session's GPU buffers, cloned from the initial templates. Caller +// holds mgr().mu. Allocates PER FQN so a buffer is created for any template +// discovered after the session's first allocation. +Error ensure_session_buffers(Context& c, int token) { + auto& buf = c.session_buf[token]; + for (const auto& kv : c.template_ptr) { + const std::string& fqn = kv.first; + if (buf.find(fqn) != buf.end()) { + continue; // already allocated for this session + } + void* tpl = kv.second; + size_t nbytes = c.template_nbytes[fqn]; + void* p = nullptr; + if (cudaMalloc(&p, nbytes) != cudaSuccess) { + ET_LOG( + Error, "mutable_state: cudaMalloc session buffer '%s'", fqn.c_str()); + return Error::Internal; + } + if (cudaMemcpy(p, tpl, nbytes, cudaMemcpyDeviceToDevice) != cudaSuccess) { + ET_LOG( + Error, "mutable_state: cudaMemcpy session buffer '%s'", fqn.c_str()); + cudaFree(p); + return Error::Internal; + } + buf[fqn] = p; + } + return Error::Ok; +} + +// Build the cached wrappers + pairs for (handle, token). Caller holds mgr().mu. +Error ensure_bound(Context& c, CudaDelegateHandle* h, int token) { + if (c.bound[h].find(token) != c.bound[h].end()) { + return Error::Ok; + } + Bound b; + auto& buf = c.session_buf[token]; + for (const auto& fd : c.desc[h]) { + const std::string& fqn = fd.first; + const Desc& d = fd.second; + auto buf_it = buf.find(fqn); + if (buf_it == buf.end() || buf_it->second == nullptr) { + // Every descriptor for this handle must have a backing session buffer; + // a null bind would silently corrupt state. + ET_LOG(Error, "mutable_state: no session buffer for '%s'", fqn.c_str()); + return Error::Internal; + } + auto template_it = c.template_nbytes.find(fqn); + if (template_it == c.template_nbytes.end() || + d.nbytes > template_it->second) { + ET_LOG( + Error, + "mutable_state: descriptor '%s' (%zu B) exceeds shared template " + "buffer (%zu B)", + fqn.c_str(), + d.nbytes, + template_it == c.template_nbytes.end() ? 0 : template_it->second); + return Error::Internal; + } + void* ptr = buf_it->second; + auto st = std::make_unique(from_blob( + ptr, + ::executorch::runtime::makeArrayRef(d.sizes.data(), d.sizes.size()), + ::executorch::runtime::makeArrayRef(d.strides.data(), d.strides.size()), + d.dtype, + d.device)); + aoti::AOTInductorConstantMapEntry entry; + entry.name = d.internal_name.c_str(); + entry.handle = reinterpret_cast(st.get()); + b.pairs.push_back(entry); + b.tensors.push_back(std::move(st)); + } + c.bound[h].emplace(token, std::move(b)); + return Error::Ok; +} + +void free_session_buffers(Context& c, int token) { + auto it = c.session_buf.find(token); + if (it != c.session_buf.end()) { + for (auto& kv : it->second) { + if (kv.second) { + cudaFree(kv.second); + } + } + c.session_buf.erase(it); + } + for (auto& hb : c.bound) { + hb.second.erase(token); + } + c.sessions.erase(token); +} + +} // namespace + +MutableStateContext mutable_state_create_context() { + auto& m = mgr(); + std::lock_guard g(m.mu); + MutableStateContext id = m.next_ctx++; + m.contexts[id]; // default-construct + return id; +} + +void mutable_state_destroy_context(MutableStateContext ctx) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return; + } + Context& c = it->second; + for (auto& kv : c.template_ptr) { + if (kv.second) { + cudaFree(kv.second); + } + } + for (auto& sb : c.session_buf) { + for (auto& kv : sb.second) { + if (kv.second) { + cudaFree(kv.second); + } + } + } + // Drop handle->ctx associations for this context. + for (auto hit = m.handle_ctx.begin(); hit != m.handle_ctx.end();) { + hit = (hit->second == ctx) ? m.handle_ctx.erase(hit) : std::next(hit); + } + m.contexts.erase(it); +} + +void mutable_state_begin_load(MutableStateContext ctx) { + if (tl_loading_ctx != kInvalidMutableContext) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto active = m.contexts.find(tl_loading_ctx); + if (active != m.contexts.end()) { + active->second.build_error = Error::InvalidState; + } + auto nested = m.contexts.find(ctx); + if (nested != m.contexts.end()) { + nested->second.build_error = Error::InvalidState; + } + ET_LOG(Error, "mutable_state: nested load scopes are not supported"); + tl_loading_ctx = kInvalidMutableContext; + return; + } + tl_loading_ctx = ctx; +} + +void mutable_state_end_load() { + tl_loading_ctx = kInvalidMutableContext; +} + +void mutable_state_register_fqns( + MutableStateContext ctx, + const std::vector& fqns) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return; + } + it->second.fqns = fqns; + it->second.fqn_set.clear(); + it->second.fqn_set.insert(fqns.begin(), fqns.end()); +} + +bool mutable_state_available(MutableStateContext ctx) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + return it != m.contexts.end() && it->second.build_error == Error::Ok && + it->second.symbols_available; +} + +int64_t mutable_state_bytes_per_session(MutableStateContext ctx) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + return it == m.contexts.end() ? 0 : it->second.total_bytes; +} + +Error mutable_state_validate_coverage(MutableStateContext ctx) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return Error::InvalidArgument; + } + Context& c = it->second; + if (c.build_error != Error::Ok) { + return c.build_error; + } + if (!c.symbols_available) { + return Error::NotSupported; + } + bool ok = true; + for (const auto& fqn : c.fqns) { + if (c.discovered_fqns.find(fqn) == c.discovered_fqns.end()) { + ET_LOG( + Error, + "mutable_state: declared mutable buffer '%s' not found in any loaded " + "method's constants (FQN mismatch?)", + fqn.c_str()); + ok = false; + } + } + ok = validate_descriptors(c) && ok; + if (!ok) { + c.build_error = Error::InvalidProgram; + return Error::InvalidProgram; + } + return Error::Ok; +} + +Result mutable_state_create_session(MutableStateContext ctx) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return Error::InvalidArgument; + } + Context& c = it->second; + if (c.build_error != Error::Ok) { + return c.build_error; + } + if (!c.symbols_available) { + ET_LOG( + Error, "mutable_state: rebinding unavailable; cannot create session"); + return Error::NotSupported; + } + int token = c.next_token++; + c.sessions.insert(token); + return token; +} + +void mutable_state_destroy_session(MutableStateContext ctx, int token) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return; // context already torn down; nothing to free + } + free_session_buffers(it->second, token); +} + +void mutable_state_set_active(MutableStateContext ctx, int token) { + tl_active_ctx = ctx; + tl_active_token = token; +} + +void mutable_state_note_handle(CudaDelegateHandle* handle) { + MutableStateContext ctx = tl_loading_ctx; + if (ctx == kInvalidMutableContext) { + return; // not loading within a managed context (single-session path) + } + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return; + } + Context& c = it->second; + m.handle_ctx[handle] = ctx; + bool ok = handle_has_symbols(handle); + c.symbols_available = c.symbols_checked ? (c.symbols_available && ok) : ok; + c.symbols_checked = true; + // Build this method's descriptor table + capture initial templates now, while + // the container still holds the model's initial mutable state and before any + // session rebinds. Requires FQNs registered before load_method. + if (ok && !c.fqns.empty() && c.desc.find(handle) == c.desc.end()) { + Error e = build_descriptors(c, handle); + if (e != Error::Ok) { + c.build_error = e; + } + } +} + +void mutable_state_forget_handle(CudaDelegateHandle* handle) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto hit = m.handle_ctx.find(handle); + if (hit == m.handle_ctx.end()) { + return; + } + auto cit = m.contexts.find(hit->second); + if (cit != m.contexts.end()) { + cit->second.desc.erase(handle); + cit->second.bound.erase(handle); + } + m.handle_ctx.erase(hit); +} + +Error mutable_state_rebind_for_execute(CudaDelegateHandle* handle) { + if (tl_active_token == kNoMutableSession) { + return Error::Ok; // single-session / legacy: nothing to rebind + } + auto& m = mgr(); + std::lock_guard g(m.mu); + + auto hit = m.handle_ctx.find(handle); + if (hit == m.handle_ctx.end()) { + ET_LOG( + Error, + "mutable_state: active session set but handle has no context (load " + "scope missed?)"); + return Error::Internal; + } + MutableStateContext ctx = hit->second; + if (ctx != tl_active_ctx) { + ET_LOG( + Error, + "mutable_state: active context mismatch (caller set a different context " + "active than the one executing)"); + return Error::Internal; + } + auto cit = m.contexts.find(ctx); + if (cit == m.contexts.end()) { + return Error::Internal; + } + Context& c = cit->second; + if (c.build_error != Error::Ok) { + return c.build_error; + } + if (!c.symbols_available) { + ET_LOG( + Error, "mutable_state: active session set but rebinding unavailable"); + return Error::NotSupported; + } + const int token = tl_active_token; + if (c.sessions.find(token) == c.sessions.end()) { + ET_LOG(Error, "mutable_state: active session token was not created"); + return Error::InvalidArgument; + } + if (handle->cuda_graph_state.phase != CudaGraphPhase::Disabled) { + ET_LOG( + Error, + "mutable_state: per-session rebinding is not supported with CUDA graph"); + return Error::NotSupported; + } + if (c.desc.find(handle) == c.desc.end()) { + ET_LOG( + Error, + "mutable_state: no descriptors for handle (note_handle missed?)"); + return Error::Internal; + } + ET_CHECK_OK_OR_RETURN_ERROR(ensure_session_buffers(c, token)); + ET_CHECK_OK_OR_RETURN_ERROR(ensure_bound(c, handle, token)); + + const Bound& b = c.bound[handle][token]; + if (b.pairs.empty()) { + return Error::Ok; + } + ET_CHECK_OK_OR_RETURN_ERROR( + handle->update_user_managed_constant_buffer_pairs( + handle->container_handle, + b.pairs.data(), + b.pairs.size(), + /*use_inactive=*/false, + /*validate_full_update=*/false), + "mutable_state: update_user_managed_constant_buffer_pairs failed"); + return Error::Ok; +} + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/cuda_mutable_state.h b/backends/cuda/runtime/cuda_mutable_state.h new file mode 100644 index 00000000000..e7ce80b88a5 --- /dev/null +++ b/backends/cuda/runtime/cuda_mutable_state.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include +#include + +// CUDA-PRIVATE per-session mutable-state management. This is intentionally NOT +// a generic ExecuTorch (Module/Method/BackendInterface) API: it is the +// CUDA/AOTI implementation of "one loaded model, many logical contexts" and is +// consumed only by CUDA-specific LLM engines (e.g. Qwen35MoEEngine). The public +// serving abstraction stays LLMEngine/LLMSession. +// +// State is keyed by a CONTEXT (one per loaded model/engine), NOT +// process-global, so multiple models (e.g. Qwen + Gemma) and repeated engine +// lifecycles in one process stay isolated. An engine: creates a context, scopes +// its model load (begin/end) so the backend associates each delegate handle +// with the context, registers the model's mutable FQNs, creates sessions, +// selects an active session before each execute, and destroys the context on +// teardown. + +namespace executorch { +namespace backends { +namespace cuda { + +struct CudaDelegateHandle; // defined in cuda_delegate_handle.h + +// Opaque per-engine context id (0 = invalid). +using MutableStateContext = int; +constexpr MutableStateContext kInvalidMutableContext = 0; + +// Active-session sentinel: execute() rebinds nothing (single-session / legacy). +constexpr int kNoMutableSession = -1; + +// --- Engine-facing API (call from the CUDA-specific LLM engine) ------------- + +// Create / destroy a context. destroy frees all of the context's sessions, +// templates, descriptors, and handle associations (safe to call once at engine +// teardown; sessions destroyed afterward become no-ops). +MutableStateContext mutable_state_create_context(); +void mutable_state_destroy_context(MutableStateContext ctx); + +// Scope a model load to a context: call begin BEFORE load_method and end AFTER, +// so the delegate handles initialized during the load are associated with +// `ctx`. Nesting is not supported (one load at a time per thread). +void mutable_state_begin_load(MutableStateContext ctx); +void mutable_state_end_load(); + +// Declare the context's per-session mutable-state FQNs (from the model's +// get_mutable_buffer_metadata). Call before begin_load/load_method. +void mutable_state_register_fqns( + MutableStateContext ctx, + const std::vector& fqns); + +// True if the context's loaded delegate(s) expose the AOTI constant-management +// symbols required for per-session rebinding. If false, the caller MUST run +// single-session. +bool mutable_state_available(MutableStateContext ctx); + +// Bytes one session adds (sum of mutable-buffer sizes), 0 if not yet known. +int64_t mutable_state_bytes_per_session(MutableStateContext ctx); + +// Validate every declared FQN was discovered in some loaded method's constants. +// Call after loading all methods; non-Ok must abort multi-session serving. +::executorch::runtime::Error mutable_state_validate_coverage( + MutableStateContext ctx); + +// Create / destroy a logical session within a context. create returns a token +// (>= 0); buffers are allocated lazily on the session's first execute. +::executorch::runtime::Result mutable_state_create_session( + MutableStateContext ctx); +void mutable_state_destroy_session(MutableStateContext ctx, int token); + +// Select the active (context, session) for subsequent Module::execute calls ON +// THIS THREAD. Set before execute, reset token to kNoMutableSession after; the +// engine must hold its serialization lock across set + execute + read-out. +void mutable_state_set_active(MutableStateContext ctx, int token); + +// --- CudaBackend-internal hooks (called from cuda_backend.cpp) --------------- + +// From CudaBackend::init: associate this handle with the context currently +// being loaded (begin_load), record symbol availability, and build the +// descriptor table + capture initial templates from the still-initial +// constants. +void mutable_state_note_handle(CudaDelegateHandle* handle); + +// From CudaBackend::destroy: remove any cached manager state keyed by this +// handle before its address can be reused by a later allocation. +void mutable_state_forget_handle(CudaDelegateHandle* handle); + +// From CudaBackend::execute, before running: if a session is active on this +// thread for this handle's context, rebind the container's mutable constants to +// the session's buffers. No-op (Ok) when no session is active. +::executorch::runtime::Error mutable_state_rebind_for_execute( + CudaDelegateHandle* handle); + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/test/test_cuda_mutable_state.cpp b/backends/cuda/runtime/test/test_cuda_mutable_state.cpp new file mode 100644 index 00000000000..d9718eadeab --- /dev/null +++ b/backends/cuda/runtime/test/test_cuda_mutable_state.cpp @@ -0,0 +1,527 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace cu = ::executorch::backends::cuda; +namespace aoti = ::executorch::backends::aoti; +namespace slim = ::executorch::backends::aoti::slim; +namespace slimc10 = ::executorch::backends::aoti::slim::c10; +using ::executorch::runtime::Error; + +namespace { + +Error fake_get_num_constants( + aoti::AOTInductorModelContainerHandle, + size_t* num_constants) { + *num_constants = 0; + return Error::Ok; +} + +Error fake_get_constant_name( + aoti::AOTInductorModelContainerHandle, + size_t, + const char**) { + return Error::Ok; +} + +Error fake_get_constant_original_fqn( + aoti::AOTInductorModelContainerHandle, + size_t, + const char**) { + return Error::Ok; +} + +Error fake_extract_constants_map( + aoti::AOTInductorModelContainerHandle, + aoti::AOTInductorConstantMapHandle, + bool) { + return Error::Ok; +} + +Error fake_update_user_managed_pairs( + aoti::AOTInductorModelContainerHandle, + const aoti::AOTInductorConstantMapEntry*, + size_t, + bool, + bool) { + return Error::Ok; +} + +struct FakeContainer { + std::vector internal_names; + std::vector fqns; + std::unordered_map extracted; + size_t update_calls = 0; + size_t last_num_pairs = 0; + std::string last_name; + void* last_bound_data = nullptr; + size_t last_bound_nbytes = 0; +}; + +Error fake_container_get_num_constants( + aoti::AOTInductorModelContainerHandle container, + size_t* num_constants) { + auto* c = reinterpret_cast(container); + *num_constants = c->internal_names.size(); + return Error::Ok; +} + +Error fake_container_get_constant_name( + aoti::AOTInductorModelContainerHandle container, + size_t idx, + const char** name) { + auto* c = reinterpret_cast(container); + *name = + idx < c->internal_names.size() ? c->internal_names[idx].c_str() : nullptr; + return Error::Ok; +} + +Error fake_container_get_constant_original_fqn( + aoti::AOTInductorModelContainerHandle container, + size_t idx, + const char** fqn) { + auto* c = reinterpret_cast(container); + *fqn = idx < c->fqns.size() ? c->fqns[idx].c_str() : nullptr; + return Error::Ok; +} + +Error fake_container_extract_constants_map( + aoti::AOTInductorModelContainerHandle container, + aoti::AOTInductorConstantMapHandle map_handle, + bool) { + auto* c = reinterpret_cast(container); + auto* out = reinterpret_cast< + std::unordered_map*>(map_handle); + *out = c->extracted; + return Error::Ok; +} + +Error fake_container_update_user_managed_pairs( + aoti::AOTInductorModelContainerHandle container, + const aoti::AOTInductorConstantMapEntry* pairs, + size_t num_pairs, + bool, + bool) { + auto* c = reinterpret_cast(container); + c->update_calls++; + c->last_num_pairs = num_pairs; + if (num_pairs > 0) { + c->last_name = pairs[0].name; + auto* t = reinterpret_cast(pairs[0].handle); + c->last_bound_data = t->data_ptr(); + c->last_bound_nbytes = t->nbytes(); + } + return Error::Ok; +} + +cu::CudaDelegateHandle fake_symbol_handle() { + cu::CudaDelegateHandle handle{}; + handle.get_num_constants = fake_get_num_constants; + handle.get_constant_name = fake_get_constant_name; + handle.get_constant_original_fqn = fake_get_constant_original_fqn; + handle.extract_constants_map = fake_extract_constants_map; + handle.update_user_managed_constant_buffer_pairs = + fake_update_user_managed_pairs; + return handle; +} + +cu::CudaDelegateHandle fake_container_handle(FakeContainer* container) { + cu::CudaDelegateHandle handle{}; + handle.container_handle = + reinterpret_cast(container); + handle.get_num_constants = fake_container_get_num_constants; + handle.get_constant_name = fake_container_get_constant_name; + handle.get_constant_original_fqn = fake_container_get_constant_original_fqn; + handle.extract_constants_map = fake_container_extract_constants_map; + handle.update_user_managed_constant_buffer_pairs = + fake_container_update_user_managed_pairs; + return handle; +} + +bool cuda_device_available() { + int device_count = 0; + const cudaError_t err = cudaGetDeviceCount(&device_count); + return err == cudaSuccess && device_count > 0; +} + +std::unique_ptr make_device_tensor( + const std::vector& values, + void** device_ptr) { + *device_ptr = nullptr; + cudaError_t err = cudaMalloc(device_ptr, values.size() * sizeof(float)); + if (err != cudaSuccess) { + ADD_FAILURE() << "cudaMalloc failed: " << cudaGetErrorString(err); + return nullptr; + } + err = cudaMemcpy( + *device_ptr, + values.data(), + values.size() * sizeof(float), + cudaMemcpyHostToDevice); + if (err != cudaSuccess) { + ADD_FAILURE() << "cudaMemcpy failed: " << cudaGetErrorString(err); + cudaFree(*device_ptr); + *device_ptr = nullptr; + return nullptr; + } + return std::make_unique(slim::from_blob( + *device_ptr, + {static_cast(values.size())}, + slimc10::ScalarType::Float, + slimc10::Device(slimc10::DeviceType::CUDA, 0))); +} + +std::unique_ptr make_cpu_tensor(std::vector& values) { + return std::make_unique(slim::from_blob( + values.data(), + {static_cast(values.size())}, + slimc10::ScalarType::Float, + slimc10::Device(slimc10::DeviceType::CPU, 0))); +} + +} // namespace + +TEST(CudaMutableStateTest, FallClosedDefaults) { + const cu::MutableStateContext bad = 999999; + cu::MutableStateContext c1 = cu::mutable_state_create_context(); + cu::MutableStateContext c2 = cu::mutable_state_create_context(); + + EXPECT_GT(c2, c1); + EXPECT_FALSE(cu::mutable_state_available(c1)); + EXPECT_EQ(cu::mutable_state_bytes_per_session(c1), 0); + EXPECT_EQ(cu::mutable_state_bytes_per_session(bad), 0); + EXPECT_EQ(cu::mutable_state_validate_coverage(bad), Error::InvalidArgument); + EXPECT_EQ(cu::mutable_state_validate_coverage(c1), Error::NotSupported); + + cu::mutable_state_register_fqns(c1, {"a.b", "c.d"}); + EXPECT_EQ(cu::mutable_state_validate_coverage(c1), Error::NotSupported); + EXPECT_EQ( + cu::mutable_state_create_session(bad).error(), Error::InvalidArgument); + EXPECT_EQ(cu::mutable_state_create_session(c1).error(), Error::NotSupported); + + cu::mutable_state_destroy_session(bad, 0); + cu::mutable_state_destroy_context(bad); + cu::mutable_state_destroy_context(c1); + cu::mutable_state_destroy_context(c2); +} + +TEST(CudaMutableStateTest, ForgetHandleDropsAssociation) { + cu::MutableStateContext c = cu::mutable_state_create_context(); + cu::CudaDelegateHandle handle{}; + + cu::mutable_state_begin_load(c); + cu::mutable_state_note_handle(&handle); + cu::mutable_state_end_load(); + + cu::mutable_state_set_active(c, 0); + EXPECT_EQ(cu::mutable_state_rebind_for_execute(&handle), Error::NotSupported); + + cu::mutable_state_forget_handle(&handle); + EXPECT_EQ(cu::mutable_state_rebind_for_execute(&handle), Error::Internal); + + cu::mutable_state_set_active( + cu::kInvalidMutableContext, cu::kNoMutableSession); + cu::mutable_state_destroy_context(c); +} + +TEST(CudaMutableStateTest, RebindRejectsUncreatedSessionToken) { + cu::MutableStateContext c = cu::mutable_state_create_context(); + cu::CudaDelegateHandle handle = fake_symbol_handle(); + + cu::mutable_state_begin_load(c); + cu::mutable_state_note_handle(&handle); + cu::mutable_state_end_load(); + ASSERT_TRUE(cu::mutable_state_available(c)); + ASSERT_EQ(cu::mutable_state_validate_coverage(c), Error::Ok); + + cu::mutable_state_set_active(c, 123); + EXPECT_EQ( + cu::mutable_state_rebind_for_execute(&handle), Error::InvalidArgument); + + auto token = cu::mutable_state_create_session(c); + ASSERT_TRUE(token.ok()); + cu::mutable_state_set_active(c, token.get()); + EXPECT_EQ(cu::mutable_state_rebind_for_execute(&handle), Error::Internal); + + cu::mutable_state_set_active( + cu::kInvalidMutableContext, cu::kNoMutableSession); + cu::mutable_state_destroy_session(c, token.get()); + cu::mutable_state_destroy_context(c); +} + +TEST(CudaMutableStateTest, NestedBeginLoadFailsClosed) { + cu::MutableStateContext c1 = cu::mutable_state_create_context(); + cu::MutableStateContext c2 = cu::mutable_state_create_context(); + cu::CudaDelegateHandle handle = fake_symbol_handle(); + + cu::mutable_state_begin_load(c1); + cu::mutable_state_begin_load(c2); + cu::mutable_state_note_handle(&handle); + cu::mutable_state_end_load(); + + EXPECT_EQ(cu::mutable_state_validate_coverage(c1), Error::InvalidState); + EXPECT_EQ(cu::mutable_state_validate_coverage(c2), Error::InvalidState); + EXPECT_FALSE(cu::mutable_state_available(c1)); + EXPECT_FALSE(cu::mutable_state_available(c2)); + EXPECT_EQ(cu::mutable_state_create_session(c1).error(), Error::InvalidState); + EXPECT_EQ(cu::mutable_state_create_session(c2).error(), Error::InvalidState); + + cu::mutable_state_destroy_context(c1); + cu::mutable_state_destroy_context(c2); +} + +TEST(CudaMutableStateTest, RebindRejectsCudaGraphHandle) { + cu::MutableStateContext c = cu::mutable_state_create_context(); + cu::CudaDelegateHandle handle = fake_symbol_handle(); + + cu::mutable_state_begin_load(c); + cu::mutable_state_note_handle(&handle); + cu::mutable_state_end_load(); + ASSERT_TRUE(cu::mutable_state_available(c)); + ASSERT_EQ(cu::mutable_state_validate_coverage(c), Error::Ok); + + auto token = cu::mutable_state_create_session(c); + ASSERT_TRUE(token.ok()); + + handle.cuda_graph_state.phase = cu::CudaGraphPhase::Warmup; + cu::mutable_state_set_active(c, token.get()); + EXPECT_EQ(cu::mutable_state_rebind_for_execute(&handle), Error::NotSupported); + + cu::mutable_state_set_active( + cu::kInvalidMutableContext, cu::kNoMutableSession); + cu::mutable_state_destroy_session(c, token.get()); + cu::mutable_state_destroy_context(c); +} + +TEST(CudaMutableStateTest, CapturesClonesAndRebindsDeviceBuffer) { + if (!cuda_device_available()) { + GTEST_SKIP() << "CUDA device unavailable"; + } + + void* source_ptr = nullptr; + auto source_tensor = + make_device_tensor({1.0f, 2.0f, 3.0f, 4.0f}, &source_ptr); + ASSERT_NE(source_tensor, nullptr); + ASSERT_NE(source_ptr, nullptr); + + FakeContainer container; + container.internal_names = {"internal_state"}; + container.fqns = {"model.state"}; + container.extracted["model.state"] = + reinterpret_cast(source_tensor.get()); + cu::CudaDelegateHandle handle = fake_container_handle(&container); + + cu::MutableStateContext c = cu::mutable_state_create_context(); + cu::mutable_state_register_fqns(c, {"model.state"}); + cu::mutable_state_begin_load(c); + cu::mutable_state_note_handle(&handle); + cu::mutable_state_end_load(); + + ASSERT_TRUE(cu::mutable_state_available(c)); + EXPECT_EQ(cu::mutable_state_bytes_per_session(c), 4 * sizeof(float)); + ASSERT_EQ(cu::mutable_state_validate_coverage(c), Error::Ok); + + auto token = cu::mutable_state_create_session(c); + ASSERT_TRUE(token.ok()); + cu::mutable_state_set_active(c, token.get()); + EXPECT_EQ(cu::mutable_state_rebind_for_execute(&handle), Error::Ok); + + EXPECT_EQ(container.update_calls, 1u); + EXPECT_EQ(container.last_num_pairs, 1u); + EXPECT_EQ(container.last_name, "internal_state"); + ASSERT_NE(container.last_bound_data, nullptr); + EXPECT_NE(container.last_bound_data, source_ptr); + EXPECT_EQ(container.last_bound_nbytes, 4 * sizeof(float)); + + std::vector cloned(4); + EXPECT_EQ( + cudaMemcpy( + cloned.data(), + container.last_bound_data, + cloned.size() * sizeof(float), + cudaMemcpyDeviceToHost), + cudaSuccess); + EXPECT_EQ(cloned, (std::vector{1.0f, 2.0f, 3.0f, 4.0f})); + + cu::mutable_state_set_active( + cu::kInvalidMutableContext, cu::kNoMutableSession); + cu::mutable_state_destroy_session(c, token.get()); + cu::mutable_state_destroy_context(c); + cudaFree(source_ptr); +} + +TEST(CudaMutableStateTest, SharedFqnAcrossHandlesUsesSameSessionBuffer) { + if (!cuda_device_available()) { + GTEST_SKIP() << "CUDA device unavailable"; + } + + void* prefill_ptr = nullptr; + void* decode_ptr = nullptr; + auto prefill_tensor = make_device_tensor({1.0f, 2.0f}, &prefill_ptr); + auto decode_tensor = make_device_tensor({9.0f, 8.0f}, &decode_ptr); + ASSERT_NE(prefill_tensor, nullptr); + ASSERT_NE(decode_tensor, nullptr); + ASSERT_NE(prefill_ptr, nullptr); + ASSERT_NE(decode_ptr, nullptr); + + FakeContainer prefill_container; + prefill_container.internal_names = {"prefill_internal_kv"}; + prefill_container.fqns = {"model.kv"}; + prefill_container.extracted["model.kv"] = + reinterpret_cast(prefill_tensor.get()); + cu::CudaDelegateHandle prefill_handle = + fake_container_handle(&prefill_container); + + FakeContainer decode_container; + decode_container.internal_names = {"decode_internal_kv"}; + decode_container.fqns = {"model.kv"}; + decode_container.extracted["model.kv"] = + reinterpret_cast(decode_tensor.get()); + cu::CudaDelegateHandle decode_handle = + fake_container_handle(&decode_container); + + cu::MutableStateContext c = cu::mutable_state_create_context(); + cu::mutable_state_register_fqns(c, {"model.kv"}); + cu::mutable_state_begin_load(c); + cu::mutable_state_note_handle(&prefill_handle); + cu::mutable_state_note_handle(&decode_handle); + cu::mutable_state_end_load(); + + ASSERT_TRUE(cu::mutable_state_available(c)); + ASSERT_EQ(cu::mutable_state_validate_coverage(c), Error::Ok); + + auto token = cu::mutable_state_create_session(c); + ASSERT_TRUE(token.ok()); + cu::mutable_state_set_active(c, token.get()); + EXPECT_EQ(cu::mutable_state_rebind_for_execute(&prefill_handle), Error::Ok); + EXPECT_EQ(cu::mutable_state_rebind_for_execute(&decode_handle), Error::Ok); + + ASSERT_NE(prefill_container.last_bound_data, nullptr); + ASSERT_NE(decode_container.last_bound_data, nullptr); + EXPECT_EQ(prefill_container.last_name, "prefill_internal_kv"); + EXPECT_EQ(decode_container.last_name, "decode_internal_kv"); + EXPECT_EQ( + prefill_container.last_bound_data, decode_container.last_bound_data); + EXPECT_NE(prefill_container.last_bound_data, prefill_ptr); + EXPECT_NE(decode_container.last_bound_data, decode_ptr); + + cu::mutable_state_set_active( + cu::kInvalidMutableContext, cu::kNoMutableSession); + cu::mutable_state_destroy_session(c, token.get()); + cu::mutable_state_destroy_context(c); + cudaFree(prefill_ptr); + cudaFree(decode_ptr); +} + +TEST( + CudaMutableStateTest, + ValidateCoverageRejectsLargerDescriptorForSharedFqn) { + if (!cuda_device_available()) { + GTEST_SKIP() << "CUDA device unavailable"; + } + + void* small_ptr = nullptr; + void* large_ptr = nullptr; + auto small_tensor = make_device_tensor({1.0f}, &small_ptr); + auto large_tensor = make_device_tensor({1.0f, 2.0f}, &large_ptr); + ASSERT_NE(small_tensor, nullptr); + ASSERT_NE(large_tensor, nullptr); + ASSERT_NE(small_ptr, nullptr); + ASSERT_NE(large_ptr, nullptr); + + FakeContainer small_container; + small_container.internal_names = {"small_internal"}; + small_container.fqns = {"model.state"}; + small_container.extracted["model.state"] = + reinterpret_cast(small_tensor.get()); + cu::CudaDelegateHandle small_handle = fake_container_handle(&small_container); + + FakeContainer large_container; + large_container.internal_names = {"large_internal"}; + large_container.fqns = {"model.state"}; + large_container.extracted["model.state"] = + reinterpret_cast(large_tensor.get()); + cu::CudaDelegateHandle large_handle = fake_container_handle(&large_container); + + cu::MutableStateContext c = cu::mutable_state_create_context(); + cu::mutable_state_register_fqns(c, {"model.state"}); + cu::mutable_state_begin_load(c); + cu::mutable_state_note_handle(&small_handle); + cu::mutable_state_note_handle(&large_handle); + cu::mutable_state_end_load(); + + ASSERT_TRUE(cu::mutable_state_available(c)); + EXPECT_EQ(cu::mutable_state_validate_coverage(c), Error::InvalidProgram); + EXPECT_FALSE(cu::mutable_state_available(c)); + EXPECT_EQ(cu::mutable_state_create_session(c).error(), Error::InvalidProgram); + EXPECT_EQ(large_container.update_calls, 0u); + + cu::mutable_state_destroy_context(c); + cudaFree(small_ptr); + cudaFree(large_ptr); +} + +TEST(CudaMutableStateTest, ValidateCoverageRejectsDeviceMismatchForSharedFqn) { + if (!cuda_device_available()) { + GTEST_SKIP() << "CUDA device unavailable"; + } + + void* cuda_ptr = nullptr; + auto cuda_tensor = make_device_tensor({1.0f}, &cuda_ptr); + ASSERT_NE(cuda_tensor, nullptr); + ASSERT_NE(cuda_ptr, nullptr); + + std::vector cpu_values = {1.0f}; + auto cpu_tensor = make_cpu_tensor(cpu_values); + ASSERT_NE(cpu_tensor, nullptr); + + FakeContainer cuda_container; + cuda_container.internal_names = {"cuda_internal"}; + cuda_container.fqns = {"model.state"}; + cuda_container.extracted["model.state"] = + reinterpret_cast(cuda_tensor.get()); + cu::CudaDelegateHandle cuda_handle = fake_container_handle(&cuda_container); + + FakeContainer cpu_container; + cpu_container.internal_names = {"cpu_internal"}; + cpu_container.fqns = {"model.state"}; + cpu_container.extracted["model.state"] = + reinterpret_cast(cpu_tensor.get()); + cu::CudaDelegateHandle cpu_handle = fake_container_handle(&cpu_container); + + cu::MutableStateContext c = cu::mutable_state_create_context(); + cu::mutable_state_register_fqns(c, {"model.state"}); + cu::mutable_state_begin_load(c); + cu::mutable_state_note_handle(&cuda_handle); + cu::mutable_state_note_handle(&cpu_handle); + cu::mutable_state_end_load(); + + ASSERT_TRUE(cu::mutable_state_available(c)); + EXPECT_EQ(cu::mutable_state_validate_coverage(c), Error::InvalidProgram); + EXPECT_FALSE(cu::mutable_state_available(c)); + EXPECT_EQ(cu::mutable_state_create_session(c).error(), Error::InvalidProgram); + EXPECT_EQ(cpu_container.update_calls, 0u); + + cu::mutable_state_destroy_context(c); + cudaFree(cuda_ptr); +}