From 87dd4eb6bccfe0fc1ec050e944a1795c7b911b28 Mon Sep 17 00:00:00 2001 From: Lancelot Blanchard Date: Fri, 9 May 2025 11:42:36 -0400 Subject: [PATCH] Fix bug in LayerInfo.get_kernel_size when kernel sizes are torch.Tensor with single item --- torchinfo/layer_info.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchinfo/layer_info.py b/torchinfo/layer_info.py index 3bf2d27..88741d0 100644 --- a/torchinfo/layer_info.py +++ b/torchinfo/layer_info.py @@ -161,9 +161,11 @@ def get_kernel_size(module: nn.Module) -> int | list[int] | None: if hasattr(module, "kernel_size"): k = module.kernel_size kernel_size: int | list[int] - if isinstance(k, Iterable): + if isinstance(k, Iterable) and ( + not isinstance(k, torch.Tensor) or k.ndim > 0 + ): kernel_size = list(k) - elif isinstance(k, int): + elif isinstance(k, (int, torch.Tensor)): kernel_size = int(k) else: raise TypeError(f"kernel_size has an unexpected type: {type(k)}")