Skip to content

Commit bcf3ff1

Browse files
committed
Speedup in normalize
1 parent 38a76fa commit bcf3ff1

File tree

5 files changed

+233
-85
lines changed

5 files changed

+233
-85
lines changed

.cursor/rules/optimizations.mdc

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Performance Optimization Guidelines
2+
3+
## OpenCV LUT (Look-Up Table) Operations
4+
5+
### Critical: Maintain float32 dtype for LUT arrays
6+
7+
When using `cv2.LUT()` with floating-point lookup tables, **always ensure the LUT array is float32, not float64**. This can have a dramatic performance impact, especially on large arrays like videos.
8+
9+
#### The Problem
10+
11+
OpenCV's statistics functions (`cv2.meanStdDev`, etc.) return float64 values. When these are used in LUT creation:
12+
13+
```python
14+
# BAD: Creates float64 LUT due to numpy promotion
15+
mean, std = cv2.meanStdDev(img) # Returns float64
16+
lut = (np.arange(0, 256, dtype=np.float32) - mean[0, 0]) / std[0, 0]
17+
# lut.dtype is now float64!
18+
```
19+
20+
This causes:
21+
1. `cv2.LUT()` returns a float64 array (slower operations)
22+
2. Subsequent operations (clip, etc.) are slower on float64
23+
3. Often requires `.astype(np.float32)` on the large result array (very expensive)
24+
25+
#### The Solution
26+
27+
Cast the LUT array to float32 after creation:
28+
29+
```python
30+
# GOOD: Maintain float32 throughout
31+
lut = ((np.arange(0, 256, dtype=np.float32) - mean[0, 0]) / std[0, 0]).astype(np.float32)
32+
# lut.dtype is float32
33+
```
34+
35+
#### Performance Impact
36+
37+
For a video of shape (200, 256, 256, 3):
38+
- With float64 LUT: ~111ms (includes expensive astype on result)
39+
- With float32 LUT: ~55ms (2x faster!)
40+
41+
#### Best Practices
42+
43+
1. **For uint8 images**: LUT operations are extremely fast and should be preferred when possible
44+
2. **Always check dtype**: Use `.astype(np.float32)` on small LUT arrays (256 elements) rather than large result arrays
45+
3. **Avoid dtype promotion**: Be aware that numpy operations with mixed dtypes promote to the higher precision type
46+
47+
#### Example: Image Normalization with LUT
48+
49+
```python
50+
def normalize_with_lut(img: np.ndarray) -> np.ndarray:
51+
"""Fast normalization for uint8 images using LUT"""
52+
# Get statistics
53+
mean, std = cv2.meanStdDev(img)
54+
mean = mean[0, 0]
55+
std = std[0, 0] + 1e-4
56+
57+
# Create LUT - ensure float32!
58+
lut = ((np.arange(0, 256, dtype=np.float32) - mean) / std).astype(np.float32)
59+
60+
# Apply LUT - result will be float32
61+
return cv2.LUT(img, lut).clip(-20, 20)
62+
```
63+
64+
This optimization applies to any LUT-based operation where floating-point precision is needed.

albucore/functions.py

Lines changed: 154 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,13 @@ def add(img: np.ndarray, value: ValueType, inplace: bool = False) -> np.ndarray:
276276

277277
def normalize_numpy(img: np.ndarray, mean: float | np.ndarray, denominator: float | np.ndarray) -> np.ndarray:
278278
img = img.astype(np.float32, copy=False)
279+
# Ensure mean and denominator are float32 to avoid dtype promotion
280+
if isinstance(mean, np.ndarray):
281+
mean = mean.astype(np.float32, copy=False)
282+
if isinstance(denominator, np.ndarray):
283+
denominator = denominator.astype(np.float32, copy=False)
279284
img -= mean
280-
return img * denominator
285+
return (img * denominator).astype(np.float32, copy=False)
281286

282287

283288
@preserve_channel_dim
@@ -286,12 +291,6 @@ def normalize_opencv(img: np.ndarray, mean: float | np.ndarray, denominator: flo
286291
mean_img = np.zeros_like(img, dtype=np.float32)
287292
denominator_img = np.zeros_like(img, dtype=np.float32)
288293

289-
# If mean or denominator are scalar, convert them to arrays
290-
if isinstance(mean, (float, int)):
291-
mean = np.full(img.shape, mean, dtype=np.float32)
292-
if isinstance(denominator, (float, int)):
293-
denominator = np.full(img.shape, denominator, dtype=np.float32)
294-
295294
# Ensure the shapes match for broadcasting
296295
mean_img = (mean_img + mean).astype(np.float32, copy=False)
297296
denominator_img = denominator_img + denominator
@@ -307,27 +306,39 @@ def normalize_lut(img: np.ndarray, mean: float | np.ndarray, denominator: float
307306
num_channels = get_num_channels(img)
308307

309308
if isinstance(denominator, (float, int)) and isinstance(mean, (float, int)):
310-
lut = (np.arange(0, max_value + 1, dtype=np.float32) - mean) * denominator
309+
lut = ((np.arange(0, max_value + 1, dtype=np.float32) - mean) * denominator).astype(np.float32)
311310
return cv2.LUT(img, lut)
312311

313-
if isinstance(denominator, np.ndarray) and denominator.shape != ():
314-
denominator = denominator.reshape(-1, 1)
315-
312+
# Convert to float32 if needed
316313
if isinstance(mean, np.ndarray):
317-
mean = mean.reshape(-1, 1)
314+
mean = mean.astype(np.float32, copy=False)
315+
if isinstance(denominator, np.ndarray):
316+
denominator = denominator.astype(np.float32, copy=False)
317+
318+
# Vectorized LUT creation - shape: (256, num_channels)
319+
arange_vals = np.arange(0, max_value + 1, dtype=np.float32)
320+
luts = (arange_vals[:, np.newaxis] - mean) * denominator
318321

319-
luts = (np.arange(0, max_value + 1, dtype=np.float32) - mean) * denominator
322+
# Pre-allocate result array
323+
result = np.empty_like(img, dtype=np.float32)
324+
for i in range(num_channels):
325+
result[..., i] = cv2.LUT(img[..., i], luts[:, i])
320326

321-
return np.stack([cv2.LUT(img[..., i], luts[i]) for i in range(num_channels)], axis=-1)
327+
return result
322328

323329

324330
def normalize(img: np.ndarray, mean: ValueType, denominator: ValueType) -> np.ndarray:
325331
num_channels = get_num_channels(img)
326332
denominator = convert_value(denominator, num_channels)
327333
mean = convert_value(mean, num_channels)
334+
328335
if img.dtype == np.uint8:
329336
return normalize_lut(img, mean, denominator)
330337

338+
if img.dtype == np.float32:
339+
return normalize_numpy(img, mean, denominator)
340+
341+
# Fallback to OpenCV for other dtypes
331342
return normalize_opencv(img, mean, denominator)
332343

333344

@@ -474,6 +485,66 @@ def multiply_add(img: np.ndarray, factor: ValueType, value: ValueType, inplace:
474485
return multiply_add_opencv(img, factor, value)
475486

476487

488+
def _compute_image_stats_opencv(img: np.ndarray) -> tuple[float, float]:
489+
"""Compute global mean and std for an image."""
490+
eps = 1e-4
491+
if img.ndim > 3:
492+
# For 4D/5D arrays (video/volume), OpenCV returns global mean/std directly
493+
mean, std = cv2.meanStdDev(img)
494+
return float(mean[0, 0]), float(std[0, 0]) + eps
495+
# For 3D images, use numpy for accurate global statistics
496+
return float(img.mean()), float(img.std()) + eps
497+
498+
499+
def _compute_per_channel_stats_opencv(img: np.ndarray, spatial_axes: tuple[int, ...]) -> tuple[np.ndarray, np.ndarray]:
500+
"""Compute per-channel mean and std."""
501+
eps = 1e-4
502+
if img.ndim > 3:
503+
# For 4D/5D arrays, compute per-channel statistics using numpy
504+
mean = img.mean(axis=spatial_axes)
505+
std = img.std(axis=spatial_axes) + eps
506+
else:
507+
# For 3D arrays, use OpenCV
508+
mean, std = cv2.meanStdDev(img)
509+
mean = mean[:, 0]
510+
std = std[:, 0] + eps
511+
return mean, std
512+
513+
514+
def _normalize_mean_std_opencv(img_f: np.ndarray, mean: float | np.ndarray, std: float | np.ndarray) -> np.ndarray:
515+
"""Apply mean-std normalization using OpenCV or NumPy based on dimensionality."""
516+
if img_f.ndim > 3:
517+
# Use NumPy operations for 4D/5D (faster)
518+
normalized_img = (img_f - mean) / std
519+
else:
520+
# Use OpenCV for 3D
521+
if img_f.shape[-1] > MAX_OPENCV_WORKING_CHANNELS:
522+
mean = np.full_like(img_f, mean)
523+
std = np.full_like(img_f, std)
524+
normalized_img = cv2.divide(cv2.subtract(img_f, mean, dtype=cv2.CV_32F), std, dtype=cv2.CV_32F)
525+
return np.clip(normalized_img, -20, 20, out=normalized_img)
526+
527+
528+
def _normalize_min_max_per_channel_opencv(img: np.ndarray, spatial_axes: tuple[int, ...]) -> np.ndarray:
529+
"""Apply per-channel min-max normalization."""
530+
eps = 1e-4
531+
532+
img_min = img.min(axis=spatial_axes)
533+
img_max = img.max(axis=spatial_axes)
534+
535+
if img.shape[-1] > MAX_OPENCV_WORKING_CHANNELS:
536+
img_min = np.full_like(img, img_min)
537+
img_max = np.full_like(img, img_max)
538+
539+
# Use NumPy operations for 4D/5D (faster), OpenCV for 3D
540+
if img.ndim > 3:
541+
normalized_img = (img - img_min) / (img_max - img_min + eps)
542+
else:
543+
normalized_img = cv2.divide(cv2.subtract(img, img_min), (img_max - img_min + eps), dtype=cv2.CV_32F)
544+
545+
return np.clip(normalized_img, -20, 20, out=normalized_img)
546+
547+
477548
@preserve_channel_dim
478549
def normalize_per_image_opencv(
479550
img: np.ndarray,
@@ -508,47 +579,27 @@ def normalize_per_image_opencv(
508579
- For images with >4 channels, falls back to array operations as OpenCV has limitations
509580
- Single channel images treated as "image" normalization when "image_per_channel" is specified
510581
"""
511-
img = img.astype(np.float32, copy=False)
512-
eps = 1e-4
582+
# Handle single-channel edge case
583+
if img.shape[-1] == 1 and normalization == "image_per_channel":
584+
normalization = "image"
585+
if img.shape[-1] == 1 and normalization == "min_max_per_channel":
586+
normalization = "min_max"
513587

514-
if normalization == "image" or (img.shape[-1] == 1 and normalization == "image_per_channel"):
515-
mean = img.mean().item()
516-
std = img.std().item() + eps
517-
if img.shape[-1] > MAX_OPENCV_WORKING_CHANNELS:
518-
mean = np.full_like(img, mean)
519-
std = np.full_like(img, std)
520-
normalized_img = cv2.divide(cv2.subtract(img, mean), std)
521-
return np.clip(normalized_img, -20, 20, out=normalized_img)
588+
if normalization == "image":
589+
mean, std = _compute_image_stats_opencv(img)
590+
img_f = img.astype(np.float32, copy=False)
591+
return _normalize_mean_std_opencv(img_f, mean, std)
522592

523593
if normalization == "image_per_channel":
524-
mean, std = cv2.meanStdDev(img)
525-
mean = mean[:, 0]
526-
std = std[:, 0]
527-
528-
if img.shape[-1] > MAX_OPENCV_WORKING_CHANNELS:
529-
mean = np.full_like(img, mean)
530-
std = np.full_like(img, std)
594+
mean, std = _compute_per_channel_stats_opencv(img, spatial_axes)
595+
img_f = img.astype(np.float32, copy=False)
596+
return _normalize_mean_std_opencv(img_f, mean, std)
531597

532-
normalized_img = cv2.divide(cv2.subtract(img, mean), std, dtype=cv2.CV_32F)
533-
return np.clip(normalized_img, -20, 20, out=normalized_img)
534-
535-
if normalization == "min_max" or (img.shape[-1] == 1 and normalization == "min_max_per_channel"):
598+
if normalization == "min_max":
536599
return cv2.normalize(img, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
537600

538601
if normalization == "min_max_per_channel":
539-
img_min = img.min(axis=spatial_axes)
540-
img_max = img.max(axis=spatial_axes)
541-
542-
if img.shape[-1] > MAX_OPENCV_WORKING_CHANNELS:
543-
img_min = np.full_like(img, img_min)
544-
img_max = np.full_like(img, img_max)
545-
546-
return np.clip(
547-
cv2.divide(cv2.subtract(img, img_min), (img_max - img_min + eps), dtype=cv2.CV_32F),
548-
-20,
549-
20,
550-
out=img,
551-
)
602+
return _normalize_min_max_per_channel_opencv(img, spatial_axes)
552603

553604
raise ValueError(f"Unknown normalization method: {normalization}")
554605

@@ -656,33 +707,52 @@ def normalize_per_image_lut(
656707
num_channels = get_num_channels(img)
657708

658709
if normalization == "image" or (img.shape[-1] == 1 and normalization == "image_per_channel"):
659-
mean = img.mean()
660-
std = img.std() + eps
661-
lut = (np.arange(0, max_value + 1, dtype=np.float32) - mean) / std
662-
return cv2.LUT(img, lut).clip(-20, 20).astype(np.float32)
710+
if img.ndim > 3:
711+
# For 4D/5D arrays (video/volume), OpenCV returns global mean/std directly
712+
mean, std = cv2.meanStdDev(img)
713+
mean = mean[0, 0]
714+
std = std[0, 0] + eps
715+
else:
716+
# For 3D images, use numpy for accurate global statistics
717+
mean = img.mean()
718+
std = img.std() + eps
719+
720+
lut = ((np.arange(0, max_value + 1, dtype=np.float32) - mean) / std).astype(np.float32)
721+
return cv2.LUT(img, lut).clip(-20, 20)
663722

664723
if normalization == "image_per_channel":
665-
pixel_mean = img.mean(axis=spatial_axes)
666-
pixel_std = img.std(axis=spatial_axes) + eps
667-
luts = [
668-
(np.arange(0, max_value + 1, dtype=np.float32) - pixel_mean[c]) / pixel_std[c] for c in range(num_channels)
669-
]
670-
return np.stack([cv2.LUT(img[..., i], luts[i]).clip(-20, 20) for i in range(num_channels)], axis=-1)
724+
pixel_mean = img.mean(axis=spatial_axes).astype(np.float32)
725+
pixel_std = img.std(axis=spatial_axes).astype(np.float32) + np.float32(eps)
726+
727+
# Create all LUTs at once using vectorized operations
728+
arange_vals = np.arange(0, max_value + 1, dtype=np.float32)
729+
# LUTs shape will be (256, num_channels)
730+
luts = (arange_vals[:, np.newaxis] - pixel_mean) / pixel_std
731+
732+
result = np.empty_like(img, dtype=np.float32)
733+
for i in range(num_channels):
734+
result[..., i] = cv2.LUT(img[..., i], luts[:, i])
735+
return result.clip(-20, 20)
671736

672737
if normalization == "min_max" or (img.shape[-1] == 1 and normalization == "min_max_per_channel"):
673738
img_min = img.min()
674739
img_max = img.max()
675-
lut = (np.arange(0, max_value + 1, dtype=np.float32) - img_min) / (img_max - img_min + eps)
676-
return cv2.LUT(img, lut).clip(-20, 20).astype(np.float32)
740+
lut = ((np.arange(0, max_value + 1, dtype=np.float32) - img_min) / (img_max - img_min + eps)).astype(np.float32)
741+
return cv2.LUT(img, lut).clip(-20, 20)
677742

678743
if normalization == "min_max_per_channel":
679744
img_min = img.min(axis=spatial_axes)
680745
img_max = img.max(axis=spatial_axes)
681-
luts = [
682-
(np.arange(0, max_value + 1, dtype=np.float32) - img_min[c]) / (img_max[c] - img_min[c] + eps)
683-
for c in range(num_channels)
684-
]
685-
return np.stack([cv2.LUT(img[..., i], luts[i]) for i in range(num_channels)], axis=-1).astype(np.float32)
746+
747+
# Create all LUTs at once using vectorized operations
748+
arange_vals = np.arange(0, max_value + 1, dtype=np.float32)
749+
# LUTs shape will be (256, num_channels)
750+
luts = ((arange_vals[:, np.newaxis] - img_min) / (img_max - img_min + eps)).astype(np.float32)
751+
752+
result = np.empty_like(img, dtype=np.float32)
753+
for i in range(num_channels):
754+
result[..., i] = cv2.LUT(img[..., i], luts[:, i])
755+
return result.clip(-20, 20)
686756

687757
raise ValueError(f"Unknown normalization method: {normalization}")
688758

@@ -705,8 +775,8 @@ def normalize_per_image(img: np.ndarray, normalization: NormalizationType) -> np
705775
Normalized image as float32 array with values clipped to [-20, 20] range.
706776
707777
Notes:
708-
- For uint8 images (except "image_per_channel"), uses LUT method for maximum speed
709-
- For other dtypes, uses OpenCV implementation for good performance
778+
- For uint8 images (except "min_max"), uses LUT method for maximum speed
779+
- For other dtypes, uses OpenCV or NumPy implementation for good performance
710780
- Automatically determines spatial axes based on input dimensions
711781
"""
712782
# Determine spatial axes based on input dimensions
@@ -719,17 +789,27 @@ def normalize_per_image(img: np.ndarray, normalization: NormalizationType) -> np
719789
else:
720790
raise ValueError(f"Unsupported image dimensions: {img.ndim}. Expected 3, 4, or 5 dimensions.")
721791

722-
if img.dtype == np.uint8 and (
723-
(normalization != "image_per_channel" and img.ndim == 3)
724-
or (normalization == "min_max_per_channel" and img.ndim > 3)
725-
or (normalization == "image_per_channel" and img.ndim > 3)
726-
):
792+
# Optimized routing based on benchmarks
793+
794+
# Route uint8 images
795+
if img.dtype == np.uint8:
796+
# Use LUT for everything except min_max (where OpenCV is 3x faster)
797+
if normalization == "min_max":
798+
return normalize_per_image_opencv(img, normalization, spatial_axes)
799+
# LUT is fastest for "image", "image_per_channel", and "min_max_per_channel"
727800
return normalize_per_image_lut(img, normalization, spatial_axes)
728801

729-
# For ndim > 3, use numpy implementation as OpenCV doesn't handle batch dimensions well
802+
# Route float32 images
803+
if img.dtype == np.float32:
804+
if normalization == "image":
805+
# NumPy is 1.5x faster for "image" normalization
806+
return normalize_per_image_numpy(img, normalization, spatial_axes)
807+
# OpenCV is fastest or equal for all other normalizations
808+
return normalize_per_image_opencv(img, normalization, spatial_axes)
809+
810+
# Default fallback: OpenCV for single images, NumPy for videos/volumes
730811
if img.ndim > 3:
731812
return normalize_per_image_numpy(img, normalization, spatial_axes)
732-
733813
return normalize_per_image_opencv(img, normalization, spatial_axes)
734814

735815

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ requires = [ "setuptools>=45", "wheel" ]
55

66
[project]
77
name = "albucore"
8-
version = "0.0.30"
8+
version = "0.0.31"
99

1010
description = "High-performance image processing functions for deep learning and computer vision."
1111
readme = "README.md"

0 commit comments

Comments
 (0)