Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .ci/scripts/export_model_artifact.sh
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ if [ "$MODEL_NAME" = "parakeet" ]; then

python examples/models/parakeet/export_parakeet_tdt.py \
--backend "$DEVICE" \
--output-dir "${OUTPUT_DIR}"
--output-dir "${OUTPUT_DIR}" \
--dtype bf16

test -f "${OUTPUT_DIR}/model.pte"
# CUDA saves named data to separate .ptd file, Metal embeds in .pte
Expand Down
12 changes: 11 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
#
# ==============================================================================

.PHONY: voxtral-cuda voxtral-cpu voxtral-metal whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cpu parakeet-metal llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help

help:
@echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make <target>\`. Available targets:"
Expand All @@ -100,6 +100,7 @@ help:
@echo " whisper-cpu - Build Whisper runner with CPU backend"
@echo " whisper-metal - Build Whisper runner with Metal backend (macOS only)"
@echo " parakeet-cuda - Build Parakeet runner with CUDA backend"
@echo " parakeet-cuda-debug - Build Parakeet runner with CUDA backend (debug mode)"
@echo " parakeet-cpu - Build Parakeet runner with CPU backend"
@echo " parakeet-metal - Build Parakeet runner with Metal backend (macOS only)"
@echo " llama-cpu - Build Llama runner with CPU backend"
Expand Down Expand Up @@ -180,6 +181,15 @@ parakeet-cuda:
@echo "✓ Build complete!"
@echo " Binary: cmake-out/examples/models/parakeet/parakeet_runner"

parakeet-cuda-debug:
@echo "==> Building and installing ExecuTorch with CUDA (debug mode)..."
cmake --workflow --preset llm-debug-cuda
@echo "==> Building Parakeet runner with CUDA (debug mode)..."
cd examples/models/parakeet && cmake --workflow --preset parakeet-cuda-debug
@echo ""
@echo "✓ Build complete!"
@echo " Binary: cmake-out/examples/models/parakeet/parakeet_runner"

parakeet-cpu:
@echo "==> Building and installing ExecuTorch..."
cmake --workflow --preset llm-release
Expand Down
34 changes: 34 additions & 0 deletions examples/models/parakeet/CMakePresets.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@
"list": ["Linux", "Windows"]
}
},
{
"name": "parakeet-cuda-debug",
"displayName": "Parakeet runner (CUDA, Debug)",
"inherits": ["parakeet-base"],
"cacheVariables": {
"CMAKE_BUILD_TYPE": "Debug",
"EXECUTORCH_BUILD_CUDA": "ON"
},
"condition": {
"type": "inList",
"string": "${hostSystemName}",
"list": ["Linux", "Windows"]
}
},
{
"name": "parakeet-metal",
"displayName": "Parakeet runner (Metal)",
Expand Down Expand Up @@ -56,6 +70,12 @@
"configurePreset": "parakeet-cuda",
"targets": ["parakeet_runner"]
},
{
"name": "parakeet-cuda-debug",
"displayName": "Build Parakeet runner (CUDA, Debug)",
"configurePreset": "parakeet-cuda-debug",
"targets": ["parakeet_runner"]
},
{
"name": "parakeet-metal",
"displayName": "Build Parakeet runner (Metal)",
Expand Down Expand Up @@ -92,6 +112,20 @@
}
]
},
{
"name": "parakeet-cuda-debug",
"displayName": "Configure and build Parakeet runner (CUDA, Debug)",
"steps": [
{
"type": "configure",
"name": "parakeet-cuda-debug"
},
{
"type": "build",
"name": "parakeet-cuda-debug"
}
]
},
{
"name": "parakeet-metal",
"displayName": "Configure and build Parakeet runner (Metal)",
Expand Down
28 changes: 17 additions & 11 deletions examples/models/parakeet/export_parakeet_tdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,15 @@ def forward(
return mel, mel_len


def export_all(model):
def export_all(model, dtype=torch.float):
"""Export all model components.

The maximum audio duration is determined by the model's internal
max_audio_length (~50 seconds for Parakeet with max_audio_length=5000).

Args:
model: The NeMo ASR model to export.
dtype: Data type for floating-point tensors (default: torch.float).
"""
programs = {}

Expand All @@ -316,7 +320,8 @@ def export_all(model):

preprocessor_wrapper = PreprocessorWrapper(model.preprocessor)
preprocessor_wrapper.eval()
sample_audio = torch.randn(max_audio_samples)

sample_audio = torch.randn(max_audio_samples, dtype=torch.float)
sample_length = torch.tensor([sample_audio.shape[0]], dtype=torch.int64)
# The preprocessor uses different code paths when CUDA is available, which include
# data-dependent conditionals that torch.export cannot handle. Force CPU path.
Expand All @@ -337,7 +342,7 @@ def export_all(model):
feat_in = getattr(model.encoder, "_feat_in", 128)
# Use max_mel_frames as example to ensure Dim.AUTO infers the full range.
# Smaller examples cause Dim.AUTO to infer narrow bounds.
audio_signal = torch.randn(1, feat_in, max_mel_frames)
audio_signal = torch.randn(1, feat_in, max_mel_frames, dtype=dtype)
length = torch.tensor([max_mel_frames], dtype=torch.int64)
encoder_with_proj = EncoderWithProjection(model.encoder, model.joint)
encoder_with_proj.eval()
Expand All @@ -359,8 +364,8 @@ def export_all(model):
decoder_step = DecoderStep(model.decoder, model.joint)
decoder_step.eval()
token = torch.tensor([[0]], dtype=torch.long)
h = torch.zeros(num_layers, 1, pred_hidden)
c = torch.zeros(num_layers, 1, pred_hidden)
h = torch.zeros(num_layers, 1, pred_hidden, dtype=dtype)
c = torch.zeros(num_layers, 1, pred_hidden, dtype=dtype)
programs["decoder_step"] = export(
decoder_step,
(token, h, c),
Expand All @@ -371,8 +376,8 @@ def export_all(model):
joint_hidden = model.joint.joint_hidden
num_token_classes = model.tokenizer.vocab_size + 1 # +1 for blank

f_proj = torch.randn(1, 1, joint_hidden)
g_proj = torch.randn(1, 1, joint_hidden)
f_proj = torch.randn(1, 1, joint_hidden, dtype=dtype)
g_proj = torch.randn(1, 1, joint_hidden, dtype=dtype)
programs["joint"] = export(
JointWithArgmax(model.joint, num_token_classes),
(f_proj, g_proj),
Expand Down Expand Up @@ -551,9 +556,9 @@ def main():
)
args = parser.parse_args()

# Validate dtype for Metal backend
if args.backend == "metal" and args.dtype == "fp16":
parser.error("Metal backend only supports fp32 and bf16, not fp16")
# Validate dtype
if args.dtype == "fp16":
parser.error("fp16 is not yet supported")

os.makedirs(args.output_dir, exist_ok=True)

Expand All @@ -572,7 +577,8 @@ def main():
model = model.to(torch.float16)

print("\nExporting components...")
programs, metadata = export_all(model)
export_dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float
programs, metadata = export_all(model, dtype=export_dtype)

et = lower_to_executorch(programs, metadata=metadata, backend=args.backend)

Expand Down
147 changes: 108 additions & 39 deletions examples/models/parakeet/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@
#include "types.h"

#include <executorch/extension/llm/runner/llm_runner_helper.h>
#include <executorch/extension/llm/runner/util.h>
#include <executorch/extension/llm/runner/wav_loader.h>
#include <executorch/extension/llm/tokenizers/third-party/llama.cpp-unicode/include/unicode.h>
#include <executorch/extension/module/module.h>
#include <executorch/extension/tensor/tensor_ptr_maker.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/platform/log.h>
#ifdef ET_BUILD_METAL
#include <executorch/backends/apple/metal/runtime/stats.h>
Expand Down Expand Up @@ -108,6 +110,39 @@ TimestampOutputMode parse_timestamp_output_mode(const std::string& raw_arg) {
"'. Expected: token, word, segment, all.");
}

// Helper to get expected scalar type for a method input
::executorch::runtime::Result<::executorch::aten::ScalarType>
get_input_scalar_type(
Module& model,
const char* method_name,
size_t input_index) {
auto method_meta_result = model.method_meta(method_name);
if (!method_meta_result.ok()) {
ET_LOG(Error, "Failed to get method metadata for %s", method_name);
return method_meta_result.error();
}
auto method_meta = method_meta_result.get();
if (method_meta.num_inputs() <= input_index) {
ET_LOG(
Error,
"Method %s has %zu inputs, but requested index %zu",
method_name,
method_meta.num_inputs(),
input_index);
return ::executorch::runtime::Error::InvalidArgument;
}
auto input_meta_result = method_meta.input_tensor_meta(input_index);
if (input_meta_result.error() != ::executorch::runtime::Error::Ok) {
ET_LOG(
Error,
"Failed to get input tensor metadata for %s[%zu]",
method_name,
input_index);
return input_meta_result.error();
}
return input_meta_result.get().scalar_type();
}

std::vector<Token> greedy_decode_executorch(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im a little bit lost in the logic here: based on my understanding, greedy_decode_executorch should only for joint and decoder, but not encoder. and encoder takes preprocessor's output as input, which is always fp32 now. Since we exported encoder as target dtype (e.g. bf16), how do we makes bf16 encoder consume fp32 preprocessor's outpu?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and encoder takes preprocessor's output as input, which is always fp32 now

Not true, if we pass dtype == bfloat16, preprocessor will take in a float and gives bfloat16 result

Copy link
Contributor

@Gasoonjia Gasoonjia Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm not sure if i misunderstood something, but export script says:
# Preprocessor always uses float32 - runner converts output to encoder's dtype
but i didn't find the code in runner for the type conversion.
if preprocessor will give bfloat16 output directly:

  1. perhapes updating the export doc?
  2. how does the preprocessor know that? Seems the export configs for preprocessor are the same across different dtype.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me dive a bit deeper into the model. I can take a look at the graph to see if it changes the dtype of the output

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Graph:

graph():
    %b_preprocessor_dtype_sentinel_tensor : [num_users=0] = placeholder[target=b_preprocessor_dtype_sentinel_tensor]
    %b_preprocessor_featurizer_window : [num_users=1] = placeholder[target=b_preprocessor_featurizer_window]
    %b_preprocessor_featurizer_fb : [num_users=1] = placeholder[target=b_preprocessor_featurizer_fb]
    %audio : [num_users=2] = placeholder[target=audio]
    %length : [num_users=1] = placeholder[target=length]
    %sym_size_int_4 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%audio, 0), kwargs = {})
    %unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%audio, 0), kwargs = {})
    %submod_5 : [num_users=1] = get_attr[target=submod_1]
    %to : [num_users=1] = call_function[target=torch.ops.higher_order.wrap_with_set_grad_enabled](args = (False, %submod_5, %unsqueeze), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%to, 0), kwargs = {})
    %submod_6 : [num_users=1] = get_attr[target=submod_2]
    %wrap_with_set_grad_enabled : [num_users=2] = call_function[target=torch.ops.higher_order.wrap_with_set_grad_enabled](args = (False, %submod_6, %length, %sym_size_int_4, %getitem_4, %b_preprocessor_featurizer_window, %b_preprocessor_featurizer_fb), kwargs = {})
    %masked_fill_2 : [num_users=1] = call_function[target=operator.getitem](args = (%wrap_with_set_grad_enabled, 0), kwargs = {})
    %where : [num_users=1] = call_function[target=operator.getitem](args = (%wrap_with_set_grad_enabled, 1), kwargs = {})
    %submod_7 : [num_users=1] = get_attr[target=submod_3]
    %to_6 : [num_users=1] = call_function[target=torch.ops.higher_order.wrap_with_set_grad_enabled](args = (False, %submod_7, %masked_fill_2), kwargs = {})
    %getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%to_6, 0), kwargs = {})
    return (getitem_5, where)

Looking at the graph, the key is in those placeholder buffers:

  • b_preprocessor_featurizer_window - the window function (e.g., Hann window)
  • b_preprocessor_featurizer_fb - the mel filterbank matrix

When we call model.to(torch.bfloat16), it converts all parameters and buffers in the model, including the preprocessor's window and filterbank tensors.

The mel spectrogram computation does something like:

# Simplified view of what happens inside:
windowed = audio_frame * window          # window is bf16 → result is bf16
spectrum = torch.fft.rfft(windowed)         # stays bf16
mel = spectrum @ filterbank                      # filterbank is bf16 → result is bf16

PyTorch's type promotion rules mean that when our bf16 audio interacts with bf16 buffers, the output stays bf16.

Module& model,
const ::executorch::aten::Tensor& f_proj,
Expand All @@ -118,27 +153,49 @@ std::vector<Token> greedy_decode_executorch(
int64_t max_symbols_per_step = 10) {
std::vector<Token> hypothesis;

// Shape: [1, time_steps, joint_hidden]
auto f_proj_sizes = f_proj.sizes();
int64_t time_steps = f_proj_sizes[1];
int64_t proj_dim = f_proj_sizes[2];
// Shape: [1, T, joint_hidden]
size_t proj_dim = static_cast<size_t>(f_proj.sizes()[2]);

// Initialize LSTM state
std::vector<float> h_data(num_rnn_layers * 1 * pred_hidden, 0.0f);
std::vector<float> c_data(num_rnn_layers * 1 * pred_hidden, 0.0f);
// Get expected dtype for decoder_step h and c inputs (indices 1 and 2)
auto h_dtype_result = get_input_scalar_type(model, "decoder_step", 1);
if (!h_dtype_result.ok()) {
return hypothesis;
}
auto c_dtype_result = get_input_scalar_type(model, "decoder_step", 2);
if (!c_dtype_result.ok()) {
return hypothesis;
}
auto h_dtype = h_dtype_result.get();
auto c_dtype = c_dtype_result.get();

ET_LOG(
Info,
"Decoder h dtype: %s, c dtype: %s",
::executorch::runtime::toString(h_dtype),
::executorch::runtime::toString(c_dtype));

// Calculate buffer sizes based on dtype
size_t h_elem_size = ::executorch::runtime::elementSize(h_dtype);
size_t c_elem_size = ::executorch::runtime::elementSize(c_dtype);
size_t num_elements =
static_cast<size_t>(num_rnn_layers) * static_cast<size_t>(pred_hidden);

// Initialize LSTM state with zeros (using byte buffers for dtype flexibility)
std::vector<uint8_t> h_data(num_elements * h_elem_size, 0);
std::vector<uint8_t> c_data(num_elements * c_elem_size, 0);

auto h = from_blob(
h_data.data(),
{static_cast<::executorch::aten::SizesType>(num_rnn_layers),
1,
static_cast<::executorch::aten::SizesType>(pred_hidden)},
::executorch::aten::ScalarType::Float);
h_dtype);
auto c = from_blob(
c_data.data(),
{static_cast<::executorch::aten::SizesType>(num_rnn_layers),
1,
static_cast<::executorch::aten::SizesType>(pred_hidden)},
::executorch::aten::ScalarType::Float);
c_dtype);

// Prime the decoder with SOS (= blank_id) to match NeMo TDT label-looping:
// - SOS is defined as blank:
Expand All @@ -159,41 +216,61 @@ std::vector<Token> greedy_decode_executorch(
auto g_proj_init = init_outputs[0].toTensor();
auto new_h_init = init_outputs[1].toTensor();
auto new_c_init = init_outputs[2].toTensor();
std::memcpy(
h_data.data(),
new_h_init.const_data_ptr<float>(),
h_data.size() * sizeof(float));
std::memcpy(
c_data.data(),
new_c_init.const_data_ptr<float>(),
c_data.size() * sizeof(float));
std::memcpy(h_data.data(), new_h_init.const_data_ptr(), h_data.size());
std::memcpy(c_data.data(), new_c_init.const_data_ptr(), c_data.size());

// Copy g_proj data for reuse
std::vector<float> g_proj_data(
g_proj_init.const_data_ptr<float>(),
g_proj_init.const_data_ptr<float>() + g_proj_init.numel());
// Get expected dtype for joint inputs (f and g at indices 0 and 1)
auto f_dtype_result = get_input_scalar_type(model, "joint", 0);
if (!f_dtype_result.ok()) {
return hypothesis;
}
auto g_dtype_result = get_input_scalar_type(model, "joint", 1);
if (!g_dtype_result.ok()) {
return hypothesis;
}
auto f_dtype = f_dtype_result.get();
auto g_dtype = g_dtype_result.get();

ET_LOG(
Info,
"Joint f dtype: %s, g dtype: %s",
::executorch::runtime::toString(f_dtype),
::executorch::runtime::toString(g_dtype));

size_t f_elem_size = ::executorch::runtime::elementSize(f_dtype);
size_t g_elem_size = ::executorch::runtime::elementSize(g_dtype);

// Copy g_proj data for reuse (using byte buffer for dtype flexibility)
size_t g_proj_num_bytes =
static_cast<size_t>(g_proj_init.numel()) * g_elem_size;
std::vector<uint8_t> g_proj_data(g_proj_num_bytes);
std::memcpy(
g_proj_data.data(), g_proj_init.const_data_ptr(), g_proj_num_bytes);

int64_t t = 0;
int64_t symbols_on_frame = 0;
const uint8_t* f_proj_ptr =
static_cast<const uint8_t*>(f_proj.const_data_ptr());
size_t f_t_num_bytes = proj_dim * f_elem_size;

// Scan over encoder output
while (t < encoder_len) {
// Get encoder frame at time t: f_proj[:, t:t+1, :]
const float* f_proj_ptr = f_proj.const_data_ptr<float>();
std::vector<uint8_t> f_t_data(f_t_num_bytes);
std::memcpy(
f_t_data.data(),
f_proj_ptr + static_cast<size_t>(t) * f_t_num_bytes,
f_t_num_bytes);

std::vector<float> f_t_data(1 * 1 * proj_dim);
for (int64_t d = 0; d < proj_dim; d++) {
f_t_data[d] = f_proj_ptr[t * proj_dim + d];
}
auto f_t = from_blob(
f_t_data.data(),
{1, 1, static_cast<::executorch::aten::SizesType>(proj_dim)},
::executorch::aten::ScalarType::Float);
f_dtype);

auto g_proj = from_blob(
g_proj_data.data(),
{1, 1, static_cast<::executorch::aten::SizesType>(proj_dim)},
::executorch::aten::ScalarType::Float);
g_dtype);

auto joint_result = model.execute(
"joint", std::vector<::executorch::runtime::EValue>{f_t, g_proj});
Expand Down Expand Up @@ -230,18 +307,10 @@ std::vector<Token> greedy_decode_executorch(
auto new_c = outputs[2].toTensor();

// Update h, c, and g_proj
std::memcpy(h_data.data(), new_h.const_data_ptr(), h_data.size());
std::memcpy(c_data.data(), new_c.const_data_ptr(), c_data.size());
std::memcpy(
h_data.data(),
new_h.const_data_ptr<float>(),
h_data.size() * sizeof(float));
std::memcpy(
c_data.data(),
new_c.const_data_ptr<float>(),
c_data.size() * sizeof(float));
std::memcpy(
g_proj_data.data(),
new_g_proj.const_data_ptr<float>(),
g_proj_data.size() * sizeof(float));
g_proj_data.data(), new_g_proj.const_data_ptr(), g_proj_data.size());

t += dur;

Expand Down
Loading