diff --git a/exir/passes/convert_constant_dim_order_pass.py b/exir/passes/convert_constant_dim_order_pass.py index e2c2046dbb8..b9984d4db44 100644 --- a/exir/passes/convert_constant_dim_order_pass.py +++ b/exir/passes/convert_constant_dim_order_pass.py @@ -1,7 +1,11 @@ +import logging + import torch from torch.export import ExportedProgram from torch.export.graph_signature import InputKind +logger = logging.getLogger(__name__) + def _should_transform(tensor: torch.Tensor) -> bool: """ @@ -17,7 +21,7 @@ def _update_placeholder_meta( exported_program: ExportedProgram, target: str, kind: InputKind, -): +) -> None: input_spec = next( ( spec @@ -27,7 +31,8 @@ def _update_placeholder_meta( None, ) if input_spec is None: - raise RuntimeError(f"Missing input spec for lifted tensor {target}.") + logger.warning(f"Missing input spec for constant {target}") + return placeholder_node = next( ( @@ -37,8 +42,9 @@ def _update_placeholder_meta( ), None, ) - if input_spec is None: - raise RuntimeError(f"Missing placeholder for {input_spec.arg.name}.") + if placeholder_node is None: + logger.warning(f"Missing placeholder node for constant {target}") + return placeholder_node.meta["val"] = placeholder_node.meta["val"].contiguous()