Skip to content

Commit ab2ba20

Browse files
authored
flags support for get_num_channels (#55)
* flags support for get_num_channels * flags support for get_num_channels * refresh tests * update version
1 parent 9498350 commit ab2ba20

File tree

3 files changed

+231
-6
lines changed

3 files changed

+231
-6
lines changed

albucore/utils.py

Lines changed: 118 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,126 @@ def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.n
122122
return wrapped_function
123123

124124

125-
def get_num_channels(image: np.ndarray) -> int:
126-
return image.shape[-1] if image.ndim >= NUM_MULTI_CHANNEL_DIMENSIONS else 1
125+
def get_num_channels(image: np.ndarray, has_batch_dim: bool = False, has_depth_dim: bool = False) -> int:
126+
"""Get the number of channels in an image array.
127127
128+
This function determines the number of channels in an image array by analyzing its shape
129+
and accounting for optional batch and depth dimensions. The function assumes that the
130+
last dimension represents channels when the array has more than 2 spatial dimensions
131+
after accounting for any batch or depth dimensions.
128132
129-
def is_grayscale_image(image: np.ndarray) -> bool:
130-
return get_num_channels(image) == 1
133+
Args:
134+
image: Input image array. Can have various shapes:
135+
- HW: (height, width) - grayscale image
136+
- HWC: (height, width, channels) - multi-channel image
137+
- NHW: (batch, height, width) - batch of grayscale images
138+
- NHWC: (batch, height, width, channels) - batch of multi-channel images
139+
- DHW: (depth, height, width) - 3D grayscale volume
140+
- DHWC: (depth, height, width, channels) - 3D multi-channel volume
141+
- DNHW: (depth, batch, height, width) - batch of 3D grayscale volumes
142+
- DNHWC: (depth, batch, height, width, channels) - batch of 3D multi-channel volumes
143+
has_batch_dim: If True, the first dimension is treated as a batch dimension (N).
144+
has_depth_dim: If True, the first dimension (or second if has_batch_dim is True)
145+
is treated as a depth dimension (D).
146+
147+
Returns:
148+
int: Number of channels in the image. Returns 1 for grayscale images and the size
149+
of the last dimension for multi-channel images.
150+
151+
Examples:
152+
>>> # 2D grayscale image
153+
>>> img = np.zeros((100, 200))
154+
>>> get_num_channels(img)
155+
1
156+
157+
>>> # RGB image
158+
>>> img = np.zeros((100, 200, 3))
159+
>>> get_num_channels(img)
160+
3
161+
162+
>>> # Batch of grayscale images
163+
>>> img = np.zeros((10, 100, 200))
164+
>>> get_num_channels(img, has_batch_dim=True)
165+
1
166+
167+
>>> # Batch of RGB images
168+
>>> img = np.zeros((10, 100, 200, 3))
169+
>>> get_num_channels(img, has_batch_dim=True)
170+
3
171+
172+
>>> # 3D volume
173+
>>> img = np.zeros((5, 100, 200))
174+
>>> get_num_channels(img, has_depth_dim=True)
175+
1
176+
177+
>>> # Batch of 3D volumes with RGB
178+
>>> img = np.zeros((5, 10, 100, 200, 3))
179+
>>> get_num_channels(img, has_batch_dim=True, has_depth_dim=True)
180+
3
181+
182+
Note:
183+
The function assumes that after accounting for batch and depth dimensions,
184+
the remaining dimensions follow the pattern HW (grayscale) or HWC (multi-channel).
185+
"""
186+
# Calculate how many dimensions to skip from the beginning
187+
dims_to_skip = int(has_depth_dim) + int(has_batch_dim)
188+
189+
# After skipping D and/or N dimensions, we should have HW or HWC
190+
remaining_dims = image.ndim - dims_to_skip
191+
192+
# If we have more than 2 spatial dimensions (H, W), the last one is channels
193+
# Otherwise, it's single channel
194+
return image.shape[-1] if remaining_dims > 2 else 1
195+
196+
197+
def is_grayscale_image(image: np.ndarray, has_batch_dim: bool = False, has_depth_dim: bool = False) -> bool:
198+
"""Check if an image array represents a grayscale (single-channel) image.
199+
200+
This function determines whether an image has only one channel by calling get_num_channels
201+
and checking if the result equals 1. It properly handles various array shapes including
202+
batched images and 3D volumes.
203+
204+
Args:
205+
image: Input image array. Can have various shapes as described in get_num_channels.
206+
has_batch_dim: If True, the first dimension is treated as a batch dimension (N).
207+
has_depth_dim: If True, the first dimension (or second if has_batch_dim is True)
208+
is treated as a depth dimension (D).
209+
210+
Returns:
211+
bool: True if the image has only 1 channel (grayscale), False otherwise.
212+
213+
Examples:
214+
>>> # 2D grayscale image
215+
>>> img = np.zeros((100, 200))
216+
>>> is_grayscale_image(img)
217+
True
218+
219+
>>> # RGB image
220+
>>> img = np.zeros((100, 200, 3))
221+
>>> is_grayscale_image(img)
222+
False
223+
224+
>>> # Single channel image with explicit channel dimension
225+
>>> img = np.zeros((100, 200, 1))
226+
>>> is_grayscale_image(img)
227+
True
228+
229+
>>> # Batch of grayscale images
230+
>>> img = np.zeros((10, 100, 200))
231+
>>> is_grayscale_image(img, has_batch_dim=True)
232+
True
233+
234+
>>> # Batch of RGB images
235+
>>> img = np.zeros((10, 100, 200, 3))
236+
>>> is_grayscale_image(img, has_batch_dim=True)
237+
False
238+
239+
See Also:
240+
get_num_channels: For getting the exact number of channels.
241+
is_rgb_image: For checking if an image has exactly 3 channels (RGB).
242+
is_multispectral_image: For checking if an image has channels other than 1 or 3.
243+
"""
244+
return get_num_channels(image, has_batch_dim=has_batch_dim, has_depth_dim=has_depth_dim) == 1
131245

132246

133247
def get_opencv_dtype_from_numpy(value: np.ndarray | int | np.dtype | object) -> int:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ requires = [ "setuptools>=45", "wheel" ]
55

66
[project]
77
name = "albucore"
8-
version = "0.0.25"
8+
version = "0.0.26"
99

1010
description = "High-performance image processing functions for deep learning and computer vision."
1111
readme = "README.md"

tests/test_utils.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import cv2
44
from albucore.decorators import contiguous
55
from albucore.functions import float32_io, from_float, to_float, uint8_io
6-
from albucore.utils import NPDTYPE_TO_OPENCV_DTYPE, clip, convert_value, get_opencv_dtype_from_numpy, get_num_channels
6+
from albucore.utils import NPDTYPE_TO_OPENCV_DTYPE, clip, convert_value, get_opencv_dtype_from_numpy, get_num_channels, is_grayscale_image
77

88

99
@pytest.mark.parametrize("input_img, dtype, expected", [
@@ -257,3 +257,114 @@ def test_get_num_channels(shape, expected_channels, description):
257257
"""Test get_num_channels for various array dimensions."""
258258
image = np.zeros(shape)
259259
assert get_num_channels(image) == expected_channels, f"Failed for {description} with shape {shape}"
260+
261+
262+
@pytest.mark.parametrize("shape, has_batch_dim, has_depth_dim, expected_channels, description", [
263+
# HW: shape=(100, 200) → channels=1
264+
((100, 200), False, False, 1, "HW: grayscale image"),
265+
266+
# HWC: shape=(100, 200, 3) → channels=3
267+
((100, 200, 3), False, False, 3, "HWC: RGB image"),
268+
((100, 200, 1), False, False, 1, "HWC: single channel image"),
269+
((100, 200, 4), False, False, 4, "HWC: RGBA image"),
270+
271+
# NHW: shape=(10, 100, 200) → channels=1 (batch of grayscale)
272+
((10, 100, 200), True, False, 1, "NHW: batch of grayscale images"),
273+
274+
# NHWC: shape=(10, 100, 200, 3) → channels=3 (batch of RGB)
275+
((10, 100, 200, 3), True, False, 3, "NHWC: batch of RGB images"),
276+
((10, 100, 200, 1), True, False, 1, "NHWC: batch of single channel images"),
277+
278+
# DHW: shape=(5, 100, 200) → channels=1 (3D volume)
279+
((5, 100, 200), False, True, 1, "DHW: 3D volume"),
280+
281+
# DHWC: shape=(5, 100, 200, 3) → channels=3 (3D volume with RGB slices)
282+
((5, 100, 200, 3), False, True, 3, "DHWC: 3D volume with RGB slices"),
283+
((5, 100, 200, 1), False, True, 1, "DHWC: 3D volume with single channel"),
284+
285+
# DNHW: shape=(5, 10, 100, 200) → channels=1 (batch of 3D volumes)
286+
((5, 10, 100, 200), True, True, 1, "DNHW: batch of 3D volumes"),
287+
288+
# DNHWC: shape=(5, 10, 100, 200, 3) → channels=3 (batch of 3D volumes with RGB)
289+
((5, 10, 100, 200, 3), True, True, 3, "DNHWC: batch of 3D volumes with RGB"),
290+
((5, 10, 100, 200, 1), True, True, 1, "DNHWC: batch of 3D volumes with single channel"),
291+
292+
# Additional edge cases
293+
((32, 32), False, False, 1, "HW: square grayscale"),
294+
((224, 224, 3), False, False, 3, "HWC: standard RGB image size"),
295+
((1, 512, 512), True, False, 1, "NHW: single image in batch"),
296+
((1, 512, 512, 3), True, False, 3, "NHWC: single RGB image in batch"),
297+
])
298+
def test_get_num_channels_with_dimension_flags(shape, has_batch_dim, has_depth_dim, expected_channels, description):
299+
"""Test get_num_channels with batch and depth dimension flags."""
300+
image = np.zeros(shape)
301+
result = get_num_channels(image, has_batch_dim=has_batch_dim, has_depth_dim=has_depth_dim)
302+
assert result == expected_channels, f"Failed for {description} with shape {shape}, has_batch_dim={has_batch_dim}, has_depth_dim={has_depth_dim}"
303+
304+
305+
@pytest.mark.parametrize("shape, has_batch_dim, has_depth_dim, expected_grayscale, description", [
306+
# HW: shape=(100, 200) → grayscale=True
307+
((100, 200), False, False, True, "HW: grayscale image"),
308+
309+
# HWC: shape=(100, 200, 3) → grayscale=False
310+
((100, 200, 3), False, False, False, "HWC: RGB image"),
311+
((100, 200, 1), False, False, True, "HWC: single channel image"),
312+
((100, 200, 4), False, False, False, "HWC: RGBA image"),
313+
314+
# NHW: shape=(10, 100, 200) → grayscale=True
315+
((10, 100, 200), True, False, True, "NHW: batch of grayscale images"),
316+
317+
# NHWC: shape=(10, 100, 200, 3) → grayscale=False
318+
((10, 100, 200, 3), True, False, False, "NHWC: batch of RGB images"),
319+
((10, 100, 200, 1), True, False, True, "NHWC: batch of single channel images"),
320+
321+
# DHW: shape=(5, 100, 200) → grayscale=True
322+
((5, 100, 200), False, True, True, "DHW: 3D volume"),
323+
324+
# DHWC: shape=(5, 100, 200, 3) → grayscale=False
325+
((5, 100, 200, 3), False, True, False, "DHWC: 3D volume with RGB slices"),
326+
((5, 100, 200, 1), False, True, True, "DHWC: 3D volume with single channel"),
327+
328+
# DNHW: shape=(5, 10, 100, 200) → grayscale=True
329+
((5, 10, 100, 200), True, True, True, "DNHW: batch of 3D volumes"),
330+
331+
# DNHWC: shape=(5, 10, 100, 200, 3) → grayscale=False
332+
((5, 10, 100, 200, 3), True, True, False, "DNHWC: batch of 3D volumes with RGB"),
333+
((5, 10, 100, 200, 1), True, True, True, "DNHWC: batch of 3D volumes with single channel"),
334+
])
335+
def test_is_grayscale_image(shape, has_batch_dim, has_depth_dim, expected_grayscale, description):
336+
"""Test is_grayscale_image with various shape combinations."""
337+
image = np.zeros(shape)
338+
result = is_grayscale_image(image, has_batch_dim=has_batch_dim, has_depth_dim=has_depth_dim)
339+
assert result == expected_grayscale, f"Failed for {description} with shape {shape}, has_batch_dim={has_batch_dim}, has_depth_dim={has_depth_dim}"
340+
341+
342+
@pytest.mark.parametrize("shape, has_batch_dim, has_depth_dim", [
343+
# Basic 2D cases
344+
((100, 200), False, False),
345+
((100, 200, 1), False, False),
346+
((100, 200, 3), False, False),
347+
# Batch cases (NHW/NHWC)
348+
((10, 100, 200), True, False),
349+
((10, 100, 200, 1), True, False),
350+
((10, 100, 200, 3), True, False),
351+
# Depth cases (DHW/DHWC)
352+
((5, 100, 200), False, True),
353+
((5, 100, 200, 1), False, True),
354+
((5, 100, 200, 3), False, True),
355+
# Batch and depth cases (DNHW/DNHWC)
356+
((5, 10, 100, 200), True, True),
357+
((5, 10, 100, 200, 1), True, True),
358+
((5, 10, 100, 200, 3), True, True),
359+
])
360+
def test_get_num_channels_and_is_grayscale_consistency(shape, has_batch_dim, has_depth_dim):
361+
"""Test that get_num_channels and is_grayscale_image are consistent."""
362+
image = np.zeros(shape)
363+
num_channels = get_num_channels(image, has_batch_dim=has_batch_dim, has_depth_dim=has_depth_dim)
364+
is_grayscale = is_grayscale_image(image, has_batch_dim=has_batch_dim, has_depth_dim=has_depth_dim)
365+
366+
# is_grayscale should be True if and only if num_channels == 1
367+
assert (num_channels == 1) == is_grayscale, (
368+
f"Inconsistency for shape {shape}, has_batch_dim={has_batch_dim}, has_depth_dim={has_depth_dim}: "
369+
f"num_channels={num_channels}, is_grayscale={is_grayscale}"
370+
)

0 commit comments

Comments
 (0)