-
Notifications
You must be signed in to change notification settings - Fork 815
Add bfloat16 support for Parakeet #16821
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16821
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 39 Pending, 1 Unrelated FailureAs of commit 578aea9 with merge base b928496 ( NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ef8cdb8 to
83ed829
Compare
10ab53b to
0fffa68
Compare
Gasoonjia
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx man for such efficient support!
| return input_meta_result.get().scalar_type(); | ||
| } | ||
|
|
||
| std::vector<Token> greedy_decode_executorch( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
im a little bit lost in the logic here: based on my understanding, greedy_decode_executorch should only for joint and decoder, but not encoder. and encoder takes preprocessor's output as input, which is always fp32 now. Since we exported encoder as target dtype (e.g. bf16), how do we makes bf16 encoder consume fp32 preprocessor's outpu?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and encoder takes preprocessor's output as input, which is always fp32 now
Not true, if we pass dtype == bfloat16, preprocessor will take in a float and gives bfloat16 result
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm not sure if i misunderstood something, but export script says:
# Preprocessor always uses float32 - runner converts output to encoder's dtype
but i didn't find the code in runner for the type conversion.
if preprocessor will give bfloat16 output directly:
- perhapes updating the export doc?
- how does the preprocessor know that? Seems the export configs for preprocessor are the same across different dtype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me dive a bit deeper into the model. I can take a look at the graph to see if it changes the dtype of the output
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Graph:
graph():
%b_preprocessor_dtype_sentinel_tensor : [num_users=0] = placeholder[target=b_preprocessor_dtype_sentinel_tensor]
%b_preprocessor_featurizer_window : [num_users=1] = placeholder[target=b_preprocessor_featurizer_window]
%b_preprocessor_featurizer_fb : [num_users=1] = placeholder[target=b_preprocessor_featurizer_fb]
%audio : [num_users=2] = placeholder[target=audio]
%length : [num_users=1] = placeholder[target=length]
%sym_size_int_4 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%audio, 0), kwargs = {})
%unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%audio, 0), kwargs = {})
%submod_5 : [num_users=1] = get_attr[target=submod_1]
%to : [num_users=1] = call_function[target=torch.ops.higher_order.wrap_with_set_grad_enabled](args = (False, %submod_5, %unsqueeze), kwargs = {})
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%to, 0), kwargs = {})
%submod_6 : [num_users=1] = get_attr[target=submod_2]
%wrap_with_set_grad_enabled : [num_users=2] = call_function[target=torch.ops.higher_order.wrap_with_set_grad_enabled](args = (False, %submod_6, %length, %sym_size_int_4, %getitem_4, %b_preprocessor_featurizer_window, %b_preprocessor_featurizer_fb), kwargs = {})
%masked_fill_2 : [num_users=1] = call_function[target=operator.getitem](args = (%wrap_with_set_grad_enabled, 0), kwargs = {})
%where : [num_users=1] = call_function[target=operator.getitem](args = (%wrap_with_set_grad_enabled, 1), kwargs = {})
%submod_7 : [num_users=1] = get_attr[target=submod_3]
%to_6 : [num_users=1] = call_function[target=torch.ops.higher_order.wrap_with_set_grad_enabled](args = (False, %submod_7, %masked_fill_2), kwargs = {})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%to_6, 0), kwargs = {})
return (getitem_5, where)
Looking at the graph, the key is in those placeholder buffers:
b_preprocessor_featurizer_window- the window function (e.g., Hann window)b_preprocessor_featurizer_fb- the mel filterbank matrix
When we call model.to(torch.bfloat16), it converts all parameters and buffers in the model, including the preprocessor's window and filterbank tensors.
The mel spectrogram computation does something like:
# Simplified view of what happens inside:
windowed = audio_frame * window # window is bf16 → result is bf16
spectrum = torch.fft.rfft(windowed) # stays bf16
mel = spectrum @ filterbank # filterbank is bf16 → result is bf16PyTorch's type promotion rules mean that when our bf16 audio interacts with bf16 buffers, the output stays bf16.
|
@manuelcandales @larryliu0820 it would be good to test bf16 for metal too |
0fffa68 to
e6a0602
Compare
Gasoonjia
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for detailed update! LGTM now!
nit: mind changing PR title to remove float16?
| preprocessor_wrapper = PreprocessorWrapper(model.preprocessor) | ||
| preprocessor_wrapper.eval() | ||
| sample_audio = torch.randn(max_audio_samples) | ||
| # Preprocessor always uses float32 - runner converts output to encoder's dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe remove this comment? looks like torch.export handle the dtype changing instead of runner,.
e6a0602 to
68950d9
Compare
Enable both export script and runner to support bfloat16 and float16. Changing CI to run on bloat16 by default. Also added a `parakeet-cuda-debug` mode
68950d9 to
578aea9
Compare
Enable both export script and runner to support bfloat16. Changing CI to run on bfloat16 by default.
Also added a
parakeet-cuda-debugmode