|
3 | 3 | import cv2 |
4 | 4 | from albucore.decorators import contiguous |
5 | 5 | from albucore.functions import float32_io, from_float, to_float, uint8_io |
6 | | -from albucore.utils import NPDTYPE_TO_OPENCV_DTYPE, clip, convert_value, get_opencv_dtype_from_numpy, get_num_channels |
| 6 | +from albucore.utils import NPDTYPE_TO_OPENCV_DTYPE, clip, convert_value, get_opencv_dtype_from_numpy, get_num_channels, is_grayscale_image |
7 | 7 |
|
8 | 8 |
|
9 | 9 | @pytest.mark.parametrize("input_img, dtype, expected", [ |
@@ -257,3 +257,114 @@ def test_get_num_channels(shape, expected_channels, description): |
257 | 257 | """Test get_num_channels for various array dimensions.""" |
258 | 258 | image = np.zeros(shape) |
259 | 259 | assert get_num_channels(image) == expected_channels, f"Failed for {description} with shape {shape}" |
| 260 | + |
| 261 | + |
| 262 | +@pytest.mark.parametrize("shape, has_batch_dim, has_depth_dim, expected_channels, description", [ |
| 263 | + # HW: shape=(100, 200) → channels=1 |
| 264 | + ((100, 200), False, False, 1, "HW: grayscale image"), |
| 265 | +
|
| 266 | + # HWC: shape=(100, 200, 3) → channels=3 |
| 267 | + ((100, 200, 3), False, False, 3, "HWC: RGB image"), |
| 268 | + ((100, 200, 1), False, False, 1, "HWC: single channel image"), |
| 269 | + ((100, 200, 4), False, False, 4, "HWC: RGBA image"), |
| 270 | +
|
| 271 | + # NHW: shape=(10, 100, 200) → channels=1 (batch of grayscale) |
| 272 | + ((10, 100, 200), True, False, 1, "NHW: batch of grayscale images"), |
| 273 | +
|
| 274 | + # NHWC: shape=(10, 100, 200, 3) → channels=3 (batch of RGB) |
| 275 | + ((10, 100, 200, 3), True, False, 3, "NHWC: batch of RGB images"), |
| 276 | + ((10, 100, 200, 1), True, False, 1, "NHWC: batch of single channel images"), |
| 277 | +
|
| 278 | + # DHW: shape=(5, 100, 200) → channels=1 (3D volume) |
| 279 | + ((5, 100, 200), False, True, 1, "DHW: 3D volume"), |
| 280 | +
|
| 281 | + # DHWC: shape=(5, 100, 200, 3) → channels=3 (3D volume with RGB slices) |
| 282 | + ((5, 100, 200, 3), False, True, 3, "DHWC: 3D volume with RGB slices"), |
| 283 | + ((5, 100, 200, 1), False, True, 1, "DHWC: 3D volume with single channel"), |
| 284 | +
|
| 285 | + # DNHW: shape=(5, 10, 100, 200) → channels=1 (batch of 3D volumes) |
| 286 | + ((5, 10, 100, 200), True, True, 1, "DNHW: batch of 3D volumes"), |
| 287 | +
|
| 288 | + # DNHWC: shape=(5, 10, 100, 200, 3) → channels=3 (batch of 3D volumes with RGB) |
| 289 | + ((5, 10, 100, 200, 3), True, True, 3, "DNHWC: batch of 3D volumes with RGB"), |
| 290 | + ((5, 10, 100, 200, 1), True, True, 1, "DNHWC: batch of 3D volumes with single channel"), |
| 291 | +
|
| 292 | + # Additional edge cases |
| 293 | + ((32, 32), False, False, 1, "HW: square grayscale"), |
| 294 | + ((224, 224, 3), False, False, 3, "HWC: standard RGB image size"), |
| 295 | + ((1, 512, 512), True, False, 1, "NHW: single image in batch"), |
| 296 | + ((1, 512, 512, 3), True, False, 3, "NHWC: single RGB image in batch"), |
| 297 | +]) |
| 298 | +def test_get_num_channels_with_dimension_flags(shape, has_batch_dim, has_depth_dim, expected_channels, description): |
| 299 | + """Test get_num_channels with batch and depth dimension flags.""" |
| 300 | + image = np.zeros(shape) |
| 301 | + result = get_num_channels(image, has_batch_dim=has_batch_dim, has_depth_dim=has_depth_dim) |
| 302 | + assert result == expected_channels, f"Failed for {description} with shape {shape}, has_batch_dim={has_batch_dim}, has_depth_dim={has_depth_dim}" |
| 303 | + |
| 304 | + |
| 305 | +@pytest.mark.parametrize("shape, has_batch_dim, has_depth_dim, expected_grayscale, description", [ |
| 306 | + # HW: shape=(100, 200) → grayscale=True |
| 307 | + ((100, 200), False, False, True, "HW: grayscale image"), |
| 308 | +
|
| 309 | + # HWC: shape=(100, 200, 3) → grayscale=False |
| 310 | + ((100, 200, 3), False, False, False, "HWC: RGB image"), |
| 311 | + ((100, 200, 1), False, False, True, "HWC: single channel image"), |
| 312 | + ((100, 200, 4), False, False, False, "HWC: RGBA image"), |
| 313 | +
|
| 314 | + # NHW: shape=(10, 100, 200) → grayscale=True |
| 315 | + ((10, 100, 200), True, False, True, "NHW: batch of grayscale images"), |
| 316 | +
|
| 317 | + # NHWC: shape=(10, 100, 200, 3) → grayscale=False |
| 318 | + ((10, 100, 200, 3), True, False, False, "NHWC: batch of RGB images"), |
| 319 | + ((10, 100, 200, 1), True, False, True, "NHWC: batch of single channel images"), |
| 320 | +
|
| 321 | + # DHW: shape=(5, 100, 200) → grayscale=True |
| 322 | + ((5, 100, 200), False, True, True, "DHW: 3D volume"), |
| 323 | +
|
| 324 | + # DHWC: shape=(5, 100, 200, 3) → grayscale=False |
| 325 | + ((5, 100, 200, 3), False, True, False, "DHWC: 3D volume with RGB slices"), |
| 326 | + ((5, 100, 200, 1), False, True, True, "DHWC: 3D volume with single channel"), |
| 327 | +
|
| 328 | + # DNHW: shape=(5, 10, 100, 200) → grayscale=True |
| 329 | + ((5, 10, 100, 200), True, True, True, "DNHW: batch of 3D volumes"), |
| 330 | +
|
| 331 | + # DNHWC: shape=(5, 10, 100, 200, 3) → grayscale=False |
| 332 | + ((5, 10, 100, 200, 3), True, True, False, "DNHWC: batch of 3D volumes with RGB"), |
| 333 | + ((5, 10, 100, 200, 1), True, True, True, "DNHWC: batch of 3D volumes with single channel"), |
| 334 | +]) |
| 335 | +def test_is_grayscale_image(shape, has_batch_dim, has_depth_dim, expected_grayscale, description): |
| 336 | + """Test is_grayscale_image with various shape combinations.""" |
| 337 | + image = np.zeros(shape) |
| 338 | + result = is_grayscale_image(image, has_batch_dim=has_batch_dim, has_depth_dim=has_depth_dim) |
| 339 | + assert result == expected_grayscale, f"Failed for {description} with shape {shape}, has_batch_dim={has_batch_dim}, has_depth_dim={has_depth_dim}" |
| 340 | + |
| 341 | + |
| 342 | +@pytest.mark.parametrize("shape, has_batch_dim, has_depth_dim", [ |
| 343 | + # Basic 2D cases |
| 344 | + ((100, 200), False, False), |
| 345 | + ((100, 200, 1), False, False), |
| 346 | + ((100, 200, 3), False, False), |
| 347 | + # Batch cases (NHW/NHWC) |
| 348 | + ((10, 100, 200), True, False), |
| 349 | + ((10, 100, 200, 1), True, False), |
| 350 | + ((10, 100, 200, 3), True, False), |
| 351 | + # Depth cases (DHW/DHWC) |
| 352 | + ((5, 100, 200), False, True), |
| 353 | + ((5, 100, 200, 1), False, True), |
| 354 | + ((5, 100, 200, 3), False, True), |
| 355 | + # Batch and depth cases (DNHW/DNHWC) |
| 356 | + ((5, 10, 100, 200), True, True), |
| 357 | + ((5, 10, 100, 200, 1), True, True), |
| 358 | + ((5, 10, 100, 200, 3), True, True), |
| 359 | +]) |
| 360 | +def test_get_num_channels_and_is_grayscale_consistency(shape, has_batch_dim, has_depth_dim): |
| 361 | + """Test that get_num_channels and is_grayscale_image are consistent.""" |
| 362 | + image = np.zeros(shape) |
| 363 | + num_channels = get_num_channels(image, has_batch_dim=has_batch_dim, has_depth_dim=has_depth_dim) |
| 364 | + is_grayscale = is_grayscale_image(image, has_batch_dim=has_batch_dim, has_depth_dim=has_depth_dim) |
| 365 | + |
| 366 | + # is_grayscale should be True if and only if num_channels == 1 |
| 367 | + assert (num_channels == 1) == is_grayscale, ( |
| 368 | + f"Inconsistency for shape {shape}, has_batch_dim={has_batch_dim}, has_depth_dim={has_depth_dim}: " |
| 369 | + f"num_channels={num_channels}, is_grayscale={is_grayscale}" |
| 370 | + ) |
0 commit comments