diff --git a/Cargo.lock b/Cargo.lock index 045c72176fd..fd70408ef1b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -799,7 +799,7 @@ dependencies = [ "bitflags", "cexpr", "clang-sys", - "itertools 0.10.5", + "itertools 0.13.0", "log", "prettyplease", "proc-macro2", @@ -1317,7 +1317,7 @@ checksum = "af491d569909a7e4dee0ad7db7f5341fef5c614d5b8ec8cf765732aba3cff681" dependencies = [ "serde", "termcolor", - "unicode-width 0.1.14", + "unicode-width 0.2.2", ] [[package]] @@ -6137,9 +6137,9 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "044b1fa4f259f4df9ad5078e587b208f5d288a25407575fcddb9face30c7c692" dependencies = [ - "rand 0.8.6", + "rand 0.9.4", "socket2", - "thiserror 1.0.69", + "thiserror 2.0.18", ] [[package]] @@ -6352,7 +6352,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7" dependencies = [ "heck", - "itertools 0.10.5", + "itertools 0.14.0", "log", "multimap", "petgraph", @@ -6384,7 +6384,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" dependencies = [ "anyhow", - "itertools 0.10.5", + "itertools 0.14.0", "proc-macro2", "quote", "syn 2.0.117", @@ -8812,6 +8812,27 @@ dependencies = [ "tracing-serde", ] +[[package]] +name = "tracing-test" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19a4c448db514d4f24c5ddb9f73f2ee71bfb24c526cf0c570ba142d1119e0051" +dependencies = [ + "tracing-core", + "tracing-subscriber", + "tracing-test-macro", +] + +[[package]] +name = "tracing-test-macro" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad06847b7afb65c7866a36664b75c40b895e318cea4f71299f013fb22965329d" +dependencies = [ + "quote", + "syn 2.0.117", +] + [[package]] name = "try-lock" version = "0.2.5" @@ -10138,6 +10159,8 @@ dependencies = [ "prost 0.14.3", "rand 0.10.1", "rstest", + "tracing", + "tracing-test", "vortex-array", "vortex-buffer", "vortex-error", diff --git a/vortex-turboquant/Cargo.toml b/vortex-turboquant/Cargo.toml index ab3f63583d3..00b1c24982b 100644 --- a/vortex-turboquant/Cargo.toml +++ b/vortex-turboquant/Cargo.toml @@ -20,6 +20,7 @@ workspace = true half = { workspace = true } num-traits = { workspace = true } prost = { workspace = true } +tracing = { workspace = true } vortex-array = { workspace = true } vortex-buffer = { workspace = true } vortex-error = { workspace = true } @@ -32,6 +33,7 @@ vortex-utils = { workspace = true, features = ["dashmap"] } divan = { workspace = true } rand = { workspace = true } rstest = { workspace = true } +tracing-test = "0.2" vortex-file = { workspace = true } vortex-io = { workspace = true } vortex-layout = { workspace = true } diff --git a/vortex-turboquant/benches/encode_decode.rs b/vortex-turboquant/benches/encode_decode.rs index f88c37c347a..9ef41ae2210 100644 --- a/vortex-turboquant/benches/encode_decode.rs +++ b/vortex-turboquant/benches/encode_decode.rs @@ -4,9 +4,9 @@ //! Benchmarks for `turboquant_encode` and `turboquant_decode` across different validity-mask //! shapes. //! -//! The four mask shapes (`AllTrue`, `AllFalse`, dense `Values`, sparse `Values`) exercise the -//! variant-specialized paths added in the mask refactor in `vector/normalize.rs`, -//! `vector/quantize.rs`, and `scalar_fns/decode.rs`. +//! The four mask shapes (`AllTrue`, `AllFalse`, dense `Values`, sparse `Values`) exercise both +//! the encoder's per-row mask dispatch in `vector/quantize.rs` and the variant-specialized mask +//! arms in `scalar_fns/decode.rs`. #![expect(clippy::unwrap_used)] @@ -118,7 +118,7 @@ fn decode(encoded: ArrayRef, ctx: &mut ExecutionCtx) -> ArrayRef { fn config() -> TurboQuantConfig { // 4 bits, 4 SORF rounds, fixed seed: representative defaults from the test fixtures. - TurboQuantConfig::try_new(4, 0xDEADBEEF, 4).unwrap() + TurboQuantConfig::try_new(4, 0xDEADBEEF, 4, None).unwrap() } #[divan::bench(args = MASK_SHAPES)] diff --git a/vortex-turboquant/src/centroids.rs b/vortex-turboquant/src/centroids.rs index 2be60c0ed4e..c00d920c5ca 100644 --- a/vortex-turboquant/src/centroids.rs +++ b/vortex-turboquant/src/centroids.rs @@ -13,14 +13,17 @@ //! The Max-Lloyd algorithm finds optimal quantization centroids that minimize MSE for this //! distribution. //! -//! Centroids are not stored in TurboQuant arrays. They are deterministically derived from -//! `(padded_dim, bit_width)` and cached process-locally. +//! Centroids and their decision boundaries are not stored in TurboQuant arrays. They are +//! deterministically derived from `(block_size, bit_width)` and cached together process-locally as +//! a [`Codebook`]. Each block of a block-decomposed TurboQuant array uses its own codebook sized to +//! that block's width. //! //! The centroid model follows the random orthogonal transform marginal used by the TurboQuant //! paper. This encoder applies a SORF-style structured transform instead of a dense random Gaussian //! or orthogonal matrix, so paper-level error bounds should not be treated as verified for this //! implementation without separate empirical validation. +use std::sync::Arc; use std::sync::LazyLock; use vortex_buffer::Buffer; @@ -29,7 +32,7 @@ use vortex_error::vortex_ensure; use vortex_utils::aliases::dash_map::DashMap; use crate::config::MAX_BIT_WIDTH; -use crate::config::MIN_DIMENSION; +use crate::config::MIN_BLOCK_SIZE; // NB: All of these constants were chosen arbitrarily. @@ -44,38 +47,60 @@ const CONVERGENCE_EPSILON: f64 = 1e-12; /// The trapezoidal rule evaluates the integrand at `INTEGRATION_TRAPEZOIDS + 1` points. const INTEGRATION_TRAPEZOIDS: usize = 1000; -/// Global centroid cache keyed by (dimension, bit_width). -static CENTROID_CACHE: LazyLock>> = LazyLock::new(DashMap::default); +/// Global codebook cache keyed by `(block_size, bit_width)`. +static CODEBOOK_CACHE: LazyLock>> = + LazyLock::new(DashMap::default); -/// Get or compute cached centroids for the given dimension and bit width. +/// A cached scalar-quantization codebook for one `(block_size, bit_width)`. /// -/// Returns `2^bit_width` centroids sorted in ascending order, representing optimal scalar -/// quantization levels for the coordinate distribution after a random orthogonal transform in -/// `dimension`-dimensional space. -pub(crate) fn compute_or_get_centroids(dimension: u32, bit_width: u8) -> VortexResult> { +/// Centroids and boundaries are stored together because the boundaries are a pure function of the +/// centroids and share the same key. Decode reads one side while encode reads the other, and a +/// cache hit hands back a single [`Arc`] regardless of which side the caller needs. +pub(crate) struct Codebook { + /// `2^bit_width` centroids sorted in ascending order. Decode maps each code to its centroid. + pub(crate) centroids: Buffer, + /// Decision boundaries (`centroids.len() - 1` midpoints) consumed by [`find_nearest_centroid`] + /// when encode maps each coordinate to its nearest centroid. + pub(crate) boundaries: Buffer, +} + +/// Get or compute the cached [`Codebook`] for the given block size and bit width. +/// +/// The centroids are `2^bit_width` MSE-optimal quantization levels (sorted ascending) for the +/// coordinate distribution after a random orthogonal transform in `block_size`-dimensional space; +/// the boundaries are their midpoints. Both are cached behind one [`Arc`] so a cache hit is a single +/// reference-count bump. +pub(crate) fn compute_or_get_codebook( + block_size: u32, + bit_width: u8, +) -> VortexResult> { vortex_ensure!( (1..=MAX_BIT_WIDTH).contains(&bit_width), "TurboQuant bit_width must be 1-{}, got {bit_width}", MAX_BIT_WIDTH ); vortex_ensure!( - dimension >= MIN_DIMENSION, - "TurboQuant dimension must be >= {}, got {dimension}", - MIN_DIMENSION + block_size >= MIN_BLOCK_SIZE, + "TurboQuant block size must be >= {MIN_BLOCK_SIZE}, got {block_size}" ); - if let Some(centroids) = CENTROID_CACHE.get(&(dimension, bit_width)) { - return Ok(centroids.clone()); + if let Some(codebook) = CODEBOOK_CACHE.get(&(block_size, bit_width)) { + return Ok(Arc::clone(codebook.value())); } - let centroids = max_lloyd_centroids(dimension, bit_width); - CENTROID_CACHE.insert((dimension, bit_width), centroids.clone()); + let centroids = max_lloyd_centroids(block_size, bit_width); + let boundaries = compute_centroid_boundaries(¢roids); + let codebook = Arc::new(Codebook { + centroids, + boundaries, + }); + CODEBOOK_CACHE.insert((block_size, bit_width), Arc::clone(&codebook)); - Ok(centroids) + Ok(codebook) } // TODO(connor): It would potentially be more performant if this was modelled as const generic -// parameters to functions. +// parameters to functions. Probably not worth the complexity. /// Half-integer exponent: represents `int_part + (if has_half { 0.5 } else { 0.0 })`. /// /// The marginal distribution exponent `(d-3)/2` is always an integer (when `d` is odd) or a @@ -83,19 +108,18 @@ pub(crate) fn compute_or_get_centroids(dimension: u32, bit_width: u8) -> VortexR /// /// This type makes that invariant explicit and avoids floating-point comparison in the hot path. #[derive(Clone, Copy, Debug)] -struct HalfIntExponent { - int_part: i32, +struct HalfUIntExponent { + int_part: u32, has_half: bool, } -impl HalfIntExponent { - /// Compute `(numerator) / 2` as a half-integer exponent. +impl HalfUIntExponent { + /// Compute `numerator / 2` as a half-integer exponent. /// - /// `numerator` is `d - 3` where `d` is the dimension (>= 2), so it can be negative. - fn from_numerator(numerator: i32) -> Self { - // Use Euclidean division to get floor division toward negative infinity. - let int_part = numerator.div_euclid(2); - let has_half = numerator.rem_euclid(2) != 0; + /// `numerator` is `d - 3` where `d` is the block size. + fn from_numerator(numerator: u32) -> Self { + let int_part = numerator / 2; + let has_half = !numerator.is_multiple_of(2); Self { int_part, has_half } } } @@ -111,21 +135,24 @@ impl HalfIntExponent { /// /// Centroids are seeded uniformly on `[±sqrt(bit_width) * sigma]` (where `sigma` is the standard /// deviation of the normal distribution that hypershere dimension values take, and specifically -/// `sigma = 1/sqrt(dimension)`) rather than across the full `[-1, 1]`, which strands most of the +/// `sigma = 1/sqrt(block_size)`) rather than across the full `[-1, 1]`, which strands most of the /// centroids in the near-zero-mass tails. /// /// Note that the `sqrt(bit_width)` is mostly empirically derived, we do not have a theoretical /// basis for choosing this other than the fact that it seems to produce good results. -fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Buffer { +fn max_lloyd_centroids(block_size: u32, bit_width: u8) -> Buffer { debug_assert!((1..=MAX_BIT_WIDTH).contains(&bit_width)); + // Callers validate `block_size >= MIN_BLOCK_SIZE`; asserted here so the `block_size - 3` + // below cannot underflow if a future caller forgets. + debug_assert!(block_size >= MIN_BLOCK_SIZE); let num_centroids = 1usize << bit_width; // For the marginal distribution on [-1, 1], we use the exponent (d-3)/2. - let exponent = HalfIntExponent::from_numerator(dimension as i32 - 3); + let exponent = HalfUIntExponent::from_numerator(block_size - 3); // The coordinate marginal concentrates around 0 with this standard deviation. - let sigma = 1.0 / f64::from(dimension).sqrt(); - let init_half = (f64::from(bit_width).sqrt() * sigma).min(1.0); + let sigma = 1.0 / (block_size as f64).sqrt(); + let init_half = ((bit_width as f64).sqrt() * sigma).min(1.0); // Initialize centroids uniformly on [-init_half, init_half], where the mass lives, so no cell // starts in a zero-mass region and freezes. @@ -169,7 +196,9 @@ fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Buffer { /// Returns `E[X | lo <= X <= hi]` where X has PDF proportional to `(1 - x^2)^exponent` on [-1, 1]. /// /// Since there is no closed form for the integrals, we compute this numerically. -fn mean_between_centroids(lo: f64, hi: f64, exponent: HalfIntExponent) -> f64 { +fn mean_between_centroids(lo: f64, hi: f64, exponent: HalfUIntExponent) -> f64 { + // If hi and lo are **very** close to each other, don't bother finding the "correct" conditional + // mean, as the midpoint is probably sufficient. if (hi - lo).abs() < 1e-15 { return (lo + hi) / 2.0; } @@ -205,15 +234,21 @@ fn mean_between_centroids(lo: f64, hi: f64, exponent: HalfIntExponent) -> f64 { /// Uses `powi` + `sqrt` instead of `powf` for the half-integer exponents that arise from `(d-3)/2`. /// This is significantly faster than the general `powf` which goes through /// `exp(exponent * ln(base))`. -fn pdf_unnormalized(x_val: f64, exponent: HalfIntExponent) -> f64 { +fn pdf_unnormalized(x_val: f64, exponent: HalfUIntExponent) -> f64 { let base = (1.0 - x_val * x_val).max(0.0); + #[expect( + clippy::cast_possible_wrap, + reason = "exponent is half a block size and fits i32" + )] + let int_exp = exponent.int_part as i32; + if exponent.has_half { // Half-integer exponent: base^(int_part) * sqrt(base). - base.powi(exponent.int_part) * base.sqrt() + base.powi(int_exp) * base.sqrt() } else { // Integer exponent: use powi directly. - base.powi(exponent.int_part) + base.powi(int_exp) } } @@ -222,7 +257,7 @@ fn pdf_unnormalized(x_val: f64, exponent: HalfIntExponent) -> f64 { /// For `k` centroids, returns `k-1` boundaries. A value below `boundaries[0]` maps to centroid 0, a /// value in `[boundaries[i-1], boundaries[i])` maps to centroid `i`, and a /// value `>= boundaries[k-2]` maps to centroid `k-1`. -pub(crate) fn compute_centroid_boundaries(centroids: &[f32]) -> Vec { +fn compute_centroid_boundaries(centroids: &[f32]) -> Buffer { centroids.windows(2).map(|w| (w[0] + w[1]) * 0.5).collect() } @@ -268,8 +303,8 @@ mod tests { #[case] bits: u8, #[case] expected: usize, ) -> VortexResult<()> { - let centroids = compute_or_get_centroids(dim, bits)?; - assert_eq!(centroids.len(), expected); + let codebook = compute_or_get_codebook(dim, bits)?; + assert_eq!(codebook.centroids.len(), expected); Ok(()) } @@ -280,12 +315,12 @@ mod tests { #[case(128, 4)] #[case(768, 2)] fn centroids_are_sorted(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { - let centroids = compute_or_get_centroids(dim, bits)?; - for window in centroids.windows(2) { + let codebook = compute_or_get_codebook(dim, bits)?; + for window in codebook.centroids.windows(2) { assert!( window[0] < window[1], "centroids not sorted: {:?}", - centroids + codebook.centroids ); } Ok(()) @@ -297,7 +332,8 @@ mod tests { #[case(256, 2)] #[case(768, 2)] fn centroids_are_symmetric(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { - let centroids = compute_or_get_centroids(dim, bits)?; + let codebook = compute_or_get_codebook(dim, bits)?; + let centroids = &codebook.centroids; let count = centroids.len(); for idx in 0..count / 2 { let diff = (centroids[idx] + centroids[count - 1 - idx]).abs(); @@ -316,8 +352,8 @@ mod tests { #[case(128, 1)] #[case(128, 4)] fn centroids_within_bounds(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { - let centroids = compute_or_get_centroids(dim, bits)?; - for &val in centroids.iter() { + let codebook = compute_or_get_codebook(dim, bits)?; + for &val in codebook.centroids.iter() { assert!( (-1.0..=1.0).contains(&val), "centroid out of [-1, 1]: {val}", @@ -327,35 +363,46 @@ mod tests { } #[test] - fn centroids_cached() -> VortexResult<()> { - let c1 = compute_or_get_centroids(128, 2)?; - let c2 = compute_or_get_centroids(128, 2)?; - assert_eq!(c1, c2); + fn codebook_cached() -> VortexResult<()> { + let cb1 = compute_or_get_codebook(128, 2)?; + let cb2 = compute_or_get_codebook(128, 2)?; + // The cache returns the same reference-counted codebook instance. + assert!(Arc::ptr_eq(&cb1, &cb2)); + // There is exactly one fewer boundary than centroids: a midpoint between adjacent levels. + assert_eq!(cb1.boundaries.len(), cb1.centroids.len() - 1); Ok(()) } #[test] fn find_nearest_basic() -> VortexResult<()> { - let centroids = compute_or_get_centroids(128, 2)?; - let boundaries = compute_centroid_boundaries(¢roids); - assert_eq!(find_nearest_centroid(-1.0, &boundaries), 0); + let codebook = compute_or_get_codebook(128, 2)?; + let centroids = &codebook.centroids; + let boundaries = &codebook.boundaries; + assert_eq!(find_nearest_centroid(-1.0, boundaries), 0); #[expect(clippy::cast_possible_truncation)] let last_idx = (centroids.len() - 1) as u8; - assert_eq!(find_nearest_centroid(1.0, &boundaries), last_idx); + assert_eq!(find_nearest_centroid(1.0, boundaries), last_idx); for (idx, &cv) in centroids.iter().enumerate() { #[expect(clippy::cast_possible_truncation)] let expected = idx as u8; - assert_eq!(find_nearest_centroid(cv, &boundaries), expected); + assert_eq!(find_nearest_centroid(cv, boundaries), expected); } Ok(()) } #[test] fn rejects_invalid_params() { - assert!(compute_or_get_centroids(128, 0).is_err()); - assert!(compute_or_get_centroids(128, 9).is_err()); - assert!(compute_or_get_centroids(1, 2).is_err()); - assert!(compute_or_get_centroids(127, 2).is_err()); + assert!(compute_or_get_codebook(128, 0).is_err()); + assert!(compute_or_get_codebook(128, 9).is_err()); + assert!(compute_or_get_codebook(1, 2).is_err()); + assert!(compute_or_get_codebook(63, 2).is_err()); + } + + #[test] + fn codebook_available_for_min_dimension() -> VortexResult<()> { + let codebook = compute_or_get_codebook(64, 2)?; + assert_eq!(codebook.centroids.len(), 4); + Ok(()) } } diff --git a/vortex-turboquant/src/config.rs b/vortex-turboquant/src/config.rs index 57cd8b1e94b..fa40aba7e63 100644 --- a/vortex-turboquant/src/config.rs +++ b/vortex-turboquant/src/config.rs @@ -8,9 +8,13 @@ use vortex_error::vortex_ensure; /// Minimum vector dimension for TurboQuant encoding. /// -/// Note that this is not a theoretical minimum, it is mostly a practical one to limit the total -/// amount of distortion. -pub(crate) const MIN_DIMENSION: u32 = 128; +/// Not a theoretical minimum, just a practical floor to limit total distortion. The minimum +/// per-block width [`MIN_BLOCK_SIZE`] is defined to equal this, so the smallest valid input is a +/// single minimum-width block; the two floors are intentionally tied to the same value. +pub(crate) const MIN_DIMENSION: u32 = 64; + +/// Minimum power-of-two block size. +pub(crate) const MIN_BLOCK_SIZE: u32 = MIN_DIMENSION; /// Maximum supported number of bits per quantized coordinate. pub(crate) const MAX_BIT_WIDTH: u8 = 8; @@ -21,15 +25,39 @@ pub struct TurboQuantConfig { bit_width: u8, seed: u64, num_rounds: u8, + block_sizes: Option>, +} + +impl Default for TurboQuantConfig { + /// Defaults to 8 bits per coordinate, seed 42, 3 SORF rounds, and the encode-time default + /// block decomposition. + fn default() -> Self { + Self { + bit_width: MAX_BIT_WIDTH, + seed: 42, + num_rounds: 3, + block_sizes: None, + } + } } impl TurboQuantConfig { /// Build a TurboQuant configuration. /// + /// When `block_sizes` is `None`, the encoder defaults to a single power-of-two block covering + /// the full dimension. When `Some`, the blocks are validated (non-empty, power-of-two, greater + /// than `MIN_BLOCK_SIZE`, and sum covers all dimensions). + /// /// # Errors /// - /// Returns an error if `bit_width` is outside `1..=8` or `num_rounds` is zero. - pub fn try_new(bit_width: u8, seed: u64, num_rounds: u8) -> VortexResult { + /// Returns an error if `bit_width` is outside `1..=8`, `num_rounds` is zero, or the supplied + /// `block_sizes` violate any of the dimension-independent rules. + pub fn try_new( + bit_width: u8, + seed: u64, + num_rounds: u8, + block_sizes: Option>, + ) -> VortexResult { vortex_ensure!( (1..=MAX_BIT_WIDTH).contains(&bit_width), "TurboQuant bit_width must be 1-{MAX_BIT_WIDTH}, got {bit_width}", @@ -39,10 +67,15 @@ impl TurboQuantConfig { "TurboQuant num_rounds must be > 0, got {num_rounds}" ); + if let Some(block_sizes) = block_sizes.as_deref() { + validate_block_shape(block_sizes)?; + } + Ok(Self { bit_width, seed, num_rounds, + block_sizes, }) } @@ -60,16 +93,11 @@ impl TurboQuantConfig { pub fn num_rounds(&self) -> u8 { self.num_rounds } -} -impl Default for TurboQuantConfig { - /// Defaults to 8 bits per coordinate, seed 42, and 3 SORF rounds. - fn default() -> Self { - Self { - bit_width: MAX_BIT_WIDTH, - seed: 42, - num_rounds: 3, - } + /// User-supplied power-of-two block decomposition, if any. `None` defers block resolution to + /// the encoder, which then picks a single block of the dimension rounded up to a power of two. + pub fn block_sizes(&self) -> Option<&[u32]> { + self.block_sizes.as_deref() } } @@ -77,8 +105,54 @@ impl fmt::Display for TurboQuantConfig { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "bit_width: {}, seed: {}, num_rounds: {}", + "bit_width: {}, seed: {}, num_rounds: {}, block_sizes: ", self.bit_width, self.seed, self.num_rounds - ) + )?; + + match self.block_sizes.as_deref() { + None => write!(f, "None"), + Some(block_sizes) => { + write!(f, "Some([")?; + for (index, block) in block_sizes.iter().enumerate() { + if index > 0 { + write!(f, ", ")?; + } + write!(f, "{block}")?; + } + write!(f, "])") + } + } } } + +/// Validate the dimension-independent block-shape rules: non-empty, power-of-two, each block at +/// least `MIN_BLOCK_SIZE`. +pub(crate) fn validate_block_shape(block_sizes: &[u32]) -> VortexResult<()> { + vortex_ensure!( + !block_sizes.is_empty(), + "TurboQuant block_sizes must be non-empty" + ); + + for (index, &block) in block_sizes.iter().enumerate() { + vortex_ensure!( + block >= MIN_BLOCK_SIZE, + "TurboQuant block {index} must be >= {MIN_BLOCK_SIZE}, got {block}" + ); + vortex_ensure!( + block.is_power_of_two(), + "TurboQuant block {index} must be a power of two, got {block}" + ); + } + Ok(()) +} + +/// Validate the dimension-dependent rule that the resolved blocks cover every dimension. The +/// encoder (`resolve_block_sizes`) and metadata validation (`validate_tq_metadata`) both call this. +pub(crate) fn validate_block_sum(block_sizes: &[u32], dimensions: u32) -> VortexResult<()> { + let sum: u64 = block_sizes.iter().map(|&block| block as u64).sum(); + vortex_ensure!( + sum >= dimensions as u64, + "TurboQuant block_sizes sum {sum} must be >= dimensions {dimensions}" + ); + Ok(()) +} diff --git a/vortex-turboquant/src/lib.rs b/vortex-turboquant/src/lib.rs index 7aeb60368dd..4aab214a34f 100644 --- a/vortex-turboquant/src/lib.rs +++ b/vortex-turboquant/src/lib.rs @@ -3,55 +3,71 @@ //! TurboQuant vector quantization extension type for Vortex. //! -//! Implements a Stage 1 TurboQuant encoding ([arXiv:2504.19874], [RFC 0033]) for lossy compression -//! of high-dimensional vector data. The extension operates on +//! Implements the TurboQuant encoding ([arXiv:2504.19874]) for lossy compression (quantization) of +//! high-dimensional vector data as a lossy extension type. TurboQuant converts to and from //! [`Vector`](vortex_tensor::vector::Vector) extension arrays, encoding their `FixedSizeList` -//! storage into quantized codes after a structured orthogonal surrogate transform. +//! storage into quantized codes after a random orthogonal transform (via a Structured Orthogonal +//! Random Features matrix, see the `sorf` module for more information). //! //! [arXiv:2504.19874]: https://arxiv.org/abs/2504.19874 -//! [RFC 0033]: https://vortex-data.github.io/rfcs/rfc/0033.html //! //! # Overview //! -//! TurboQuant minimizes mean-squared reconstruction error (1-8 bits per coordinate) -//! using MSE-optimal scalar quantization on coordinates of a transformed unit vector. +//! TurboQuant minimizes mean-squared reconstruction error (1-8 bits per coordinate) using +//! MSE-optimal scalar quantization on coordinates of a transformed unit vector. //! -//! The [`TQEncode`] scalar function first computes and stores the original L2 norm for each vector -//! row, then normalizes each valid nonzero row internally before SORF transform and scalar -//! quantization. The [`TQDecode`] scalar function dequantizes through deterministic centroids, -//! applies the inverse SORF transform, truncates back to the original dimension, and re-applies the -//! stored norm. +//! Each input vector of `dimensions` coordinates is split into a fixed sequence of contiguous +//! power-of-two-sized slices called **blocks**, whose widths are chosen by the user through +//! [`TurboQuantConfig`] and stored verbatim in [`TurboQuantMetadata::block_sizes`]. The TurboQuant +//! algorithm runs on every block independently: each block has its own stored L2 norm, its own SORF +//! matrix seeded by a distinct derived seed, and its own scalar-quantization centroid table sized +//! to that block's width. Block `i` covers input coordinates +//! `[offset_i .. offset_i + block_sizes[i])` with `offset_i = sum(block_sizes[..i])`; a block +//! extending past `dimensions` is zero-padded on encode, and the reconstructed coordinates past +//! `dimensions` are dropped on decode. //! -//! The encoded storage is a row-aligned extension tree: +//! The encoded storage is a row-aligned extension tree of one outer struct holding one inner struct +//! per block: //! //! ```text //! Extension( //! Struct { -//! norms: Primitive, -//! codes: FixedSizeList, padded_dim, vector_validity>, +//! block_0: Struct { +//! norms: Primitive, +//! codes: FixedSizeList, block_sizes[0], vector_validity>, +//! }, +//! ... +//! block_{N-1}: Struct { norms: ..., codes: FixedSizeList }, //! } //! ) //! ``` //! -//! Stored norms are authoritative for future TurboQuant-aware scalar functions. Decoded quantized -//! directions are not guaranteed to have unit norm after scalar quantization and inverse transform. +//! IMPORTANT NOTE: Stored norms are authoritative for future TurboQuant-aware scalar functions. +//! Decoded quantized directions are not guaranteed to have unit norm after scalar quantization and +//! inverse transform. +//! +//! # Limitations +//! +//! The current encoding is intentionally MSE-only. It does not yet implement the paper's QJL +//! residual correction for unbiased inner-product estimation. //! //! # Source map //! //! Implementation details are documented next to the code that owns them: //! -//! - `vector/storage.rs`: physical storage shape, full-length child arrays, and field-level -//! validity for null vectors. -//! - `vector/normalize.rs`: TurboQuant-local normalization and how it differs from the tensor -//! crate's null-row zeroing helper. -//! - `vector/quantize.rs`: SORF transform, centroid lookup, and why invalid rows are skipped rather -//! than quantized. +//! - `config.rs`: the operator-facing [`TurboQuantConfig`] and its bit-width and block-list +//! validation. +//! - `vtable.rs`: the [`TurboQuant`] extension dtype, [`TurboQuantMetadata`] (including the +//! `block_sizes` list), its proto wire format, and storage-dtype validation. +//! - `scalar_fns/`: the [`TQEncode`] and [`TQDecode`] scalar functions and the metadata wire +//! format glue. +//! - `vector/storage.rs`: the row-aligned per-block storage layout and the outer-covers-inner +//! validity coverage rules. +//! - `vector/quantize.rs`: the block-aware encode pipeline (per-block norm, per-block SORF, +//! scalar quantization). //! - `centroids.rs`: deterministic Max-Lloyd centroid computation and process-local caching. -//! - `sorf/`: the Walsh-Hadamard-based structured transform and the stable SplitMix64 sign stream. -//! -//! The current encoding is intentionally MSE-only. It does not yet implement the paper's QJL -//! residual correction for unbiased inner-product estimation, and it still uses internal -//! power-of-2 padding rather than the block decomposition proposed in RFC 0033. +//! - `sorf/`: the Walsh-Hadamard-based structured transform, the stable SplitMix64 sign stream, +//! and the per-block seed derivation that gives each block its own SORF instance. mod centroids; mod config; @@ -66,11 +82,15 @@ pub use scalar_fns::TQEncode; pub use vtable::TurboQuant; pub use vtable::TurboQuantMetadata; -// TODO(connor): We need to somehow make sure that callers call `vortex_tensor::initialize` first. +// TODO(connor): enforce the `vortex_tensor::initialize` ordering at registration time. /// Register the TurboQuant extension type with a Vortex session. +/// +/// Callers must register `vortex_tensor` on the session first so the `Vector` extension type that +/// TurboQuant converts to and from is available. pub fn initialize(session: &vortex_session::VortexSession) { use vortex_array::dtype::session::DTypeSessionExt; use vortex_array::scalar_fn::session::ScalarFnSessionExt; + session.dtypes().register(TurboQuant); session.scalar_fns().register(TQEncode); diff --git a/vortex-turboquant/src/scalar_fns/decode.rs b/vortex-turboquant/src/scalar_fns/decode.rs index 6791a1aef61..ff76110cd2b 100644 --- a/vortex-turboquant/src/scalar_fns/decode.rs +++ b/vortex-turboquant/src/scalar_fns/decode.rs @@ -2,13 +2,27 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors //! TurboQuant decode scalar function. +//! +//! Reverses the per-row, per-block encode pipeline in [`crate::vector::quantize`]. For each +//! block `i` of each valid input row, the decoder: +//! +//! 1. Reads that block's per-row codes and gathers the matching centroid values from a +//! `2^bit_width`-entry centroid table built for width `block_sizes[i]`. +//! 2. Applies the inverse SORF of width `block_sizes[i]` seeded with the same +//! `derive_block_seed(metadata.seed, i)` the encoder used. +//! 3. Multiplies the rotated coordinates by the per-row block norm stored in that block's +//! `norms` column. +//! 4. Writes the result into a row-aligned scratch buffer of width `sum(block_sizes)` at offsets +//! `[offset_i .. offset_i + block_sizes[i])`, the same offsets the encoder sliced from. +//! +//! Once every block is reconstructed for a row, the first `dimensions` coordinates of the +//! scratch buffer are copied into the output `Vector`, dropping any overspilling coordinates +//! the encoder zero-padded. use std::fmt; use std::fmt::Formatter; use std::sync::Arc; -use num_traits::Float; -use num_traits::FromPrimitive; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; @@ -16,7 +30,6 @@ use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; use vortex_array::dtype::DType; -use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; use vortex_array::dtype::extension::ExtDType; use vortex_array::expr::Expression; @@ -29,19 +42,18 @@ use vortex_array::scalar_fn::ScalarFnId; use vortex_array::scalar_fn::ScalarFnVTable; use vortex_array::scalar_fn::TypedScalarFnInstance; use vortex_array::validity::Validity; -use vortex_buffer::BufferMut; -use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_error::vortex_err; -use vortex_mask::Mask; use vortex_session::VortexSession; use vortex_tensor::vector::Vector; -use crate::centroids::compute_or_get_centroids; -use crate::sorf::SorfMatrix; +use crate::centroids::compute_or_get_codebook; +use crate::sorf::splitmix64::derive_block_seed; +use crate::sorf::transform::SorfMatrix; +use crate::vector::dequantize::DecodeInputs; +use crate::vector::dequantize::decode_typed; use crate::vector::storage::parse_storage; -use crate::vector::tq_padded_dim; use crate::vtable::TurboQuantMetadata; use crate::vtable::tq_metadata; @@ -153,31 +165,58 @@ impl ScalarFnVTable for TQDecode { /// Decode a `TurboQuant` extension array back into a `Vector` extension array. /// -/// The decoded directions are inverse-transformed, truncated to the original dimension, and -/// multiplied by the stored row norms. The conversion is lossy and does not roundtrip with -/// [`TQEncode`](crate::TQEncode). +/// Decodes each block by looking up centroid values from per-block codes, applying the inverse +/// SORF transform, and scaling by the stored per-row norm. +/// +/// Results are assembled into a scratch buffer of width `sum(block_sizes)`, then truncated to the +/// first `dimensions` coordinates to produce the output `Vector`. pub(crate) fn decode_vector(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { let parsed = parse_storage(input, ctx)?; - let metadata = parsed.metadata; if parsed.len == 0 { - return build_empty_vector(metadata, parsed.vector_validity); + return build_empty_vector(parsed.metadata, parsed.vector_validity); } - let padded_dim = tq_padded_dim(metadata.dimensions)?; - let transform = SorfMatrix::try_new(padded_dim, metadata.num_rounds as usize, metadata.seed)?; - let padded_dim = u32::try_from(padded_dim) - .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; - - let centroids = compute_or_get_centroids(padded_dim, metadata.bit_width)?; + let metadata = parsed.metadata; + let block_sizes: Vec = metadata + .block_sizes + .iter() + .map(|&b| { + usize::try_from(b).map_err(|_| vortex_err!("TurboQuant block {b} does not fit usize")) + }) + .collect::>>()?; + let total_width: usize = block_sizes.iter().sum(); + + let mut transforms = Vec::with_capacity(block_sizes.len()); + let mut centroids = Vec::with_capacity(block_sizes.len()); + + for (index, (&block, &block_u32)) in block_sizes + .iter() + .zip(metadata.block_sizes.iter()) + .enumerate() + { + let seed_i = derive_block_seed(metadata.seed, index); + + transforms.push(SorfMatrix::try_new( + block, + metadata.num_rounds as usize, + seed_i, + )?); + centroids.push( + compute_or_get_codebook(block_u32, metadata.bit_width)? + .centroids + .clone(), + ); + } match_each_float_ptype!(metadata.element_ptype, |T| { decode_typed::( DecodeInputs { metadata: &metadata, - sorf_matrix: &transform, - centroids: ¢roids, - norms: &parsed.norms, - codes: &parsed.codes, + block_sizes: &block_sizes, + total_width, + sorf_matrices: &transforms, + centroid_tables: ¢roids, + block_storages: &parsed.blocks, }, parsed.vector_validity, parsed.len, @@ -202,116 +241,3 @@ fn build_empty_vector( Vector::try_new_vector_array(fsl.into_array()) }) } - -/// Borrowed bundle of the per-array decode inputs passed to the typed inner loop. -/// -/// Packaged as a struct rather than positional arguments because `decode_typed` runs through -/// [`vortex_array::match_each_float_ptype!`] which expands once per supported element ptype. -/// Each expansion takes the same set of inputs, and the struct keeps the call site short. -struct DecodeInputs<'a> { - /// TurboQuant metadata recovered from the input extension dtype. - metadata: &'a TurboQuantMetadata, - /// SORF transform reconstructed from `metadata.seed` and `metadata.num_rounds`. - sorf_matrix: &'a SorfMatrix, - /// Centroid codebook for `(padded_dim, bit_width)`, in f32. - centroids: &'a [f32], - /// Per-row stored L2 norm of the original input vector, in the element ptype. - norms: &'a PrimitiveArray, - /// Flat per-row centroid indices, `num_vectors * padded_dim` bytes. - codes: &'a PrimitiveArray, -} - -fn decode_typed( - decode: DecodeInputs<'_>, - vector_validity: Validity, - num_vectors: usize, - ctx: &mut ExecutionCtx, -) -> VortexResult -where - T: NativePType + Float + FromPrimitive, -{ - let metadata = decode.metadata; - let dimensions = usize::try_from(metadata.dimensions) - .vortex_expect("dimensions stays representable as usize"); - let padded_dim = decode.sorf_matrix.padded_dim(); - let centroids = decode.centroids; - let norms = decode.norms.as_slice::(); - let codes = decode.codes.as_slice::(); - let mask = vector_validity.execute_mask(num_vectors, ctx)?; - - let output_len = num_vectors - .checked_mul(dimensions) - .ok_or_else(|| vortex_err!("TurboQuant decoded vector length overflow"))?; - let mut output = BufferMut::::with_capacity(output_len); - - let mut decoded = vec![0.0f32; padded_dim]; - let mut inverse = vec![0.0f32; padded_dim]; - - let mut decode_row = |output: &mut BufferMut, i: usize| { - let code_row = &codes[i * padded_dim..][..padded_dim]; - - for (dst, &code) in decoded.iter_mut().zip(code_row.iter()) { - *dst = *centroids - .get(usize::from(code)) - .vortex_expect("TurboQuant code exceeds centroid count"); - } - - decode.sorf_matrix.inverse_transform(&decoded, &mut inverse); - - let norm = norms[i]; - for &value in inverse.iter().take(dimensions) { - // `T::from_f32` is infallible for the supported float ptypes (`f16`, `f32`, - // `f64`): values outside `f16` range saturate to `±inf` rather than returning - // `None`. - let value = T::from_f32(value) - .vortex_expect("from_f32 is infallible for supported float types"); - - // SAFETY: total pushes across all match arms equal `output_len`. - unsafe { output.push_unchecked(value * norm) }; - } - }; - - match &mask { - Mask::AllFalse(_) => { - // SAFETY: `output` was allocated with capacity `output_len`, and this push writes - // exactly `output_len` zero placeholders. - unsafe { output.push_n_unchecked(T::zero(), output_len) }; - } - Mask::AllTrue(_) => { - for i in 0..num_vectors { - decode_row(&mut output, i); - } - } - Mask::Values(values_mask) => { - let mut cursor = 0; - - for &(start, end) in values_mask.slices() { - if start > cursor { - // SAFETY: total pushes across all arms equal `output_len`. - unsafe { output.push_n_unchecked(T::zero(), (start - cursor) * dimensions) }; - } - - for i in start..end { - decode_row(&mut output, i); - } - - cursor = end; - } - - if cursor < num_vectors { - // SAFETY: total pushes across all arms equal `output_len`. - unsafe { output.push_n_unchecked(T::zero(), (num_vectors - cursor) * dimensions) }; - } - } - } - - let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - metadata.dimensions, - vector_validity, - num_vectors, - )?; - - Vector::try_new_vector_array(fsl.into_array()) -} diff --git a/vortex-turboquant/src/scalar_fns/encode.rs b/vortex-turboquant/src/scalar_fns/encode.rs index 29ce7cc580a..f81c51b375f 100644 --- a/vortex-turboquant/src/scalar_fns/encode.rs +++ b/vortex-turboquant/src/scalar_fns/encode.rs @@ -9,12 +9,8 @@ use std::fmt::Formatter; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; -use vortex_array::arrays::Extension; use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::ScalarFnArray; -use vortex_array::arrays::extension::ExtensionArrayExt; -use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::dtype::DType; use vortex_array::dtype::extension::ExtDType; use vortex_array::expr::Expression; @@ -34,12 +30,11 @@ use super::metadata::deserialize_config; use super::metadata::serialize_config; use crate::TurboQuantConfig; use crate::config::MIN_DIMENSION; -use crate::vector::normalize::tq_normalize_as_l2_denorm; -use crate::vector::quantize::empty_quantization; -use crate::vector::quantize::turboquant_quantize_core; -use crate::vector::storage::build_codes_child; +use crate::config::validate_block_shape; +use crate::config::validate_block_sum; +use crate::vector::quantize::prepare_block_state; +use crate::vector::quantize::turboquant_encode_blocks; use crate::vector::storage::build_storage; -use crate::vector::tq_padded_dim; use crate::vtable::TurboQuant; use crate::vtable::TurboQuantMetadata; use crate::vtable::tq_storage_dtype; @@ -49,7 +44,7 @@ use crate::vtable::tq_storage_dtype; /// `TQEncode` itself is a `ScalarFnVTable` and so its options round-trip through expression /// serialization. /// -/// Unlike `TQDecode`, it deliberately does **not** implement `ScalarFnArrayVTable` since the +/// Unlike `TQDecode`, it deliberately does NOT implement `ScalarFnArrayVTable` since the /// persisted artifact would be the original vector array, not the TurboQuant-quantized array. #[derive(Clone)] pub struct TQEncode; @@ -125,7 +120,7 @@ impl ScalarFnVTable for TQEncode { dimensions >= MIN_DIMENSION, "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimensions}", ); - tq_padded_dim(dimensions)?; + let block_sizes = resolve_block_sizes(options.block_sizes(), dimensions, false)?; let metadata = TurboQuantMetadata { element_ptype: vector_metadata.element_ptype(), @@ -133,6 +128,7 @@ impl ScalarFnVTable for TQEncode { bit_width: options.bit_width(), seed: options.seed(), num_rounds: options.num_rounds(), + block_sizes, }; let storage_dtype = tq_storage_dtype(&metadata, input_dtype.nullability())?; let ext_dtype = ExtDType::::try_new(metadata, storage_dtype)?.erased(); @@ -166,11 +162,7 @@ impl ScalarFnVTable for TQEncode { } } -/// Lossily encode a `Vector` extension array into a `TurboQuant` extension array. -/// -/// Valid rows are normalized internally before SORF transform and scalar quantization. The original -/// row norms are stored explicitly, and original vector nulls are preserved on the storage struct -/// and both row-aligned child arrays. +/// Encode a `Vector` extension array into a block-decomposed `TurboQuant` extension array. pub(crate) fn encode_vector( input: ArrayRef, config: &TurboQuantConfig, @@ -183,42 +175,89 @@ pub(crate) fn encode_vector( .and_then(|ext_dtype| ext_dtype.metadata_opt::()) .ok_or_else(|| vortex_err!("TurboQuant encode expects a Vector extension array"))?; - let element_ptype = vector_metadata.element_ptype(); - let dimensions = vector_metadata.dimensions(); vortex_ensure!( dimensions >= MIN_DIMENSION, "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimensions}", ); - let padded_dim = tq_padded_dim(dimensions)?; + let block_sizes = resolve_block_sizes(config.block_sizes(), dimensions, true)?; let vector_validity = input.validity()?; - let l2_denorm = tq_normalize_as_l2_denorm(input, ctx)?; - let normalized = l2_denorm.child_at(0).clone(); - let norms = l2_denorm.child_at(1).clone(); + let state = prepare_block_state( + config.seed(), + config.num_rounds(), + config.bit_width(), + &block_sizes, + )?; - let normalized_ext = normalized - .as_opt::() - .ok_or_else(|| vortex_err!("normalized TurboQuant input must be a Vector extension"))?; - let normalized_fsl: FixedSizeListArray = normalized_ext.storage_array().clone().execute(ctx)?; - - let core = if normalized_fsl.is_empty() { - empty_quantization(padded_dim) - } else { - // SAFETY: `tq_normalize_as_l2_denorm` returned this normalized Vector child. - unsafe { turboquant_quantize_core(&normalized_fsl, config, ctx)? } - }; - let codes = build_codes_child(num_vectors, core, vector_validity.clone())?; + // Encode all blocks independently with the TurboQuant quantization algorithm. + let blocks = + turboquant_encode_blocks(input, &block_sizes, &state, vector_validity.clone(), ctx)?; + let storage = build_storage(blocks, &block_sizes, num_vectors, vector_validity)?; let metadata = TurboQuantMetadata { - element_ptype, + element_ptype: vector_metadata.element_ptype(), dimensions, bit_width: config.bit_width(), seed: config.seed(), num_rounds: config.num_rounds(), + block_sizes, }; - let storage = build_storage(norms, codes, num_vectors, vector_validity)?; Ok(ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage)?.into_array()) } + +/// Resolve the block list, validate the dim-dependent rules, and emit soft warnings. +/// +/// `warn = false` skips the `tracing::warn!` emission so `return_dtype` can be called from +/// places where logging would be noisy. +fn resolve_block_sizes( + config_block_sizes: Option<&[u32]>, + dimensions: u32, + warn: bool, +) -> VortexResult> { + let block_sizes = match config_block_sizes { + Some(block_sizes) => block_sizes.to_vec(), + None => vec![dimensions.checked_next_power_of_two().ok_or_else(|| { + vortex_err!( + "TurboQuant dimensions {dimensions} overflow u32 when rounded up to a power of two" + ) + })?], + }; + + // Validate the resolved blocks. This covers the default single-block path, which is not + // validated at config-construction time, and re-checks user blocks harmlessly. The + // `sum >= dimensions` coverage rule is enforced by `validate_block_sum` (u64-accumulated). + validate_block_shape(&block_sizes)?; + validate_block_sum(&block_sizes, dimensions)?; + + // TODO(connor): We NEED to make sure that this is propagated to any users. Should we just do + // this unconditionally? + if warn { + let sum: u64 = block_sizes.iter().map(|&block| block as u64).sum(); + let mut covered: u32 = 0; + for (index, &block) in block_sizes.iter().enumerate() { + if covered >= dimensions { + tracing::warn!( + block_index = index, + block = block, + dimensions = dimensions, + "TurboQuant block lies entirely past dimensions; it will only store \ + padding-derived codes" + ); + } + covered = covered.saturating_add(block); + } + + if sum > (dimensions as u64).saturating_mul(2) { + tracing::warn!( + sum = sum, + dimensions = dimensions, + "TurboQuant block_sizes sum exceeds 2 * dimensions; significant padding overhead" + ); + } + } + + Ok(block_sizes) +} diff --git a/vortex-turboquant/src/scalar_fns/metadata.rs b/vortex-turboquant/src/scalar_fns/metadata.rs index f5eddfe51d9..d8a4134e999 100644 --- a/vortex-turboquant/src/scalar_fns/metadata.rs +++ b/vortex-turboquant/src/scalar_fns/metadata.rs @@ -15,14 +15,22 @@ pub(super) struct TQScalarFnMetadata { seed: u64, #[prost(uint32, tag = "3")] num_rounds: u32, + /// Optional user-supplied block decomposition. An empty repeated field on the wire (default) + /// decodes to `block_sizes: None`, and one or more entries decode to `Some(vec![..])`. + #[prost(uint32, repeated, tag = "4")] + block_sizes: Vec, } impl TQScalarFnMetadata { pub(super) fn from_config(config: &TurboQuantConfig) -> Self { Self { - bit_width: u32::from(config.bit_width()), + bit_width: config.bit_width() as u32, seed: config.seed(), - num_rounds: u32::from(config.num_rounds()), + num_rounds: config.num_rounds() as u32, + block_sizes: config + .block_sizes() + .map(<[u32]>::to_vec) + .unwrap_or_default(), } } @@ -31,8 +39,13 @@ impl TQScalarFnMetadata { .map_err(|_| vortex_err!("TurboQuant bit_width does not fit u8"))?; let num_rounds = u8::try_from(self.num_rounds) .map_err(|_| vortex_err!("TurboQuant num_rounds does not fit u8"))?; + let block_sizes = if self.block_sizes.is_empty() { + None + } else { + Some(self.block_sizes.clone()) + }; - TurboQuantConfig::try_new(bit_width, self.seed, num_rounds) + TurboQuantConfig::try_new(bit_width, self.seed, num_rounds, block_sizes) } } @@ -45,3 +58,28 @@ pub(super) fn deserialize_config(metadata: &[u8]) -> VortexResult VortexResult<()> { + let config = TurboQuantConfig::try_new(3, 7, 2, None)?; + let bytes = serialize_config(&config); + let round = deserialize_config(&bytes)?; + assert_eq!(round.block_sizes(), None); + assert_eq!(round, config); + Ok(()) + } + + #[test] + fn serialize_roundtrips_block_sizes_some() -> VortexResult<()> { + let config = TurboQuantConfig::try_new(3, 7, 2, Some(vec![512, 256]))?; + let bytes = serialize_config(&config); + let round = deserialize_config(&bytes)?; + assert_eq!(round.block_sizes(), Some([512, 256].as_slice())); + assert_eq!(round, config); + Ok(()) + } +} diff --git a/vortex-turboquant/src/sorf/mod.rs b/vortex-turboquant/src/sorf/mod.rs index cce477aa906..e6093739be5 100644 --- a/vortex-turboquant/src/sorf/mod.rs +++ b/vortex-turboquant/src/sorf/mod.rs @@ -1,7 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -mod splitmix64; -mod transform; - -pub(crate) use transform::SorfMatrix; +pub(crate) mod splitmix64; +pub(crate) mod transform; diff --git a/vortex-turboquant/src/sorf/splitmix64.rs b/vortex-turboquant/src/sorf/splitmix64.rs index fc3f9073ced..092d5a10386 100644 --- a/vortex-turboquant/src/sorf/splitmix64.rs +++ b/vortex-turboquant/src/sorf/splitmix64.rs @@ -40,9 +40,28 @@ impl SplitMix64 { } } +/// Derive the per-block SORF seed from the global TurboQuant seed and the block index. +/// +/// The derivation offsets `global_seed` by `block_index * SPLITMIX64_INCREMENT` (matching the +/// additive part of one splitmix64 step) and then applies the splitmix64 mixing tail (the two +/// `MUL1` / `MUL2` rounds plus the final xor-shift). `block_index = 0` is therefore the mixing +/// tail applied directly to `global_seed`, not `global_seed` itself. +/// +/// This function is part of the wire contract and MUST NOT change once shipped: the per-block +/// sign mask stream depends on this output exactly. +pub(crate) fn derive_block_seed(global_seed: u64, block_index: usize) -> u64 { + // `usize::MAX <= u64::MAX` on every target this crate supports, so the cast is lossless. + let block_index = block_index as u64; + let mut state = global_seed.wrapping_add(block_index.wrapping_mul(SPLITMIX64_INCREMENT)); + state = (state ^ (state >> 30)).wrapping_mul(SPLITMIX64_MUL1); + state = (state ^ (state >> 27)).wrapping_mul(SPLITMIX64_MUL2); + state ^ (state >> 31) +} + #[cfg(test)] mod tests { use super::SplitMix64; + use super::derive_block_seed; const SPLITMIX64_SEED0_GOLDEN: [u64; 4] = [ 0xE220_A839_7B1D_CDAF, @@ -75,4 +94,35 @@ mod tests { .collect(); assert_eq!(actual, SPLITMIX64_SEED42_GOLDEN); } + + /// Golden values for `derive_block_seed(42, 0..4)` computed by hand from the splitmix64 + /// reference (additive offset of `block_index * INCREMENT`, followed by the two MUL rounds + /// and the final xor-shift). These pin the wire contract. + /// + /// Indices 1, 2, and 3 align with `SPLITMIX64_SEED42_GOLDEN[0..3]` because `SplitMix64`'s + /// `next_u64` increments before mixing: its `k`-th output is `mix(seed + (k + 1) * INCREMENT)`, + /// while `derive_block_seed(seed, k)` is `mix(seed + k * INCREMENT)`. Index 0 is the mixing + /// tail applied directly to `42`, which has no counterpart in the existing stream golden. + const DERIVED_SEED_42_GOLDEN: [u64; 4] = [ + 0xA759_EA27_D472_7622, + 0xBDD7_3226_2FEB_6E95, + 0x28EF_E333_B266_F103, + 0x4752_6757_130F_9F52, + ]; + + #[test] + fn derive_block_seed_matches_splitmix64_stream_at_zero_indices() { + let actual: Vec = (0..DERIVED_SEED_42_GOLDEN.len()) + .map(|i| derive_block_seed(42, i)) + .collect(); + assert_eq!(actual, DERIVED_SEED_42_GOLDEN); + } + + #[test] + fn derive_block_seed_distinct_for_consecutive_indices() { + let mut seeds: Vec = (0..16).map(|i| derive_block_seed(0xDEAD_BEEF, i)).collect(); + seeds.sort_unstable(); + seeds.dedup(); + assert_eq!(seeds.len(), 16, "derive_block_seed produced duplicates"); + } } diff --git a/vortex-turboquant/src/sorf/transform.rs b/vortex-turboquant/src/sorf/transform.rs index 3fa221fe03a..3fca20cbaea 100644 --- a/vortex-turboquant/src/sorf/transform.rs +++ b/vortex-turboquant/src/sorf/transform.rs @@ -38,8 +38,9 @@ //! time using only in-place 2-element butterfly operations. No row of the full n x n Hadamard //! matrix is ever materialized. //! -//! For dimensions that are not powers of 2, the input is zero-padded to the next power of 2 before -//! the transform and truncated afterward. +//! The transform operates on an exact power-of-two width. In the block-decomposed pipeline the +//! block sizer guarantees power-of-two block widths, so any zero-padding of an overspilling block +//! happens at the block-slicing layer (the encoder), not inside this transform. //! //! # Sign representation //! @@ -71,7 +72,11 @@ pub(crate) struct SorfMatrix { /// The number of sign-diagonal + WHT rounds. num_rounds: usize, - /// The padded dimension (next power of 2 >= dimension). + /// The block width the transform operates on, always an exact power of two. + /// + /// In the block-decomposed pipeline the block sizer guarantees power-of-two widths, so no + /// padding happens here; the `padded_` prefix is retained only because this width historically + /// padded the input up to the next power of two. padded_dim: usize, /// Normalization factor: `padded_dim^(-num_rounds/2)`, applied once at the end. @@ -116,17 +121,9 @@ impl SorfMatrix { }) } - /// Returns the padded dimension (next power of 2 >= dim). - /// - /// All `transform`/`inverse_transform` buffers must be this length. - pub(crate) fn padded_dim(&self) -> usize { - self.padded_dim - } - /// Apply the forward orthogonal transform: `output = R(input)`. /// - /// Both `input` and `output` must have length [`padded_dim()`](Self::padded_dim). The caller is - /// responsible for zero-padding input beyond `dim` positions. + /// Both `input` and `output` must have length equal to the matrix's padded dimension. pub(crate) fn transform(&self, input: &[f32], output: &mut [f32]) { debug_assert_eq!(input.len(), self.padded_dim); debug_assert_eq!(output.len(), self.padded_dim); @@ -137,7 +134,7 @@ impl SorfMatrix { /// Apply the inverse orthogonal transform: `output = R⁻¹(input)`. /// - /// Both `input` and `output` must have length `padded_dim()`. + /// Both `input` and `output` must have length equal to the matrix's padded dimension. pub(crate) fn inverse_transform(&self, input: &[f32], output: &mut [f32]) { debug_assert_eq!(input.len(), self.padded_dim); debug_assert_eq!(output.len(), self.padded_dim); @@ -263,7 +260,7 @@ mod tests { } fn rounds_to_usize(num_rounds: u8) -> usize { - usize::from(num_rounds) + num_rounds as usize } #[test] @@ -273,14 +270,13 @@ mod tests { let seed = 42u64; let transform1 = SorfMatrix::try_new(padded_dim, num_rounds, seed)?; let transform2 = SorfMatrix::try_new(padded_dim, num_rounds, seed)?; - let pd = transform1.padded_dim(); - let mut input = vec![0.0f32; pd]; + let mut input = vec![0.0f32; padded_dim]; for i in 0..padded_dim { input[i] = i as f32; } - let mut out1 = vec![0.0f32; pd]; - let mut out2 = vec![0.0f32; pd]; + let mut out1 = vec![0.0f32; padded_dim]; + let mut out2 = vec![0.0f32; padded_dim]; transform1.transform(&input, &mut out1); transform2.transform(&input, &mut out2); @@ -343,8 +339,8 @@ mod tests { fn roundtrip_exact(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> { let dim = dim_to_usize(dim); let num_rounds = rounds_to_usize(num_rounds); - let transform = SorfMatrix::try_new(dim.next_power_of_two(), num_rounds, 42u64)?; - let padded_dim = transform.padded_dim(); + let padded_dim = dim.next_power_of_two(); + let transform = SorfMatrix::try_new(padded_dim, num_rounds, 42u64)?; let mut input = vec![0.0f32; padded_dim]; for i in 0..dim { @@ -381,8 +377,8 @@ mod tests { fn preserves_norm(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> { let dim = dim_to_usize(dim); let num_rounds = rounds_to_usize(num_rounds); - let transform = SorfMatrix::try_new(dim.next_power_of_two(), num_rounds, 7u64)?; - let padded_dim = transform.padded_dim(); + let padded_dim = dim.next_power_of_two(); + let transform = SorfMatrix::try_new(padded_dim, num_rounds, 7u64)?; let mut input = vec![0.0f32; padded_dim]; for i in 0..dim { diff --git a/vortex-turboquant/src/tests/blocks.rs b/vortex-turboquant/src/tests/blocks.rs new file mode 100644 index 00000000000..647bb691f96 --- /dev/null +++ b/vortex-turboquant/src/tests/blocks.rs @@ -0,0 +1,424 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use rstest::rstest; +use tracing_test::traced_test; +use vortex_array::ExecutionCtx; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::StructArray; +use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::arrays::struct_::StructArrayExt; +use vortex_array::validity::Validity; +use vortex_error::VortexResult; + +use super::execute_tq_decode; +use super::execute_tq_encode; +use super::f32_vector_array; +use super::test_session; +use super::turboquant_storage; +use super::vector_array; +use super::vector_values_f32; +use crate::TurboQuantConfig; +use crate::vtable::tq_metadata; + +#[rstest] +#[case::dim_64_default(64, None, vec![64])] +#[case::dim_128_default(128, None, vec![128])] +#[case::dim_768_default(768, None, vec![1024])] +#[case::dim_768_explicit(768, Some(vec![512, 256]), vec![512, 256])] +#[case::dim_384(384, Some(vec![256, 128]), vec![256, 128])] +#[case::dim_1536(1536, Some(vec![1024, 512]), vec![1024, 512])] +#[case::dim_837_with_overspill(837, Some(vec![512, 256, 64, 64]), vec![512, 256, 64, 64])] +fn encode_decode_roundtrip( + #[case] dim: u32, + #[case] config_block_sizes: Option>, + #[case] expected_block_sizes: Vec, +) -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(dim, 4, 0.125, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 17, 2, config_block_sizes)?; + + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let metadata = tq_metadata(encoded.dtype())?; + assert_eq!(metadata.block_sizes, expected_block_sizes); + assert_eq!(metadata.dimensions, dim); + + let decoded = execute_tq_decode(encoded, &mut ctx)?; + let decoded_values = vector_values_f32(decoded, &mut ctx)?; + assert_eq!(decoded_values.len(), 4 * dim as usize); + Ok(()) +} + +#[test] +fn encode_rejects_block_sizes_with_sum_less_than_dim() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(128, 1, 1.0, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(2, 42, 3, Some(vec![64]))?; + assert!(execute_tq_encode(input, &config, &mut ctx).is_err()); + Ok(()) +} + +#[test] +#[traced_test] +fn encode_warns_on_overspilling_final_block() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + // `dim = 65` with `block_sizes = [64, 64, 64]`. The third block (positions 128..192) starts + // entirely past `dim = 65`, so `resolve_block_sizes` should fire its + // "lies entirely past dimensions" warning for `block_index = 2`. + let input = f32_vector_array(65, 1, 1.0, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(2, 42, 3, Some(vec![64, 64, 64]))?; + let _encoded = execute_tq_encode(input, &config, &mut ctx)?; + assert!(logs_contain("lies entirely past dimensions")); + Ok(()) +} + +#[test] +#[traced_test] +fn encode_warns_on_sum_more_than_double_dimensions() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + // `dim = 128`, `block_sizes = [256, 256]` sums to `512 > 2 * 128`. + let input = f32_vector_array(128, 1, 1.0, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(2, 42, 3, Some(vec![256, 256]))?; + let _encoded = execute_tq_encode(input, &config, &mut ctx)?; + assert!(logs_contain("exceeds 2 * dimensions")); + Ok(()) +} + +/// Encode and decode a synthetic vector array and confirm that per-block normalized MSE stays +/// below an empirical bound that shrinks as `1 / 2^(2 * bit_width)`. The bound is loose enough +/// to be flake-free but tight enough to catch regressions in the centroid table, the SORF +/// rotation, or the per-block norm round-trip. +/// +/// For a `b`-bit Lloyd-Max scalar quantizer applied to coordinates of a randomly rotated +/// unit-norm vector in dimension `d`, the per-coordinate marginal has variance roughly `1/d` and +/// the per-coordinate MSE is roughly `c / (d * 2^(2b))` for some distribution-dependent +/// constant `c`. Summed over the `d` coordinates of a block and normalized by the block's L2 +/// norm squared, the expected normalized MSE is on the order of `1 / 2^(2b)`. The empirical +/// bound below is around `8 / 2^(2b)`, well above the theoretical floor but well below the +/// `~1.0` you would see if the centroid lookup or inverse SORF were silently broken. +#[rstest] +#[case::two_bit(2u8)] +#[case::four_bit(4u8)] +#[case::six_bit(6u8)] +fn encode_decode_per_block_mse_within_bound(#[case] bit_width: u8) -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let block_sizes: Vec = vec![512, 256]; + let dim: u32 = block_sizes.iter().sum(); + let rows: usize = 16; + + // Generate deterministic pseudo-random f32 inputs without pulling a PRNG dep. The linear + // congruential recurrence is good enough to produce coordinates whose per-block L2 norms + // are non-trivially spread out across rows. + let total = rows * dim as usize; + let mut values = vec![0.0f32; total]; + let mut state: u32 = 0x1234_5678; + for v in values.iter_mut() { + state = state.wrapping_mul(1_103_515_245).wrapping_add(12_345); + #[expect( + clippy::cast_precision_loss, + reason = "f32 precision is sufficient for the synthetic input distribution" + )] + let x = ((state as f32) / (u32::MAX as f32 / 2.0)) - 1.0; + *v = x; + } + + let input = vector_array(dim, &values, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(bit_width, 7, 3, Some(block_sizes.clone()))?; + let encoded = execute_tq_encode(input.clone(), &config, &mut ctx)?; + let decoded = execute_tq_decode(encoded, &mut ctx)?; + + let original = vector_values_f32(input, &mut ctx)?; + let recovered = vector_values_f32(decoded, &mut ctx)?; + + #[expect( + clippy::cast_precision_loss, + reason = "`1 << (2 * bit_width)` fits a u32 for `bit_width <= 8`" + )] + let quant_levels_sq = (1u32 << (2 * bit_width)) as f32; + let max_normalized_mse = 8.0_f32 / quant_levels_sq; + let dim = dim as usize; + for row in 0..rows { + let mut offset = 0usize; + for (block_index, &block) in block_sizes.iter().enumerate() { + let block = block as usize; + let orig = &original[row * dim + offset..][..block]; + let rec = &recovered[row * dim + offset..][..block]; + let norm_sq: f32 = orig.iter().map(|&x| x * x).sum(); + let err_sq: f32 = orig + .iter() + .zip(rec.iter()) + .map(|(&o, &r)| (o - r).powi(2)) + .sum(); + // Guard against the degenerate zero-norm row (the LCG above will not produce one in + // practice, but the guard makes the invariant explicit). + let normalized_mse = err_sq / norm_sq.max(1e-10); + assert!( + normalized_mse < max_normalized_mse, + "row {row} block {block_index} normalized MSE {normalized_mse} exceeds bound \ + {max_normalized_mse} for bit_width {bit_width}", + ); + offset += block; + } + } + Ok(()) +} + +#[test] +fn same_size_blocks_use_different_seeds() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + // Two identical 128-wide blocks on a 256-dim input. The first 128 coordinates of every row + // match the second 128 coordinates exactly, so if the blocks shared seeds their `codes` + // would also match. They should not. + let mut values = vec![0.0f32; 4 * 256]; + for row in 0..4 { + for j in 0..128 { + let v = ((row * 128 + j) as f32) * 0.001; + values[row * 256 + j] = v; + values[row * 256 + 128 + j] = v; + } + } + let input = vector_array(256, &values, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 11, 2, Some(vec![128, 128]))?; + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let outer = turboquant_storage(encoded, &mut ctx)?; + let block_0_codes = block_codes(&outer, 0, &mut ctx)?; + let block_1_codes = block_codes(&outer, 1, &mut ctx)?; + assert_ne!(block_0_codes, block_1_codes); + Ok(()) +} + +/// Multi-block null-row coverage: a null outer row must produce zero placeholders on every +/// block's `norms` and `codes` children while valid rows still roundtrip per-block. +#[test] +fn encode_decode_multi_block_null_rows() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let dim = 768u32; + let rows = 3usize; + let mut values = vec![0.0f32; rows * dim as usize]; + for (i, v) in values.iter_mut().enumerate() { + *v = ((i % 13) as f32) * 0.1 + 0.05; + } + let validity = Validity::from_iter([true, false, true]); + let input = vector_array(dim, &values, validity)?; + let block_sizes = [512usize, 256]; + let config = TurboQuantConfig::try_new(4, 23, 3, Some(vec![512, 256]))?; + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + + // The null row (row 1) must store zero placeholders in every block's `norms` and `codes`, not + // merely reconstruct to zero on decode. + let outer = turboquant_storage(encoded.clone(), &mut ctx)?; + for (block_index, &block) in block_sizes.iter().enumerate() { + let norms = block_norms(&outer, block_index, &mut ctx)?; + let codes = block_codes(&outer, block_index, &mut ctx)?; + assert_eq!( + norms[1], 0.0, + "block {block_index} null-row norm must be zero" + ); + assert!( + codes[block..2 * block].iter().all(|&c| c == 0), + "block {block_index} null-row codes must be zero" + ); + } + + let decoded = execute_tq_decode(encoded, &mut ctx)?; + let validity = + super::vector_validity(decoded.clone(), &mut ctx)?.execute_mask(rows, &mut ctx)?; + assert!(validity.value(0)); + assert!(!validity.value(1)); + assert!(validity.value(2)); + // The null row's reconstructed coordinates should be zero placeholders. + let values = vector_values_f32(decoded, &mut ctx)?; + let null_row = &values[dim as usize..2 * dim as usize]; + assert!(null_row.iter().all(|&v| v == 0.0)); + Ok(()) +} + +/// Multi-block per-block zero-norm coverage: a row whose first block's slice is entirely zero +/// but whose second block carries energy must reconstruct the second block correctly while +/// leaving the first block at zero (the per-block `norm = 0` placeholder path). +#[test] +fn encode_decode_multi_block_zero_norm_block() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let dim = 768u32; + let rows = 2usize; + let mut values = vec![0.0f32; rows * dim as usize]; + // Row 0 valid everywhere; row 1 zero in `block_0` (positions 0..512), nonzero in `block_1`. + for (i, v) in values[..dim as usize].iter_mut().enumerate() { + *v = ((i % 13) as f32) * 0.1 + 0.05; + } + for v in values[dim as usize + 512..2 * dim as usize].iter_mut() { + *v = 0.5; + } + let input = vector_array(dim, &values, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(4, 29, 3, Some(vec![512, 256]))?; + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let decoded = execute_tq_decode(encoded, &mut ctx)?; + let recovered = vector_values_f32(decoded, &mut ctx)?; + let row1_block0 = &recovered[dim as usize..dim as usize + 512]; + let row1_block1 = &recovered[dim as usize + 512..2 * dim as usize]; + assert!( + row1_block0.iter().all(|&v| v == 0.0), + "zero-norm block expected to reconstruct as zeros" + ); + let block1_energy: f32 = row1_block1.iter().map(|&v| v * v).sum(); + assert!( + block1_energy > 0.0, + "nonzero block expected to recover energy" + ); + Ok(()) +} + +/// A dimension whose next power of two overflows `u32` must produce a clean error from the default +/// (`block_sizes = None`) path rather than panicking. Regression for the overflow guard that the +/// block-decomposition refactor dropped from the old `tq_padded_dim`. +#[test] +fn encode_rejects_dimension_overflow() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = vector_array::(2_147_483_649, &[], Validity::NonNullable)?; + assert!(execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx).is_err()); + Ok(()) +} + +/// A finite f64 value far above `f32::MAX` casts to inf in the f32 quantization pipeline, making +/// the block norm non-finite. Encode must reject it cleanly rather than emit corrupt codes. +#[test] +fn encode_rejects_non_finite_f64_norm() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let mut values = vec![0.0f64; 64]; + values[0] = 1e300; + let input = vector_array::(64, &values, Validity::NonNullable)?; + assert!(execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx).is_err()); + Ok(()) +} + +/// A well-typed input that already contains a NaN or infinite coordinate makes the block norm +/// non-finite; encode must reject it cleanly (the guard is otherwise only reached via f64 +/// overflow, so this pins the direct non-finite-input path). +#[rstest] +#[case::nan(f32::NAN)] +#[case::pos_inf(f32::INFINITY)] +#[case::neg_inf(f32::NEG_INFINITY)] +fn encode_rejects_non_finite_coordinate(#[case] bad: f32) -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let mut values = vec![0.5f32; 64]; + values[0] = bad; + let input = vector_array(64, &values, Validity::NonNullable)?; + assert!(execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx).is_err()); + Ok(()) +} + +/// Whole-vector reconstruction fidelity over the real `dim` coordinates (overspill padding +/// dropped). Complements `encode_decode_per_block_mse_within_bound` by covering the default +/// single-block path (which the per-block test never exercises) and an overspilling block shape +/// whose final block mixes real coordinates with zero padding. The bound is a loose whole-vector +/// normalized MSE that still catches a broken centroid lookup, inverse SORF, or norm round-trip +/// (those drive normalized MSE toward ~1.0). +#[rstest] +#[case::default_single_block(768, None)] +#[case::overspill(837, Some(vec![512, 256, 64, 64]))] +#[case::two_block_384(384, Some(vec![256, 128]))] +#[case::two_block_1536(1536, Some(vec![1024, 512]))] +#[case::single_64(64, None)] +#[case::single_128(128, None)] +fn encode_decode_real_dim_fidelity( + #[case] dim: u32, + #[case] config_block_sizes: Option>, +) -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let rows: usize = 8; + let bit_width: u8 = 6; + + let total = rows * dim as usize; + let mut values = vec![0.0f32; total]; + let mut state: u32 = 0x9E37_79B9; + for v in values.iter_mut() { + state = state.wrapping_mul(1_103_515_245).wrapping_add(12_345); + #[expect( + clippy::cast_precision_loss, + reason = "f32 precision is sufficient for the synthetic input distribution" + )] + let x = ((state as f32) / (u32::MAX as f32 / 2.0)) - 1.0; + *v = x; + } + + let input = vector_array(dim, &values, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(bit_width, 7, 3, config_block_sizes)?; + let encoded = execute_tq_encode(input.clone(), &config, &mut ctx)?; + let decoded = execute_tq_decode(encoded, &mut ctx)?; + + let original = vector_values_f32(input, &mut ctx)?; + let recovered = vector_values_f32(decoded, &mut ctx)?; + assert_eq!(recovered.len(), rows * dim as usize); + + #[expect( + clippy::cast_precision_loss, + reason = "`1 << (2 * bit_width)` fits a u32 for `bit_width <= 8`" + )] + let quant_levels_sq = (1u32 << (2 * bit_width)) as f32; + let max_normalized_mse = 16.0_f32 / quant_levels_sq; + let dim = dim as usize; + for row in 0..rows { + let orig = &original[row * dim..][..dim]; + let rec = &recovered[row * dim..][..dim]; + let norm_sq: f32 = orig.iter().map(|&x| x * x).sum(); + let err_sq: f32 = orig + .iter() + .zip(rec.iter()) + .map(|(&o, &r)| (o - r).powi(2)) + .sum(); + let normalized_mse = err_sq / norm_sq.max(1e-10); + assert!( + normalized_mse < max_normalized_mse, + "row {row} whole-vector normalized MSE {normalized_mse} exceeds bound \ + {max_normalized_mse} (dim {dim})", + ); + } + Ok(()) +} + +fn block_codes( + outer: &StructArray, + block_index: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult> { + let inner: StructArray = outer + .unmasked_field_by_name(format!("block_{block_index}"))? + .clone() + .execute(ctx)?; + let codes: FixedSizeListArray = inner + .unmasked_field_by_name("codes")? + .clone() + .execute(ctx)?; + let elements: PrimitiveArray = codes.elements().clone().execute(ctx)?; + Ok(elements.as_slice::().to_vec()) +} + +fn block_norms( + outer: &StructArray, + block_index: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult> { + let inner: StructArray = outer + .unmasked_field_by_name(format!("block_{block_index}"))? + .clone() + .execute(ctx)?; + let norms: PrimitiveArray = inner + .unmasked_field_by_name("norms")? + .clone() + .execute(ctx)?; + Ok(norms.as_slice::().to_vec()) +} diff --git a/vortex-turboquant/src/tests/encode_decode.rs b/vortex-turboquant/src/tests/encode_decode.rs index ed5aab190aa..4096da841d7 100644 --- a/vortex-turboquant/src/tests/encode_decode.rs +++ b/vortex-turboquant/src/tests/encode_decode.rs @@ -7,9 +7,8 @@ use vortex_array::VortexSessionExecute; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; -use vortex_array::arrays::extension::ExtensionArrayExt; +use vortex_array::arrays::StructArray; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::arrays::struct_::StructArrayExt; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; @@ -27,15 +26,33 @@ use super::vector_element_ptype; use super::vector_validity; use super::vector_values_f32; use crate::TurboQuantConfig; -use crate::centroids::compute_or_get_centroids; -use crate::vector::normalize::tq_normalize_as_l2_denorm; +use crate::centroids::compute_or_get_codebook; #[rstest] -#[case::zero_bits(0, 42, 3)] -#[case::too_many_bits(9, 42, 3)] -#[case::zero_rounds(2, 42, 0)] -fn config_rejects_invalid_values(#[case] bit_width: u8, #[case] seed: u64, #[case] num_rounds: u8) { - assert!(TurboQuantConfig::try_new(bit_width, seed, num_rounds).is_err()); +#[case::zero_bits(0, 42, 3, None)] +#[case::too_many_bits(9, 42, 3, None)] +#[case::zero_rounds(2, 42, 0, None)] +#[case::empty_blocks(2, 42, 3, Some(vec![]))] +#[case::non_power_of_two_block(2, 42, 3, Some(vec![96]))] +#[case::undersized_block(2, 42, 3, Some(vec![32]))] +fn config_rejects_invalid_values( + #[case] bit_width: u8, + #[case] seed: u64, + #[case] num_rounds: u8, + #[case] block_sizes: Option>, +) { + assert!(TurboQuantConfig::try_new(bit_width, seed, num_rounds, block_sizes).is_err()); +} + +#[rstest] +#[case::default_block_sizes(None)] +#[case::single_min_block(Some(vec![64]))] +#[case::two_blocks(Some(vec![512, 256]))] +#[case::four_blocks(Some(vec![512, 256, 64, 64]))] +fn config_accepts_valid_block_shapes(#[case] block_sizes: Option>) -> VortexResult<()> { + let config = TurboQuantConfig::try_new(2, 42, 3, block_sizes.clone())?; + assert_eq!(config.block_sizes(), block_sizes.as_deref()); + Ok(()) } #[test] @@ -52,17 +69,7 @@ fn encode_rejects_non_vector_input() { fn encode_rejects_small_dimensions() -> VortexResult<()> { let session = test_session(); let mut ctx = session.create_execution_ctx(); - let input = f32_vector_array(127, 1, 1.0, Validity::NonNullable)?; - - assert!(execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx).is_err()); - Ok(()) -} - -#[test] -fn encode_rejects_padded_dimension_overflow() -> VortexResult<()> { - let session = test_session(); - let mut ctx = session.create_execution_ctx(); - let input = vector_array::(2_147_483_649, &[], Validity::NonNullable)?; + let input = f32_vector_array(63, 1, 1.0, Validity::NonNullable)?; assert!(execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx).is_err()); Ok(()) @@ -70,10 +77,10 @@ fn encode_rejects_padded_dimension_overflow() -> VortexResult<()> { #[test] fn centroid_cache_is_deterministic() -> VortexResult<()> { - let first = compute_or_get_centroids(128, 3)?; - let second = compute_or_get_centroids(128, 3)?; + let first = compute_or_get_codebook(128, 3)?; + let second = compute_or_get_codebook(128, 3)?; - assert_eq!(first.as_slice(), second.as_slice()); + assert_eq!(first.centroids.as_slice(), second.centroids.as_slice()); Ok(()) } @@ -97,22 +104,26 @@ fn encode_stores_norms_and_struct_validity() -> VortexResult<()> { let validity = Validity::from_iter([true, false, true]); let input = f32_vector_array(128, 3, 0.25, validity)?; - let config = TurboQuantConfig::try_new(3, 1, 2)?; + let config = TurboQuantConfig::try_new(3, 1, 2, None)?; let encoded = execute_tq_encode(input, &config, &mut ctx)?; - let storage = turboquant_storage(encoded, &mut ctx)?; - let mask = storage.struct_validity().execute_mask(3, &mut ctx)?; - let norms: PrimitiveArray = storage + let outer = turboquant_storage(encoded, &mut ctx)?; + let outer_mask = outer.struct_validity().execute_mask(3, &mut ctx)?; + let block_0: StructArray = outer + .unmasked_field_by_name("block_0")? + .clone() + .execute(&mut ctx)?; + let norms: PrimitiveArray = block_0 .unmasked_field_by_name("norms")? .clone() .execute(&mut ctx)?; - let codes: FixedSizeListArray = storage + let codes: FixedSizeListArray = block_0 .unmasked_field_by_name("codes")? .clone() .execute(&mut ctx)?; - assert!(mask.value(0)); - assert!(!mask.value(1)); - assert!(mask.value(2)); + assert!(outer_mask.value(0)); + assert!(!outer_mask.value(1)); + assert!(outer_mask.value(2)); assert_eq!(norms.validity()?.nullability(), Nullability::Nullable); assert_eq!(codes.validity()?.nullability(), Nullability::Nullable); @@ -134,45 +145,6 @@ fn encode_stores_norms_and_struct_validity() -> VortexResult<()> { Ok(()) } -#[test] -fn normalize_as_l2_denorm_preserves_child_validity() -> VortexResult<()> { - let session = test_session(); - let mut ctx = session.create_execution_ctx(); - let mut values = vec![0.0f32; 3 * 128]; - values[0] = 3.0; - values[1] = 4.0; - values[128..256].fill(13.0); - values[256] = 1.0; - let input = vector_array(128, &values, Validity::from_iter([true, false, true]))?; - - let l2_denorm = tq_normalize_as_l2_denorm(input, &mut ctx)?; - let normalized = l2_denorm.child_at(0).clone(); - let norms = l2_denorm.child_at(1).clone(); - - let normalized_ext: ExtensionArray = normalized.execute(&mut ctx)?; - let normalized_fsl: FixedSizeListArray = - normalized_ext.storage_array().clone().execute(&mut ctx)?; - let normalized_values: PrimitiveArray = normalized_fsl.elements().clone().execute(&mut ctx)?; - let norms: PrimitiveArray = norms.execute(&mut ctx)?; - let normalized_validity = normalized_fsl.validity()?.execute_mask(3, &mut ctx)?; - let norms_validity = norms.validity()?.execute_mask(3, &mut ctx)?; - - assert!(normalized_validity.value(0)); - assert!(!normalized_validity.value(1)); - assert!(normalized_validity.value(2)); - assert!(norms_validity.value(0)); - assert!(!norms_validity.value(1)); - assert!(norms_validity.value(2)); - assert_eq!(norms.validity()?.nullability(), Nullability::Nullable); - assert_eq!(norms.as_slice::()[0], 5.0); - assert!( - normalized_values.as_slice::()[128..256] - .iter() - .all(|&value| value == 0.0) - ); - Ok(()) -} - #[test] fn encode_decode_preserves_nulls_and_zero_norm_rows() -> VortexResult<()> { let session = test_session(); @@ -202,7 +174,7 @@ fn encode_decode_preserves_nulls_and_zero_norm_rows() -> VortexResult<()> { fn encode_decode_supports_non_f32_inputs(#[case] ptype: PType) -> VortexResult<()> { let session = test_session(); let mut ctx = session.create_execution_ctx(); - let config = TurboQuantConfig::try_new(3, 42, 3)?; + let config = TurboQuantConfig::try_new(3, 42, 3, None)?; match ptype { PType::F16 => { @@ -236,7 +208,7 @@ fn decode_scales_by_stored_norms() -> VortexResult<()> { let mut ctx = session.create_execution_ctx(); let base = f32_vector_array(128, 1, 0.5, Validity::NonNullable)?; let scaled = f32_vector_array(128, 1, 1.0, Validity::NonNullable)?; - let config = TurboQuantConfig::try_new(2, 99, 3)?; + let config = TurboQuantConfig::try_new(2, 99, 3, None)?; let base_values = vector_values_f32( execute_tq_decode(execute_tq_encode(base, &config, &mut ctx)?, &mut ctx)?, diff --git a/vortex-turboquant/src/tests/file.rs b/vortex-turboquant/src/tests/file.rs index e59b7a95c75..5cf5bfa29ad 100644 --- a/vortex-turboquant/src/tests/file.rs +++ b/vortex-turboquant/src/tests/file.rs @@ -39,6 +39,7 @@ fn file_roundtrip_with_initialize_session() -> VortexResult<()> { let metadata = tq_metadata(read.dtype())?; assert_eq!(metadata.dimensions, 128); + assert_eq!(metadata.block_sizes, vec![128]); let decoded = execute_tq_decode_from_metadata(read, &mut ctx)?; let validity = vector_validity(decoded, &mut ctx)?.execute_mask(2, &mut ctx)?; assert!(validity.value(0)); @@ -52,7 +53,7 @@ fn file_roundtrip_lazy_decode_scalar_fn_with_initialize_session() -> VortexResul let session = file_session(&runtime); let mut ctx = session.create_execution_ctx(); let input = f32_vector_array(128, 2, 0.25, Validity::from_iter([true, false]))?; - let config = TurboQuantConfig::try_new(3, 42, 3)?; + let config = TurboQuantConfig::try_new(3, 42, 3, None)?; let encoded = execute_tq_encode(input, &config, &mut ctx)?; let decoded = TQDecode::try_new_array(encoded)?.into_array(); @@ -71,3 +72,30 @@ fn file_roundtrip_lazy_decode_scalar_fn_with_initialize_session() -> VortexResul assert!(!validity.value(1)); Ok(()) } + +#[test] +fn file_roundtrip_multi_block() -> VortexResult<()> { + let runtime = SingleThreadRuntime::default(); + let session = file_session(&runtime); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(768, 2, 0.25, Validity::from_iter([true, false]))?; + let config = TurboQuantConfig::try_new(3, 42, 3, Some(vec![512, 256]))?; + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + + let mut file_bytes = Vec::new(); + VortexWriteOptions::new(session.clone()) + .blocking(&runtime) + .write(&mut file_bytes, encoded.to_array_iterator())?; + + let file = session.open_options().open_buffer(file_bytes)?; + let read = runtime.block_on(async { file.scan()?.into_array_stream()?.read_all().await })?; + + let metadata = tq_metadata(read.dtype())?; + assert_eq!(metadata.dimensions, 768); + assert_eq!(metadata.block_sizes, vec![512, 256]); + let decoded = execute_tq_decode_from_metadata(read, &mut ctx)?; + let validity = vector_validity(decoded, &mut ctx)?.execute_mask(2, &mut ctx)?; + assert!(validity.value(0)); + assert!(!validity.value(1)); + Ok(()) +} diff --git a/vortex-turboquant/src/tests/malformed.rs b/vortex-turboquant/src/tests/malformed.rs index f99f0ee5105..b87b377f924 100644 --- a/vortex-turboquant/src/tests/malformed.rs +++ b/vortex-turboquant/src/tests/malformed.rs @@ -2,6 +2,7 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use rstest::rstest; +use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; use vortex_array::arrays::ExtensionArray; @@ -21,6 +22,44 @@ use super::vector_validity; use crate::TurboQuant; use crate::TurboQuantMetadata; +fn metadata() -> TurboQuantMetadata { + TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 128, + bit_width: 1, + seed: 42, + num_rounds: 3, + block_sizes: vec![128], + } +} + +fn build_single_block_tq( + metadata: TurboQuantMetadata, + norms: ArrayRef, + codes: ArrayRef, + rows: usize, + outer_validity: Validity, +) -> ArrayRef { + let inner = StructArray::try_new( + FieldNames::from(["norms", "codes"]), + vec![norms, codes], + rows, + outer_validity.clone(), + ) + .unwrap() + .into_array(); + let outer = StructArray::try_new( + FieldNames::from(["block_0"]), + vec![inner], + rows, + outer_validity, + ) + .unwrap(); + ExtensionArray::try_new_from_vtable(TurboQuant, metadata, outer.into_array()) + .unwrap() + .into_array() +} + #[rstest] #[case::nullable_norms_under_nonnullable_struct( Nullability::NonNullable, @@ -49,13 +88,6 @@ fn decode_accepts_child_nullability_that_covers_struct_validity( ) -> VortexResult<()> { let session = test_session(); let mut ctx = session.create_execution_ctx(); - let metadata = TurboQuantMetadata { - element_ptype: PType::F32, - dimensions: 128, - bit_width: 1, - seed: 42, - num_rounds: 3, - }; let norms = PrimitiveArray::new::(Buffer::copy_from([1.0]), Validity::from(norms_nullability)) .into_array(); @@ -65,19 +97,15 @@ fn decode_accepts_child_nullability_that_covers_struct_validity( 128, Validity::from(codes_nullability), 1, - ) - .unwrap() + )? .into_array(); - let storage = StructArray::try_new( - FieldNames::from(["norms", "codes"]), - vec![norms, codes], + let tq = build_single_block_tq( + metadata(), + norms, + codes, 1, Validity::from(struct_nullability), - ) - .unwrap(); - let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array()) - .unwrap() - .into_array(); + ); execute_tq_decode_from_metadata(tq, &mut ctx)?; Ok(()) @@ -87,27 +115,19 @@ fn decode_accepts_child_nullability_that_covers_struct_validity( fn decode_accepts_struct_mask_with_all_valid_children() -> VortexResult<()> { let session = test_session(); let mut ctx = session.create_execution_ctx(); - let metadata = TurboQuantMetadata { - element_ptype: PType::F32, - dimensions: 128, - bit_width: 1, - seed: 42, - num_rounds: 3, - }; let norms = PrimitiveArray::new::(Buffer::copy_from([1.0, 1.0, 1.0]), Validity::NonNullable) .into_array(); let codes = PrimitiveArray::new::(vec![0u8; 3 * 128], Validity::NonNullable); let codes = FixedSizeListArray::try_new(codes.into_array(), 128, Validity::NonNullable, 3)? .into_array(); - let storage = StructArray::try_new( - FieldNames::from(["norms", "codes"]), - vec![norms, codes], + let tq = build_single_block_tq( + metadata(), + norms, + codes, 3, Validity::from_iter([true, false, true]), - )?; - let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array())? - .into_array(); + ); let decoded = execute_tq_decode_from_metadata(tq, &mut ctx)?; let validity = vector_validity(decoded, &mut ctx)?.execute_mask(3, &mut ctx)?; @@ -121,13 +141,6 @@ fn decode_accepts_struct_mask_with_all_valid_children() -> VortexResult<()> { fn decode_rejects_child_masks_that_disagree_with_struct_validity() -> VortexResult<()> { let session = test_session(); let mut ctx = session.create_execution_ctx(); - let metadata = TurboQuantMetadata { - element_ptype: PType::F32, - dimensions: 128, - bit_width: 1, - seed: 42, - num_rounds: 3, - }; let norms = PrimitiveArray::new::( Buffer::copy_from([1.0, 1.0, 1.0]), Validity::from_iter([true, true, false]), @@ -141,13 +154,47 @@ fn decode_rejects_child_masks_that_disagree_with_struct_validity() -> VortexResu 3, )? .into_array(); - let storage = StructArray::try_new( + let tq = build_single_block_tq( + metadata(), + norms, + codes, + 3, + Validity::from_iter([true, false, true]), + ); + + assert!(execute_tq_decode_from_metadata(tq, &mut ctx).is_err()); + Ok(()) +} + +#[test] +fn decode_rejects_inner_struct_validity_narrower_than_outer() -> VortexResult<()> { + // The outer struct marks all three rows valid, but the inner `block_0` struct marks row 1 + // invalid. `parse_storage` must reject this: each inner block's struct validity must + // *cover* the outer struct's validity, so an inner row marked invalid where the outer is + // valid is a contract violation. + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let rows = 3; + let norms = + PrimitiveArray::new::(Buffer::copy_from([1.0, 1.0, 1.0]), Validity::NonNullable) + .into_array(); + let codes = PrimitiveArray::new::(vec![0u8; rows * 128], Validity::NonNullable); + let codes = FixedSizeListArray::try_new(codes.into_array(), 128, Validity::NonNullable, rows)? + .into_array(); + let inner = StructArray::try_new( FieldNames::from(["norms", "codes"]), vec![norms, codes], - 3, + rows, Validity::from_iter([true, false, true]), + )? + .into_array(); + let outer = StructArray::try_new( + FieldNames::from(["block_0"]), + vec![inner], + rows, + Validity::NonNullable, )?; - let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array())? + let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata(), outer.into_array())? .into_array(); assert!(execute_tq_decode_from_metadata(tq, &mut ctx).is_err()); @@ -155,35 +202,129 @@ fn decode_rejects_child_masks_that_disagree_with_struct_validity() -> VortexResu } #[test] -#[should_panic(expected = "TurboQuant code exceeds centroid count")] -fn decode_panics_on_codes_outside_centroid_table() { +fn decode_rejects_codes_outside_centroid_table() -> VortexResult<()> { let session = test_session(); let mut ctx = session.create_execution_ctx(); - let metadata = TurboQuantMetadata { - element_ptype: PType::F32, - dimensions: 128, - bit_width: 1, - seed: 42, - num_rounds: 3, - }; let norms = PrimitiveArray::new::(Buffer::copy_from([1.0]), Validity::NonNullable).into_array(); let mut codes = vec![0u8; 128]; codes[0] = 2; let codes = PrimitiveArray::new::(codes, Validity::NonNullable); - let codes = FixedSizeListArray::try_new(codes.into_array(), 128, Validity::NonNullable, 1) - .unwrap() + let codes = FixedSizeListArray::try_new(codes.into_array(), 128, Validity::NonNullable, 1)? .into_array(); - let storage = StructArray::try_new( - FieldNames::from(["norms", "codes"]), - vec![norms, codes], + let tq = build_single_block_tq(metadata(), norms, codes, 1, Validity::NonNullable); + + // A code pointing past the centroid table must surface a clean error from the public decode + // path rather than panicking through `vortex_expect`. + assert!(execute_tq_decode_from_metadata(tq, &mut ctx).is_err()); + Ok(()) +} + +#[test] +fn decode_ignores_out_of_range_codes_in_null_rows() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + // Two rows; row 1 is masked out (outer validity false). Its placeholder codes contain an + // out-of-range value (2, where bit_width=1 gives only 2 centroids). Decode must not validate + // codes for masked-out rows, so this decodes cleanly with row 1 as a null placeholder. + let norms = PrimitiveArray::new::(Buffer::copy_from([1.0, 0.0]), Validity::NonNullable) + .into_array(); + let mut codes = vec![0u8; 2 * 128]; + codes[128] = 2; + let codes = PrimitiveArray::new::(codes, Validity::NonNullable); + let codes = FixedSizeListArray::try_new(codes.into_array(), 128, Validity::NonNullable, 2)? + .into_array(); + let tq = build_single_block_tq( + metadata(), + norms, + codes, + 2, + Validity::from_iter([true, false]), + ); + + let decoded = execute_tq_decode_from_metadata(tq, &mut ctx)?; + let validity = vector_validity(decoded, &mut ctx)?.execute_mask(2, &mut ctx)?; + assert!(validity.value(0)); + assert!(!validity.value(1)); + Ok(()) +} + +#[test] +fn decode_rejects_non_finite_stored_norm() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + // A non-finite stored norm would scale the reconstruction to inf/NaN; decode rejects it. + let norms = + PrimitiveArray::new::(Buffer::copy_from([f32::INFINITY]), Validity::NonNullable) + .into_array(); + let codes = PrimitiveArray::new::(vec![0u8; 128], Validity::NonNullable); + let codes = FixedSizeListArray::try_new(codes.into_array(), 128, Validity::NonNullable, 1)? + .into_array(); + let tq = build_single_block_tq(metadata(), norms, codes, 1, Validity::NonNullable); + + assert!(execute_tq_decode_from_metadata(tq, &mut ctx).is_err()); + Ok(()) +} + +#[test] +fn decode_rejects_negative_stored_norm() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + // L2 norms are never negative; a negative stored norm (only from corrupt storage) would + // sign-flip the reconstruction, so decode rejects it. + let norms = + PrimitiveArray::new::(Buffer::copy_from([-1.0]), Validity::NonNullable).into_array(); + let codes = PrimitiveArray::new::(vec![0u8; 128], Validity::NonNullable); + let codes = FixedSizeListArray::try_new(codes.into_array(), 128, Validity::NonNullable, 1)? + .into_array(); + let tq = build_single_block_tq(metadata(), norms, codes, 1, Validity::NonNullable); + + assert!(execute_tq_decode_from_metadata(tq, &mut ctx).is_err()); + Ok(()) +} + +/// Malformed storage in a LATER block (not block_0) must be rejected too, exercising the per-block +/// decode loop's validation beyond the first block: here block_1 carries an out-of-range code. +#[test] +fn decode_rejects_out_of_range_code_in_later_block() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 256, + bit_width: 1, + seed: 42, + num_rounds: 3, + block_sizes: vec![128, 128], + }; + // bit_width = 1 gives 2 centroids, so a code of 2 is out of range. + let make_block = |bad_code: bool| -> VortexResult { + let norms = PrimitiveArray::new::(Buffer::copy_from([1.0]), Validity::NonNullable) + .into_array(); + let mut codes = vec![0u8; 128]; + if bad_code { + codes[0] = 2; + } + let codes = PrimitiveArray::new::(codes, Validity::NonNullable); + let codes = FixedSizeListArray::try_new(codes.into_array(), 128, Validity::NonNullable, 1)? + .into_array(); + Ok(StructArray::try_new( + FieldNames::from(["norms", "codes"]), + vec![norms, codes], + 1, + Validity::NonNullable, + )? + .into_array()) + }; + let outer = StructArray::try_new( + FieldNames::from(["block_0", "block_1"]), + vec![make_block(false)?, make_block(true)?], 1, Validity::NonNullable, - ) - .unwrap(); - let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array()) - .unwrap() - .into_array(); + )?; + let tq = + ExtensionArray::try_new_from_vtable(TurboQuant, metadata, outer.into_array())?.into_array(); - drop(execute_tq_decode_from_metadata(tq, &mut ctx)); + assert!(execute_tq_decode_from_metadata(tq, &mut ctx).is_err()); + Ok(()) } diff --git a/vortex-turboquant/src/tests/metadata.rs b/vortex-turboquant/src/tests/metadata.rs index e0d1042f02f..bfe46fcc063 100644 --- a/vortex-turboquant/src/tests/metadata.rs +++ b/vortex-turboquant/src/tests/metadata.rs @@ -17,9 +17,7 @@ use vortex_error::vortex_err; use crate::TurboQuant; use crate::TurboQuantMetadata; -use crate::vector::storage::CODES_FIELD; -use crate::vector::storage::NORMS_FIELD; -use crate::vector::tq_padded_dim; +use crate::vtable::tq_storage_dtype; #[derive(Clone, PartialEq, Message)] struct MetadataWire { @@ -33,41 +31,28 @@ struct MetadataWire { seed: u64, #[prost(uint32, tag = "5")] num_rounds: u32, -} - -fn tq_storage_dtype( - metadata: &TurboQuantMetadata, - row_nullability: Nullability, -) -> VortexResult { - let padded_dim = u32::try_from(tq_padded_dim(metadata.dimensions)?) - .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; - Ok(DType::Struct( - StructFields::new( - FieldNames::from([NORMS_FIELD, CODES_FIELD]), - vec![ - DType::Primitive(metadata.element_ptype, row_nullability), - DType::FixedSizeList( - Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), - padded_dim, - row_nullability, - ), - ], - ), - row_nullability, - )) + #[prost(uint32, repeated, tag = "6")] + block_sizes: Vec, } #[rstest] -#[case::f16(PType::F16)] -#[case::f32(PType::F32)] -#[case::f64(PType::F64)] -fn metadata_serialization_roundtrips(#[case] element_ptype: PType) -> VortexResult<()> { +#[case::f16(PType::F16, vec![128])] +#[case::f32(PType::F32, vec![128])] +#[case::f64(PType::F64, vec![128])] +#[case::two_block(PType::F32, vec![512, 256])] +#[case::four_block(PType::F32, vec![512, 256, 64, 64])] +fn metadata_serialization_roundtrips( + #[case] element_ptype: PType, + #[case] block_sizes: Vec, +) -> VortexResult<()> { + let dimensions = block_sizes.iter().sum::(); let metadata = TurboQuantMetadata { element_ptype, - dimensions: 128, + dimensions, bit_width: 4, seed: 7, num_rounds: 3, + block_sizes, }; let encoded = TurboQuant.serialize_metadata(&metadata)?; @@ -77,6 +62,35 @@ fn metadata_serialization_roundtrips(#[case] element_ptype: PType) -> VortexResu Ok(()) } +/// A pre-block / corrupt array whose on-the-wire `block_sizes` is empty (legacy) or sums below +/// `dimensions` must be rejected by `deserialize_metadata` with a clean error, never a panic. This +/// pins the documented on-disk format break. Built by corrupting a valid serialization so the test +/// does not depend on the `PType` wire discriminant. +#[rstest] +#[case::empty_legacy(vec![])] +#[case::sum_below_dimensions(vec![64])] +fn deserialize_rejects_malformed_block_sizes(#[case] block_sizes: Vec) -> VortexResult<()> { + let valid = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 128, + bit_width: 4, + seed: 7, + num_rounds: 3, + block_sizes: vec![128], + }; + let encoded = TurboQuant.serialize_metadata(&valid)?; + let mut wire = MetadataWire::decode(encoded.as_slice()) + .map_err(|e| vortex_err!("decode MetadataWire: {e}"))?; + wire.block_sizes = block_sizes; + + assert!( + TurboQuant + .deserialize_metadata(&wire.encode_to_vec()) + .is_err() + ); + Ok(()) +} + #[test] fn metadata_serialization_uses_ptype_discriminants() -> VortexResult<()> { let metadata = TurboQuantMetadata { @@ -85,6 +99,7 @@ fn metadata_serialization_uses_ptype_discriminants() -> VortexResult<()> { bit_width: 4, seed: 7, num_rounds: 3, + block_sizes: vec![128], }; let encoded = TurboQuant.serialize_metadata(&metadata)?; @@ -92,6 +107,7 @@ fn metadata_serialization_uses_ptype_discriminants() -> VortexResult<()> { assert_eq!(wire.element_ptype, PType::F32 as i32); assert_eq!(wire.dimensions, 128); + assert_eq!(wire.block_sizes, vec![128u32]); Ok(()) } @@ -103,11 +119,13 @@ fn metadata_display_matches_field_order() { bit_width: 4, seed: 7, num_rounds: 3, + block_sizes: vec![128], }; assert_eq!( metadata.to_string(), - "element_ptype: f32, dimensions: 128, bit_width: 4, seed: 7, num_rounds: 3" + "element_ptype: f32, dimensions: 128, bit_width: 4, seed: 7, num_rounds: 3, \ + block_sizes: [128]" ); } @@ -115,16 +133,15 @@ fn metadata_display_matches_field_order() { fn dtype_validation_accepts_expected_storage() -> VortexResult<()> { let metadata = TurboQuantMetadata { element_ptype: PType::F32, - dimensions: 129, + dimensions: 768, bit_width: 2, seed: 42, num_rounds: 3, + block_sizes: vec![512, 256], }; + let storage = tq_storage_dtype(&metadata, Nullability::Nullable)?; - ExtDType::::try_new( - metadata, - tq_storage_dtype(&metadata, Nullability::Nullable)?, - )?; + ExtDType::::try_new(metadata, storage)?; Ok(()) } @@ -132,16 +149,15 @@ fn dtype_validation_accepts_expected_storage() -> VortexResult<()> { fn dtype_validation_accepts_nonnullable_storage() -> VortexResult<()> { let metadata = TurboQuantMetadata { element_ptype: PType::F32, - dimensions: 129, + dimensions: 768, bit_width: 2, seed: 42, num_rounds: 3, + block_sizes: vec![512, 256], }; + let storage = tq_storage_dtype(&metadata, Nullability::NonNullable)?; - ExtDType::::try_new( - metadata, - tq_storage_dtype(&metadata, Nullability::NonNullable)?, - )?; + ExtDType::::try_new(metadata, storage)?; Ok(()) } @@ -153,16 +169,18 @@ fn dtype_validation_rejects_malformed_storage() { bit_width: 2, seed: 42, num_rounds: 3, + block_sizes: vec![128], }; + // Outer struct fields do not match the expected `block_0` schema. let storage = DType::Struct( StructFields::new( FieldNames::from(["norms", "codes"]), vec![ DType::Primitive(PType::F32, Nullability::Nullable), DType::FixedSizeList( - DType::Primitive(PType::U8, Nullability::Nullable).into(), + Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), 128, - Nullability::NonNullable, + Nullability::Nullable, ), ], ), diff --git a/vortex-turboquant/src/tests/mod.rs b/vortex-turboquant/src/tests/mod.rs index ffa1db175a7..03c7693010f 100644 --- a/vortex-turboquant/src/tests/mod.rs +++ b/vortex-turboquant/src/tests/mod.rs @@ -37,11 +37,11 @@ use crate::TQEncode; use crate::TurboQuantConfig; use crate::initialize; +mod blocks; mod encode_decode; mod file; mod malformed; mod metadata; -mod parity; mod scalar_fns; fn test_session() -> VortexSession { diff --git a/vortex-turboquant/src/tests/parity.rs b/vortex-turboquant/src/tests/parity.rs deleted file mode 100644 index 4360d90849d..00000000000 --- a/vortex-turboquant/src/tests/parity.rs +++ /dev/null @@ -1,38 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_array::VortexSessionExecute; -use vortex_array::validity::Validity; -use vortex_error::VortexResult; -use vortex_tensor::encodings::turboquant::TurboQuantConfig as OldTurboQuantConfig; -use vortex_tensor::encodings::turboquant::turboquant_encode; - -use super::execute_tq_decode; -use super::execute_tq_encode; -use super::f32_vector_array; -use super::test_session; -use super::vector_values_f32; -use crate::TurboQuantConfig; - -#[test] -fn encode_decode_matches_old_turboquant_decode() -> VortexResult<()> { - let session = test_session(); - let mut ctx = session.create_execution_ctx(); - let input = f32_vector_array(128, 2, 0.125, Validity::NonNullable)?; - let config = TurboQuantConfig::try_new(3, 42, 3)?; - - let new_encoded = execute_tq_encode(input.clone(), &config, &mut ctx)?; - let new_decoded = execute_tq_decode(new_encoded, &mut ctx)?; - let old_config = OldTurboQuantConfig { - bit_width: config.bit_width(), - seed: config.seed(), - num_rounds: config.num_rounds(), - }; - let old_decoded = turboquant_encode(input, &old_config, &mut ctx)?.execute(&mut ctx)?; - - let new_values = vector_values_f32(new_decoded, &mut ctx)?; - let old_values = vector_values_f32(old_decoded, &mut ctx)?; - - assert_eq!(new_values, old_values); - Ok(()) -} diff --git a/vortex-turboquant/src/tests/scalar_fns.rs b/vortex-turboquant/src/tests/scalar_fns.rs index de125e8a5f1..ed8ded8135e 100644 --- a/vortex-turboquant/src/tests/scalar_fns.rs +++ b/vortex-turboquant/src/tests/scalar_fns.rs @@ -20,7 +20,7 @@ use crate::vtable::tq_metadata; #[test] fn scalar_fn_ids_and_options_roundtrip() -> VortexResult<()> { let session = test_session(); - let config = TurboQuantConfig::try_new(4, 7, 2)?; + let config = TurboQuantConfig::try_new(4, 7, 2, None)?; assert_eq!(TQEncode.id().as_ref(), "vortex.turboquant.encode"); assert_eq!(TQDecode.id().as_ref(), "vortex.turboquant.decode"); @@ -42,12 +42,13 @@ fn scalar_fn_arrays_encode_and_decode_vectors() -> VortexResult<()> { let session = test_session(); let mut ctx = session.create_execution_ctx(); let input = f32_vector_array(128, 2, 0.25, Validity::from_iter([true, false]))?; - let config = TurboQuantConfig::try_new(3, 42, 3)?; + let config = TurboQuantConfig::try_new(3, 42, 3, None)?; let encoded_lazy = TQEncode::try_new_array(input, &config)?; let encoded_metadata = tq_metadata(encoded_lazy.dtype())?; assert_eq!(encoded_metadata.dimensions, 128); assert_eq!(encoded_metadata.bit_width, config.bit_width()); + assert_eq!(encoded_metadata.block_sizes, vec![128]); assert!(encoded_lazy.dtype().as_extension().is::()); let encoded = encoded_lazy.into_array().execute(&mut ctx)?; diff --git a/vortex-turboquant/src/vector/dequantize.rs b/vortex-turboquant/src/vector/dequantize.rs new file mode 100644 index 00000000000..629706cb0c0 --- /dev/null +++ b/vortex-turboquant/src/vector/dequantize.rs @@ -0,0 +1,185 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Block-aware TurboQuant encode pipeline. + +// TODO(connor): More docs! + +use num_traits::Float; +use num_traits::FromPrimitive; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::dtype::NativePType; +use vortex_array::validity::Validity; +use vortex_buffer::Buffer; +use vortex_buffer::BufferMut; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; +use vortex_mask::Mask; +use vortex_tensor::vector::Vector; + +use crate::TurboQuantMetadata; +use crate::sorf::transform::SorfMatrix; +use crate::vector::storage::Block; + +/// Borrowed bundle of per-array decode inputs passed to the typed inner loop. +/// +/// Packaged as a struct rather than positional arguments because [`decode_typed`] runs through +/// [`match_each_float_ptype!`] which expands once per supported element ptype. Each expansion +/// takes the same set of inputs, and the struct keeps the call site short. +pub(crate) struct DecodeInputs<'a> { + /// TurboQuant metadata recovered from the input extension dtype. + pub(crate) metadata: &'a TurboQuantMetadata, + + /// Block widths in `usize`, parallel to `metadata.block_sizes`. Cached to avoid repeated + /// `usize::try_from` in the row loop. + pub(crate) block_sizes: &'a [usize], + + /// Sum of `block_sizes`. The decode loop's row-aligned scratch buffer is this wide. + pub(crate) total_width: usize, + + /// One `SorfMatrix` per block, seeded via `derive_block_seed(metadata.seed, i)`. + pub(crate) sorf_matrices: &'a [SorfMatrix], + + /// One centroid table per block, keyed on `(block_sizes[i], bit_width)`. + pub(crate) centroid_tables: &'a [Buffer], + + /// Per-block executed `(norms, codes)` storage children, in block order. + pub(crate) block_storages: &'a [Block], +} + +// TODO(connor): Clean up this function! +pub(crate) fn decode_typed( + decode: DecodeInputs<'_>, + vector_validity: Validity, + num_vectors: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult +where + T: NativePType + Float + FromPrimitive, +{ + let metadata = decode.metadata; + let dimensions = usize::try_from(metadata.dimensions) + .vortex_expect("dimensions stays representable as usize"); + let mask = vector_validity.execute_mask(num_vectors, ctx)?; + + let output_len = num_vectors + .checked_mul(dimensions) + .ok_or_else(|| vortex_err!("TurboQuant decoded vector length overflow"))?; + let mut output = BufferMut::::with_capacity(output_len); + + let mut decoded_blocks: Vec> = + decode.block_sizes.iter().map(|&b| vec![0.0; b]).collect(); + let mut inverse_blocks: Vec> = + decode.block_sizes.iter().map(|&b| vec![0.0; b]).collect(); + // `total_width == sum(block_sizes) >= dimensions` (enforced by validate_block_sum at metadata + // construction/deserialize), so the `row_scratch[..dimensions]` copy below is always in bounds. + let mut row_scratch = vec![0.0f32; decode.total_width]; + + let block_norms: Vec<&[T]> = decode + .block_storages + .iter() + .map(|bs| bs.norms.as_slice::()) + .collect(); + let block_codes: Vec<&[u8]> = decode + .block_storages + .iter() + .map(|bs| bs.codes.as_slice::()) + .collect(); + + // `decode_row` is fallible and validates each code against its block's centroid table at the + // lookup site. Validation happens here rather than up front over every physical row so that + // null / masked-out rows, whose placeholder codes are never decoded, cannot trip a bounds + // error: the closure is only invoked for rows selected by `mask` below. + let mut decode_row = |output: &mut BufferMut, row: usize| -> VortexResult<()> { + let mut offset = 0usize; + for (block_index, &block) in decode.block_sizes.iter().enumerate() { + let code_row = &block_codes[block_index][row * block..][..block]; + let centroids = decode.centroid_tables[block_index].as_slice(); + for (dst, &code) in decoded_blocks[block_index].iter_mut().zip(code_row.iter()) { + *dst = *centroids.get(code as usize).ok_or_else(|| { + vortex_err!( + "TurboQuant code {code} exceeds centroid count {} for block {block_index}", + centroids.len() + ) + })?; + } + decode.sorf_matrices[block_index].inverse_transform( + &decoded_blocks[block_index], + &mut inverse_blocks[block_index], + ); + + let norm = block_norms[block_index][row]; + let norm_f32 = norm + .to_f32() + .vortex_expect("to_f32 is infallible for supported float types"); + // A stored norm must be a finite, non-negative magnitude; reject malformed storage that + // would otherwise scale the reconstruction by garbage (or sign-flip it). + vortex_ensure!( + norm_f32.is_finite() && norm_f32 >= 0.0, + "TurboQuant stored block norm is not a valid magnitude, got {norm_f32}" + ); + for (dst, &value) in row_scratch[offset..offset + block] + .iter_mut() + .zip(inverse_blocks[block_index].iter()) + { + *dst = value * norm_f32; + } + offset += block; + } + debug_assert_eq!(offset, decode.total_width); + + for &value in &row_scratch[..dimensions] { + let value = T::from_f32(value) + .vortex_expect("from_f32 is infallible for supported float types"); + // SAFETY: total pushes equal `output_len` across all match arms below. + unsafe { output.push_unchecked(value) }; + } + Ok(()) + }; + + match &mask { + Mask::AllFalse(_) => { + // SAFETY: `output` has capacity `output_len` and this writes exactly `output_len` + // zero placeholders, so the push stays within the reserved capacity. + unsafe { output.push_n_unchecked(T::zero(), output_len) }; + } + Mask::AllTrue(_) => { + for row in 0..num_vectors { + decode_row(&mut output, row)?; + } + } + Mask::Values(values_mask) => { + let mut cursor = 0; + for &(start, end) in values_mask.slices() { + if start > cursor { + // SAFETY: total pushes across all arms equal `output_len`. + unsafe { output.push_n_unchecked(T::zero(), (start - cursor) * dimensions) }; + } + for row in start..end { + decode_row(&mut output, row)?; + } + cursor = end; + } + if cursor < num_vectors { + // SAFETY: total pushes across all arms equal `output_len`. + unsafe { output.push_n_unchecked(T::zero(), (num_vectors - cursor) * dimensions) }; + } + } + } + + let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + metadata.dimensions, + vector_validity, + num_vectors, + )?; + + Vector::try_new_vector_array(fsl.into_array()) +} diff --git a/vortex-turboquant/src/vector/mod.rs b/vortex-turboquant/src/vector/mod.rs index f4fe8726103..effa59a0f47 100644 --- a/vortex-turboquant/src/vector/mod.rs +++ b/vortex-turboquant/src/vector/mod.rs @@ -1,27 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! Vector-side helpers: normalization, quantization, and physical storage layout. +//! Vector-side helpers: block-aware quantization and physical storage layout. -pub(crate) mod normalize; +pub(crate) mod dequantize; pub(crate) mod quantize; pub(crate) mod storage; - -use vortex_error::VortexResult; -use vortex_error::vortex_err; - -/// Compute the padded SORF dimension for an original vector dimension. -/// -/// The SORF transform requires a power-of-two width, so non-power-of-two input dimensions are -/// padded with zeros up to the next power of two. The padded dimension is stored implicitly via -/// [`TurboQuantMetadata::dimensions`](crate::TurboQuantMetadata) plus the codes child's -/// `FixedSizeList` width and recovered at decode time via this function. Returns an error when -/// the next power of two overflows the input integer type. -pub(crate) fn tq_padded_dim(dimensions: u32) -> VortexResult { - let padded_dim = dimensions - .checked_next_power_of_two() - .ok_or_else(|| vortex_err!("TurboQuant padded dimension overflow for {dimensions}"))?; - - usize::try_from(padded_dim) - .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit usize")) -} diff --git a/vortex-turboquant/src/vector/normalize.rs b/vortex-turboquant/src/vector/normalize.rs deleted file mode 100644 index 642949eecf6..00000000000 --- a/vortex-turboquant/src/vector/normalize.rs +++ /dev/null @@ -1,236 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! TurboQuant-local vector normalization. - -// TODO(connor): Remove this comment once we delete the other version in `vortex-tensor`. -// The tensor crate also has a `normalize_as_l2_denorm` helper, but TurboQuant needs different -// validity semantics: a null vector is not a zero vector, so invalid rows keep their row validity -// on both `L2Denorm` children and downstream quantization skips them. - -use num_traits::Float; -use vortex_array::ArrayRef; -use vortex_array::ExecutionCtx; -use vortex_array::IntoArray; -use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::PrimitiveArray; -use vortex_array::arrays::ScalarFnArray; -use vortex_array::arrays::extension::ExtensionArrayExt; -use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex_array::dtype::NativePType; -use vortex_array::extension::EmptyMetadata; -use vortex_array::match_each_float_ptype; -use vortex_array::validity::Validity; -use vortex_buffer::BufferMut; -use vortex_error::VortexResult; -use vortex_error::vortex_ensure_eq; -use vortex_error::vortex_err; -use vortex_mask::Mask; -use vortex_mask::MaskValues; -use vortex_tensor::scalar_fns::l2_denorm::L2Denorm; -use vortex_tensor::scalar_fns::l2_norm::L2Norm; -use vortex_tensor::vector::AnyVector; -use vortex_tensor::vector::Vector; - -/// Normalize a `Vector` array and wrap it with its original row norms with [`L2Denorm`]. -/// -/// This preserves input row validity on both [`L2Denorm`] children. Or in other words, validity is -/// propagated down to the children so that TurboQuant can skip quantizing those vectors (as it does -/// not have a good way to represent 0 vectors in its quantized domain). -pub(crate) fn tq_normalize_as_l2_denorm( - input: ArrayRef, - ctx: &mut ExecutionCtx, -) -> VortexResult { - let row_count = input.len(); - let vector_metadata = input - .dtype() - .as_extension_opt() - .and_then(|ext_dtype| ext_dtype.metadata_opt::()) - .ok_or_else(|| vortex_err!("TurboQuant normalization expects a Vector extension array"))?; - let dimensions = vector_metadata.dimensions() as usize; - let vector_validity = input.validity()?; - - // Use `L2Norm` to calculate the normals for each vector. - let norms: ArrayRef = L2Norm::try_new_array(input.clone(), row_count)? - .into_array() - .execute(ctx)?; - let primitive_norms: PrimitiveArray = norms.clone().execute(ctx)?; - - let input: ExtensionArray = input.execute(ctx)?; - let storage: FixedSizeListArray = input.storage_array().clone().execute(ctx)?; - vortex_ensure_eq!( - storage.list_size() as usize, - dimensions, - "Vector storage dimension must be {dimensions}, got {}", - storage.list_size() - ); - let elements: PrimitiveArray = storage.elements().clone().execute(ctx)?; - - let mask = vector_validity.execute_mask(row_count, ctx)?; - - let normalized = match_each_float_ptype!(elements.ptype(), |T| { - normalize_vectors::( - &elements, - &primitive_norms, - &mask, - dimensions, - vector_validity.clone(), - ) - })?; - - // SAFETY: matches the lossy-encoding relaxation documented on - // [`L2Denorm::new_array_unchecked`]. Norms come from `L2Norm` over the same input, so they - // match the vector element type and row count. Valid nonzero rows are divided by their stored - // norm and are unit-norm. Valid zero-norm rows and invalid rows use physical zero placeholders; - // invalid rows remain guarded by row-level invalid validity. - unsafe { L2Denorm::new_array_unchecked(normalized, norms, row_count) } -} - -fn normalize_vectors( - elements: &PrimitiveArray, - norms: &PrimitiveArray, - mask: &Mask, - dimensions: usize, - vector_validity: Validity, -) -> VortexResult -where - T: Float + NativePType, -{ - let num_vectors = norms.len(); - - let values = elements.as_slice::(); - let norm_values = norms.as_slice::(); - - let output_len = num_vectors - .checked_mul(dimensions) - .ok_or_else(|| vortex_err!("TurboQuant normalized vector length overflow"))?; - let mut output = BufferMut::::with_capacity(output_len); - - // The total number of pushes is always exactly `num_vectors * dimensions == output_len` - // across every arm below, which is the invariant the per-row `unsafe` blocks rely on. - match mask { - Mask::AllFalse(_) => { - // Every row is invalid: bulk-fill the output with zero placeholders. - // - // SAFETY: `output` was allocated with capacity `output_len`, and this push writes - // exactly `output_len` zero placeholders. - unsafe { output.push_n_unchecked(T::zero(), output_len) }; - } - Mask::AllTrue(_) => { - for i in 0..num_vectors { - // SAFETY: `output` was allocated with capacity `output_len = num_vectors * - // dimensions`. This loop runs `num_vectors` times and each call pushes exactly - // `dimensions` elements, so capacity for `dimensions` more elements always - // remains. - unsafe { normalize_one_row::(&mut output, values, norm_values, dimensions, i) }; - } - } - Mask::Values(values_mask) => { - // SAFETY: `output` was allocated with capacity `output_len = num_vectors * - // dimensions`, which is the bound the helper requires. - unsafe { - normalize_vectors_with_mask::( - &mut output, - values, - norm_values, - dimensions, - num_vectors, - values_mask, - ) - }; - } - } - - // Vector elements are always non-nullable. - let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); - - #[expect( - clippy::cast_possible_truncation, - reason = "this initially came from a u32" - )] - let storage = FixedSizeListArray::try_new( - elements.into_array(), - dimensions as u32, - vector_validity, - num_vectors, - )?; - - Ok( - ExtensionArray::try_new_from_vtable(Vector, EmptyMetadata, storage.into_array())? - .into_array(), - ) -} - -/// Normalize a single valid row, or push `dimensions` zero placeholders if the row's L2 norm -/// is zero. -/// -/// A valid vector with L2 norm zero is all zeros, so dividing through it would be undefined. -/// Treating it the same as an invalid row preserves the original semantics. -/// -/// # Safety -/// -/// `output` must have capacity for at least `dimensions` more elements before this call. -unsafe fn normalize_one_row( - output: &mut BufferMut, - values: &[T], - norm_values: &[T], - dimensions: usize, - i: usize, -) where - T: Float + NativePType, -{ - let norm = norm_values[i]; - - if norm == T::zero() { - // SAFETY: caller guarantees capacity for `dimensions` more elements. - unsafe { output.push_n_unchecked(T::zero(), dimensions) }; - } else { - let row_values = &values[i * dimensions..][..dimensions]; - - for &value in row_values { - // SAFETY: caller guarantees capacity for `dimensions` more elements. - unsafe { output.push_unchecked(value / norm) }; - } - } -} - -/// Walk the pre-cached run boundaries of a `Values` mask, bulk-pushing zero placeholders for -/// invalid runs and normalizing valid runs row by row. -/// -/// # Safety -/// -/// `output` must have capacity for at least `num_vectors * dimensions` more elements before -/// this call. -unsafe fn normalize_vectors_with_mask( - output: &mut BufferMut, - values: &[T], - norm_values: &[T], - dimensions: usize, - num_vectors: usize, - values_mask: &MaskValues, -) where - T: Float + NativePType, -{ - let mut cursor = 0; - - for &(start, end) in values_mask.slices() { - if start > cursor { - // SAFETY: capacity invariant from caller. - unsafe { output.push_n_unchecked(T::zero(), (start - cursor) * dimensions) }; - } - - for i in start..end { - // SAFETY: capacity invariant from caller — each call pushes `dimensions` and the - // total number of valid rows in the mask is bounded by `num_vectors`. - unsafe { normalize_one_row::(output, values, norm_values, dimensions, i) }; - } - - cursor = end; - } - - if cursor < num_vectors { - // SAFETY: capacity invariant from caller. - unsafe { output.push_n_unchecked(T::zero(), (num_vectors - cursor) * dimensions) }; - } -} diff --git a/vortex-turboquant/src/vector/quantize.rs b/vortex-turboquant/src/vector/quantize.rs index 0861b9f6805..9c3e56605c0 100644 --- a/vortex-turboquant/src/vector/quantize.rs +++ b/vortex-turboquant/src/vector/quantize.rs @@ -1,161 +1,297 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! Core TurboQuant quantization helpers. +//! Block-aware TurboQuant encode pipeline. //! -//! Quantization consumes the TurboQuant-local normalized `Vector` child. Valid rows are transformed -//! and mapped to scalar centroid indices. Invalid rows remain in the full-length output but are -//! skipped: their physical code bytes are placeholders guarded by the `codes` row validity. +//! Each block of an input vector array is encoded independently: per-row L2 norm, per-row SORF +//! transform sized to the block, and per-row scalar quantization against the block's centroid +//! table. The output is one [`Block`] per block in `block_sizes`, each row-aligned to the +//! input row count and carrying the input's row validity. //! -//! This matters because TurboQuant's scalar codebook is optimized for coordinates of transformed -//! unit-norm vectors. The codebook does not generally contain an exact zero centroid, and a -//! physical code byte of `0` means "centroid 0", not "zero coordinate". Null vectors therefore -//! should not be converted to zero vectors and fed through the quantizer. +//! # Block slicing +//! +//! Block `i` covers input coordinates `[offset_i .. offset_i + block_sizes[i])`, where +//! `offset_i = sum(block_sizes[..i])`. When a block extends past `dimensions` its tail is +//! zero-padded; a block whose `offset_i >= dimensions` is entirely padding. Such overspilling +//! block lists are valid, not rejected; `resolve_block_sizes` emits a `tracing::warn!` only for a +//! block lying entirely past `dimensions` or a sum exceeding `2 * dimensions`. +//! +//! # Per-block algorithm +//! +//! For each block `i` of each valid input row, the encoder: +//! +//! 1. Slices the block out of the input, zero-padding any range that extends past `dimensions`. +//! 2. Computes the block's L2 norm and writes it into that block's `norms` column. +//! 3. Divides the slice by that norm to produce a unit-norm block. +//! 4. Applies a SORF transform of width `block_sizes[i]` seeded with +//! `derive_block_seed(config.seed(), i)`, so every block has its own distinct rotation even +//! when two blocks share the same width. +//! 5. Scalar-quantizes the rotated coordinates against a `2^bit_width`-entry centroid table built +//! for width `block_sizes[i]` and writes the codes into that block's `codes` column. +//! +//! # Null and zero-norm rows +//! +//! Per-row null and zero-norm handling mirrors the previous single-block pipeline: a null row +//! writes zero placeholders into every block's `norms` and `codes`, and a valid row whose block +//! slice has zero norm writes zeros into that block's children only. use half::f16; +use num_traits::Float; +use num_traits::FromPrimitive; +use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; +use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::dtype::NativePType; use vortex_array::dtype::PType; +use vortex_array::match_each_float_ptype; +use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_buffer::BufferMut; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_mask::Mask; +use vortex_tensor::vector::AnyVector; -use super::tq_padded_dim; -use crate::TurboQuantConfig; -use crate::centroids::compute_centroid_boundaries; -use crate::centroids::compute_or_get_centroids; +use crate::centroids::compute_or_get_codebook; use crate::centroids::find_nearest_centroid; -use crate::sorf::SorfMatrix; +use crate::sorf::splitmix64::derive_block_seed; +use crate::sorf::transform::SorfMatrix; +use crate::vector::storage::Block; -/// Shared intermediate results from the quantization loop. -pub(crate) struct QuantizationResult { - pub(crate) all_indices: Buffer, - pub(crate) padded_dim: usize, +/// Per-block precomputed runtime state shared across rows. +/// +/// Built once per encode call and reused for every row of the input array. +pub(crate) struct BlockRuntimeState { + /// One [`SorfMatrix`] per block, sized to its block width and seeded from [`derive_block_seed`] + /// `(global_seed, block_index)`. + matrices: Vec, + /// Precomputed centroid boundaries used by [`find_nearest_centroid`], one cheap-to-clone + /// reference-counted [`Buffer`] per block. + boundaries: Vec>, } -pub(crate) fn empty_quantization(padded_dim: usize) -> QuantizationResult { - QuantizationResult { - all_indices: Buffer::empty(), - padded_dim, +/// Build the per-block SORF transforms and centroid tables for a given config and resolved block +/// list. Inexpensive when the centroid cache is warm. +pub(crate) fn prepare_block_state( + seed: u64, + num_rounds: u8, + bit_width: u8, + block_sizes: &[u32], +) -> VortexResult { + let mut matrices = Vec::with_capacity(block_sizes.len()); + let mut boundaries = Vec::with_capacity(block_sizes.len()); + + for (index, &block) in block_sizes.iter().enumerate() { + let block_usize = usize::try_from(block) + .map_err(|_| vortex_err!("TurboQuant block {block} does not fit usize"))?; + + // Each block gets a distinct SORF rotation derived from the global seed and its index. + let seed_i = derive_block_seed(seed, index); + + matrices.push(SorfMatrix::try_new( + block_usize, + num_rounds as usize, + seed_i, + )?); + + boundaries.push( + compute_or_get_codebook(block, bit_width)? + .boundaries + .clone(), + ); } + + Ok(BlockRuntimeState { + matrices, + boundaries, + }) } -/// Core quantization: transform and quantize already-normalized rows. -/// -/// # Safety +/// Encode every block of `input` (the original `Vector` extension array that is not pre-normalized) +/// into its own `(norms, codes)` row-aligned pair. /// -/// The input `fsl` must contain unit-norm vectors (already L2-normalized) for every valid row. -/// Invalid rows are left row-aligned in the output but are not transformed or quantized. The -/// transform and centroid lookup happen in f32. -pub(crate) unsafe fn turboquant_quantize_core( - fsl: &FixedSizeListArray, - config: &TurboQuantConfig, +/// Returns one [`Block`] per block in `block_sizes`, each carrying `num_vectors` rows. +pub(crate) fn turboquant_encode_blocks( + input: ArrayRef, + block_sizes: &[u32], + state: &BlockRuntimeState, + vector_validity: Validity, ctx: &mut ExecutionCtx, -) -> VortexResult { - let dimension = fsl.list_size(); - let num_vectors = fsl.len(); - let padded_dim = tq_padded_dim(dimension)?; - - let sorf_transform = - SorfMatrix::try_new(padded_dim, config.num_rounds() as usize, config.seed())?; - debug_assert_eq!(sorf_transform.padded_dim(), padded_dim); - let padded_dim_u32 = u32::try_from(padded_dim) - .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; - - let elements_prim: PrimitiveArray = fsl.elements().clone().execute(ctx)?; - let f32_elements = cast_to_f32(elements_prim)?; - let validity = fsl.validity()?; - let mask = validity.execute_mask(num_vectors, ctx)?; - - let centroids = compute_or_get_centroids(padded_dim_u32, config.bit_width())?; - let boundaries = compute_centroid_boundaries(¢roids); - - let codes_len = num_vectors - .checked_mul(padded_dim) - .ok_or_else(|| vortex_err!("TurboQuant codes length overflow"))?; - let mut all_indices = BufferMut::::with_capacity(codes_len); - - let mut padded = vec![0.0f32; padded_dim]; - let mut transformed = vec![0.0f32; padded_dim]; - - // Pad, SORF-transform, and quantize a single row, pushing `padded_dim` codes into - // `all_indices`. Captures the read-only inputs and the scratch buffers so each call site - // only needs to pass `all_indices` and the row index. - // - // NB: `all_indices` cannot be captured here: the `Values` arm interleaves the closure call - // with direct `all_indices.push_n_unchecked` calls. - let f32_slice = f32_elements.as_slice(); - let dimension = dimension as usize; - let mut quantize_row = |all_indices: &mut BufferMut, row: usize| { - // Reuse `padded` and `transformed` from the outer scope. - padded[..dimension].copy_from_slice(&f32_slice[row * dimension..][..dimension]); - padded[dimension..].fill(0.0); - sorf_transform.transform(&padded, &mut transformed); - - for &value in &transformed { - // SAFETY: total pushes across all match arms equal `codes_len`. - unsafe { all_indices.push_unchecked(find_nearest_centroid(value, &boundaries)) }; - } - }; - - // The total number of pushes is always exactly `num_vectors * padded_dim == codes_len` - // across every arm below, which is the invariant the per-row `unsafe` blocks rely on. - match &mask { - Mask::AllFalse(_) => { - // Every row is invalid: bulk-fill placeholder zero codes. - // - // SAFETY: `all_indices` was allocated with capacity `codes_len`, and this push - // writes exactly `codes_len` zero codes. - unsafe { all_indices.push_n_unchecked(0, codes_len) }; - } - Mask::AllTrue(_) => { - for row in 0..num_vectors { - quantize_row(&mut all_indices, row); - } - } - Mask::Values(values_mask) => { - let mut cursor = 0; +) -> VortexResult> { + let num_vectors = input.len(); + let vector_metadata = input + .dtype() + .as_extension_opt() + .and_then(|ext_dtype| ext_dtype.metadata_opt::()) + .ok_or_else(|| vortex_err!("TurboQuant encode expects a Vector extension array"))?; - for &(start, end) in values_mask.slices() { - if start > cursor { - // SAFETY: total pushes across all arms equal `codes_len`. - unsafe { all_indices.push_n_unchecked(0, (start - cursor) * padded_dim) }; - } + let dimensions = usize::try_from(vector_metadata.dimensions()) + .map_err(|_| vortex_err!("TurboQuant dimensions does not fit usize"))?; + let element_ptype = vector_metadata.element_ptype(); - for row in start..end { - quantize_row(&mut all_indices, row); - } + let extension: ExtensionArray = input.execute(ctx)?; + let storage: FixedSizeListArray = extension.storage_array().clone().execute(ctx)?; + let elements: PrimitiveArray = storage.elements().clone().execute(ctx)?; + let mask = vector_validity.execute_mask(num_vectors, ctx)?; - cursor = end; + // TODO(connor): It would be more "correct" to compute norms **before** casting to f32. + let f32_input = cast_to_f32(elements)?; + let f32_slice = f32_input.as_slice(); + + // `encode_blocks_typed` is monomorphized per float ptype although its hot loop runs in f32, and + // only the output norm column depends on `T`. + let block_arrays = match_each_float_ptype!(element_ptype, |T| { + encode_blocks_typed::( + f32_slice, + dimensions, + num_vectors, + &mask, + block_sizes, + state, + vector_validity.clone(), + )? + }); + + Ok(block_arrays) +} + +// TODO(connor): Clean up this function! +fn encode_blocks_typed( + input: &[f32], + dimensions: usize, + num_vectors: usize, + mask: &Mask, + block_sizes: &[u32], + state: &BlockRuntimeState, + vector_validity: Validity, +) -> VortexResult> +where + T: NativePType + Float + FromPrimitive, +{ + let block_widths: Vec = block_sizes + .iter() + .map(|&b| { + usize::try_from(b).map_err(|_| vortex_err!("TurboQuant block {b} does not fit usize")) + }) + .collect::>>()?; + + // `total_block_width` sizes the per-block scratch and the offset-invariant assert below; the + // `sum >= dimensions` rule itself is enforced upstream by `validate_block_sum` (via + // `resolve_block_sizes`), so it is not re-checked here. + let total_block_width: usize = block_widths.iter().sum(); + + // Per-block output buffers. `norms_out[b]` collects `num_vectors` block-norm values; + // `codes_out[b]` collects `num_vectors * block_sizes[b]` u8 codes. + let mut norms_out: Vec> = block_sizes + .iter() + .map(|_| BufferMut::::with_capacity(num_vectors)) + .collect(); + let mut codes_out: Vec> = block_widths + .iter() + .map(|&b| { + let len = num_vectors + .checked_mul(b) + .ok_or_else(|| vortex_err!("TurboQuant codes length overflow"))?; + Ok::<_, vortex_error::VortexError>(BufferMut::::with_capacity(len)) + }) + .collect::>()?; + + // Per-block scratch buffers reused across rows. + let mut padded_scratch: Vec> = block_widths.iter().map(|&b| vec![0.0f32; b]).collect(); + let mut transformed_scratch: Vec> = + block_widths.iter().map(|&b| vec![0.0f32; b]).collect(); + + for row in 0..num_vectors { + let is_valid = mask.value(row); + let row_input = &input[row * dimensions..][..dimensions]; + let mut offset = 0usize; + for (block_index, &block) in block_widths.iter().enumerate() { + if !is_valid { + // SAFETY: norms_out[block_index] reserved `num_vectors` capacity at start. + unsafe { norms_out[block_index].push_unchecked(T::zero()) }; + // SAFETY: codes_out[block_index] reserved `num_vectors * block` capacity. + unsafe { codes_out[block_index].push_n_unchecked(0u8, block) }; + offset += block; + continue; + } + // Copy the row's block slice into the scratch buffer, zero-padding the final block + // when `offset + block > dimensions`. + let take = block.min(dimensions.saturating_sub(offset)); + if take > 0 { + padded_scratch[block_index][..take] + .copy_from_slice(&row_input[offset..offset + take]); } + if take < block { + padded_scratch[block_index][take..].fill(0.0); + } + // Computed in f32 to match the SORF transform precision. For f64 inputs this is an + // intentional precision downgrade relative to the legacy per-input-ptype `L2Norm`, + // accepted as part of the block-decomposition wire-format break. + let norm_sq: f32 = padded_scratch[block_index] + .iter() + .map(|&v| v * v) + .sum::(); + let norm_f32 = norm_sq.sqrt(); + let norm_value = T::from_f32(norm_f32) + .vortex_expect("from_f32 is infallible for supported float types"); + // Reject a non-finite stored norm (an input magnitude out of the element type's range) + // rather than emit an array the decoder cannot reconstruct. + if !norm_value.is_finite() { + vortex_bail!( + "TurboQuant block norm is not finite; an input magnitude is out of range" + ); + } + // SAFETY: capacity reserved above. + unsafe { norms_out[block_index].push_unchecked(norm_value) }; - if cursor < num_vectors { - // SAFETY: total pushes across all arms equal `codes_len`. - unsafe { all_indices.push_n_unchecked(0, (num_vectors - cursor) * padded_dim) }; + if norm_f32 == 0.0 { + // SAFETY: capacity reserved above. + unsafe { codes_out[block_index].push_n_unchecked(0u8, block) }; + offset += block; + continue; } + + // Normalize in place by the block norm. + for value in padded_scratch[block_index].iter_mut() { + *value /= norm_f32; + } + state.matrices[block_index].transform( + &padded_scratch[block_index], + &mut transformed_scratch[block_index], + ); + + let boundaries = &state.boundaries[block_index]; + for &value in &transformed_scratch[block_index] { + let code = find_nearest_centroid(value, boundaries); + // SAFETY: capacity reserved above. + unsafe { codes_out[block_index].push_unchecked(code) }; + } + offset += block; } + debug_assert_eq!(offset, total_block_width); } - Ok(QuantizationResult { - all_indices: all_indices.freeze(), - padded_dim, - }) + let mut result = Vec::with_capacity(block_sizes.len()); + for block_index in 0..block_sizes.len() { + let norms_buf = std::mem::take(&mut norms_out[block_index]).freeze(); + let codes_buf = std::mem::take(&mut codes_out[block_index]).freeze(); + let norms = PrimitiveArray::new::(norms_buf, vector_validity.clone()); + let codes = PrimitiveArray::new::(codes_buf, Validity::NonNullable); + result.push(Block { norms, codes }); + } + Ok(result) } /// Cast a float [`PrimitiveArray`] to a `Buffer`. /// -/// Several operations in this crate (SORF transform, TurboQuant quantization) work exclusively -/// in f32. This function handles the cast from any float ptype: -/// -/// - f16: losslessly widened to f32. -/// - f32: zero-copy buffer extraction. -/// - f64: truncated to f32 precision. Values outside f32 range become +/- infinity. This is -/// acceptable because callers of this function operate in f32 and document this constraint. +/// All in-loop arithmetic happens in f32 for SORF compatibility; the input element ptype is +/// lossily widened or narrowed once at the start. fn cast_to_f32(prim: PrimitiveArray) -> VortexResult> { match prim.ptype() { PType::F16 => Ok(prim diff --git a/vortex-turboquant/src/vector/storage.rs b/vortex-turboquant/src/vector/storage.rs index d1b4f06cc05..468582b0a97 100644 --- a/vortex-turboquant/src/vector/storage.rs +++ b/vortex-turboquant/src/vector/storage.rs @@ -3,23 +3,22 @@ //! TurboQuant physical storage helpers. //! -//! TurboQuant storage is row-aligned and full length: +//! Block-decomposed TurboQuant storage is a row-aligned outer struct of inner +//! `Struct { norms, codes }` blocks, one per power-of-two block size in `metadata.block_sizes`: //! //! ```text //! Struct { -//! norms: Primitive, -//! codes: FixedSizeList, padded_dim, vector_validity>, +//! block_0: Struct { +//! norms: Primitive, +//! codes: FixedSizeList, block_sizes[0], vector_validity>, +//! }, +//! ... +//! block_{N-1}: Struct { norms: ..., codes: FixedSizeList }, //! } //! ``` //! -//! Row nullability is carried on the outer struct and on the `norms` and `codes` field arrays. -//! This is deliberate duplication: null vectors remain null throughout encode/decode instead of being -//! converted into zero vectors. The code bytes for invalid rows are physical placeholders only; the -//! field-level validity records that those rows were not quantized. -//! -//! Parsing treats the outer struct validity as authoritative. Child validity may be wider than -//! the struct validity (for example after a generic mask only updates the struct validity), but -//! each child must be valid wherever the struct row is valid. +//! Outer struct validity is authoritative. Each inner block's struct validity must cover the outer. +//! Additionally, each inner block's `norms` and `codes` validity must cover the inner struct. use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; @@ -35,130 +34,150 @@ use vortex_array::dtype::FieldNames; use vortex_array::validity::Validity; use vortex_error::VortexResult; use vortex_error::vortex_ensure; -use vortex_error::vortex_err; -use vortex_mask::Mask; +use vortex_error::vortex_ensure_eq; -use super::quantize::QuantizationResult; use crate::vtable::TurboQuantMetadata; use crate::vtable::tq_metadata; -/// Name of the stored row-norm child. +/// Name of the stored row-norm child inside an inner block struct. pub(crate) const NORMS_FIELD: &str = "norms"; -/// Name of the stored quantized-code child. +/// Name of the stored quantized-code child inside an inner block struct. pub(crate) const CODES_FIELD: &str = "codes"; -/// Executed storage children of a TurboQuant extension array plus the authoritative outer -/// struct validity. Every child is row-aligned to `len` and every child's validity covers -/// `vector_validity`. +/// Deterministic field name for the inner struct of block index `index`. +pub(crate) fn block_field_name(index: usize) -> String { + format!("block_{index}") +} + +/// The stored `(norms, codes)` of a single block. +/// +/// Encode produces these from the quantized rows. Decode recovers them by executing and unwrapping +/// the physical storage. +pub(crate) struct Block { + /// Per-row stored block L2 norm, in `metadata.element_ptype`. + pub(crate) norms: PrimitiveArray, + + /// Flat per-row centroid indices, `num_vectors * block_sizes[i]` bytes long. Indexed as + /// `codes[row * block_sizes[i] + j]`. + /// + /// The codes are flat here and only wrapped into a `FixedSizeList` by [`build_storage`] (and + /// unwrapped back to flat by [`parse_storage`]). + pub(crate) codes: PrimitiveArray, +} + +/// Executed storage of a TurboQuant extension array, decomposed into per-block children plus the +/// authoritative outer struct validity. Every child is row-aligned to `len` and every inner-block +/// child's validity covers `vector_validity`. pub(crate) struct TurboQuantParsedStorage { /// Metadata recovered from the input extension dtype. pub(crate) metadata: TurboQuantMetadata, - /// Authoritative row validity for the quantized vectors, taken from the outer struct. + + /// Authoritative row validity, taken from the outer struct. pub(crate) vector_validity: Validity, - /// Per-row stored L2 norm of the original input vector, in `metadata.element_ptype`. - pub(crate) norms: PrimitiveArray, - /// Flat `u8` per-row centroid indices, `num_vectors * padded_dim` entries long. - pub(crate) codes: PrimitiveArray, + + /// One [`Block`] per entry in `metadata.block_sizes`, in order. + pub(crate) blocks: Vec, + /// Row count. pub(crate) len: usize, } -/// Build the `codes: FixedSizeList, padded_dim>` storage child. +/// Build the outer TurboQuant storage array from per-block encoder output. /// -/// Each row of `padded_dim` u8 codes indexes into the deterministic centroid codebook derived -/// from `(padded_dim, bit_width)`. The centroid values are intentionally not stored in the array. -pub(crate) fn build_codes_child( +/// `blocks` must have one entry per block in `block_sizes`, in block order. Each block's flat codes +/// are wrapped into a `FixedSizeList` of the block's width and paired with its norms in a +/// `Struct { norms, codes }` field of the outer struct. +pub(crate) fn build_storage( + blocks: Vec, + block_sizes: &[u32], num_vectors: usize, - quantization: QuantizationResult, vector_validity: Validity, ) -> VortexResult { - let codes = PrimitiveArray::new::(quantization.all_indices, Validity::NonNullable); - let padded_dim_u32 = u32::try_from(quantization.padded_dim) - .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; + let mut names = Vec::with_capacity(blocks.len()); + let mut fields = Vec::with_capacity(blocks.len()); + + for (index, (block, &block_size)) in blocks.into_iter().zip(block_sizes.iter()).enumerate() { + names.push(block_field_name(index)); + + let codes_fsl = FixedSizeListArray::try_new( + block.codes.into_array(), + block_size, + vector_validity.clone(), + num_vectors, + )? + .into_array(); + let inner = StructArray::try_new( + FieldNames::from([NORMS_FIELD, CODES_FIELD]), + vec![block.norms.into_array(), codes_fsl], + num_vectors, + vector_validity.clone(), + )? + .into_array(); + fields.push(inner); + } - Ok(FixedSizeListArray::try_new( - codes.into_array(), - padded_dim_u32, - vector_validity, - num_vectors, - )? - .into_array()) -} - -/// Build the TurboQuant `Struct { norms, codes }` storage array. -pub(crate) fn build_storage( - norms: ArrayRef, - codes: ArrayRef, - len: usize, - vector_validity: Validity, -) -> VortexResult { Ok(StructArray::try_new( - FieldNames::from([NORMS_FIELD, CODES_FIELD]), - vec![norms, codes], - len, + FieldNames::from_iter(names), + fields, + num_vectors, vector_validity, )? .into_array()) } -/// Parse a TurboQuant extension array into executed storage children. +/// Parse a TurboQuant extension array into per-block executed storage children. pub(crate) fn parse_storage( input: ArrayRef, ctx: &mut ExecutionCtx, ) -> VortexResult { let metadata = tq_metadata(input.dtype())?; let ext: ExtensionArray = input.execute(ctx)?; - let storage: StructArray = ext.storage_array().clone().execute(ctx)?; - - let norms: PrimitiveArray = storage - .unmasked_field_by_name(NORMS_FIELD)? - .clone() - .execute(ctx)?; - - let codes_fsl: FixedSizeListArray = storage - .unmasked_field_by_name(CODES_FIELD)? - .clone() - .execute(ctx)?; - let codes: PrimitiveArray = codes_fsl.elements().clone().execute(ctx)?; - - let len = storage.len(); - let struct_validity = storage.struct_validity(); - let norms_validity = norms.validity()?; - let codes_validity = codes_fsl.validity()?; - - let struct_mask = struct_validity.execute_mask(len, ctx)?; - let norms_mask = norms_validity.execute_mask(len, ctx)?; - let codes_mask = codes_validity.execute_mask(len, ctx)?; - validate_child_validity_covers_struct(&struct_mask, &norms_mask, &codes_mask)?; + let outer: StructArray = ext.storage_array().clone().execute(ctx)?; + + let len = outer.len(); + let outer_validity = outer.struct_validity(); + let outer_mask = outer_validity.execute_mask(len, ctx)?; + + let mut blocks = Vec::with_capacity(metadata.block_sizes.len()); + for (index, &block) in metadata.block_sizes.iter().enumerate() { + let name = block_field_name(index); + let inner: StructArray = outer.unmasked_field_by_name(&name)?.clone().execute(ctx)?; + + // Ensure the outer struct mask covers the block mask. + let inner_validity = inner.struct_validity(); + let inner_mask = inner_validity.execute_mask(len, ctx)?; + vortex_ensure!(outer_mask.clone().bitand_not(&inner_mask).all_false()); + + let norms: PrimitiveArray = inner + .unmasked_field_by_name(NORMS_FIELD)? + .clone() + .execute(ctx)?; + let codes_fsl: FixedSizeListArray = inner + .unmasked_field_by_name(CODES_FIELD)? + .clone() + .execute(ctx)?; + vortex_ensure_eq!( + codes_fsl.list_size(), + block, + "TurboQuant inner block {name} {CODES_FIELD} list size must be {block}, got {}", + codes_fsl.list_size() + ); + let codes: PrimitiveArray = codes_fsl.elements().clone().execute(ctx)?; + + // Ensure that block mask covers the norms and codes masks. + let norms_mask = norms.validity()?.execute_mask(len, ctx)?; + let codes_mask = codes_fsl.validity()?.execute_mask(len, ctx)?; + vortex_ensure!(inner_mask.clone().bitand_not(&norms_mask).all_false()); + vortex_ensure!(inner_mask.clone().bitand_not(&codes_mask).all_false()); + + blocks.push(Block { norms, codes }); + } Ok(TurboQuantParsedStorage { metadata, - vector_validity: struct_validity, - norms, - codes, + vector_validity: outer_validity, + blocks, len, }) } - -/// Validate that both child masks cover the struct mask: every row that the struct considers -/// valid must also be valid in the `norms` and `codes` children. -/// -/// `struct_mask & !child_mask` selects rows where the struct is valid but the child is not. If -/// no such row exists, the child covers the struct. [`Mask::bitand_not`] is variant-specialized, -/// so this short-circuits in `O(1)` when either mask is `AllTrue` or `AllFalse`. -fn validate_child_validity_covers_struct( - struct_mask: &Mask, - norms_mask: &Mask, - codes_mask: &Mask, -) -> VortexResult<()> { - vortex_ensure!( - struct_mask.clone().bitand_not(norms_mask).all_false(), - "TurboQuant {NORMS_FIELD} row validity must cover storage validity" - ); - vortex_ensure!( - struct_mask.clone().bitand_not(codes_mask).all_false(), - "TurboQuant {CODES_FIELD} row validity must cover storage validity" - ); - Ok(()) -} diff --git a/vortex-turboquant/src/vtable.rs b/vortex-turboquant/src/vtable.rs index 854bcee6c70..b7c886f8435 100644 --- a/vortex-turboquant/src/vtable.rs +++ b/vortex-turboquant/src/vtable.rs @@ -22,29 +22,40 @@ use vortex_error::vortex_err; use crate::TurboQuantConfig; use crate::config::MIN_DIMENSION; +use crate::config::validate_block_shape; +use crate::config::validate_block_sum; use crate::vector::storage::CODES_FIELD; use crate::vector::storage::NORMS_FIELD; -use crate::vector::tq_padded_dim; +use crate::vector::storage::block_field_name; /// TurboQuant logical extension type. Per-array configuration lives in [`TurboQuantMetadata`]. #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] pub struct TurboQuant; -/// Serialized metadata for a TurboQuant extension array. The fields together suffice to -/// reconstruct the SORF transform and centroid codebook at decode time. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +/// Serialized metadata for a TurboQuant extension array. The fields together suffice to reconstruct +/// the SORF transforms, centroid codebooks, and storage layout at decode time. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct TurboQuantMetadata { - /// Original vector element ptype and stored row-norm ptype. Restricted to `f16` / `f32` / - /// `f64`. + /// Original vector element ptype and stored row-norm ptype. Restricted to `f16`/`f32`/`f64`. pub element_ptype: PType, - /// Original vector dimension before SORF padding to the next power of two. + + /// Original vector dimension before block decomposition. pub dimensions: u32, + /// Bits per coordinate in the scalar quantizer codebook (`1..=8`). pub bit_width: u8, - /// Seed used to derive the deterministic SORF transform. + + /// Global seed used to derive each block's deterministic SORF transform. pub seed: u64, - /// Number of sign-diagonal plus Walsh-Hadamard rounds in the SORF transform. + + /// Number of sign-diagonal plus Walsh-Hadamard rounds in each block's SORF transform. pub num_rounds: u8, + + /// Powers-of-two block sizes the encoder used. + /// + /// Note that this is always non-empty. Additionally, `sum(block_sizes) >= dimensions` and each + /// entry is at least `MIN_BLOCK_SIZE`. + pub block_sizes: Vec, } impl ExtVTable for TurboQuant { @@ -61,9 +72,10 @@ impl ExtVTable for TurboQuant { let proto = TurboQuantMetadataProto { element_ptype: metadata.element_ptype as i32, dimensions: metadata.dimensions, - bit_width: u32::from(metadata.bit_width), + bit_width: metadata.bit_width as u32, seed: metadata.seed, - num_rounds: u32::from(metadata.num_rounds), + num_rounds: metadata.num_rounds as u32, + block_sizes: metadata.block_sizes.clone(), }; Ok(proto.encode_to_vec()) @@ -89,6 +101,11 @@ impl ExtVTable for TurboQuant { bit_width, seed: proto.seed, num_rounds, + // Block decomposition intentionally breaks the pre-block on-disk format: arrays + // written before this field existed decode to an empty `block_sizes`, which + // `validate_tq_metadata` rejects below with a clean error (not a panic). There is no + // backward-compatibility shim because the TurboQuant on-disk format is not yet stable. + block_sizes: proto.block_sizes, }; validate_tq_metadata(&metadata)?; @@ -109,7 +126,7 @@ impl ExtVTable for TurboQuant { } /// Wire-format representation of [`TurboQuantMetadata`]. Field tags MUST NOT change once -/// shipped; new fields must use unused tags and remain optional. +/// shipped; new fields must use unused tags. #[derive(Clone, PartialEq, Message)] struct TurboQuantMetadataProto { #[prost(enumeration = "PType", tag = "1")] @@ -122,6 +139,8 @@ struct TurboQuantMetadataProto { seed: u64, #[prost(uint32, tag = "5")] num_rounds: u32, + #[prost(uint32, repeated, tag = "6")] + block_sizes: Vec, } /// Extract TurboQuant metadata from a dtype. @@ -136,30 +155,54 @@ pub(crate) fn tq_metadata(dtype: &DType) -> VortexResult { .metadata_opt::() .ok_or_else(|| vortex_err!("expected a TurboQuant extension array, got {dtype}"))?; - Ok(*metadata) + Ok(metadata.clone()) } +/// Construct the storage dtype for a given metadata and row nullability. +/// +/// Produces an outer struct of `metadata.block_sizes.len()` inner `Struct { norms, codes }` fields, +/// each parameterized by its own block size. pub(crate) fn tq_storage_dtype( metadata: &TurboQuantMetadata, row_nullability: Nullability, ) -> VortexResult { - let padded_dim = u32::try_from(tq_padded_dim(metadata.dimensions)?) - .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; + let mut names = Vec::with_capacity(metadata.block_sizes.len()); + let mut fields = Vec::with_capacity(metadata.block_sizes.len()); + + for (index, &block_size) in metadata.block_sizes.iter().enumerate() { + names.push(block_field_name(index)); + fields.push(inner_block_dtype( + metadata.element_ptype, + block_size, + row_nullability, + )); + } Ok(DType::Struct( + StructFields::new(FieldNames::from_iter(names), fields), + row_nullability, + )) +} + +/// The struct type for each block. +/// +/// Note that we propagate the nullability through both the fields and the outer struct itself for +/// simplicity. +fn inner_block_dtype(element_ptype: PType, block_size: u32, row_nullability: Nullability) -> DType { + DType::Struct( StructFields::new( FieldNames::from([NORMS_FIELD, CODES_FIELD]), vec![ - DType::Primitive(metadata.element_ptype, row_nullability), + DType::Primitive(element_ptype, row_nullability), DType::FixedSizeList( Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), - padded_dim, + block_size, row_nullability, ), ], ), row_nullability, - )) + ) } /// Validate [`TurboQuantMetadata`] invariants. Called on both serialize and deserialize so a @@ -176,55 +219,98 @@ fn validate_tq_metadata(metadata: &TurboQuantMetadata) -> VortexResult<()> { metadata.element_ptype ); - tq_padded_dim(metadata.dimensions)?; + validate_block_shape(&metadata.block_sizes)?; + validate_block_sum(&metadata.block_sizes, metadata.dimensions)?; - TurboQuantConfig::try_new(metadata.bit_width, metadata.seed, metadata.num_rounds).map(|_| ()) + TurboQuantConfig::try_new( + metadata.bit_width, + metadata.seed, + metadata.num_rounds, + Some(metadata.block_sizes.clone()), + ) + .map(|_| ()) } /// Validate that `dtype` matches the storage shape produced by [`tq_storage_dtype`] for /// `metadata`. Called from [`TurboQuant::validate_dtype`]. fn validate_tq_storage_dtype(metadata: &TurboQuantMetadata, dtype: &DType) -> VortexResult<()> { - let DType::Struct(fields, _) = dtype else { + let DType::Struct(outer_fields, _) = dtype else { vortex_bail!("TurboQuant storage dtype must be a Struct, got {dtype}"); }; - let expected_names = FieldNames::from([NORMS_FIELD, CODES_FIELD]); + + let expected_names: Vec<_> = (0..metadata.block_sizes.len()) + .map(block_field_name) + .collect(); + vortex_ensure_eq!( + outer_fields.names(), + &FieldNames::from_iter(expected_names.iter().cloned()), + "TurboQuant storage outer fields must be {:?}, got {}", + expected_names, + outer_fields.names() + ); + + for (index, &block) in metadata.block_sizes.iter().enumerate() { + let name = block_field_name(index); + let Some(inner) = outer_fields.field(&name) else { + vortex_bail!("TurboQuant storage missing inner field {name}"); + }; + validate_inner_block_dtype(metadata.element_ptype, block, &name, &inner)?; + } + + Ok(()) +} + +fn validate_inner_block_dtype( + element_ptype: PType, + block: u32, + name: &str, + dtype: &DType, +) -> VortexResult<()> { + let DType::Struct(fields, _) = dtype else { + vortex_bail!("TurboQuant inner block {name} must be a Struct, got {dtype}"); + }; + let expected = FieldNames::from([NORMS_FIELD, CODES_FIELD]); vortex_ensure_eq!( fields.names(), - &expected_names, - "TurboQuant storage fields must be {expected_names}, got {}", + &expected, + "TurboQuant inner block {name} fields must be {expected}, got {}", fields.names() ); let Some(norms_dtype) = fields.field(NORMS_FIELD) else { - vortex_bail!("TurboQuant storage missing {NORMS_FIELD} field"); + vortex_bail!("TurboQuant inner block {name} missing {NORMS_FIELD}"); }; let DType::Primitive(norms_ptype, _) = norms_dtype else { - vortex_bail!("TurboQuant {NORMS_FIELD} field must be primitive, got {norms_dtype}"); + vortex_bail!( + "TurboQuant inner block {name} {NORMS_FIELD} must be primitive, got {norms_dtype}" + ); }; vortex_ensure_eq!( norms_ptype, - metadata.element_ptype, - "TurboQuant {NORMS_FIELD} ptype must be {}, got {norms_ptype}", - metadata.element_ptype + element_ptype, + "TurboQuant inner block {name} {NORMS_FIELD} ptype must be {element_ptype}, got \ + {norms_ptype}" ); let Some(codes_dtype) = fields.field(CODES_FIELD) else { - vortex_bail!("TurboQuant storage missing {CODES_FIELD} field"); + vortex_bail!("TurboQuant inner block {name} missing {CODES_FIELD}"); }; let DType::FixedSizeList(element_dtype, list_size, _) = codes_dtype else { - vortex_bail!("TurboQuant {CODES_FIELD} field must be fixed-size-list, got {codes_dtype}"); + vortex_bail!( + "TurboQuant inner block {name} {CODES_FIELD} must be fixed-size-list, got \ + {codes_dtype}" + ); }; - let padded_dim = u32::try_from(tq_padded_dim(metadata.dimensions)?) - .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; vortex_ensure_eq!( list_size, - padded_dim, - "TurboQuant {CODES_FIELD} list size must be {padded_dim}, got {list_size}" + block, + "TurboQuant inner block {name} {CODES_FIELD} list size must be {block}, got {list_size}" ); vortex_ensure_eq!( element_dtype.as_ref(), &DType::Primitive(PType::U8, Nullability::NonNullable), - "TurboQuant {CODES_FIELD} elements must be non-nullable u8, got {element_dtype}" + "TurboQuant inner block {name} {CODES_FIELD} elements must be non-nullable u8, got \ + {element_dtype}" ); Ok(()) @@ -234,8 +320,16 @@ impl fmt::Display for TurboQuantMetadata { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "element_ptype: {}, dimensions: {}, bit_width: {}, seed: {}, num_rounds: {}", - self.element_ptype, self.dimensions, self.bit_width, self.seed, self.num_rounds - ) + "element_ptype: {}, dimensions: {}, bit_width: {}, seed: {}, num_rounds: {}, \ + block_sizes: [", + self.element_ptype, self.dimensions, self.bit_width, self.seed, self.num_rounds, + )?; + for (index, block) in self.block_sizes.iter().enumerate() { + if index > 0 { + write!(f, ", ")?; + } + write!(f, "{block}")?; + } + write!(f, "]") } }