diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index 2befd78b41b..bc8c33ff77c 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -107,9 +107,10 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp runtime/shims/cuda_guard.cpp ) -# Only build int4mm shim when CUDA language/toolchain is available. +# Only build CUDA-specific shims when CUDA language/toolchain is available. if(CMAKE_CUDA_COMPILER) list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu) + list(APPEND _aoti_cuda_shim_sources runtime/shims/randint.cu) endif() add_library(aoti_cuda_shims SHARED ${_aoti_cuda_shim_sources}) @@ -150,7 +151,8 @@ endif() # retention. if(_cuda_is_msvc_toolchain) target_link_libraries( - aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart ${CMAKE_DL_LIBS} + aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart CUDA::curand + ${CMAKE_DL_LIBS} ) # Link object library directly so symbols are pulled exactly once while # avoiding duplicate static/object inclusion and interface leakage. @@ -160,7 +162,7 @@ else() aoti_cuda_shims PRIVATE cuda_platform PUBLIC -Wl,--whole-archive aoti_common_shims_slim -Wl,--no-whole-archive - CUDA::cudart ${CMAKE_DL_LIBS} + CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS} ) endif() diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 661b4f2b960..825687003d1 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -145,6 +145,7 @@ def save_data_externally(cls) -> bool: def get_supported_fallback_kernels(cls) -> Dict[str, Any]: return { "at::_ops::_weight_int4pack_mm::call": None, + "aoti_torch_cuda_randint_low_out": None, } @classmethod @@ -170,8 +171,7 @@ def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any] mode = spec.value.decode("utf-8").upper() if mode not in ["ON", "OFF"]: raise ValueError( - f"Invalid triton_kernel_mode: {mode}. " - f"Expected 'ON' or 'OFF'." + f"Invalid triton_kernel_mode: {mode}. Expected 'ON' or 'OFF'." ) triton_kernel_mode = mode passes = [MoveCondPredicateToCpuPass()] diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index f3b16dad583..2aaa29cd51b 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -382,7 +382,16 @@ class ET_EXPERIMENTAL CudaBackend final return (DelegateHandle*)handle; // Return the handle post-processing } - // Once per execution + // Execute the AOTI-compiled CUDA kernel for one inference step. + // + // Currently supports both CPU and CUDA memory for IO tensors: + // - Inputs: detected via cudaPointerGetAttributes; CUDA data is wrapped + // in-place (no copy), CPU data is copied to GPU via from_etensor(). + // - Outputs: either copied to ETensor's backing memory (CPU or CUDA), + // or the ETensor is rewired to point at GPU memory (skip-copy mode). + // + // TODO: Once the device tensor pipeline is fully adopted, all IO tensors + // will reside in CUDA memory. Remove the CPU fallback paths. Error execute( BackendExecutionContext& context, DelegateHandle* handle_, @@ -405,14 +414,17 @@ class ET_EXPERIMENTAL CudaBackend final n_outputs, args.size()) - // Verify device info on all memory-planned, ET-driven IO tensors. - // All input and output tensors should have device_type = CUDA, which - // is set during serialization by PropagateDevicePass based on the - // target_device compile spec from CudaPartitioner. + // Verify device metadata on all IO tensors. + // All tensors should have device_type = CUDA, set during serialization + // by PropagateDevicePass based on the target_device compile spec from + // CudaPartitioner. // - // Note: At this stage, the tensor memory is still on CPU. The device_type - // is metadata indicating where the tensor *should* reside. The backend - // is responsible for copying data to the actual CUDA device. + // Note: device_type is metadata — the actual memory location may be + // either CPU (legacy path with H2D copy ops) or CUDA (when device + // memory planning is enabled via enable_non_cpu_memory_planning, + // which allocates delegate IO in CUDA memory). The backend detects + // the actual location via cudaPointerGetAttributes and handles both + // cases. for (size_t i = 0; i < n_inputs + n_outputs; i++) { auto* tensor = &(args[i]->toTensor()); auto device_type = tensor->unsafeGetTensorImpl()->device_type(); @@ -425,26 +437,29 @@ class ET_EXPERIMENTAL CudaBackend final static_cast(device_type)); } - // NOTE: ExecuTorch tensors may be on CPU or GPU due to the skip-copy - // optimization. We need to create GPU copies for CUDA kernel execution - // using SlimTensor. + // Convert ExecuTorch tensors to SlimTensors for AOTI kernel execution. + // Input data may be in CPU or CUDA memory — the backend detects and + // handles both cases automatically (see memory model comment above). std::vector gpu_inputs(n_inputs); std::vector gpu_outputs(n_outputs); // Process input tensors: convert ETensor (CPU) to SlimTensor (GPU) for (size_t i = 0; i < n_inputs; i++) { - auto* cpu_tensor = &(args[i]->toTensor()); + auto* input_tensor = &(args[i]->toTensor()); - // Check if input data is already on GPU (skip-copy optimization for - // inputs) This can happen when the caller has pre-staged data on GPU + // Detect if input data is already in CUDA memory. This occurs when: + // - Device memory planning is enabled (enable_non_cpu_memory_planning), + // which allocates delegate IO in CUDA memory + // - The input is a skip-copy output from a previous method execution + // When detected, the data is wrapped directly — no H2D copy needed. cudaPointerAttributes attributes{}; - const void* data_ptr = cpu_tensor->const_data_ptr(); + const void* data_ptr = input_tensor->const_data_ptr(); if (data_ptr != nullptr) { cudaError_t err = cudaPointerGetAttributes(&attributes, data_ptr); if (err == cudaSuccess && attributes.type == cudaMemoryTypeDevice) { // Data is already on GPU - wrap it directly without copy - auto sizes = cpu_tensor->sizes(); - auto strides = cpu_tensor->strides(); + auto sizes = input_tensor->sizes(); + auto strides = input_tensor->strides(); std::vector sizes_vec(sizes.begin(), sizes.end()); std::vector strides_vec(strides.begin(), strides.end()); @@ -452,7 +467,7 @@ class ET_EXPERIMENTAL CudaBackend final const_cast(data_ptr), slim::makeArrayRef(sizes_vec), slim::makeArrayRef(strides_vec), - static_cast(cpu_tensor->scalar_type()), + static_cast(input_tensor->scalar_type()), DEFAULT_CUDA_DEVICE, 0 // storage_offset )); @@ -461,19 +476,22 @@ class ET_EXPERIMENTAL CudaBackend final } } - // Data is on CPU - use from_etensor to copy to GPU + // Data is in CPU memory (legacy path) — copy to GPU via from_etensor. + // TODO: Remove this path once all callers use the device tensor pipeline. gpu_inputs[i] = new SlimTensor( - from_etensor(*cpu_tensor, CPU_DEVICE, DEFAULT_CUDA_DEVICE)); + from_etensor(*input_tensor, CPU_DEVICE, DEFAULT_CUDA_DEVICE)); } - // Process output tensors: create GPU SlimTensors for kernel output. - // Save pre-run handles to detect orphans after run(). + // Allocate GPU SlimTensors for kernel outputs. These are always + // freshly allocated on GPU regardless of the input memory mode. + // Save pre-run handles to detect orphans after run() (the AOTI + // runtime may replace output handles with its own allocations). std::vector pre_run_outputs(n_outputs, nullptr); for (size_t i = 0; i < n_outputs; i++) { - auto* cpu_output_tensor = &(args[i + n_inputs]->toTensor()); - auto sizes = cpu_output_tensor->sizes(); - auto strides = cpu_output_tensor->strides(); - auto scalar_type = cpu_output_tensor->scalar_type(); + auto* output_tensor = &(args[i + n_inputs]->toTensor()); + auto sizes = output_tensor->sizes(); + auto strides = output_tensor->strides(); + auto scalar_type = output_tensor->scalar_type(); std::vector sizes_vec(sizes.begin(), sizes.end()); std::vector strides_vec(strides.begin(), strides.end()); @@ -536,13 +554,18 @@ class ET_EXPERIMENTAL CudaBackend final const bool copy_outputs = !should_skip_copy_for_method(handle->method_name); + // Output disposition: copy to ETensor backing memory or keep on GPU. + // When copy_outputs is true (default), results are copied to the + // ETensor's memory (which may be CPU or CUDA planned memory). + // When false (skip-copy optimization), the ETensor is rewired to + // point at the GPU SlimTensor's memory directly. if (copy_outputs) { for (size_t i = 0; i < n_outputs; i++) { - auto* cpu_output_tensor = &(args[i + n_inputs]->toTensor()); + auto* output_tensor = &(args[i + n_inputs]->toTensor()); ET_CHECK_OK_OR_RETURN_ERROR( copy_slimtensor_to_etensor_async( - gpu_outputs[i], cpu_output_tensor, cuda_stream), - "Failed to copy GPU output %zu back to CPU ETensor", + gpu_outputs[i], output_tensor, cuda_stream), + "Failed to copy GPU output %zu back to ETensor", i); delete gpu_outputs[i]; gpu_outputs[i] = nullptr; diff --git a/backends/cuda/runtime/shims/randint.cu b/backends/cuda/runtime/shims/randint.cu new file mode 100644 index 00000000000..967bd1b6941 --- /dev/null +++ b/backends/cuda/runtime/shims/randint.cu @@ -0,0 +1,108 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include + +#include +#include + +namespace executorch::backends::cuda { + +using executorch::runtime::Error; + +namespace { + +// Transform cuRAND uniform doubles (0, 1] to int64 values in [low, high). +__global__ void uniform_to_randint_kernel( + int64_t* out, + const double* uniform, + int64_t numel, + int64_t low, + int64_t range) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < numel) { + // uniform is in (0, 1], so (uniform * range) is in (0, range]. + // Subtract 1 and clamp to get [0, range-1], then add low for [low, high-1]. + int64_t val = static_cast(uniform[idx] * range); + out[idx] = low + (val >= range ? range - 1 : val); + } +} + +curandGenerator_t get_or_create_generator() { + static curandGenerator_t gen = nullptr; + if (gen == nullptr) { + curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT); + curandSetPseudoRandomGeneratorSeed( + gen, static_cast(time(nullptr))); + } + return gen; +} + +} // anonymous namespace + +extern "C" { + +AOTITorchError aoti_torch_cuda_randint_low_out( + SlimTensor* out, + int64_t low, + int64_t high, + const int64_t* size, + int64_t size_len_) { + ET_CHECK_OR_RETURN_ERROR( + out != nullptr, + InvalidArgument, + "aoti_torch_cuda_randint_low_out: out tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + high > low, + InvalidArgument, + "aoti_torch_cuda_randint_low_out: requires high > low"); + + int64_t numel = 1; + for (int64_t i = 0; i < size_len_; i++) { + numel *= size[i]; + } + if (numel == 0) { + return Error::Ok; + } + + int64_t range = high - low; + int64_t* out_data = static_cast(out->data_ptr()); + + // Allocate temporary buffer for uniform doubles on device. + double* d_uniform = nullptr; + auto alloc_err = cudaMalloc(&d_uniform, numel * sizeof(double)); + ET_CHECK_OR_RETURN_ERROR( + alloc_err == cudaSuccess, + Internal, + "aoti_torch_cuda_randint_low_out: cudaMalloc failed (%d)", + static_cast(alloc_err)); + + // Generate uniform doubles in (0, 1]. + auto gen = get_or_create_generator(); + curandGenerateUniformDouble(gen, d_uniform, numel); + + // Transform to integers in [low, high). + constexpr int kThreads = 256; + int blocks = static_cast((numel + kThreads - 1) / kThreads); + uniform_to_randint_kernel<<>>( + out_data, d_uniform, numel, low, range); + + cudaFree(d_uniform); + + return Error::Ok; +} + +} // extern "C" + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/randint.h b/backends/cuda/runtime/shims/randint.h new file mode 100644 index 00000000000..7cacc66bfd3 --- /dev/null +++ b/backends/cuda/runtime/shims/randint.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::AOTITorchError; +using SlimTensor = executorch::backends::aoti::slim::SlimTensor; + +extern "C" { + +/** + * Fills a pre-allocated CUDA tensor with random integers in [low, high). + * + * Used by AOTI-generated code when the model calls torch.randint or ops + * that decompose into randint (e.g. torch.rand_like on some dtypes). + * + * @param out Pre-allocated output tensor on CUDA (must not be null). + * @param low Lower bound (inclusive) of the random range. + * @param high Upper bound (exclusive) of the random range. + * @param size Pointer to array of output dimension sizes. + * @param size_len_ Number of dimensions. + * @return AOTITorchError error code (Error::Ok on success). + */ +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_randint_low_out( + SlimTensor* out, + int64_t low, + int64_t high, + const int64_t* size, + int64_t size_len_); + +} // extern "C" + +} // namespace executorch::backends::cuda diff --git a/examples/models/qwen3_5_moe/CMakeLists.txt b/examples/models/qwen3_5_moe/CMakeLists.txt index 6e9e52eef62..c75e0ddcd53 100644 --- a/examples/models/qwen3_5_moe/CMakeLists.txt +++ b/examples/models/qwen3_5_moe/CMakeLists.txt @@ -32,25 +32,24 @@ list(APPEND link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas) executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib) # Extensions -list( - APPEND - link_libraries - extension_llm_runner - extension_module - extension_data_loader - extension_tensor - extension_flat_tensor +list(APPEND link_libraries extension_module extension_data_loader + extension_tensor extension_flat_tensor ) # CUDA backend (required) find_package(CUDAToolkit REQUIRED) -list(APPEND link_libraries aoti_cuda_backend) +list(APPEND link_libraries aoti_cuda_backend CUDA::cudart) executorch_target_link_options_shared_lib(aoti_cuda_backend) # Tokenizer list(APPEND link_libraries tokenizers::tokenizers) -add_executable(qwen3_5_moe_runner main.cpp) +add_executable( + qwen3_5_moe_runner + main.cpp ${EXECUTORCH_ROOT}/runtime/core/device_allocator.cpp + ${EXECUTORCH_ROOT}/runtime/core/device_memory_buffer.cpp + ${EXECUTORCH_ROOT}/backends/cuda/runtime/cuda_allocator.cpp +) target_include_directories( qwen3_5_moe_runner PUBLIC ${_common_include_directories} ) diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index 7437bc5f461..8128328712c 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -398,25 +398,44 @@ def export_and_lower(model, config, args): # -O0 compiles ~8x faster than -O1 with no measurable runtime impact. inductor_config.aot_inductor.compile_wrapper_opt_level = "O0" - # Dynamic shapes + # Dynamic shapes for forward method example_tokens = torch.tensor([[0, 1]], dtype=torch.long) example_input_pos = torch.tensor([0, 1], dtype=torch.long) seq_dim = Dim("seq_len", min=1, max=config.max_seq_len - 1) dynamic_shapes = ({1: seq_dim}, {0: seq_dim}) - print("Exporting with torch.export...") + print("Exporting forward method with torch.export...") with torch.no_grad(): - exported = export( + exported_forward = export( model, (example_tokens, example_input_pos), dynamic_shapes=dynamic_shapes, strict=True, ) - print("Export successful!") + print("Forward export successful!") + + # Export sample method by temporarily swapping model.forward + print("Exporting sample method with torch.export...") + original_forward = model.forward + model.forward = model.sample + example_logits = torch.zeros(1, 2, config.vocab_size, dtype=torch.bfloat16) + example_temperature = torch.tensor([0.8], dtype=torch.float32) + sample_dynamic_shapes = ({1: seq_dim}, None) + with torch.no_grad(): + exported_sample = export( + model, + (example_logits, example_temperature), + dynamic_shapes=sample_dynamic_shapes, + strict=True, + ) + model.forward = original_forward + print("Sample export successful!") - # Lower with CUDA backend + # Lower with CUDA backend (multi-method) print("Lowering to ExecuTorch with CUDA...") - compile_specs = [CudaBackend.generate_method_name_compile_spec("forward")] + forward_compile_specs = [CudaBackend.generate_method_name_compile_spec("forward")] + sample_compile_specs = [CudaBackend.generate_method_name_compile_spec("sample")] + metadata = { "get_max_seq_len": config.max_seq_len, "get_vocab_size": config.vocab_size, @@ -426,8 +445,11 @@ def export_and_lower(model, config, args): "enable_dynamic_shape": True, } et_prog = to_edge_transform_and_lower( - exported, - partitioner=[CudaPartitioner(compile_specs)], + {"forward": exported_forward, "sample": exported_sample}, + partitioner={ + "forward": [CudaPartitioner(forward_compile_specs)], + "sample": [CudaPartitioner(sample_compile_specs)], + }, compile_config=EdgeCompileConfig( _check_ir_validity=False, _skip_dim_order=True, @@ -439,6 +461,9 @@ def export_and_lower(model, config, args): extract_delegate_segments=True, do_quant_fusion_and_const_prop=True, memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + enable_non_cpu_memory_planning=True, + skip_h2d_for_method_inputs=True, + skip_d2h_for_method_outputs=True, ), ) diff --git a/examples/models/qwen3_5_moe/main.cpp b/examples/models/qwen3_5_moe/main.cpp index 266d0e65419..ebef8d0f6b3 100644 --- a/examples/models/qwen3_5_moe/main.cpp +++ b/examples/models/qwen3_5_moe/main.cpp @@ -8,11 +8,15 @@ #include -#include #include -#include +#include +#include #include +#include +#include +#include +#include #include #include @@ -23,55 +27,216 @@ DEFINE_string(prompt, "Hello", "Prompt text."); DEFINE_double(temperature, 0.8, "Sampling temperature (0 = greedy)."); DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate."); -namespace llm = ::executorch::extension::llm; +using namespace executorch::extension; +using namespace executorch::runtime; +using etensor::DeviceType; +using executorch::aten::ScalarType; + +constexpr auto kDynamicBound = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND; int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); - if (FLAGS_model_path.empty()) { - ET_LOG(Error, "Must specify --model_path"); + if (FLAGS_model_path.empty() || FLAGS_tokenizer_path.empty()) { + std::cerr << "Must specify --model_path and --tokenizer_path" << std::endl; return 1; } - if (FLAGS_tokenizer_path.empty()) { - ET_LOG(Error, "Must specify --tokenizer_path"); + + // Load tokenizer. + auto tokenizer = std::make_unique(); + auto tok_status = tokenizer->load(FLAGS_tokenizer_path); + if (tok_status != tokenizers::Error::Ok) { + std::cerr << "Failed to load tokenizer from " << FLAGS_tokenizer_path + << std::endl; return 1; } + // Load module with optional .ptd data file for CUDA backend weights. std::vector data_files; if (!FLAGS_data_path.empty()) { data_files.push_back(FLAGS_data_path); } + Module module( + FLAGS_model_path, data_files, Module::LoadMode::MmapUseMlockIgnoreErrors); - // Load tokenizer - auto tokenizer = std::make_unique(); - auto tok_status = tokenizer->load(FLAGS_tokenizer_path); - if (tok_status != tokenizers::Error::Ok) { - ET_LOG( - Error, - "Failed to load tokenizer from %s", - FLAGS_tokenizer_path.c_str()); + auto forward_load = module.load_method("forward"); + if (forward_load != Error::Ok) { + std::cerr << "Failed to load forward method" << std::endl; + return 1; + } + auto sample_load = module.load_method("sample"); + if (sample_load != Error::Ok) { + std::cerr << "Failed to load sample method" << std::endl; return 1; } - // Create LLM runner - auto runner = llm::create_text_llm_runner( - FLAGS_model_path, std::move(tokenizer), data_files, FLAGS_temperature); - - if (runner == nullptr) { - ET_LOG(Error, "Failed to create runner"); + // Encode prompt. + auto encode_result = tokenizer->encode(FLAGS_prompt); + if (!encode_result.ok()) { + std::cerr << "Failed to encode prompt" << std::endl; return 1; } + auto prompt_tokens = encode_result.get(); + int num_prompt_tokens = static_cast(prompt_tokens.size()); + + // ======================== PREFILL ======================== + + auto prefill_start = std::chrono::high_resolution_clock::now(); + + // Create CUDA tensors directly for the full prompt. + // tokens: shape [1, num_prompt_tokens], dtype Long + std::vector token_data(prompt_tokens.begin(), prompt_tokens.end()); + auto cuda_tokens = make_tensor_ptr( + /* sizes= */ {1, static_cast(num_prompt_tokens)}, + /* data= */ token_data.data(), + /* type= */ ScalarType::Long, + /* dynamism= */ kDynamicBound, + /* deleter= */ nullptr, + /* device_type= */ DeviceType::CUDA); + + // positions: shape [num_prompt_tokens], dtype Long + std::vector pos_data(num_prompt_tokens); + std::iota(pos_data.begin(), pos_data.end(), 0); + auto cuda_pos = make_tensor_ptr( + /* sizes= */ {static_cast(num_prompt_tokens)}, + /* data= */ pos_data.data(), + /* type= */ ScalarType::Long, + /* dynamism= */ kDynamicBound, + /* deleter= */ nullptr, + /* device_type= */ DeviceType::CUDA); - // Generate - llm::GenerationConfig config; - config.temperature = FLAGS_temperature; - config.max_new_tokens = FLAGS_max_new_tokens; + // Temperature tensor: shape [1], dtype Float + float temp_val = static_cast(FLAGS_temperature); + auto cuda_temp = make_tensor_ptr( + /* sizes= */ {1}, + /* data= */ &temp_val, + /* type= */ ScalarType::Float, + /* dynamism= */ kDynamicBound, + /* deleter= */ nullptr, + /* device_type= */ DeviceType::CUDA); - auto error = runner->generate(FLAGS_prompt.c_str(), config); - if (error != executorch::runtime::Error::Ok) { - ET_LOG(Error, "Generation failed"); + // Forward pass — logits stay on CUDA. + auto forward_result = module.execute( + "forward", {/* tokens= */ *cuda_tokens, /* input_pos= */ *cuda_pos}); + if (!forward_result.ok()) { + std::cerr << "Forward (prefill) failed" << std::endl; return 1; } + auto& forward_outputs = forward_result.get(); + + // Sample — input and output both on CUDA. + auto sample_result = module.execute( + "sample", + {/* logits= */ forward_outputs[0], /* temperature= */ *cuda_temp}); + if (!sample_result.ok()) { + std::cerr << "Sample (prefill) failed" << std::endl; + return 1; + } + auto& sample_outputs = sample_result.get(); + + // D2H: copy the single sampled token back to CPU. + auto& prefill_sample_tensor = sample_outputs[0].toTensor(); + auto cpu_first_token = clone_tensor_ptr_to_cpu(TensorPtr( + &prefill_sample_tensor, /* deleter= */ [](executorch::aten::Tensor*) {})); + int64_t cur_token = cpu_first_token->const_data_ptr()[0]; + + auto prefill_end = std::chrono::high_resolution_clock::now(); + double prefill_ms = + std::chrono::duration(prefill_end - prefill_start) + .count(); + double prefill_tps = num_prompt_tokens / (prefill_ms / 1000.0); + + // Print the first generated token. + auto first_decode = tokenizer->decode( + /* prev= */ static_cast(prompt_tokens.back()), + /* cur= */ static_cast(cur_token)); + if (first_decode.ok()) { + std::cout << *first_decode << std::flush; + } + + // ======================== DECODE LOOP ======================== + + int pos = num_prompt_tokens; + int generated = 1; + int64_t prev_token = static_cast(prompt_tokens.back()); + auto decode_start = std::chrono::high_resolution_clock::now(); + + // Carry the CUDA token tensor from the previous sample call so we can + // feed it directly to forward without an extra H2D copy. + EValue cuda_token_ev(prefill_sample_tensor); + + // Qwen EOS token IDs. + const std::set eos_tokens = {151643, 151645}; + + while (generated < FLAGS_max_new_tokens) { + if (eos_tokens.count(cur_token)) + break; + + // Position H2D (single int64 per step). + int64_t pos_val = static_cast(pos); + auto cuda_next_pos = make_tensor_ptr( + /* sizes= */ {1}, + /* data= */ &pos_val, + /* type= */ ScalarType::Long, + /* dynamism= */ kDynamicBound, + /* deleter= */ nullptr, + /* device_type= */ DeviceType::CUDA); + + // Forward — reuse the CUDA token tensor from the previous sample output. + auto fwd = module.execute( + "forward", + {/* tokens= */ cuda_token_ev, /* input_pos= */ *cuda_next_pos}); + if (!fwd.ok()) { + std::cerr << "Forward (decode step " << generated << ") failed" + << std::endl; + return 1; + } + + // Sample — stays on CUDA. + auto smp = module.execute( + "sample", {/* logits= */ fwd.get()[0], /* temperature= */ *cuda_temp}); + if (!smp.ok()) { + std::cerr << "Sample (decode step " << generated << ") failed" + << std::endl; + return 1; + } + + // D2H: extract next token for EOS check. + auto& next_sample_tensor = smp.get()[0].toTensor(); + auto cpu_next_token = clone_tensor_ptr_to_cpu(TensorPtr( + &next_sample_tensor, /* deleter= */ [](executorch::aten::Tensor*) {})); + + prev_token = cur_token; + cur_token = cpu_next_token->const_data_ptr()[0]; + + // Keep the CUDA tensor for the next iteration's forward call. + cuda_token_ev = EValue(next_sample_tensor); + + // Decode and stream output. + auto dec = tokenizer->decode( + /* prev= */ static_cast(prev_token), + /* cur= */ static_cast(cur_token)); + if (dec.ok()) { + std::cout << *dec << std::flush; + } + + pos++; + generated++; + } + + auto decode_end = std::chrono::high_resolution_clock::now(); + double decode_ms = + std::chrono::duration(decode_end - decode_start) + .count(); + double decode_tps = + (generated > 1) ? (generated - 1) / (decode_ms / 1000.0) : 0; + + std::cout << std::endl; + std::cout << "Prefill: " << num_prompt_tokens << " tokens, " << prefill_tps + << " tok/s" << std::endl; + std::cout << "Decode: " << generated << " tokens, " << decode_tps << " tok/s" + << std::endl; return 0; } diff --git a/examples/models/qwen3_5_moe/model.py b/examples/models/qwen3_5_moe/model.py index d9f127d9ed1..8c71ac36024 100644 --- a/examples/models/qwen3_5_moe/model.py +++ b/examples/models/qwen3_5_moe/model.py @@ -566,6 +566,23 @@ def forward( x = self.norm(x) return self.lm_head(x) + def sample(self, logits: torch.Tensor, temperature: torch.Tensor) -> torch.Tensor: + """Temperature-based sampling from the last token's logits. + + Uses the Gumbel-max trick with GPU-generated noise (via torch.rand_like). + Mathematically equivalent to multinomial(softmax(logits / T), 1). + + Args: + logits: (B, T, vocab_size) float tensor from forward() + temperature: scalar tensor controlling randomness (0→greedy, higher→more random) + Returns: + next_token: (B, 1) LongTensor with the selected token ID + """ + logits = logits[:, -1, :] / temperature.clamp(min=1e-6) + noise = torch.rand_like(logits) + gumbel = -torch.log(-torch.log(noise + 1e-20) + 1e-20) + return (logits.float() + gumbel).argmax(dim=-1, keepdim=True) + @staticmethod def from_hf_checkpoint(model_dir, max_seq_len=4096): config_path = os.path.join(model_dir, "config.json")