@@ -276,8 +276,13 @@ def add(img: np.ndarray, value: ValueType, inplace: bool = False) -> np.ndarray:
276276
277277def normalize_numpy (img : np .ndarray , mean : float | np .ndarray , denominator : float | np .ndarray ) -> np .ndarray :
278278 img = img .astype (np .float32 , copy = False )
279+ # Ensure mean and denominator are float32 to avoid dtype promotion
280+ if isinstance (mean , np .ndarray ):
281+ mean = mean .astype (np .float32 , copy = False )
282+ if isinstance (denominator , np .ndarray ):
283+ denominator = denominator .astype (np .float32 , copy = False )
279284 img -= mean
280- return img * denominator
285+ return ( img * denominator ). astype ( np . float32 , copy = False )
281286
282287
283288@preserve_channel_dim
@@ -286,12 +291,6 @@ def normalize_opencv(img: np.ndarray, mean: float | np.ndarray, denominator: flo
286291 mean_img = np .zeros_like (img , dtype = np .float32 )
287292 denominator_img = np .zeros_like (img , dtype = np .float32 )
288293
289- # If mean or denominator are scalar, convert them to arrays
290- if isinstance (mean , (float , int )):
291- mean = np .full (img .shape , mean , dtype = np .float32 )
292- if isinstance (denominator , (float , int )):
293- denominator = np .full (img .shape , denominator , dtype = np .float32 )
294-
295294 # Ensure the shapes match for broadcasting
296295 mean_img = (mean_img + mean ).astype (np .float32 , copy = False )
297296 denominator_img = denominator_img + denominator
@@ -307,27 +306,39 @@ def normalize_lut(img: np.ndarray, mean: float | np.ndarray, denominator: float
307306 num_channels = get_num_channels (img )
308307
309308 if isinstance (denominator , (float , int )) and isinstance (mean , (float , int )):
310- lut = (np .arange (0 , max_value + 1 , dtype = np .float32 ) - mean ) * denominator
309+ lut = (( np .arange (0 , max_value + 1 , dtype = np .float32 ) - mean ) * denominator ). astype ( np . float32 )
311310 return cv2 .LUT (img , lut )
312311
313- if isinstance (denominator , np .ndarray ) and denominator .shape != ():
314- denominator = denominator .reshape (- 1 , 1 )
315-
312+ # Convert to float32 if needed
316313 if isinstance (mean , np .ndarray ):
317- mean = mean .reshape (- 1 , 1 )
314+ mean = mean .astype (np .float32 , copy = False )
315+ if isinstance (denominator , np .ndarray ):
316+ denominator = denominator .astype (np .float32 , copy = False )
317+
318+ # Vectorized LUT creation - shape: (256, num_channels)
319+ arange_vals = np .arange (0 , max_value + 1 , dtype = np .float32 )
320+ luts = (arange_vals [:, np .newaxis ] - mean ) * denominator
318321
319- luts = (np .arange (0 , max_value + 1 , dtype = np .float32 ) - mean ) * denominator
322+ # Pre-allocate result array
323+ result = np .empty_like (img , dtype = np .float32 )
324+ for i in range (num_channels ):
325+ result [..., i ] = cv2 .LUT (img [..., i ], luts [:, i ])
320326
321- return np . stack ([ cv2 . LUT ( img [..., i ], luts [ i ]) for i in range ( num_channels )], axis = - 1 )
327+ return result
322328
323329
324330def normalize (img : np .ndarray , mean : ValueType , denominator : ValueType ) -> np .ndarray :
325331 num_channels = get_num_channels (img )
326332 denominator = convert_value (denominator , num_channels )
327333 mean = convert_value (mean , num_channels )
334+
328335 if img .dtype == np .uint8 :
329336 return normalize_lut (img , mean , denominator )
330337
338+ if img .dtype == np .float32 :
339+ return normalize_numpy (img , mean , denominator )
340+
341+ # Fallback to OpenCV for other dtypes
331342 return normalize_opencv (img , mean , denominator )
332343
333344
@@ -474,6 +485,66 @@ def multiply_add(img: np.ndarray, factor: ValueType, value: ValueType, inplace:
474485 return multiply_add_opencv (img , factor , value )
475486
476487
488+ def _compute_image_stats_opencv (img : np .ndarray ) -> tuple [float , float ]:
489+ """Compute global mean and std for an image."""
490+ eps = 1e-4
491+ if img .ndim > 3 :
492+ # For 4D/5D arrays (video/volume), OpenCV returns global mean/std directly
493+ mean , std = cv2 .meanStdDev (img )
494+ return float (mean [0 , 0 ]), float (std [0 , 0 ]) + eps
495+ # For 3D images, use numpy for accurate global statistics
496+ return float (img .mean ()), float (img .std ()) + eps
497+
498+
499+ def _compute_per_channel_stats_opencv (img : np .ndarray , spatial_axes : tuple [int , ...]) -> tuple [np .ndarray , np .ndarray ]:
500+ """Compute per-channel mean and std."""
501+ eps = 1e-4
502+ if img .ndim > 3 :
503+ # For 4D/5D arrays, compute per-channel statistics using numpy
504+ mean = img .mean (axis = spatial_axes )
505+ std = img .std (axis = spatial_axes ) + eps
506+ else :
507+ # For 3D arrays, use OpenCV
508+ mean , std = cv2 .meanStdDev (img )
509+ mean = mean [:, 0 ]
510+ std = std [:, 0 ] + eps
511+ return mean , std
512+
513+
514+ def _normalize_mean_std_opencv (img_f : np .ndarray , mean : float | np .ndarray , std : float | np .ndarray ) -> np .ndarray :
515+ """Apply mean-std normalization using OpenCV or NumPy based on dimensionality."""
516+ if img_f .ndim > 3 :
517+ # Use NumPy operations for 4D/5D (faster)
518+ normalized_img = (img_f - mean ) / std
519+ else :
520+ # Use OpenCV for 3D
521+ if img_f .shape [- 1 ] > MAX_OPENCV_WORKING_CHANNELS :
522+ mean = np .full_like (img_f , mean )
523+ std = np .full_like (img_f , std )
524+ normalized_img = cv2 .divide (cv2 .subtract (img_f , mean , dtype = cv2 .CV_32F ), std , dtype = cv2 .CV_32F )
525+ return np .clip (normalized_img , - 20 , 20 , out = normalized_img )
526+
527+
528+ def _normalize_min_max_per_channel_opencv (img : np .ndarray , spatial_axes : tuple [int , ...]) -> np .ndarray :
529+ """Apply per-channel min-max normalization."""
530+ eps = 1e-4
531+
532+ img_min = img .min (axis = spatial_axes )
533+ img_max = img .max (axis = spatial_axes )
534+
535+ if img .shape [- 1 ] > MAX_OPENCV_WORKING_CHANNELS :
536+ img_min = np .full_like (img , img_min )
537+ img_max = np .full_like (img , img_max )
538+
539+ # Use NumPy operations for 4D/5D (faster), OpenCV for 3D
540+ if img .ndim > 3 :
541+ normalized_img = (img - img_min ) / (img_max - img_min + eps )
542+ else :
543+ normalized_img = cv2 .divide (cv2 .subtract (img , img_min ), (img_max - img_min + eps ), dtype = cv2 .CV_32F )
544+
545+ return np .clip (normalized_img , - 20 , 20 , out = normalized_img )
546+
547+
477548@preserve_channel_dim
478549def normalize_per_image_opencv (
479550 img : np .ndarray ,
@@ -508,47 +579,27 @@ def normalize_per_image_opencv(
508579 - For images with >4 channels, falls back to array operations as OpenCV has limitations
509580 - Single channel images treated as "image" normalization when "image_per_channel" is specified
510581 """
511- img = img .astype (np .float32 , copy = False )
512- eps = 1e-4
582+ # Handle single-channel edge case
583+ if img .shape [- 1 ] == 1 and normalization == "image_per_channel" :
584+ normalization = "image"
585+ if img .shape [- 1 ] == 1 and normalization == "min_max_per_channel" :
586+ normalization = "min_max"
513587
514- if normalization == "image" or (img .shape [- 1 ] == 1 and normalization == "image_per_channel" ):
515- mean = img .mean ().item ()
516- std = img .std ().item () + eps
517- if img .shape [- 1 ] > MAX_OPENCV_WORKING_CHANNELS :
518- mean = np .full_like (img , mean )
519- std = np .full_like (img , std )
520- normalized_img = cv2 .divide (cv2 .subtract (img , mean ), std )
521- return np .clip (normalized_img , - 20 , 20 , out = normalized_img )
588+ if normalization == "image" :
589+ mean , std = _compute_image_stats_opencv (img )
590+ img_f = img .astype (np .float32 , copy = False )
591+ return _normalize_mean_std_opencv (img_f , mean , std )
522592
523593 if normalization == "image_per_channel" :
524- mean , std = cv2 .meanStdDev (img )
525- mean = mean [:, 0 ]
526- std = std [:, 0 ]
527-
528- if img .shape [- 1 ] > MAX_OPENCV_WORKING_CHANNELS :
529- mean = np .full_like (img , mean )
530- std = np .full_like (img , std )
594+ mean , std = _compute_per_channel_stats_opencv (img , spatial_axes )
595+ img_f = img .astype (np .float32 , copy = False )
596+ return _normalize_mean_std_opencv (img_f , mean , std )
531597
532- normalized_img = cv2 .divide (cv2 .subtract (img , mean ), std , dtype = cv2 .CV_32F )
533- return np .clip (normalized_img , - 20 , 20 , out = normalized_img )
534-
535- if normalization == "min_max" or (img .shape [- 1 ] == 1 and normalization == "min_max_per_channel" ):
598+ if normalization == "min_max" :
536599 return cv2 .normalize (img , None , alpha = 0 , beta = 1 , norm_type = cv2 .NORM_MINMAX , dtype = cv2 .CV_32F )
537600
538601 if normalization == "min_max_per_channel" :
539- img_min = img .min (axis = spatial_axes )
540- img_max = img .max (axis = spatial_axes )
541-
542- if img .shape [- 1 ] > MAX_OPENCV_WORKING_CHANNELS :
543- img_min = np .full_like (img , img_min )
544- img_max = np .full_like (img , img_max )
545-
546- return np .clip (
547- cv2 .divide (cv2 .subtract (img , img_min ), (img_max - img_min + eps ), dtype = cv2 .CV_32F ),
548- - 20 ,
549- 20 ,
550- out = img ,
551- )
602+ return _normalize_min_max_per_channel_opencv (img , spatial_axes )
552603
553604 raise ValueError (f"Unknown normalization method: { normalization } " )
554605
@@ -656,33 +707,52 @@ def normalize_per_image_lut(
656707 num_channels = get_num_channels (img )
657708
658709 if normalization == "image" or (img .shape [- 1 ] == 1 and normalization == "image_per_channel" ):
659- mean = img .mean ()
660- std = img .std () + eps
661- lut = (np .arange (0 , max_value + 1 , dtype = np .float32 ) - mean ) / std
662- return cv2 .LUT (img , lut ).clip (- 20 , 20 ).astype (np .float32 )
710+ if img .ndim > 3 :
711+ # For 4D/5D arrays (video/volume), OpenCV returns global mean/std directly
712+ mean , std = cv2 .meanStdDev (img )
713+ mean = mean [0 , 0 ]
714+ std = std [0 , 0 ] + eps
715+ else :
716+ # For 3D images, use numpy for accurate global statistics
717+ mean = img .mean ()
718+ std = img .std () + eps
719+
720+ lut = ((np .arange (0 , max_value + 1 , dtype = np .float32 ) - mean ) / std ).astype (np .float32 )
721+ return cv2 .LUT (img , lut ).clip (- 20 , 20 )
663722
664723 if normalization == "image_per_channel" :
665- pixel_mean = img .mean (axis = spatial_axes )
666- pixel_std = img .std (axis = spatial_axes ) + eps
667- luts = [
668- (np .arange (0 , max_value + 1 , dtype = np .float32 ) - pixel_mean [c ]) / pixel_std [c ] for c in range (num_channels )
669- ]
670- return np .stack ([cv2 .LUT (img [..., i ], luts [i ]).clip (- 20 , 20 ) for i in range (num_channels )], axis = - 1 )
724+ pixel_mean = img .mean (axis = spatial_axes ).astype (np .float32 )
725+ pixel_std = img .std (axis = spatial_axes ).astype (np .float32 ) + np .float32 (eps )
726+
727+ # Create all LUTs at once using vectorized operations
728+ arange_vals = np .arange (0 , max_value + 1 , dtype = np .float32 )
729+ # LUTs shape will be (256, num_channels)
730+ luts = (arange_vals [:, np .newaxis ] - pixel_mean ) / pixel_std
731+
732+ result = np .empty_like (img , dtype = np .float32 )
733+ for i in range (num_channels ):
734+ result [..., i ] = cv2 .LUT (img [..., i ], luts [:, i ])
735+ return result .clip (- 20 , 20 )
671736
672737 if normalization == "min_max" or (img .shape [- 1 ] == 1 and normalization == "min_max_per_channel" ):
673738 img_min = img .min ()
674739 img_max = img .max ()
675- lut = (np .arange (0 , max_value + 1 , dtype = np .float32 ) - img_min ) / (img_max - img_min + eps )
676- return cv2 .LUT (img , lut ).clip (- 20 , 20 ). astype ( np . float32 )
740+ lut = (( np .arange (0 , max_value + 1 , dtype = np .float32 ) - img_min ) / (img_max - img_min + eps )). astype ( np . float32 )
741+ return cv2 .LUT (img , lut ).clip (- 20 , 20 )
677742
678743 if normalization == "min_max_per_channel" :
679744 img_min = img .min (axis = spatial_axes )
680745 img_max = img .max (axis = spatial_axes )
681- luts = [
682- (np .arange (0 , max_value + 1 , dtype = np .float32 ) - img_min [c ]) / (img_max [c ] - img_min [c ] + eps )
683- for c in range (num_channels )
684- ]
685- return np .stack ([cv2 .LUT (img [..., i ], luts [i ]) for i in range (num_channels )], axis = - 1 ).astype (np .float32 )
746+
747+ # Create all LUTs at once using vectorized operations
748+ arange_vals = np .arange (0 , max_value + 1 , dtype = np .float32 )
749+ # LUTs shape will be (256, num_channels)
750+ luts = ((arange_vals [:, np .newaxis ] - img_min ) / (img_max - img_min + eps )).astype (np .float32 )
751+
752+ result = np .empty_like (img , dtype = np .float32 )
753+ for i in range (num_channels ):
754+ result [..., i ] = cv2 .LUT (img [..., i ], luts [:, i ])
755+ return result .clip (- 20 , 20 )
686756
687757 raise ValueError (f"Unknown normalization method: { normalization } " )
688758
@@ -705,8 +775,8 @@ def normalize_per_image(img: np.ndarray, normalization: NormalizationType) -> np
705775 Normalized image as float32 array with values clipped to [-20, 20] range.
706776
707777 Notes:
708- - For uint8 images (except "image_per_channel "), uses LUT method for maximum speed
709- - For other dtypes, uses OpenCV implementation for good performance
778+ - For uint8 images (except "min_max "), uses LUT method for maximum speed
779+ - For other dtypes, uses OpenCV or NumPy implementation for good performance
710780 - Automatically determines spatial axes based on input dimensions
711781 """
712782 # Determine spatial axes based on input dimensions
@@ -719,17 +789,27 @@ def normalize_per_image(img: np.ndarray, normalization: NormalizationType) -> np
719789 else :
720790 raise ValueError (f"Unsupported image dimensions: { img .ndim } . Expected 3, 4, or 5 dimensions." )
721791
722- if img .dtype == np .uint8 and (
723- (normalization != "image_per_channel" and img .ndim == 3 )
724- or (normalization == "min_max_per_channel" and img .ndim > 3 )
725- or (normalization == "image_per_channel" and img .ndim > 3 )
726- ):
792+ # Optimized routing based on benchmarks
793+
794+ # Route uint8 images
795+ if img .dtype == np .uint8 :
796+ # Use LUT for everything except min_max (where OpenCV is 3x faster)
797+ if normalization == "min_max" :
798+ return normalize_per_image_opencv (img , normalization , spatial_axes )
799+ # LUT is fastest for "image", "image_per_channel", and "min_max_per_channel"
727800 return normalize_per_image_lut (img , normalization , spatial_axes )
728801
729- # For ndim > 3, use numpy implementation as OpenCV doesn't handle batch dimensions well
802+ # Route float32 images
803+ if img .dtype == np .float32 :
804+ if normalization == "image" :
805+ # NumPy is 1.5x faster for "image" normalization
806+ return normalize_per_image_numpy (img , normalization , spatial_axes )
807+ # OpenCV is fastest or equal for all other normalizations
808+ return normalize_per_image_opencv (img , normalization , spatial_axes )
809+
810+ # Default fallback: OpenCV for single images, NumPy for videos/volumes
730811 if img .ndim > 3 :
731812 return normalize_per_image_numpy (img , normalization , spatial_axes )
732-
733813 return normalize_per_image_opencv (img , normalization , spatial_axes )
734814
735815
0 commit comments