diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index f5d7ebeee27..e162fe3b936 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -161,7 +161,8 @@ if [ "$MODEL_NAME" = "parakeet" ]; then python examples/models/parakeet/export_parakeet_tdt.py \ --backend "$DEVICE" \ - --output-dir "${OUTPUT_DIR}" + --output-dir "${OUTPUT_DIR}" \ + --dtype bf16 test -f "${OUTPUT_DIR}/model.pte" # CUDA saves named data to separate .ptd file, Metal embeds in .pte diff --git a/Makefile b/Makefile index fe0236238fa..530f091654e 100644 --- a/Makefile +++ b/Makefile @@ -88,7 +88,7 @@ # # ============================================================================== -.PHONY: voxtral-cuda voxtral-cpu voxtral-metal whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cpu parakeet-metal llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help +.PHONY: voxtral-cuda voxtral-cpu voxtral-metal whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help help: @echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make \`. Available targets:" @@ -100,6 +100,7 @@ help: @echo " whisper-cpu - Build Whisper runner with CPU backend" @echo " whisper-metal - Build Whisper runner with Metal backend (macOS only)" @echo " parakeet-cuda - Build Parakeet runner with CUDA backend" + @echo " parakeet-cuda-debug - Build Parakeet runner with CUDA backend (debug mode)" @echo " parakeet-cpu - Build Parakeet runner with CPU backend" @echo " parakeet-metal - Build Parakeet runner with Metal backend (macOS only)" @echo " llama-cpu - Build Llama runner with CPU backend" @@ -180,6 +181,15 @@ parakeet-cuda: @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/parakeet/parakeet_runner" +parakeet-cuda-debug: + @echo "==> Building and installing ExecuTorch with CUDA (debug mode)..." + cmake --workflow --preset llm-debug-cuda + @echo "==> Building Parakeet runner with CUDA (debug mode)..." + cd examples/models/parakeet && cmake --workflow --preset parakeet-cuda-debug + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/parakeet/parakeet_runner" + parakeet-cpu: @echo "==> Building and installing ExecuTorch..." cmake --workflow --preset llm-release diff --git a/examples/models/parakeet/CMakePresets.json b/examples/models/parakeet/CMakePresets.json index ea93d257ba7..097e2e026d6 100644 --- a/examples/models/parakeet/CMakePresets.json +++ b/examples/models/parakeet/CMakePresets.json @@ -29,6 +29,20 @@ "list": ["Linux", "Windows"] } }, + { + "name": "parakeet-cuda-debug", + "displayName": "Parakeet runner (CUDA, Debug)", + "inherits": ["parakeet-base"], + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "EXECUTORCH_BUILD_CUDA": "ON" + }, + "condition": { + "type": "inList", + "string": "${hostSystemName}", + "list": ["Linux", "Windows"] + } + }, { "name": "parakeet-metal", "displayName": "Parakeet runner (Metal)", @@ -56,6 +70,12 @@ "configurePreset": "parakeet-cuda", "targets": ["parakeet_runner"] }, + { + "name": "parakeet-cuda-debug", + "displayName": "Build Parakeet runner (CUDA, Debug)", + "configurePreset": "parakeet-cuda-debug", + "targets": ["parakeet_runner"] + }, { "name": "parakeet-metal", "displayName": "Build Parakeet runner (Metal)", @@ -92,6 +112,20 @@ } ] }, + { + "name": "parakeet-cuda-debug", + "displayName": "Configure and build Parakeet runner (CUDA, Debug)", + "steps": [ + { + "type": "configure", + "name": "parakeet-cuda-debug" + }, + { + "type": "build", + "name": "parakeet-cuda-debug" + } + ] + }, { "name": "parakeet-metal", "displayName": "Configure and build Parakeet runner (Metal)", diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index 67085efaca0..c97c01c1bcb 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -295,11 +295,15 @@ def forward( return mel, mel_len -def export_all(model): +def export_all(model, dtype=torch.float): """Export all model components. The maximum audio duration is determined by the model's internal max_audio_length (~50 seconds for Parakeet with max_audio_length=5000). + + Args: + model: The NeMo ASR model to export. + dtype: Data type for floating-point tensors (default: torch.float). """ programs = {} @@ -316,7 +320,8 @@ def export_all(model): preprocessor_wrapper = PreprocessorWrapper(model.preprocessor) preprocessor_wrapper.eval() - sample_audio = torch.randn(max_audio_samples) + + sample_audio = torch.randn(max_audio_samples, dtype=torch.float) sample_length = torch.tensor([sample_audio.shape[0]], dtype=torch.int64) # The preprocessor uses different code paths when CUDA is available, which include # data-dependent conditionals that torch.export cannot handle. Force CPU path. @@ -337,7 +342,7 @@ def export_all(model): feat_in = getattr(model.encoder, "_feat_in", 128) # Use max_mel_frames as example to ensure Dim.AUTO infers the full range. # Smaller examples cause Dim.AUTO to infer narrow bounds. - audio_signal = torch.randn(1, feat_in, max_mel_frames) + audio_signal = torch.randn(1, feat_in, max_mel_frames, dtype=dtype) length = torch.tensor([max_mel_frames], dtype=torch.int64) encoder_with_proj = EncoderWithProjection(model.encoder, model.joint) encoder_with_proj.eval() @@ -359,8 +364,8 @@ def export_all(model): decoder_step = DecoderStep(model.decoder, model.joint) decoder_step.eval() token = torch.tensor([[0]], dtype=torch.long) - h = torch.zeros(num_layers, 1, pred_hidden) - c = torch.zeros(num_layers, 1, pred_hidden) + h = torch.zeros(num_layers, 1, pred_hidden, dtype=dtype) + c = torch.zeros(num_layers, 1, pred_hidden, dtype=dtype) programs["decoder_step"] = export( decoder_step, (token, h, c), @@ -371,8 +376,8 @@ def export_all(model): joint_hidden = model.joint.joint_hidden num_token_classes = model.tokenizer.vocab_size + 1 # +1 for blank - f_proj = torch.randn(1, 1, joint_hidden) - g_proj = torch.randn(1, 1, joint_hidden) + f_proj = torch.randn(1, 1, joint_hidden, dtype=dtype) + g_proj = torch.randn(1, 1, joint_hidden, dtype=dtype) programs["joint"] = export( JointWithArgmax(model.joint, num_token_classes), (f_proj, g_proj), @@ -551,9 +556,9 @@ def main(): ) args = parser.parse_args() - # Validate dtype for Metal backend - if args.backend == "metal" and args.dtype == "fp16": - parser.error("Metal backend only supports fp32 and bf16, not fp16") + # Validate dtype + if args.dtype == "fp16": + parser.error("fp16 is not yet supported") os.makedirs(args.output_dir, exist_ok=True) @@ -572,7 +577,8 @@ def main(): model = model.to(torch.float16) print("\nExporting components...") - programs, metadata = export_all(model) + export_dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float + programs, metadata = export_all(model, dtype=export_dtype) et = lower_to_executorch(programs, metadata=metadata, backend=args.backend) diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index 58bdb764377..cb0df6cf72f 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -26,11 +26,13 @@ #include "types.h" #include +#include #include #include #include #include #include +#include #include #ifdef ET_BUILD_METAL #include @@ -108,6 +110,39 @@ TimestampOutputMode parse_timestamp_output_mode(const std::string& raw_arg) { "'. Expected: token, word, segment, all."); } +// Helper to get expected scalar type for a method input +::executorch::runtime::Result<::executorch::aten::ScalarType> +get_input_scalar_type( + Module& model, + const char* method_name, + size_t input_index) { + auto method_meta_result = model.method_meta(method_name); + if (!method_meta_result.ok()) { + ET_LOG(Error, "Failed to get method metadata for %s", method_name); + return method_meta_result.error(); + } + auto method_meta = method_meta_result.get(); + if (method_meta.num_inputs() <= input_index) { + ET_LOG( + Error, + "Method %s has %zu inputs, but requested index %zu", + method_name, + method_meta.num_inputs(), + input_index); + return ::executorch::runtime::Error::InvalidArgument; + } + auto input_meta_result = method_meta.input_tensor_meta(input_index); + if (input_meta_result.error() != ::executorch::runtime::Error::Ok) { + ET_LOG( + Error, + "Failed to get input tensor metadata for %s[%zu]", + method_name, + input_index); + return input_meta_result.error(); + } + return input_meta_result.get().scalar_type(); +} + std::vector greedy_decode_executorch( Module& model, const ::executorch::aten::Tensor& f_proj, @@ -118,27 +153,49 @@ std::vector greedy_decode_executorch( int64_t max_symbols_per_step = 10) { std::vector hypothesis; - // Shape: [1, time_steps, joint_hidden] - auto f_proj_sizes = f_proj.sizes(); - int64_t time_steps = f_proj_sizes[1]; - int64_t proj_dim = f_proj_sizes[2]; + // Shape: [1, T, joint_hidden] + size_t proj_dim = static_cast(f_proj.sizes()[2]); - // Initialize LSTM state - std::vector h_data(num_rnn_layers * 1 * pred_hidden, 0.0f); - std::vector c_data(num_rnn_layers * 1 * pred_hidden, 0.0f); + // Get expected dtype for decoder_step h and c inputs (indices 1 and 2) + auto h_dtype_result = get_input_scalar_type(model, "decoder_step", 1); + if (!h_dtype_result.ok()) { + return hypothesis; + } + auto c_dtype_result = get_input_scalar_type(model, "decoder_step", 2); + if (!c_dtype_result.ok()) { + return hypothesis; + } + auto h_dtype = h_dtype_result.get(); + auto c_dtype = c_dtype_result.get(); + + ET_LOG( + Info, + "Decoder h dtype: %s, c dtype: %s", + ::executorch::runtime::toString(h_dtype), + ::executorch::runtime::toString(c_dtype)); + + // Calculate buffer sizes based on dtype + size_t h_elem_size = ::executorch::runtime::elementSize(h_dtype); + size_t c_elem_size = ::executorch::runtime::elementSize(c_dtype); + size_t num_elements = + static_cast(num_rnn_layers) * static_cast(pred_hidden); + + // Initialize LSTM state with zeros (using byte buffers for dtype flexibility) + std::vector h_data(num_elements * h_elem_size, 0); + std::vector c_data(num_elements * c_elem_size, 0); auto h = from_blob( h_data.data(), {static_cast<::executorch::aten::SizesType>(num_rnn_layers), 1, static_cast<::executorch::aten::SizesType>(pred_hidden)}, - ::executorch::aten::ScalarType::Float); + h_dtype); auto c = from_blob( c_data.data(), {static_cast<::executorch::aten::SizesType>(num_rnn_layers), 1, static_cast<::executorch::aten::SizesType>(pred_hidden)}, - ::executorch::aten::ScalarType::Float); + c_dtype); // Prime the decoder with SOS (= blank_id) to match NeMo TDT label-looping: // - SOS is defined as blank: @@ -159,41 +216,61 @@ std::vector greedy_decode_executorch( auto g_proj_init = init_outputs[0].toTensor(); auto new_h_init = init_outputs[1].toTensor(); auto new_c_init = init_outputs[2].toTensor(); - std::memcpy( - h_data.data(), - new_h_init.const_data_ptr(), - h_data.size() * sizeof(float)); - std::memcpy( - c_data.data(), - new_c_init.const_data_ptr(), - c_data.size() * sizeof(float)); + std::memcpy(h_data.data(), new_h_init.const_data_ptr(), h_data.size()); + std::memcpy(c_data.data(), new_c_init.const_data_ptr(), c_data.size()); - // Copy g_proj data for reuse - std::vector g_proj_data( - g_proj_init.const_data_ptr(), - g_proj_init.const_data_ptr() + g_proj_init.numel()); + // Get expected dtype for joint inputs (f and g at indices 0 and 1) + auto f_dtype_result = get_input_scalar_type(model, "joint", 0); + if (!f_dtype_result.ok()) { + return hypothesis; + } + auto g_dtype_result = get_input_scalar_type(model, "joint", 1); + if (!g_dtype_result.ok()) { + return hypothesis; + } + auto f_dtype = f_dtype_result.get(); + auto g_dtype = g_dtype_result.get(); + + ET_LOG( + Info, + "Joint f dtype: %s, g dtype: %s", + ::executorch::runtime::toString(f_dtype), + ::executorch::runtime::toString(g_dtype)); + + size_t f_elem_size = ::executorch::runtime::elementSize(f_dtype); + size_t g_elem_size = ::executorch::runtime::elementSize(g_dtype); + + // Copy g_proj data for reuse (using byte buffer for dtype flexibility) + size_t g_proj_num_bytes = + static_cast(g_proj_init.numel()) * g_elem_size; + std::vector g_proj_data(g_proj_num_bytes); + std::memcpy( + g_proj_data.data(), g_proj_init.const_data_ptr(), g_proj_num_bytes); int64_t t = 0; int64_t symbols_on_frame = 0; + const uint8_t* f_proj_ptr = + static_cast(f_proj.const_data_ptr()); + size_t f_t_num_bytes = proj_dim * f_elem_size; // Scan over encoder output while (t < encoder_len) { // Get encoder frame at time t: f_proj[:, t:t+1, :] - const float* f_proj_ptr = f_proj.const_data_ptr(); + std::vector f_t_data(f_t_num_bytes); + std::memcpy( + f_t_data.data(), + f_proj_ptr + static_cast(t) * f_t_num_bytes, + f_t_num_bytes); - std::vector f_t_data(1 * 1 * proj_dim); - for (int64_t d = 0; d < proj_dim; d++) { - f_t_data[d] = f_proj_ptr[t * proj_dim + d]; - } auto f_t = from_blob( f_t_data.data(), {1, 1, static_cast<::executorch::aten::SizesType>(proj_dim)}, - ::executorch::aten::ScalarType::Float); + f_dtype); auto g_proj = from_blob( g_proj_data.data(), {1, 1, static_cast<::executorch::aten::SizesType>(proj_dim)}, - ::executorch::aten::ScalarType::Float); + g_dtype); auto joint_result = model.execute( "joint", std::vector<::executorch::runtime::EValue>{f_t, g_proj}); @@ -230,18 +307,10 @@ std::vector greedy_decode_executorch( auto new_c = outputs[2].toTensor(); // Update h, c, and g_proj + std::memcpy(h_data.data(), new_h.const_data_ptr(), h_data.size()); + std::memcpy(c_data.data(), new_c.const_data_ptr(), c_data.size()); std::memcpy( - h_data.data(), - new_h.const_data_ptr(), - h_data.size() * sizeof(float)); - std::memcpy( - c_data.data(), - new_c.const_data_ptr(), - c_data.size() * sizeof(float)); - std::memcpy( - g_proj_data.data(), - new_g_proj.const_data_ptr(), - g_proj_data.size() * sizeof(float)); + g_proj_data.data(), new_g_proj.const_data_ptr(), g_proj_data.size()); t += dur;