Skip to content

Conversation

@larryliu0820
Copy link
Contributor

@larryliu0820 larryliu0820 commented Jan 23, 2026

Enable both export script and runner to support bfloat16. Changing CI to run on bfloat16 by default.

Also added a parakeet-cuda-debug mode

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 23, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16821

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 39 Pending, 1 Unrelated Failure

As of commit 578aea9 with merge base b928496 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 23, 2026
@larryliu0820 larryliu0820 temporarily deployed to upload-benchmark-results January 23, 2026 08:46 — with GitHub Actions Inactive
@larryliu0820 larryliu0820 force-pushed the parakeet_bf16 branch 3 times, most recently from 10ab53b to 0fffa68 Compare January 23, 2026 18:32
@larryliu0820 larryliu0820 added the release notes: desktop for desktop/laptop workstream label Jan 23, 2026
Copy link
Contributor

@Gasoonjia Gasoonjia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx man for such efficient support!

return input_meta_result.get().scalar_type();
}

std::vector<Token> greedy_decode_executorch(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im a little bit lost in the logic here: based on my understanding, greedy_decode_executorch should only for joint and decoder, but not encoder. and encoder takes preprocessor's output as input, which is always fp32 now. Since we exported encoder as target dtype (e.g. bf16), how do we makes bf16 encoder consume fp32 preprocessor's outpu?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and encoder takes preprocessor's output as input, which is always fp32 now

Not true, if we pass dtype == bfloat16, preprocessor will take in a float and gives bfloat16 result

Copy link
Contributor

@Gasoonjia Gasoonjia Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm not sure if i misunderstood something, but export script says:
# Preprocessor always uses float32 - runner converts output to encoder's dtype
but i didn't find the code in runner for the type conversion.
if preprocessor will give bfloat16 output directly:

  1. perhapes updating the export doc?
  2. how does the preprocessor know that? Seems the export configs for preprocessor are the same across different dtype.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me dive a bit deeper into the model. I can take a look at the graph to see if it changes the dtype of the output

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Graph:

graph():
    %b_preprocessor_dtype_sentinel_tensor : [num_users=0] = placeholder[target=b_preprocessor_dtype_sentinel_tensor]
    %b_preprocessor_featurizer_window : [num_users=1] = placeholder[target=b_preprocessor_featurizer_window]
    %b_preprocessor_featurizer_fb : [num_users=1] = placeholder[target=b_preprocessor_featurizer_fb]
    %audio : [num_users=2] = placeholder[target=audio]
    %length : [num_users=1] = placeholder[target=length]
    %sym_size_int_4 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%audio, 0), kwargs = {})
    %unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%audio, 0), kwargs = {})
    %submod_5 : [num_users=1] = get_attr[target=submod_1]
    %to : [num_users=1] = call_function[target=torch.ops.higher_order.wrap_with_set_grad_enabled](args = (False, %submod_5, %unsqueeze), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%to, 0), kwargs = {})
    %submod_6 : [num_users=1] = get_attr[target=submod_2]
    %wrap_with_set_grad_enabled : [num_users=2] = call_function[target=torch.ops.higher_order.wrap_with_set_grad_enabled](args = (False, %submod_6, %length, %sym_size_int_4, %getitem_4, %b_preprocessor_featurizer_window, %b_preprocessor_featurizer_fb), kwargs = {})
    %masked_fill_2 : [num_users=1] = call_function[target=operator.getitem](args = (%wrap_with_set_grad_enabled, 0), kwargs = {})
    %where : [num_users=1] = call_function[target=operator.getitem](args = (%wrap_with_set_grad_enabled, 1), kwargs = {})
    %submod_7 : [num_users=1] = get_attr[target=submod_3]
    %to_6 : [num_users=1] = call_function[target=torch.ops.higher_order.wrap_with_set_grad_enabled](args = (False, %submod_7, %masked_fill_2), kwargs = {})
    %getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%to_6, 0), kwargs = {})
    return (getitem_5, where)

Looking at the graph, the key is in those placeholder buffers:

  • b_preprocessor_featurizer_window - the window function (e.g., Hann window)
  • b_preprocessor_featurizer_fb - the mel filterbank matrix

When we call model.to(torch.bfloat16), it converts all parameters and buffers in the model, including the preprocessor's window and filterbank tensors.

The mel spectrogram computation does something like:

# Simplified view of what happens inside:
windowed = audio_frame * window          # window is bf16 → result is bf16
spectrum = torch.fft.rfft(windowed)         # stays bf16
mel = spectrum @ filterbank                      # filterbank is bf16 → result is bf16

PyTorch's type promotion rules mean that when our bf16 audio interacts with bf16 buffers, the output stays bf16.

@larryliu0820 larryliu0820 temporarily deployed to upload-benchmark-results January 23, 2026 19:46 — with GitHub Actions Inactive
@mergennachin
Copy link
Contributor

@manuelcandales @larryliu0820 it would be good to test bf16 for metal too

Copy link
Contributor

@Gasoonjia Gasoonjia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for detailed update! LGTM now!
nit: mind changing PR title to remove float16?

preprocessor_wrapper = PreprocessorWrapper(model.preprocessor)
preprocessor_wrapper.eval()
sample_audio = torch.randn(max_audio_samples)
# Preprocessor always uses float32 - runner converts output to encoder's dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe remove this comment? looks like torch.export handle the dtype changing instead of runner,.

@larryliu0820 larryliu0820 temporarily deployed to upload-benchmark-results January 27, 2026 00:28 — with GitHub Actions Inactive
@larryliu0820 larryliu0820 changed the title Add bf16 and float16 support for Parakeet Add bfloat16 support for Parakeet Jan 27, 2026
Enable both export script and runner to support bfloat16 and float16. Changing CI to run on bloat16 by default.

Also added a `parakeet-cuda-debug` mode
@larryliu0820 larryliu0820 temporarily deployed to upload-benchmark-results January 27, 2026 05:06 — with GitHub Actions Inactive
@larryliu0820 larryliu0820 merged commit 9772b07 into main Jan 27, 2026
322 of 324 checks passed
@larryliu0820 larryliu0820 deleted the parakeet_bf16 branch January 27, 2026 06:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: desktop for desktop/laptop workstream

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants