diff --git a/backends/qualcomm/builders/op_arange.py b/backends/qualcomm/builders/op_arange.py index 0a95d55dca3..882aae24ef1 100644 --- a/backends/qualcomm/builders/op_arange.py +++ b/backends/qualcomm/builders/op_arange.py @@ -6,7 +6,6 @@ from typing import Dict import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager - import torch from .node_visitor import NodeVisitor @@ -27,7 +26,8 @@ def define_node( ) -> PyQnnManager.PyQnnOpWrapper: start, end = node.args[0:2] step = node.args[2] if len(node.args) > 2 else 1 - out_tensor = torch.arange(start, end, step) + dtype = node.kwargs.get("dtype") + out_tensor = torch.arange(start, end, step, dtype=dtype) # since we can derive the constant value of current op in AoT stage # we only build static tensor here for consumers of current node diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 6d4c9eac466..da146818e38 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -177,7 +177,7 @@ def forward(self, x): class Arange(torch.nn.Module): - def __init__(self, start, end, step, dtype): + def __init__(self, start, end, step, dtype=None): super().__init__() self.start = start self.end = end diff --git a/backends/qualcomm/tests/test_passes.py b/backends/qualcomm/tests/test_passes.py index 8af66c4cbef..cd69a2428c5 100644 --- a/backends/qualcomm/tests/test_passes.py +++ b/backends/qualcomm/tests/test_passes.py @@ -7,7 +7,6 @@ InsertReshapeForReduceOps, RemoveRedundancy, ) - from executorch.exir import to_edge from executorch.exir.dialects._ops import ops as exir_ops @@ -149,6 +148,91 @@ def test_mha_to_sha(self): f"Output {i} mismatch: got {out}, expected {ref}", ) + def test_arange_dtype_from_kwargs(self): + """Test that op_arange builder respects dtype from node kwargs.""" + from collections import defaultdict + + import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager + from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager + from executorch.backends.qualcomm.builders.node_visitor_manager import ( + get_node_visitors, + ) + from executorch.backends.qualcomm.utils.utils import capture_program + + class ArangeWithDtype(torch.nn.Module): + def __init__(self, dtype): + super().__init__() + self.dtype = dtype + + def forward(self, x): + return torch.arange(0, 10, 1, dtype=self.dtype) + x + + # Map torch dtypes to expected QNN data types + dtype_to_qnn = { + torch.float32: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + torch.int64: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_INT_64, + torch.int32: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_INT_32, + } + + test_cases = [ + (torch.float32, dtype_to_qnn[torch.float32]), + (torch.int64, dtype_to_qnn[torch.int64]), + (torch.int32, dtype_to_qnn[torch.int32]), + (None, dtype_to_qnn[torch.int64]), # None uses torch default (int64) + ] + + for input_dtype, expected_qnn_dtype in test_cases: + with self.subTest(input_dtype=input_dtype): + mod = ArangeWithDtype(input_dtype) + # Use appropriate input tensor dtype + if input_dtype == torch.float32: + x = torch.zeros(10, dtype=torch.float32) + else: + x = torch.zeros(10, dtype=torch.int64) + + sample_input = (x,) + delegated_program = capture_program(mod, sample_input) + + # Transform graph through QNN pass manager + graph_module = QnnPassManager().transform_for_preprocess_pipeline( + delegated_program.exported_program + ) + + # Get node visitors (builders) + nodes_to_wrappers = defaultdict(dict) + node_visitors = get_node_visitors( + delegated_program.exported_program, enable_tensor_dump=False + ) + + # Find and process the arange node through the builder + for node in graph_module.graph.nodes: + if node.op == "call_function" and "arange" in node.target.__name__: + # Invoke the Arange builder's define_node + node_visitors[node.target.__name__].define_node( + node, nodes_to_wrappers + ) + + # Check that a tensor wrapper was created + self.assertIn( + node.name, + nodes_to_wrappers, + f"No tensor wrapper created for arange node", + ) + + # Get the tensor wrapper and verify its dtype + tensor_wrapper = nodes_to_wrappers[node.name][0] + actual_dtype = tensor_wrapper.GetDataType() + + self.assertEqual( + actual_dtype, + expected_qnn_dtype, + f"QNN dtype mismatch for input_dtype={input_dtype}: " + f"got {actual_dtype}, expected {expected_qnn_dtype}", + ) + break + else: + self.fail("No arange node found in graph") + if __name__ == "__main__": unittest.main()