Skip to content

Commit 10ab53b

Browse files
committed
Add bf16 and float16 support for Parakeet
Enable both export script and runner to support bfloat16 and float16. Changing CI to run on bloat16 by default. Also added a `parakeet-cuda-debug` mode
1 parent b928496 commit 10ab53b

5 files changed

Lines changed: 172 additions & 48 deletions

File tree

.ci/scripts/export_model_artifact.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ if [ "$MODEL_NAME" = "parakeet" ]; then
162162
python examples/models/parakeet/export_parakeet_tdt.py \
163163
--backend "$DEVICE" \
164164
--output-dir "${OUTPUT_DIR}"
165+
--dtype bf16
165166

166167
test -f "${OUTPUT_DIR}/model.pte"
167168
# CUDA saves named data to separate .ptd file, Metal embeds in .pte

Makefile

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
#
8989
# ==============================================================================
9090

91-
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cpu parakeet-metal llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help
91+
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help
9292

9393
help:
9494
@echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make <target>\`. Available targets:"
@@ -100,6 +100,7 @@ help:
100100
@echo " whisper-cpu - Build Whisper runner with CPU backend"
101101
@echo " whisper-metal - Build Whisper runner with Metal backend (macOS only)"
102102
@echo " parakeet-cuda - Build Parakeet runner with CUDA backend"
103+
@echo " parakeet-cuda-debug - Build Parakeet runner with CUDA backend (debug mode)"
103104
@echo " parakeet-cpu - Build Parakeet runner with CPU backend"
104105
@echo " parakeet-metal - Build Parakeet runner with Metal backend (macOS only)"
105106
@echo " llama-cpu - Build Llama runner with CPU backend"
@@ -180,6 +181,15 @@ parakeet-cuda:
180181
@echo "✓ Build complete!"
181182
@echo " Binary: cmake-out/examples/models/parakeet/parakeet_runner"
182183

184+
parakeet-cuda-debug:
185+
@echo "==> Building and installing ExecuTorch with CUDA (debug mode)..."
186+
cmake --workflow --preset llm-debug-cuda
187+
@echo "==> Building Parakeet runner with CUDA (debug mode)..."
188+
cd examples/models/parakeet && cmake --workflow --preset parakeet-cuda-debug
189+
@echo ""
190+
@echo "✓ Build complete!"
191+
@echo " Binary: cmake-out/examples/models/parakeet/parakeet_runner"
192+
183193
parakeet-cpu:
184194
@echo "==> Building and installing ExecuTorch..."
185195
cmake --workflow --preset llm-release

examples/models/parakeet/CMakePresets.json

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,20 @@
2929
"list": ["Linux", "Windows"]
3030
}
3131
},
32+
{
33+
"name": "parakeet-cuda-debug",
34+
"displayName": "Parakeet runner (CUDA, Debug)",
35+
"inherits": ["parakeet-base"],
36+
"cacheVariables": {
37+
"CMAKE_BUILD_TYPE": "Debug",
38+
"EXECUTORCH_BUILD_CUDA": "ON"
39+
},
40+
"condition": {
41+
"type": "inList",
42+
"string": "${hostSystemName}",
43+
"list": ["Linux", "Windows"]
44+
}
45+
},
3246
{
3347
"name": "parakeet-metal",
3448
"displayName": "Parakeet runner (Metal)",
@@ -56,6 +70,12 @@
5670
"configurePreset": "parakeet-cuda",
5771
"targets": ["parakeet_runner"]
5872
},
73+
{
74+
"name": "parakeet-cuda-debug",
75+
"displayName": "Build Parakeet runner (CUDA, Debug)",
76+
"configurePreset": "parakeet-cuda-debug",
77+
"targets": ["parakeet_runner"]
78+
},
5979
{
6080
"name": "parakeet-metal",
6181
"displayName": "Build Parakeet runner (Metal)",
@@ -92,6 +112,20 @@
92112
}
93113
]
94114
},
115+
{
116+
"name": "parakeet-cuda-debug",
117+
"displayName": "Configure and build Parakeet runner (CUDA, Debug)",
118+
"steps": [
119+
{
120+
"type": "configure",
121+
"name": "parakeet-cuda-debug"
122+
},
123+
{
124+
"type": "build",
125+
"name": "parakeet-cuda-debug"
126+
}
127+
]
128+
},
95129
{
96130
"name": "parakeet-metal",
97131
"displayName": "Configure and build Parakeet runner (Metal)",

examples/models/parakeet/export_parakeet_tdt.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -295,11 +295,15 @@ def forward(
295295
return mel, mel_len
296296

297297

298-
def export_all(model):
298+
def export_all(model, dtype=torch.float):
299299
"""Export all model components.
300300
301301
The maximum audio duration is determined by the model's internal
302302
max_audio_length (~50 seconds for Parakeet with max_audio_length=5000).
303+
304+
Args:
305+
model: The NeMo ASR model to export.
306+
dtype: Data type for floating-point tensors (default: torch.float).
303307
"""
304308
programs = {}
305309

@@ -316,7 +320,8 @@ def export_all(model):
316320

317321
preprocessor_wrapper = PreprocessorWrapper(model.preprocessor)
318322
preprocessor_wrapper.eval()
319-
sample_audio = torch.randn(max_audio_samples)
323+
# Preprocessor always uses float32 - runner converts output to encoder's dtype
324+
sample_audio = torch.randn(max_audio_samples, dtype=torch.float)
320325
sample_length = torch.tensor([sample_audio.shape[0]], dtype=torch.int64)
321326
# The preprocessor uses different code paths when CUDA is available, which include
322327
# data-dependent conditionals that torch.export cannot handle. Force CPU path.
@@ -337,7 +342,7 @@ def export_all(model):
337342
feat_in = getattr(model.encoder, "_feat_in", 128)
338343
# Use max_mel_frames as example to ensure Dim.AUTO infers the full range.
339344
# Smaller examples cause Dim.AUTO to infer narrow bounds.
340-
audio_signal = torch.randn(1, feat_in, max_mel_frames)
345+
audio_signal = torch.randn(1, feat_in, max_mel_frames, dtype=dtype)
341346
length = torch.tensor([max_mel_frames], dtype=torch.int64)
342347
encoder_with_proj = EncoderWithProjection(model.encoder, model.joint)
343348
encoder_with_proj.eval()
@@ -359,8 +364,8 @@ def export_all(model):
359364
decoder_step = DecoderStep(model.decoder, model.joint)
360365
decoder_step.eval()
361366
token = torch.tensor([[0]], dtype=torch.long)
362-
h = torch.zeros(num_layers, 1, pred_hidden)
363-
c = torch.zeros(num_layers, 1, pred_hidden)
367+
h = torch.zeros(num_layers, 1, pred_hidden, dtype=dtype)
368+
c = torch.zeros(num_layers, 1, pred_hidden, dtype=dtype)
364369
programs["decoder_step"] = export(
365370
decoder_step,
366371
(token, h, c),
@@ -371,8 +376,8 @@ def export_all(model):
371376
joint_hidden = model.joint.joint_hidden
372377
num_token_classes = model.tokenizer.vocab_size + 1 # +1 for blank
373378

374-
f_proj = torch.randn(1, 1, joint_hidden)
375-
g_proj = torch.randn(1, 1, joint_hidden)
379+
f_proj = torch.randn(1, 1, joint_hidden, dtype=dtype)
380+
g_proj = torch.randn(1, 1, joint_hidden, dtype=dtype)
376381
programs["joint"] = export(
377382
JointWithArgmax(model.joint, num_token_classes),
378383
(f_proj, g_proj),
@@ -572,7 +577,12 @@ def main():
572577
model = model.to(torch.float16)
573578

574579
print("\nExporting components...")
575-
programs, metadata = export_all(model)
580+
export_dtype = (
581+
torch.bfloat16
582+
if args.dtype == "bf16"
583+
else torch.float16 if args.dtype == "fp16" else torch.float
584+
)
585+
programs, metadata = export_all(model, dtype=export_dtype)
576586

577587
et = lower_to_executorch(programs, metadata=metadata, backend=args.backend)
578588

examples/models/parakeet/main.cpp

Lines changed: 108 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@
2626
#include "types.h"
2727

2828
#include <executorch/extension/llm/runner/llm_runner_helper.h>
29+
#include <executorch/extension/llm/runner/util.h>
2930
#include <executorch/extension/llm/runner/wav_loader.h>
3031
#include <executorch/extension/llm/tokenizers/third-party/llama.cpp-unicode/include/unicode.h>
3132
#include <executorch/extension/module/module.h>
3233
#include <executorch/extension/tensor/tensor_ptr_maker.h>
3334
#include <executorch/runtime/core/evalue.h>
35+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
3436
#include <executorch/runtime/platform/log.h>
3537
#ifdef ET_BUILD_METAL
3638
#include <executorch/backends/apple/metal/runtime/stats.h>
@@ -108,6 +110,39 @@ TimestampOutputMode parse_timestamp_output_mode(const std::string& raw_arg) {
108110
"'. Expected: token, word, segment, all.");
109111
}
110112

113+
// Helper to get expected scalar type for a method input
114+
::executorch::runtime::Result<::executorch::aten::ScalarType>
115+
get_input_scalar_type(
116+
Module& model,
117+
const char* method_name,
118+
size_t input_index) {
119+
auto method_meta_result = model.method_meta(method_name);
120+
if (!method_meta_result.ok()) {
121+
ET_LOG(Error, "Failed to get method metadata for %s", method_name);
122+
return method_meta_result.error();
123+
}
124+
auto method_meta = method_meta_result.get();
125+
if (method_meta.num_inputs() <= input_index) {
126+
ET_LOG(
127+
Error,
128+
"Method %s has %zu inputs, but requested index %zu",
129+
method_name,
130+
method_meta.num_inputs(),
131+
input_index);
132+
return ::executorch::runtime::Error::InvalidArgument;
133+
}
134+
auto input_meta_result = method_meta.input_tensor_meta(input_index);
135+
if (input_meta_result.error() != ::executorch::runtime::Error::Ok) {
136+
ET_LOG(
137+
Error,
138+
"Failed to get input tensor metadata for %s[%zu]",
139+
method_name,
140+
input_index);
141+
return input_meta_result.error();
142+
}
143+
return input_meta_result.get().scalar_type();
144+
}
145+
111146
std::vector<Token> greedy_decode_executorch(
112147
Module& model,
113148
const ::executorch::aten::Tensor& f_proj,
@@ -118,27 +153,49 @@ std::vector<Token> greedy_decode_executorch(
118153
int64_t max_symbols_per_step = 10) {
119154
std::vector<Token> hypothesis;
120155

121-
// Shape: [1, time_steps, joint_hidden]
122-
auto f_proj_sizes = f_proj.sizes();
123-
int64_t time_steps = f_proj_sizes[1];
124-
int64_t proj_dim = f_proj_sizes[2];
156+
// Shape: [1, T, joint_hidden]
157+
size_t proj_dim = static_cast<size_t>(f_proj.sizes()[2]);
125158

126-
// Initialize LSTM state
127-
std::vector<float> h_data(num_rnn_layers * 1 * pred_hidden, 0.0f);
128-
std::vector<float> c_data(num_rnn_layers * 1 * pred_hidden, 0.0f);
159+
// Get expected dtype for decoder_step h and c inputs (indices 1 and 2)
160+
auto h_dtype_result = get_input_scalar_type(model, "decoder_step", 1);
161+
if (!h_dtype_result.ok()) {
162+
return hypothesis;
163+
}
164+
auto c_dtype_result = get_input_scalar_type(model, "decoder_step", 2);
165+
if (!c_dtype_result.ok()) {
166+
return hypothesis;
167+
}
168+
auto h_dtype = h_dtype_result.get();
169+
auto c_dtype = c_dtype_result.get();
170+
171+
ET_LOG(
172+
Info,
173+
"Decoder h dtype: %s, c dtype: %s",
174+
::executorch::runtime::toString(h_dtype),
175+
::executorch::runtime::toString(c_dtype));
176+
177+
// Calculate buffer sizes based on dtype
178+
size_t h_elem_size = ::executorch::runtime::elementSize(h_dtype);
179+
size_t c_elem_size = ::executorch::runtime::elementSize(c_dtype);
180+
size_t num_elements =
181+
static_cast<size_t>(num_rnn_layers) * static_cast<size_t>(pred_hidden);
182+
183+
// Initialize LSTM state with zeros (using byte buffers for dtype flexibility)
184+
std::vector<uint8_t> h_data(num_elements * h_elem_size, 0);
185+
std::vector<uint8_t> c_data(num_elements * c_elem_size, 0);
129186

130187
auto h = from_blob(
131188
h_data.data(),
132189
{static_cast<::executorch::aten::SizesType>(num_rnn_layers),
133190
1,
134191
static_cast<::executorch::aten::SizesType>(pred_hidden)},
135-
::executorch::aten::ScalarType::Float);
192+
h_dtype);
136193
auto c = from_blob(
137194
c_data.data(),
138195
{static_cast<::executorch::aten::SizesType>(num_rnn_layers),
139196
1,
140197
static_cast<::executorch::aten::SizesType>(pred_hidden)},
141-
::executorch::aten::ScalarType::Float);
198+
c_dtype);
142199

143200
// Prime the decoder with SOS (= blank_id) to match NeMo TDT label-looping:
144201
// - SOS is defined as blank:
@@ -159,41 +216,61 @@ std::vector<Token> greedy_decode_executorch(
159216
auto g_proj_init = init_outputs[0].toTensor();
160217
auto new_h_init = init_outputs[1].toTensor();
161218
auto new_c_init = init_outputs[2].toTensor();
162-
std::memcpy(
163-
h_data.data(),
164-
new_h_init.const_data_ptr<float>(),
165-
h_data.size() * sizeof(float));
166-
std::memcpy(
167-
c_data.data(),
168-
new_c_init.const_data_ptr<float>(),
169-
c_data.size() * sizeof(float));
219+
std::memcpy(h_data.data(), new_h_init.const_data_ptr(), h_data.size());
220+
std::memcpy(c_data.data(), new_c_init.const_data_ptr(), c_data.size());
170221

171-
// Copy g_proj data for reuse
172-
std::vector<float> g_proj_data(
173-
g_proj_init.const_data_ptr<float>(),
174-
g_proj_init.const_data_ptr<float>() + g_proj_init.numel());
222+
// Get expected dtype for joint inputs (f and g at indices 0 and 1)
223+
auto f_dtype_result = get_input_scalar_type(model, "joint", 0);
224+
if (!f_dtype_result.ok()) {
225+
return hypothesis;
226+
}
227+
auto g_dtype_result = get_input_scalar_type(model, "joint", 1);
228+
if (!g_dtype_result.ok()) {
229+
return hypothesis;
230+
}
231+
auto f_dtype = f_dtype_result.get();
232+
auto g_dtype = g_dtype_result.get();
233+
234+
ET_LOG(
235+
Info,
236+
"Joint f dtype: %s, g dtype: %s",
237+
::executorch::runtime::toString(f_dtype),
238+
::executorch::runtime::toString(g_dtype));
239+
240+
size_t f_elem_size = ::executorch::runtime::elementSize(f_dtype);
241+
size_t g_elem_size = ::executorch::runtime::elementSize(g_dtype);
242+
243+
// Copy g_proj data for reuse (using byte buffer for dtype flexibility)
244+
size_t g_proj_num_bytes =
245+
static_cast<size_t>(g_proj_init.numel()) * g_elem_size;
246+
std::vector<uint8_t> g_proj_data(g_proj_num_bytes);
247+
std::memcpy(
248+
g_proj_data.data(), g_proj_init.const_data_ptr(), g_proj_num_bytes);
175249

176250
int64_t t = 0;
177251
int64_t symbols_on_frame = 0;
252+
const uint8_t* f_proj_ptr =
253+
static_cast<const uint8_t*>(f_proj.const_data_ptr());
254+
size_t f_t_num_bytes = proj_dim * f_elem_size;
178255

179256
// Scan over encoder output
180257
while (t < encoder_len) {
181258
// Get encoder frame at time t: f_proj[:, t:t+1, :]
182-
const float* f_proj_ptr = f_proj.const_data_ptr<float>();
259+
std::vector<uint8_t> f_t_data(f_t_num_bytes);
260+
std::memcpy(
261+
f_t_data.data(),
262+
f_proj_ptr + static_cast<size_t>(t) * f_t_num_bytes,
263+
f_t_num_bytes);
183264

184-
std::vector<float> f_t_data(1 * 1 * proj_dim);
185-
for (int64_t d = 0; d < proj_dim; d++) {
186-
f_t_data[d] = f_proj_ptr[t * proj_dim + d];
187-
}
188265
auto f_t = from_blob(
189266
f_t_data.data(),
190267
{1, 1, static_cast<::executorch::aten::SizesType>(proj_dim)},
191-
::executorch::aten::ScalarType::Float);
268+
f_dtype);
192269

193270
auto g_proj = from_blob(
194271
g_proj_data.data(),
195272
{1, 1, static_cast<::executorch::aten::SizesType>(proj_dim)},
196-
::executorch::aten::ScalarType::Float);
273+
g_dtype);
197274

198275
auto joint_result = model.execute(
199276
"joint", std::vector<::executorch::runtime::EValue>{f_t, g_proj});
@@ -230,18 +307,10 @@ std::vector<Token> greedy_decode_executorch(
230307
auto new_c = outputs[2].toTensor();
231308

232309
// Update h, c, and g_proj
310+
std::memcpy(h_data.data(), new_h.const_data_ptr(), h_data.size());
311+
std::memcpy(c_data.data(), new_c.const_data_ptr(), c_data.size());
233312
std::memcpy(
234-
h_data.data(),
235-
new_h.const_data_ptr<float>(),
236-
h_data.size() * sizeof(float));
237-
std::memcpy(
238-
c_data.data(),
239-
new_c.const_data_ptr<float>(),
240-
c_data.size() * sizeof(float));
241-
std::memcpy(
242-
g_proj_data.data(),
243-
new_g_proj.const_data_ptr<float>(),
244-
g_proj_data.size() * sizeof(float));
313+
g_proj_data.data(), new_g_proj.const_data_ptr(), g_proj_data.size());
245314

246315
t += dur;
247316

0 commit comments

Comments
 (0)