diff --git a/backends/qualcomm/_passes/canonicalize_conv.py b/backends/qualcomm/_passes/canonicalize_conv.py index 8836ed44328..08d34ced8ba 100644 --- a/backends/qualcomm/_passes/canonicalize_conv.py +++ b/backends/qualcomm/_passes/canonicalize_conv.py @@ -115,86 +115,98 @@ def call(self, graph_module: torch.fx.GraphModule): ) with graph_module.graph.inserting_after(qdq_node_after_unsqueeze): - filter_arg = node.args[1] - filter_node = ( - filter_arg - if filter_arg.op == "placeholder" - else node.args[1].args[0] - ) - filter_node.meta["val"] = ( - filter_node.meta["val"].unsqueeze(2).contiguous() - ) - filter_tensor = get_parameter( - filter_node, self.edge_program - ).unsqueeze(2) - set_parameter( - ( - torch.nn.Parameter(filter_tensor) - if filter_tensor.dtype == torch.float - else filter_tensor - ), - filter_node, - self.edge_program, - ) - - num_args = len(node.args) - - bias_node = node.args[2] if num_args > 2 else None - stride = [1] + node.args[3] if num_args > 3 else [1, 1] - padding = [0] + node.args[4] if num_args > 4 else [0, 0] - if node.target == torch.ops.aten.conv1d.default: - dilation = [1] + node.args[5] if num_args > 5 else [1, 1] - groups = node.args[6] if num_args > 6 else 1 - conv_args = ( - qdq_node_after_unsqueeze, - node.args[1], - bias_node, - stride, - padding, - dilation, - groups, + # conv2d must be inserted before conv1d in the graph to preserve correct + # topological ordering. This is required due to conv-bn fusion: when conv1d + # has no bias, the fused bias (from batchnorm) is introduced as a new node, + # and its corresponding dq (dequantize) node must appear before conv2d in + # the execution order. + with graph_module.graph.inserting_before(node): + filter_arg = node.args[1] + filter_node = ( + filter_arg + if filter_arg.op == "placeholder" + else node.args[1].args[0] ) - else: - output_padding = ( - [0] + node.args[5] if num_args > 5 else [0, 0] + filter_node.meta["val"] = filter_node.meta["val"].unsqueeze( + 2 ) - groups = node.args[6] if num_args > 6 else 1 - dilation = [1] + node.args[7] if num_args > 7 else [1, 1] - conv_args = ( - qdq_node_after_unsqueeze, - node.args[1], - bias_node, - stride, - padding, - output_padding, - groups, - dilation, - ) - conv2d_node = graph.create_node( - "call_function", - self.conv1d_op_map[node.target], - conv_args, - ) - conv2d_node.meta = copy_meta( - node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)} - ) - qdq_node_after_conv2d = append_qdq( - graph_module=graph_module, - node=conv2d_node, - qdq_node=list(node.users)[0], - ) - - with graph_module.graph.inserting_after(qdq_node_after_conv2d): - squeeze_op = torch.ops.aten.squeeze_copy.dims - squeeze_node = graph.create_node( - "call_function", - squeeze_op, + filter_tensor = get_parameter( + filter_node, self.edge_program + ).unsqueeze(2) + set_parameter( ( - qdq_node_after_conv2d, - [2], + torch.nn.Parameter(filter_tensor) + if filter_tensor.dtype == torch.float + else filter_tensor ), + filter_node, + self.edge_program, + ) + + num_args = len(node.args) + + bias_node = node.args[2] if num_args > 2 else None + stride = [1] + node.args[3] if num_args > 3 else [1, 1] + padding = [0] + node.args[4] if num_args > 4 else [0, 0] + if node.target == torch.ops.aten.conv1d.default: + dilation = ( + [1] + node.args[5] if num_args > 5 else [1, 1] + ) + groups = node.args[6] if num_args > 6 else 1 + conv_args = ( + qdq_node_after_unsqueeze, + node.args[1], + bias_node, + stride, + padding, + dilation, + groups, + ) + else: + output_padding = ( + [0] + node.args[5] if num_args > 5 else [0, 0] + ) + groups = node.args[6] if num_args > 6 else 1 + dilation = ( + [1] + node.args[7] if num_args > 7 else [1, 1] + ) + conv_args = ( + qdq_node_after_unsqueeze, + node.args[1], + bias_node, + stride, + padding, + output_padding, + groups, + dilation, + ) + conv2d_node = graph.create_node( + "call_function", + self.conv1d_op_map[node.target], + conv_args, + ) + conv2d_node.meta = copy_meta( + node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)} ) - squeeze_node.meta = copy_meta(node.meta) + qdq_node_after_conv2d = append_qdq( + graph_module=graph_module, + node=conv2d_node, + qdq_node=list(node.users)[0], + ) + + with graph_module.graph.inserting_after( + qdq_node_after_conv2d + ): + squeeze_op = torch.ops.aten.squeeze_copy.dims + squeeze_node = graph.create_node( + "call_function", + squeeze_op, + ( + qdq_node_after_conv2d, + [2], + ), + ) + squeeze_node.meta = copy_meta(node.meta) for user in node.users.copy(): user.replace_input_with(node, squeeze_node) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index da6b4bec66c..08a425147ad 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -490,6 +490,24 @@ def example_inputs(self): } +class Conv1dBn(torch.nn.Module): + def __init__(self, bias=True): + super().__init__() + self.conv = torch.nn.Conv1d( + in_channels=2048, + out_channels=2048, + kernel_size=15, + groups=2048, + bias=bias, + ) + self.batch_norm = torch.nn.BatchNorm1d(2048) + + def forward(self, x): + x = self.conv(x) + x = self.batch_norm(x) + return x + + class Conv1dSequential(torch.nn.Module): def __init__(self, bias=True): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 4e23f43c2ea..52b7c9eff9c 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -377,6 +377,13 @@ def test_qnn_backend_conv1d(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + def test_qnn_conv1d_batch_norm(self): + modules = [Conv1dBn(), Conv1dBn(bias=False)] # noqa: F405 + sample_input = (torch.randn([1, 2048, 858]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d(self): modules = [Conv2dSequential(), Conv2dSequential(bias=False)] # noqa: F405 sample_input = (torch.randn([1, 1, 3, 3]),) @@ -2637,6 +2644,14 @@ def test_qnn_backend_conv1d(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_conv1d_batch_norm(self): + modules = [Conv1dBn(), Conv1dBn(bias=False)] # noqa: F405 + sample_input = (torch.randn([1, 2048, 858]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d(self): modules = [Conv2dSequential(), Conv2dSequential(bias=False)] # noqa: F405 sample_input = (torch.randn([1, 1, 3, 3]),) @@ -6870,6 +6885,11 @@ class MLLMSpecs: tok_embedding_pte_size: float decoder_pte_size: float + @dataclass(frozen=True) + class ALMSpecs(MLLMSpecs): + audio_path: str + golden_audio_feature: str + @dataclass(frozen=True) class VLMSpecs(MLLMSpecs): image_path: str @@ -6877,6 +6897,18 @@ class VLMSpecs(MLLMSpecs): # TODO: refactor to support different backends def setUp(self): + self.alm_specs = { + "granite_speech_3_3-2b": TestExampleMultimodalityScript.ALMSpecs( + max_seq_len=512, + sm8650_token_rate=5, + sm8750_token_rate=8, + encoder_pte_size=900_000_000, # 900MB + tok_embedding_pte_size=240_000_000, # 240MB + decoder_pte_size=3_000_000_000, # 3GB + audio_path="https://huggingface.co/ibm-granite/granite-speech-3.3-2b/resolve/main/10226_10111_000000.wav?download=true", # Audio content: after his nap,... + golden_audio_feature="after his nap,", + ), + } self.vlm_specs = { "smolvlm_500m_instruct": TestExampleMultimodalityScript.VLMSpecs( max_seq_len=128, @@ -6900,6 +6932,96 @@ def setUp(self): ), } + def test_static_asr(self): + if not self.required_envs([self.model_name]): + self.skipTest("missing required envs") + + if self.enable_x86_64: + # Running on host is extremely slow for large models, so we skip this check to avoid timeouts. + # Please verify the output on the actual device instead. + self.skipTest( + "Skipping the check for the static ASR model on x86 due to long execution time." + ) + + alm_specs: TestExampleMultimodalityScript.ALMSpecs = self.alm_specs[ + self.model_name + ] + prompt = "can you transcribe the speech into a written format?" + audio_path = alm_specs.audio_path + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + "--prompt", + prompt, + "--audio_path", + audio_path, + "--temperature", + "0", + "--decoder_model", + f"{self.model_name}", + "--model_mode", + "kv", + "--max_seq_len", + f"{alm_specs.max_seq_len}", + ] + if self.compile_only: + cmds.extend(["--compile_only"]) + elif self.device: + cmds.extend(["--device", self.device]) + if self.host: + cmds.extend(["--host", self.host]) + if self.pre_gen_pte: + cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + if not self.compile_only: + model_out = msg["result"][0] + self.assertTrue( + alm_specs.golden_audio_feature in model_out.lower(), + f"Expected Output contains feature: '{alm_specs.golden_audio_feature}' Actual Output: '{model_out}'", + ) + print(f"Audio Path: {audio_path}") + print(f"Query: {prompt}") + print(f"Answer: {model_out}") + + encoder_pte_size = msg["audio_encoder_pte_size"] + tok_embedding_pte_size = msg["tok_embedding_pte_size"] + decoder_pte_size = msg["pte_size"] + self.assertLessEqual(encoder_pte_size, alm_specs.encoder_pte_size) + self.assertLessEqual( + tok_embedding_pte_size, alm_specs.tok_embedding_pte_size + ) + self.assertLessEqual(decoder_pte_size, alm_specs.decoder_pte_size) + print(f"Encoder PTE Size: {encoder_pte_size} bytes") + print(f"Token Embedding PTE Size: {tok_embedding_pte_size} bytes") + print(f"Text Decoder PTE Size: {decoder_pte_size} bytes") + + attr_name = f"{self.model.lower()}_token_rate" + if not self.compile_only and hasattr(alm_specs, attr_name): + device_inference_speed = msg["inference_speed"] + expected_inference_speed = getattr(alm_specs, attr_name) + print(f"Prompt Evaluation: {device_inference_speed} tokens/second") + self.assertGreaterEqual( + device_inference_speed, expected_inference_speed + ) + def test_static_vlm(self): if not self.required_envs([self.model_name]): self.skipTest("missing required envs") @@ -6964,7 +7086,7 @@ def test_static_vlm(self): print(f"Query: {prompt}") print(f"Answer: {model_out}") if not self.enable_x86_64: - encoder_pte_size = msg["encoder_pte_size"] + encoder_pte_size = msg["vision_encoder_pte_size"] tok_embedding_pte_size = msg["tok_embedding_pte_size"] decoder_pte_size = msg["pte_size"] self.assertLessEqual(encoder_pte_size, vlm_specs.encoder_pte_size) diff --git a/examples/models/granite_speech/BUCK b/examples/models/granite_speech/BUCK new file mode 100644 index 00000000000..9660c0cad90 --- /dev/null +++ b/examples/models/granite_speech/BUCK @@ -0,0 +1,24 @@ +load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +fbcode_target(_kind = runtime.python_library, + name = "granite_speech", + srcs = [ + "__init__.py", + "convert_weights.py", + ], + _is_external_target = True, + base_module = "executorch.examples.models.granite_speech", + resources = { + "config/2b_config.json": "config/2b_config.json", + }, + deps = [ + "//caffe2:torch", + "//executorch/examples/models/llama:llama2_model", + "fbcode//pytorch/torchtune:lib", + "fbsource//third-party/pypi/safetensors:safetensors", + ], + visibility = ["PUBLIC"], +) diff --git a/examples/models/granite_speech/__init__.py b/examples/models/granite_speech/__init__.py new file mode 100644 index 00000000000..8adefab4ed2 --- /dev/null +++ b/examples/models/granite_speech/__init__.py @@ -0,0 +1,16 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.examples.models.granite_speech.convert_weights import convert_weights +from executorch.examples.models.llama.model import Llama2Model + + +class GraniteSpeechModel(Llama2Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +__all__ = [ + "GraniteSpeechModel", + "convert_weights", +] diff --git a/examples/models/granite_speech/config/2b_config.json b/examples/models/granite_speech/config/2b_config.json new file mode 100644 index 00000000000..f96e84f8f03 --- /dev/null +++ b/examples/models/granite_speech/config/2b_config.json @@ -0,0 +1,19 @@ +{ + "dim": 2048, + "attention_qkv_bias": false, + "attention_multiplier": 0.015625, + "bos_idx": 0, + "embedding_scale_factor": 12.0, + "eos_idx": 0, + "act_fn": "silu", + "hidden_dim": 8192, + "n_heads": 32, + "n_layers": 40, + "n_kv_heads": 8, + "norm_eps": 1e-05, + "rope_theta": 10000000.0, + "vocab_size": 49160, + "use_hf_rope": false, + "residual_multiplier": 0.22, + "logits_scaling": 8.0 +} diff --git a/examples/models/granite_speech/convert_weights.py b/examples/models/granite_speech/convert_weights.py new file mode 100644 index 00000000000..1f3b1a5b731 --- /dev/null +++ b/examples/models/granite_speech/convert_weights.py @@ -0,0 +1,111 @@ +import argparse + +import json +import os +from typing import Dict + +import torch +from safetensors.torch import load_file + +from torchtune.models.convert_weights import get_mapped_key + + +# Weight mappings from Granite-Speech's checkpoint to ExecuTorch's transformer parameters. +_GRANITE_TO_EXECUTORCH = { + "language_model.model.embed_tokens.weight": "tok_embeddings.weight", + "language_model.model.norm.weight": "norm.weight", + "language_model.model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "language_model.model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "language_model.model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "language_model.model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "language_model.model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "language_model.model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + "language_model.model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "language_model.model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "language_model.model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", +} + + +def granite_to_executorch( + state_dict: Dict[str, torch.Tensor] +) -> Dict[str, torch.Tensor]: + """ + Convert the state dict so that it matches what ExecuTorch's transformer definition expects. + """ + converted_state_dict = {} + for key, value in state_dict.items(): + try: + new_key = get_mapped_key(key, _GRANITE_TO_EXECUTORCH) + converted_state_dict[new_key] = value + except: + # only preserve parameters of text decoder + pass + + converted_state_dict["output.weight"] = converted_state_dict[ + "tok_embeddings.weight" + ] + return converted_state_dict + + +def load_checkpoint_from_safetensors(input_dir: str) -> Dict: + index_path = os.path.join(input_dir, "model.safetensors.index.json") + if os.path.exists(index_path): + # Sharded checkpoint. + with open(index_path, "r") as f: + index = json.load(f) + weight_map = index["weight_map"] + checkpoint_shards = sorted(set(weight_map.values())) + + # Load all the shards into memory + shard_to_weights = {} + for shard in checkpoint_shards: + shard_to_weights[shard] = load_file(os.path.join(input_dir, shard)) + + # Merge tensors into consolidated state dict. + merged_state_dict = {} + for weight_name, shard in weight_map.items(): + tensor = shard_to_weights[shard][weight_name] + merged_state_dict[weight_name] = tensor + return merged_state_dict + else: + # Single checkpoint. + state_dict = load_file(os.path.join(input_dir, "model.safetensors")) + return state_dict + + +def load_checkpoint(input_dir: str) -> Dict: + pytorch_path = os.path.join(input_dir, "pytorch_model.bin") + if os.path.exists(pytorch_path): + print("Loading checkpoint from PyTorch .bin file") + return torch.load(pytorch_path, map_location="cpu", weights_only=True) + print("Loading checkpoint from safetensors directory") + return load_checkpoint_from_safetensors(input_dir) + + +def convert_weights(input_dir: str, output_file: str) -> None: + print("Loading checkpoint...") + sd = load_checkpoint(input_dir) + print("Converting checkpoint...") + sd = granite_to_executorch(sd) + print("Saving checkpoint...") + torch.save(sd, output_file) + print("Done.") + + +def main(): + parser = argparse.ArgumentParser( + description="Convert Granite-Speech weights to ExecuTorch transformer format." + ) + parser.add_argument( + "input_dir", + type=str, + help="Path to directory containing safetensor checkpoint files, or PyTorch checkpoint file.", + ) + parser.add_argument("output", type=str, help="Path to the output checkpoint") + + args = parser.parse_args() + convert_weights(args.input_dir, args.output) + + +if __name__ == "__main__": + main() diff --git a/examples/qualcomm/oss_scripts/llama/README.md b/examples/qualcomm/oss_scripts/llama/README.md index fb926e9f613..794b4c11b33 100644 --- a/examples/qualcomm/oss_scripts/llama/README.md +++ b/examples/qualcomm/oss_scripts/llama/README.md @@ -4,8 +4,8 @@ **Video Tutorial:** [Build Along: Run LLMs Locally on Qualcomm Hardware Using ExecuTorch](https://www.youtube.com/watch?v=41PKDlGM3oU) -This file provides you the instructions to run LLM Decoder model and VLM model with different parameters via Qualcomm HTP backend. We currently support the following models: -- LLM +This file provides you the instructions to run LLM Decoder model, VLM model, and ALM model with different parameters via Qualcomm HTP backend. We currently support the following models: +- Large language models 1. LLAMA2 Stories 110M 1. LLAMA3.2 1B @@ -21,10 +21,13 @@ This file provides you the instructions to run LLM Decoder model and VLM model w 1. QWEN3 0.6B / 1.7B 1. SmolLM2 135M 1. SmolLM3 3B -- VLM +- Vision-Language Models 1. SmolVLM 500M 1. InternVL3 1B +- Audio-Language models + + 1. Granite-speech-3.3-2b We offer the following modes to execute the model: @@ -215,7 +218,7 @@ Multimodal models extend LLM by processing multiple input modalities (vision, au **Current Support Status:** - **Vision-Language Models (VLM)**: Fully supported -- **Audio-Language Models (ALM)**: Coming soon +- **Audio-Language Models (ALM)**: Fully supported ### Multimodal Architecture @@ -228,7 +231,7 @@ Multimodal inference follows these key stages: 1. **Modality-Specific Encoding** - **Vision**: Images are processed through a vision encoder to generate visual embeddings - - **Audio**: Audio waveforms are processed through an audio encoder *(future support)* + - **Audio**: Audio waveforms are processed through an audio encoder to generate audio embeddings - **Text**: Text prompts are tokenized and embedded 2. **Embedding Fusion** @@ -242,12 +245,105 @@ Multimodal inference follows these key stages: --- +## Audio-Language Model (ALM) Support + +Audio-Language Models (ALMs) combine speech/audio processing and natural language processing to understand and generate text based on audio inputs. ALMs in this framework consist of: + +### Dependencies + +ALM models require the `soundfile` package for audio loading: + +```bash +pip install soundfile +``` + + +- **[Audio Encoder](model/audio_encoder.py)**: Processes raw audio waveforms into audio embeddings (e.g., CTC encoder for Granite-speech) + - **Projector** (included in audio encoder): Aligns audio embeddings with the language model's embedding space. +- **[Language Decoder](model/static_llama.py)**: Reuse static llama to generates text based on fused visual and text embeddings. + +### Instructions + +#### Granite-speech-3.3-2b + +Default example using hybrid mode. +```bash +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model granite_speech_3_3-2b --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "can you transcribe the speech into a written format?" --audio_path "https://huggingface.co/ibm-granite/granite-speech-3.3-2b/resolve/main/10226_10111_000000.wav?download=true" +``` + +### Specifying Custom Audio + +You can specify a custom audio file for ALM models using the `--audio_path` flag: +- **HTTP/HTTPS URLs**: Direct links to audio on the web + - Example: `"https://huggingface.co/ibm-granite/granite-speech-3.3-2b/resolve/main/10226_10111_000000.wav?download=true"` +- **HuggingFace repository filenames**: Files that exist in the model's HuggingFace repository are automatically downloaded + - Example: `"10226_10111_000000.wav"` (auto-downloaded from `ibm-granite/granite-speech-3.3-2b`) +- **Local file paths**: Absolute or relative paths to `.wav` files on your system + - Example: `"/path/to/your/audio.wav"` + +**Default behavior:** +If `--audio_path` is not specified, the system will automatically use the default audio file defined in the model's configuration file (`encoder/encoder_config.py`). + +#### Audio Preprocessing + +The audio encoder configuration is defined in `encoder/encoder_config.py`: + +```python +# In encoder/encoder_config.py +@dataclass(init=False, frozen=True) +class GraniteSpeechEncoder(AudioModalityConfig): + encoder_class = GraniteSpeechCTCEncoderWrapper + audio_seq_len = 171 + audio_url = "https://huggingface.co/ibm-granite/granite-speech-3.3-2b/resolve/main/10226_10111_000000.wav?download=true" # Default audio (content: "After his nap, ...") + quant_recipe = GraniteSpeechEncoderQuantRecipe +``` + +- **audio_seq_len**: Number of audio tokens generated by the encoder. + +The audio is automatically: +1. Loaded from the specified file path or downloaded from HuggingFace +2. Read as a waveform using `soundfile` and converted to a float tensor of shape `[1, T]` +3. Processed by the HuggingFace `AutoProcessor` to produce mel-filterbank features of shape `(1, 844, 160)` +4. Passed through the CTC encoder and QFormer projector to produce audio embeddings of shape `[1, audio_seq_len, hidden_dim]` + +### Using Pre-Generated PTE Files + +If you have already compiled a ALM model, you can run inference with pre-generated PTE files: + +```bash +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model granite_speech_3_3-2b --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "can you transcribe the speech into a written format?" --audio_path "https://huggingface.co/ibm-granite/granite-speech-3.3-2b/resolve/main/10226_10111_000000.wav?download=true" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE} +``` + +### ALM Processing Details + +The ALM inference pipeline consists of: + +1. **Audio Encoding Phase** + - Input audio waveform is loaded and preprocessed into mel-filterbank features: `(1, 844, 160)` + - CTC encoder extracts acoustic features using Conformer blocks with block-wise local attention + - QFormer projector aligns audio embeddings to the language model dimension: `[batch, audio_seq_len, hidden_dim]` + +2. **Text Tokenization Phase** + - User prompt is tokenized into text tokens + - Text tokens are embedded: `[batch, text_seq_len, hidden_dim]` + +3. **Embedding Fusion Phase** + - Audio and text embeddings are concatenated according to the model's template + - The `