Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/qualcomm/builders/op_arange.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager

import torch

from .node_visitor import NodeVisitor
Expand All @@ -27,7 +26,8 @@ def define_node(
) -> PyQnnManager.PyQnnOpWrapper:
start, end = node.args[0:2]
step = node.args[2] if len(node.args) > 2 else 1
out_tensor = torch.arange(start, end, step)
dtype = node.kwargs.get("dtype")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about use node.meta["val"].dtype? I think it will not be None.

out_tensor = torch.arange(start, end, step, dtype=dtype)

# since we can derive the constant value of current op in AoT stage
# we only build static tensor here for consumers of current node
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def forward(self, x):


class Arange(torch.nn.Module):
def __init__(self, start, end, step, dtype):
def __init__(self, start, end, step, dtype=None):
super().__init__()
self.start = start
self.end = end
Expand Down
86 changes: 85 additions & 1 deletion backends/qualcomm/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
InsertReshapeForReduceOps,
RemoveRedundancy,
)

from executorch.exir import to_edge
from executorch.exir.dialects._ops import ops as exir_ops

Expand Down Expand Up @@ -149,6 +148,91 @@ def test_mha_to_sha(self):
f"Output {i} mismatch: got {out}, expected {ref}",
)

def test_arange_dtype_from_kwargs(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you reuse original unit test for arange op?
Such as add more dtype coverage.

def test_qnn_backend_arange(self):

def test_qnn_backend_arange(self):

"""Test that op_arange builder respects dtype from node kwargs."""
from collections import defaultdict

import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager
from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager
from executorch.backends.qualcomm.builders.node_visitor_manager import (
get_node_visitors,
)
from executorch.backends.qualcomm.utils.utils import capture_program

class ArangeWithDtype(torch.nn.Module):
def __init__(self, dtype):
super().__init__()
self.dtype = dtype

def forward(self, x):
return torch.arange(0, 10, 1, dtype=self.dtype) + x

# Map torch dtypes to expected QNN data types
dtype_to_qnn = {
torch.float32: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
torch.int64: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_INT_64,
torch.int32: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_INT_32,
}

test_cases = [
(torch.float32, dtype_to_qnn[torch.float32]),
(torch.int64, dtype_to_qnn[torch.int64]),
(torch.int32, dtype_to_qnn[torch.int32]),
(None, dtype_to_qnn[torch.int64]), # None uses torch default (int64)
]

for input_dtype, expected_qnn_dtype in test_cases:
with self.subTest(input_dtype=input_dtype):
mod = ArangeWithDtype(input_dtype)
# Use appropriate input tensor dtype
if input_dtype == torch.float32:
x = torch.zeros(10, dtype=torch.float32)
else:
x = torch.zeros(10, dtype=torch.int64)

sample_input = (x,)
delegated_program = capture_program(mod, sample_input)

# Transform graph through QNN pass manager
graph_module = QnnPassManager().transform_for_preprocess_pipeline(
delegated_program.exported_program
)

# Get node visitors (builders)
nodes_to_wrappers = defaultdict(dict)
node_visitors = get_node_visitors(
delegated_program.exported_program, enable_tensor_dump=False
)

# Find and process the arange node through the builder
for node in graph_module.graph.nodes:
if node.op == "call_function" and "arange" in node.target.__name__:
# Invoke the Arange builder's define_node
node_visitors[node.target.__name__].define_node(
node, nodes_to_wrappers
)

# Check that a tensor wrapper was created
self.assertIn(
node.name,
nodes_to_wrappers,
f"No tensor wrapper created for arange node",
)

# Get the tensor wrapper and verify its dtype
tensor_wrapper = nodes_to_wrappers[node.name][0]
actual_dtype = tensor_wrapper.GetDataType()

self.assertEqual(
actual_dtype,
expected_qnn_dtype,
f"QNN dtype mismatch for input_dtype={input_dtype}: "
f"got {actual_dtype}, expected {expected_qnn_dtype}",
)
break
else:
self.fail("No arange node found in graph")


if __name__ == "__main__":
unittest.main()
Loading