Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 87 additions & 75 deletions backends/qualcomm/_passes/canonicalize_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,86 +115,98 @@ def call(self, graph_module: torch.fx.GraphModule):
)

with graph_module.graph.inserting_after(qdq_node_after_unsqueeze):
filter_arg = node.args[1]
filter_node = (
filter_arg
if filter_arg.op == "placeholder"
else node.args[1].args[0]
)
filter_node.meta["val"] = (
filter_node.meta["val"].unsqueeze(2).contiguous()
)
filter_tensor = get_parameter(
filter_node, self.edge_program
).unsqueeze(2)
set_parameter(
(
torch.nn.Parameter(filter_tensor)
if filter_tensor.dtype == torch.float
else filter_tensor
),
filter_node,
self.edge_program,
)

num_args = len(node.args)

bias_node = node.args[2] if num_args > 2 else None
stride = [1] + node.args[3] if num_args > 3 else [1, 1]
padding = [0] + node.args[4] if num_args > 4 else [0, 0]
if node.target == torch.ops.aten.conv1d.default:
dilation = [1] + node.args[5] if num_args > 5 else [1, 1]
groups = node.args[6] if num_args > 6 else 1
conv_args = (
qdq_node_after_unsqueeze,
node.args[1],
bias_node,
stride,
padding,
dilation,
groups,
# conv2d must be inserted before conv1d in the graph to preserve correct
# topological ordering. This is required due to conv-bn fusion: when conv1d
# has no bias, the fused bias (from batchnorm) is introduced as a new node,
# and its corresponding dq (dequantize) node must appear before conv2d in
# the execution order.
with graph_module.graph.inserting_before(node):
filter_arg = node.args[1]
filter_node = (
filter_arg
if filter_arg.op == "placeholder"
else node.args[1].args[0]
)
else:
output_padding = (
[0] + node.args[5] if num_args > 5 else [0, 0]
filter_node.meta["val"] = filter_node.meta["val"].unsqueeze(
2
)
groups = node.args[6] if num_args > 6 else 1
dilation = [1] + node.args[7] if num_args > 7 else [1, 1]
conv_args = (
qdq_node_after_unsqueeze,
node.args[1],
bias_node,
stride,
padding,
output_padding,
groups,
dilation,
)
conv2d_node = graph.create_node(
"call_function",
self.conv1d_op_map[node.target],
conv_args,
)
conv2d_node.meta = copy_meta(
node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
)
qdq_node_after_conv2d = append_qdq(
graph_module=graph_module,
node=conv2d_node,
qdq_node=list(node.users)[0],
)

with graph_module.graph.inserting_after(qdq_node_after_conv2d):
squeeze_op = torch.ops.aten.squeeze_copy.dims
squeeze_node = graph.create_node(
"call_function",
squeeze_op,
filter_tensor = get_parameter(
filter_node, self.edge_program
).unsqueeze(2)
set_parameter(
(
qdq_node_after_conv2d,
[2],
torch.nn.Parameter(filter_tensor)
if filter_tensor.dtype == torch.float
else filter_tensor
),
filter_node,
self.edge_program,
)

num_args = len(node.args)

bias_node = node.args[2] if num_args > 2 else None
stride = [1] + node.args[3] if num_args > 3 else [1, 1]
padding = [0] + node.args[4] if num_args > 4 else [0, 0]
if node.target == torch.ops.aten.conv1d.default:
dilation = (
[1] + node.args[5] if num_args > 5 else [1, 1]
)
groups = node.args[6] if num_args > 6 else 1
conv_args = (
qdq_node_after_unsqueeze,
node.args[1],
bias_node,
stride,
padding,
dilation,
groups,
)
else:
output_padding = (
[0] + node.args[5] if num_args > 5 else [0, 0]
)
groups = node.args[6] if num_args > 6 else 1
dilation = (
[1] + node.args[7] if num_args > 7 else [1, 1]
)
conv_args = (
qdq_node_after_unsqueeze,
node.args[1],
bias_node,
stride,
padding,
output_padding,
groups,
dilation,
)
conv2d_node = graph.create_node(
"call_function",
self.conv1d_op_map[node.target],
conv_args,
)
conv2d_node.meta = copy_meta(
node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
)
squeeze_node.meta = copy_meta(node.meta)
qdq_node_after_conv2d = append_qdq(
graph_module=graph_module,
node=conv2d_node,
qdq_node=list(node.users)[0],
)

with graph_module.graph.inserting_after(
qdq_node_after_conv2d
):
squeeze_op = torch.ops.aten.squeeze_copy.dims
squeeze_node = graph.create_node(
"call_function",
squeeze_op,
(
qdq_node_after_conv2d,
[2],
),
)
squeeze_node.meta = copy_meta(node.meta)

for user in node.users.copy():
user.replace_input_with(node, squeeze_node)
Expand Down
18 changes: 18 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,24 @@ def example_inputs(self):
}


class Conv1dBn(torch.nn.Module):
def __init__(self, bias=True):
super().__init__()
self.conv = torch.nn.Conv1d(
in_channels=2048,
out_channels=2048,
kernel_size=15,
groups=2048,
bias=bias,
)
self.batch_norm = torch.nn.BatchNorm1d(2048)

def forward(self, x):
x = self.conv(x)
x = self.batch_norm(x)
return x


class Conv1dSequential(torch.nn.Module):
def __init__(self, bias=True):
super().__init__()
Expand Down
124 changes: 123 additions & 1 deletion backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,13 @@ def test_qnn_backend_conv1d(self):
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_conv1d_batch_norm(self):
modules = [Conv1dBn(), Conv1dBn(bias=False)] # noqa: F405
sample_input = (torch.randn([1, 2048, 858]),)
for i, module in enumerate(modules):
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_conv2d(self):
modules = [Conv2dSequential(), Conv2dSequential(bias=False)] # noqa: F405
sample_input = (torch.randn([1, 1, 3, 3]),)
Expand Down Expand Up @@ -2637,6 +2644,14 @@ def test_qnn_backend_conv1d(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_conv1d_batch_norm(self):
modules = [Conv1dBn(), Conv1dBn(bias=False)] # noqa: F405
sample_input = (torch.randn([1, 2048, 858]),)
for i, module in enumerate(modules):
with self.subTest(i=i):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_conv2d(self):
modules = [Conv2dSequential(), Conv2dSequential(bias=False)] # noqa: F405
sample_input = (torch.randn([1, 1, 3, 3]),)
Expand Down Expand Up @@ -6870,13 +6885,30 @@ class MLLMSpecs:
tok_embedding_pte_size: float
decoder_pte_size: float

@dataclass(frozen=True)
class ALMSpecs(MLLMSpecs):
audio_path: str
golden_audio_feature: str

@dataclass(frozen=True)
class VLMSpecs(MLLMSpecs):
image_path: str
golden_image_feature: str

# TODO: refactor to support different backends
def setUp(self):
self.alm_specs = {
"granite_speech_3_3-2b": TestExampleMultimodalityScript.ALMSpecs(
max_seq_len=512,
sm8650_token_rate=5,
sm8750_token_rate=8,
encoder_pte_size=900_000_000, # 900MB
tok_embedding_pte_size=240_000_000, # 240MB
decoder_pte_size=3_000_000_000, # 3GB
audio_path="https://huggingface.co/ibm-granite/granite-speech-3.3-2b/resolve/main/10226_10111_000000.wav?download=true", # Audio content: after his nap,...
golden_audio_feature="after his nap,",
),
}
self.vlm_specs = {
"smolvlm_500m_instruct": TestExampleMultimodalityScript.VLMSpecs(
max_seq_len=128,
Expand All @@ -6900,6 +6932,96 @@ def setUp(self):
),
}

def test_static_asr(self):
if not self.required_envs([self.model_name]):
self.skipTest("missing required envs")

if self.enable_x86_64:
# Running on host is extremely slow for large models, so we skip this check to avoid timeouts.
# Please verify the output on the actual device instead.
self.skipTest(
"Skipping the check for the static ASR model on x86 due to long execution time."
)

alm_specs: TestExampleMultimodalityScript.ALMSpecs = self.alm_specs[
self.model_name
]
prompt = "can you transcribe the speech into a written format?"
audio_path = alm_specs.audio_path
cmds = [
"python",
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
"--artifact",
self.artifact_dir,
"--build_folder",
self.build_folder,
"--model",
self.model,
"--ip",
self.ip,
"--port",
str(self.port),
"--prompt",
prompt,
"--audio_path",
audio_path,
"--temperature",
"0",
"--decoder_model",
f"{self.model_name}",
"--model_mode",
"kv",
"--max_seq_len",
f"{alm_specs.max_seq_len}",
]
if self.compile_only:
cmds.extend(["--compile_only"])
elif self.device:
cmds.extend(["--device", self.device])
if self.host:
cmds.extend(["--host", self.host])
if self.pre_gen_pte:
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])

p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
with Listener((self.ip, self.port)) as listener:
conn = listener.accept()
p.communicate()
msg = json.loads(conn.recv())
if "Error" in msg:
self.fail(msg["Error"])
else:
if not self.compile_only:
model_out = msg["result"][0]
self.assertTrue(
alm_specs.golden_audio_feature in model_out.lower(),
f"Expected Output contains feature: '{alm_specs.golden_audio_feature}' Actual Output: '{model_out}'",
)
print(f"Audio Path: {audio_path}")
print(f"Query: {prompt}")
print(f"Answer: {model_out}")

encoder_pte_size = msg["audio_encoder_pte_size"]
tok_embedding_pte_size = msg["tok_embedding_pte_size"]
decoder_pte_size = msg["pte_size"]
self.assertLessEqual(encoder_pte_size, alm_specs.encoder_pte_size)
self.assertLessEqual(
tok_embedding_pte_size, alm_specs.tok_embedding_pte_size
)
self.assertLessEqual(decoder_pte_size, alm_specs.decoder_pte_size)
print(f"Encoder PTE Size: {encoder_pte_size} bytes")
print(f"Token Embedding PTE Size: {tok_embedding_pte_size} bytes")
print(f"Text Decoder PTE Size: {decoder_pte_size} bytes")

attr_name = f"{self.model.lower()}_token_rate"
if not self.compile_only and hasattr(alm_specs, attr_name):
device_inference_speed = msg["inference_speed"]
expected_inference_speed = getattr(alm_specs, attr_name)
print(f"Prompt Evaluation: {device_inference_speed} tokens/second")
self.assertGreaterEqual(
device_inference_speed, expected_inference_speed
)

def test_static_vlm(self):
if not self.required_envs([self.model_name]):
self.skipTest("missing required envs")
Expand Down Expand Up @@ -6964,7 +7086,7 @@ def test_static_vlm(self):
print(f"Query: {prompt}")
print(f"Answer: {model_out}")
if not self.enable_x86_64:
encoder_pte_size = msg["encoder_pte_size"]
encoder_pte_size = msg["vision_encoder_pte_size"]
tok_embedding_pte_size = msg["tok_embedding_pte_size"]
decoder_pte_size = msg["pte_size"]
self.assertLessEqual(encoder_pte_size, vlm_specs.encoder_pte_size)
Expand Down
24 changes: 24 additions & 0 deletions examples/models/granite_speech/BUCK
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target")
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

fbcode_target(_kind = runtime.python_library,
name = "granite_speech",
srcs = [
"__init__.py",
"convert_weights.py",
],
_is_external_target = True,
base_module = "executorch.examples.models.granite_speech",
resources = {
"config/2b_config.json": "config/2b_config.json",
},
deps = [
"//caffe2:torch",
"//executorch/examples/models/llama:llama2_model",
"fbcode//pytorch/torchtune:lib",
"fbsource//third-party/pypi/safetensors:safetensors",
],
visibility = ["PUBLIC"],
)
Loading
Loading