From e53742808e5428c7fe672dbc09468ac1d81e6964 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 27 May 2026 15:53:00 +0100 Subject: [PATCH 1/3] choose better initialization for centroids Signed-off-by: Connor Tsui --- .../src/encodings/turboquant/centroids.rs | 65 ++++++++++++++----- vortex-turboquant/src/centroids.rs | 31 ++++++--- 2 files changed, 70 insertions(+), 26 deletions(-) diff --git a/vortex-tensor/src/encodings/turboquant/centroids.rs b/vortex-tensor/src/encodings/turboquant/centroids.rs index 1af86c79d85..ab653de3c35 100644 --- a/vortex-tensor/src/encodings/turboquant/centroids.rs +++ b/vortex-tensor/src/encodings/turboquant/centroids.rs @@ -3,11 +3,23 @@ //! Max-Lloyd centroid computation for TurboQuant scalar quantizers. //! -//! Pre-computes optimal scalar quantizer centroids for the marginal distribution of coordinates -//! after random rotation of a unit-norm vector. In high dimensions, each coordinate of a randomly -//! rotated unit vector follows a distribution proportional to `(1 - x^2)^((d-3)/2)` on `[-1, 1]`, -//! which converges to `N(0, 1/d)`. The Max-Lloyd algorithm finds optimal quantization centroids -//! that minimize MSE for this distribution. +//! Pre-computes and caches optimal scalar quantizer centroids for the marginal distribution of +//! coordinates after a random orthogonal transform of a unit-norm vector. +//! +//! In high dimensions, each coordinate of a randomly transformed unit vector follows a +//! distribution proportional to `(1 - x^2)^((d-3)/2)` on `[-1, 1]`, which converges to +//! `N(0, 1/d)`. +//! +//! The Max-Lloyd algorithm finds optimal quantization centroids that minimize MSE for this +//! distribution. +//! +//! Centroids are not stored in TurboQuant arrays. They are deterministically derived from +//! `(padded_dim, bit_width)` and cached process-locally. +//! +//! The centroid model follows the random orthogonal transform marginal used by the TurboQuant +//! paper. This encoder applies a SORF-style structured transform instead of a dense random Gaussian +//! or orthogonal matrix, so paper-level error bounds should not be treated as verified for this +//! implementation without separate empirical validation. use std::sync::LazyLock; @@ -19,14 +31,18 @@ use vortex_utils::aliases::dash_map::DashMap; use crate::encodings::turboquant::MAX_BIT_WIDTH; use crate::encodings::turboquant::MIN_DIMENSION; +// NB: All of these constants were chosen arbitrarily. + /// The maximum iterations for Max-Lloyd algorithm when computing centroids. const MAX_ITERATIONS: usize = 200; /// The Max-Lloyd convergence threshold for stopping early when computing centroids. const CONVERGENCE_EPSILON: f64 = 1e-12; -/// Number of numerical integration points for computing conditional expectations. -const INTEGRATION_POINTS: usize = 1000; +/// Number of trapezoids used for numerical integration when computing conditional expectations. +/// +/// The trapezoidal rule evaluates the integrand at `INTEGRATION_TRAPEZOIDS + 1` points. +const INTEGRATION_TRAPEZOIDS: usize = 1000; /// Global centroid cache keyed by (dimension, bit_width). static CENTROID_CACHE: LazyLock>> = LazyLock::new(DashMap::default); @@ -34,9 +50,9 @@ static CENTROID_CACHE: LazyLock>> = LazyLock::new /// Get or compute cached centroids for the given dimension and bit width. /// /// Returns `2^bit_width` centroids sorted in ascending order, representing optimal scalar -/// quantization levels for the coordinate distribution after random rotation in +/// quantization levels for the coordinate distribution after a random orthogonal transform in /// `dimension`-dimensional space. -pub fn compute_or_get_centroids(dimension: u32, bit_width: u8) -> VortexResult> { +pub(crate) fn compute_or_get_centroids(dimension: u32, bit_width: u8) -> VortexResult> { vortex_ensure!( (1..=MAX_BIT_WIDTH).contains(&bit_width), "TurboQuant bit_width must be 1-{}, got {bit_width}", @@ -86,12 +102,20 @@ impl HalfIntExponent { /// Compute optimal centroids via the Max-Lloyd (Lloyd-Max) algorithm. /// -/// Operates on the marginal distribution of a single coordinate of a randomly rotated unit vector -/// in d dimensions. +/// Operates on the marginal distribution of a single coordinate of a randomly transformed unit +/// vector in d dimensions. /// /// The probability distribution function is: /// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]` /// where `C_d` is the normalizing constant. +/// +/// Centroids are seeded uniformly on `[±sqrt(bit_width) * sigma]` (where `sigma` is the standard +/// deviation of the normal distribution that hypershere dimension values take, and specifically +/// `sigma = 1/sqrt(dimension)`) rather than across the full `[-1, 1]`, which strands most of the +/// centroids in the near-zero-mass tails. +/// +/// Note that the `sqrt(bit_width)` is mostly empirically derived, we do not have a theoretical +/// basis for choosing this other than the fact that it seems to produce good results. fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Buffer { debug_assert!((1..=MAX_BIT_WIDTH).contains(&bit_width)); let num_centroids = 1usize << bit_width; @@ -99,9 +123,14 @@ fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Buffer { // For the marginal distribution on [-1, 1], we use the exponent (d-3)/2. let exponent = HalfIntExponent::from_numerator(dimension as i32 - 3); - // Initialize centroids uniformly on [-1, 1]. + // The coordinate marginal concentrates around 0 with this standard deviation. + let sigma = 1.0 / f64::from(dimension).sqrt(); + let init_half = (f64::from(bit_width).sqrt() * sigma).min(1.0); + + // Initialize centroids uniformly on [-init_half, init_half], where the mass lives, so no cell + // starts in a zero-mass region and freezes. let mut centroids: Vec = (0..num_centroids) - .map(|idx| -1.0 + (2.0 * (idx as f64) + 1.0) / (num_centroids as f64)) + .map(|idx| -init_half + (2.0 * (idx as f64) + 1.0) * init_half / (num_centroids as f64)) .collect(); let mut boundaries: Vec = vec![0.0; num_centroids + 1]; @@ -145,16 +174,16 @@ fn mean_between_centroids(lo: f64, hi: f64, exponent: HalfIntExponent) -> f64 { return (lo + hi) / 2.0; } - let dx = (hi - lo) / INTEGRATION_POINTS as f64; + let dx = (hi - lo) / INTEGRATION_TRAPEZOIDS as f64; let mut numerator = 0.0; let mut denominator = 0.0; - for step in 0..=INTEGRATION_POINTS { + for step in 0..=INTEGRATION_TRAPEZOIDS { let x_val = lo + (step as f64) * dx; let weight = pdf_unnormalized(x_val, exponent); - let trap_weight = if step == 0 || step == INTEGRATION_POINTS { + let trap_weight = if step == 0 || step == INTEGRATION_TRAPEZOIDS { 0.5 } else { 1.0 @@ -193,7 +222,7 @@ fn pdf_unnormalized(x_val: f64, exponent: HalfIntExponent) -> f64 { /// For `k` centroids, returns `k-1` boundaries. A value below `boundaries[0]` maps to centroid 0, a /// value in `[boundaries[i-1], boundaries[i])` maps to centroid `i`, and a /// value `>= boundaries[k-2]` maps to centroid `k-1`. -pub fn compute_centroid_boundaries(centroids: &[f32]) -> Vec { +pub(crate) fn compute_centroid_boundaries(centroids: &[f32]) -> Vec { centroids.windows(2).map(|w| (w[0] + w[1]) * 0.5).collect() } @@ -203,7 +232,7 @@ pub fn compute_centroid_boundaries(centroids: &[f32]) -> Vec { /// centroids. Uses binary search on the midpoints, avoiding distance comparisons /// in the inner loop. #[inline] -pub fn find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8 { +pub(crate) fn find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8 { debug_assert!( boundaries.windows(2).all(|w| w[0] <= w[1]), "boundaries must be sorted" diff --git a/vortex-turboquant/src/centroids.rs b/vortex-turboquant/src/centroids.rs index 8499dfc397a..2be60c0ed4e 100644 --- a/vortex-turboquant/src/centroids.rs +++ b/vortex-turboquant/src/centroids.rs @@ -31,7 +31,7 @@ use vortex_utils::aliases::dash_map::DashMap; use crate::config::MAX_BIT_WIDTH; use crate::config::MIN_DIMENSION; -// NB: Some of these numbers were arbitrarily chosen... +// NB: All of these constants were chosen arbitrarily. /// The maximum iterations for Max-Lloyd algorithm when computing centroids. const MAX_ITERATIONS: usize = 200; @@ -39,8 +39,10 @@ const MAX_ITERATIONS: usize = 200; /// The Max-Lloyd convergence threshold for stopping early when computing centroids. const CONVERGENCE_EPSILON: f64 = 1e-12; -/// Number of numerical integration points for computing conditional expectations. -const INTEGRATION_POINTS: usize = 1000; +/// Number of trapezoids used for numerical integration when computing conditional expectations. +/// +/// The trapezoidal rule evaluates the integrand at `INTEGRATION_TRAPEZOIDS + 1` points. +const INTEGRATION_TRAPEZOIDS: usize = 1000; /// Global centroid cache keyed by (dimension, bit_width). static CENTROID_CACHE: LazyLock>> = LazyLock::new(DashMap::default); @@ -106,6 +108,14 @@ impl HalfIntExponent { /// The probability distribution function is: /// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]` /// where `C_d` is the normalizing constant. +/// +/// Centroids are seeded uniformly on `[±sqrt(bit_width) * sigma]` (where `sigma` is the standard +/// deviation of the normal distribution that hypershere dimension values take, and specifically +/// `sigma = 1/sqrt(dimension)`) rather than across the full `[-1, 1]`, which strands most of the +/// centroids in the near-zero-mass tails. +/// +/// Note that the `sqrt(bit_width)` is mostly empirically derived, we do not have a theoretical +/// basis for choosing this other than the fact that it seems to produce good results. fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Buffer { debug_assert!((1..=MAX_BIT_WIDTH).contains(&bit_width)); let num_centroids = 1usize << bit_width; @@ -113,9 +123,14 @@ fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Buffer { // For the marginal distribution on [-1, 1], we use the exponent (d-3)/2. let exponent = HalfIntExponent::from_numerator(dimension as i32 - 3); - // Initialize centroids uniformly on [-1, 1]. + // The coordinate marginal concentrates around 0 with this standard deviation. + let sigma = 1.0 / f64::from(dimension).sqrt(); + let init_half = (f64::from(bit_width).sqrt() * sigma).min(1.0); + + // Initialize centroids uniformly on [-init_half, init_half], where the mass lives, so no cell + // starts in a zero-mass region and freezes. let mut centroids: Vec = (0..num_centroids) - .map(|idx| -1.0 + (2.0 * (idx as f64) + 1.0) / (num_centroids as f64)) + .map(|idx| -init_half + (2.0 * (idx as f64) + 1.0) * init_half / (num_centroids as f64)) .collect(); let mut boundaries: Vec = vec![0.0; num_centroids + 1]; @@ -159,16 +174,16 @@ fn mean_between_centroids(lo: f64, hi: f64, exponent: HalfIntExponent) -> f64 { return (lo + hi) / 2.0; } - let dx = (hi - lo) / INTEGRATION_POINTS as f64; + let dx = (hi - lo) / INTEGRATION_TRAPEZOIDS as f64; let mut numerator = 0.0; let mut denominator = 0.0; - for step in 0..=INTEGRATION_POINTS { + for step in 0..=INTEGRATION_TRAPEZOIDS { let x_val = lo + (step as f64) * dx; let weight = pdf_unnormalized(x_val, exponent); - let trap_weight = if step == 0 || step == INTEGRATION_POINTS { + let trap_weight = if step == 0 || step == INTEGRATION_TRAPEZOIDS { 0.5 } else { 1.0 From 0a22adcd27c339392e4caa5880fb9fa676955026 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 22 May 2026 13:31:13 -0400 Subject: [PATCH 2/3] add distortion benchmark for turboquant Signed-off-by: Connor Tsui --- .../scripts/plot-turboquant-distortion.py | 541 ++++++++++++++++++ .../vector-search-bench/src/distortion.rs | 330 +++++++++++ benchmarks/vector-search-bench/src/lib.rs | 39 ++ benchmarks/vector-search-bench/src/main.rs | 136 +++-- 4 files changed, 1008 insertions(+), 38 deletions(-) create mode 100644 benchmarks/vector-search-bench/scripts/plot-turboquant-distortion.py create mode 100644 benchmarks/vector-search-bench/src/distortion.rs diff --git a/benchmarks/vector-search-bench/scripts/plot-turboquant-distortion.py b/benchmarks/vector-search-bench/scripts/plot-turboquant-distortion.py new file mode 100644 index 00000000000..4d7fa20dc80 --- /dev/null +++ b/benchmarks/vector-search-bench/scripts/plot-turboquant-distortion.py @@ -0,0 +1,541 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "matplotlib", +# ] +# /// + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright the Vortex contributors + +"""Sweep bits-vs-distortion for TurboQuant and plot the curves. + +Calls `vector-search-bench distortion` for each (dataset, bits) combination, parses the +table from stdout, and plots reconstruction NMSE and pairwise cosine-error curves with +mean/median/max shown on a log-scaled y-axis. + +Each `--dataset` value may optionally pin a train layout with a colon, e.g. +`--dataset cohere-small-100k:single`, for datasets that host more than one layout. + +Usage: + uv run benchmarks/vector-search-bench/scripts/plot-turboquant-distortion.py \\ + --dataset sift-small-500k + uv run benchmarks/vector-search-bench/scripts/plot-turboquant-distortion.py \\ + --dataset sift-small-500k --dataset glove-small-100k --samples 8192 + uv run benchmarks/vector-search-bench/scripts/plot-turboquant-distortion.py \\ + --dataset cohere-small-100k:single --bits 1 2 3 4 5 6 7 8 \\ + --output /tmp/distortion.png +""" + +import argparse +import math +import re +import subprocess +import sys +from dataclasses import dataclass +from pathlib import Path + +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +from matplotlib.ticker import MaxNLocator + +REPO_ROOT = Path(__file__).resolve().parents[3] +DEFAULT_BINARY = REPO_ROOT / "target" / "release" / "vector-search-bench" + +METRIC_NAMES = [ + "reconstruction NMSE mean", + "reconstruction NMSE median", + "reconstruction NMSE max", + "decoded cosine err mean", + "decoded cosine err median", + "decoded cosine err max", +] + + +@dataclass(frozen=True) +class DatasetTarget: + """One dataset to sweep, with the layout the bench should use for it.""" + + name: str + layout: str | None # `None` means let the bench auto-pick. + + +@dataclass +class Run: + target: DatasetTarget + dim: int + bits: int + values: dict[str, float] + + @property + def dataset(self) -> str: + return self.target.name + + +DIM_RE = re.compile(r"dim=(\d+)") + + +def parse_dataset_arg(spec: str, default_layout: str | None) -> DatasetTarget: + """Split a `name[:layout]` CLI value. `default_layout` fills in when no `:` is given.""" + if ":" in spec: + name, layout = spec.split(":", 1) + return DatasetTarget(name=name, layout=layout or None) + return DatasetTarget(name=spec, layout=default_layout) + + +def parse_dim(stdout: str) -> int: + """Pull `dim=N` out of the `## ...` header line.""" + match = DIM_RE.search(stdout) + if not match: + raise RuntimeError(f"could not find dim=N in header:\n{stdout}") + return int(match.group(1)) + + +def parse_table(stdout: str) -> dict[str, float]: + """Pull `metric -> value` rows out of the tabled stdout.""" + row_re = re.compile(r"│\s*(.+?)\s*│\s*([^│]+?)\s*│") + values: dict[str, float] = {} + for line in stdout.splitlines(): + match = row_re.match(line) + if not match: + continue + metric, value = match.group(1).strip(), match.group(2).strip() + if metric in METRIC_NAMES: + values[metric] = float(value) + missing = [m for m in METRIC_NAMES if m not in values] + if missing: + raise RuntimeError(f"could not parse metrics {missing} from:\n{stdout}") + return values + + +def run_one( + binary: Path, + target: DatasetTarget, + bits: int, + samples: int, + seed: int, + rounds: int, +) -> Run: + cmd = [ + str(binary), + "distortion", + "--dataset", + target.name, + "--bits", + str(bits), + "--samples", + str(samples), + "--seed", + str(seed), + "--rounds", + str(rounds), + ] + if target.layout: + cmd.extend(["--layout", target.layout]) + layout_tag = f" layout={target.layout}" if target.layout else "" + print(f" running {target.name}{layout_tag} @ bits={bits} ...", file=sys.stderr) + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + return Run( + target=target, + dim=parse_dim(result.stdout), + bits=bits, + values=parse_table(result.stdout), + ) + + +def nmse_bound_stage1(bits: int) -> float: + """Paper's NMSE upper bound for TurboQuant_mse (Stage 1). + + From the Stage 1 theorem (`main.tex`, line 272): for a unit-norm vector `x` quantized + to `b` bits per coordinate, `E[||x - x'||^2] <= (sqrt(3)*pi/2) / 4^b`. Because `x` is + unit-norm, `||x - x'||^2` equals the normalized squared error `||x - x'||^2 / ||x||^2`, + so the bound applies to the `reconstruction NMSE mean` curve directly. + """ + return (math.sqrt(3.0) * math.pi / 2.0) / (4.0**bits) + + +def compression_ratio(bits: int, dim: int) -> float: + """Theoretical TurboQuant compression ratio vs f32 storage. + + Per the `vortex_tensor::encodings::turboquant` module docs, each vector is stored + as `padded_dim * bits / 8` bytes of quantized codes plus one f32 stored norm + (4 bytes), where `padded_dim` is the next power of two at least `dim`. The ratio is + nonlinear in `bits` because of POT padding and the per-vector norm overhead. + """ + padded_dim = 1 << (dim - 1).bit_length() if dim > 1 else 1 + per_vector_bytes = padded_dim * bits / 8.0 + 4.0 + original_bytes = dim * 4.0 + return original_bytes / per_vector_bytes + + +def cosine_bound(bits: int, dim: int) -> float: + """Paper's Stage-2 inner-product bound, rendered as an absolute-error envelope. + + From the Stage 2 theorem (`main.tex`, line 288): for unit y and an `x` quantized via + TurboQuant_prod (Stage 2, MSE + QJL residual), `E[| - |^2] <= + sqrt(3)*pi^2/d * 4^(-b)`. Taking sqrt gives an upper envelope on the RMS error per + bit width, and by Jensen also on the mean abs error. + + Caveat: Vortex currently implements only Stage 1 (no QJL residual correction). The + Stage 1 inner-product error is biased and can sit *above* this Stage-2 envelope. + """ + return math.pi * (3.0**0.25) / math.sqrt(dim) / (2.0**bits) + + +DATASET_PALETTE = [ + "#1f77b4", # blue + "#d62728", # red + "#2ca02c", # green + "#9467bd", # purple + "#ff7f0e", # orange + "#17becf", # teal + "#e377c2", # pink + "#8c564b", # brown + "#7f7f7f", # grey + "#bcbd22", # olive +] + +STAT_STYLES = [ + # (metric_suffix, label, linestyle, linewidth, marker) + ("mean", "mean", "-", 2.4, "o"), + ("max", "max", ":", 1.4, None), +] + + +def plot(runs: list[Run], args: argparse.Namespace) -> None: + by_dataset: dict[str, list[Run]] = {} + for r in runs: + by_dataset.setdefault(r.dataset, []).append(r) + for ds_runs in by_dataset.values(): + ds_runs.sort(key=lambda r: r.bits) + + plt.rcParams.update( + { + "font.size": 11, + "axes.titlesize": 13, + "axes.titleweight": "semibold", + "axes.labelsize": 11, + "axes.spines.top": False, + "axes.spines.right": False, + "axes.grid": True, + "grid.alpha": 0.25, + "grid.linewidth": 0.6, + "legend.frameon": False, + } + ) + + fig, axes = plt.subplots(1, 3, figsize=(20, 6.5), constrained_layout=True) + fig.suptitle( + f"TurboQuant distortion vs bits per coordinate" + f" (samples={args.samples:,}, seed={args.seed}, rounds={args.rounds})", + fontsize=14, + fontweight="semibold", + ) + + dataset_colors = {ds: DATASET_PALETTE[i % len(DATASET_PALETTE)] for i, ds in enumerate(by_dataset)} + dataset_dims = {ds: ds_runs[0].dim for ds, ds_runs in by_dataset.items()} + + plot_panel( + axes[0], + by_dataset, + dataset_colors, + metric_prefix="reconstruction NMSE", + title="Reconstruction NMSE (per vector, normalized squared error)", + ylabel=r"$\|x - x^\prime\|^2 / \|x\|^2$", + ) + bits_axis = sorted({r.bits for r in runs}) + axes[0].plot( + bits_axis, + [nmse_bound_stage1(b) for b in bits_axis], + color="#222222", + linestyle=(0, (4, 2, 1, 2)), + linewidth=1.6, + zorder=0, + ) + + plot_panel( + axes[1], + by_dataset, + dataset_colors, + metric_prefix="decoded cosine err", + title=r"Pairwise cosine error $|\cos(x_i, x_j) - \cos(x_i^\prime, x_j^\prime)|$", + ylabel="absolute error", + ) + for dataset, ds_runs in by_dataset.items(): + color = dataset_colors[dataset] + d = ds_runs[0].dim + bits = sorted({r.bits for r in ds_runs}) + axes[1].plot( + bits, + [cosine_bound(b, d) for b in bits], + color=color, + linestyle=(0, (4, 2, 1, 2)), + linewidth=1.2, + alpha=0.6, + zorder=0, + ) + + plot_compression_panel(axes[2], by_dataset, dataset_colors) + + add_legends(fig, axes, dataset_colors, dataset_dims) + fig.text( + 0.5, + -0.015, + "Cosine bound is the paper's Stage-2 (TurboQuant_prod, MSE + QJL residual) " + "envelope; Vortex currently ships Stage 1 only, so empirical curves may sit " + "above it. Compression ratio is theoretical " + "(padded_dim * bits / 8 + 4 bytes per vector), excludes per-shard centroid " + "tables and file metadata.", + ha="center", + fontsize=9, + color="#555555", + wrap=True, + ) + + if args.output: + fig.savefig(args.output, dpi=140, bbox_inches="tight") + print(f"saved {args.output}", file=sys.stderr) + else: + plt.show() + + +def plot_panel( + ax, + by_dataset: dict[str, list[Run]], + dataset_colors: dict[str, str], + metric_prefix: str, + title: str, + ylabel: str, +) -> None: + for dataset, ds_runs in by_dataset.items(): + color = dataset_colors[dataset] + bits = [r.bits for r in ds_runs] + for stat_key, _label, linestyle, linewidth, marker in STAT_STYLES: + metric = f"{metric_prefix} {stat_key}" + ys = [r.values[metric] for r in ds_runs] + ax.plot( + bits, + ys, + color=color, + linestyle=linestyle, + linewidth=linewidth, + marker=marker, + markersize=6, + markerfacecolor=color, + markeredgecolor="white", + markeredgewidth=0.8, + alpha=0.95 if marker else 0.75, + ) + ax.set_yscale("log") + ax.set_xlabel("bits per coordinate") + ax.set_ylabel(ylabel) + ax.set_title(title) + ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + ax.grid(True, which="major", linewidth=0.7, alpha=0.45) + ax.grid(True, which="minor", linewidth=0.4, alpha=0.22) + ax.minorticks_on() + + +def plot_compression_panel( + ax, + by_dataset: dict[str, list[Run]], + dataset_colors: dict[str, str], +) -> None: + bits_axis = sorted({r.bits for runs in by_dataset.values() for r in runs}) + for dataset, ds_runs in by_dataset.items(): + color = dataset_colors[dataset] + d = ds_runs[0].dim + padded = 1 << (d - 1).bit_length() if d > 1 else 1 + suffix = f" (padded {padded})" if padded != d else " (no padding)" + ax.plot( + bits_axis, + [compression_ratio(b, d) for b in bits_axis], + color=color, + linestyle="-", + linewidth=2.4, + marker="o", + markersize=6, + markerfacecolor=color, + markeredgecolor="white", + markeredgewidth=0.8, + label=f"{dataset}{suffix}", + ) + ax.set_xlabel("bits per coordinate") + ax.set_ylabel(r"ratio vs f32 (= $4d \,/\, (\mathrm{padded}\!\cdot\! b/8 + 4)$)") + ax.set_title("Compression ratio (theoretical)") + ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + ax.grid(True, which="major", linewidth=0.7, alpha=0.45) + ax.grid(True, which="minor", linewidth=0.4, alpha=0.22) + ax.minorticks_on() + ax.legend( + title="dataset", + loc="upper right", + fontsize=9, + title_fontsize=10, + ) + + +def add_legends( + fig, + axes, + dataset_colors: dict[str, str], + dataset_dims: dict[str, int], +) -> None: + dataset_handles = [ + Line2D( + [], + [], + color=color, + linewidth=2.4, + marker="o", + markersize=6, + markerfacecolor=color, + markeredgecolor="white", + markeredgewidth=0.8, + label=f"{dataset} (d = {dataset_dims[dataset]})", + ) + for dataset, color in dataset_colors.items() + ] + stat_handles = [ + Line2D( + [], + [], + color="#333333", + linestyle=linestyle, + linewidth=linewidth, + marker=marker, + markersize=6 if marker else 0, + markerfacecolor="#333333", + markeredgecolor="white", + markeredgewidth=0.8, + label=label, + ) + for _, label, linestyle, linewidth, marker in STAT_STYLES + ] + nmse_bound_handle_s1 = Line2D( + [], + [], + color="#222222", + linestyle=(0, (4, 2, 1, 2)), + linewidth=1.6, + label=r"paper bound: $D_{\mathrm{mse}} \leq \frac{\sqrt{3}\,\pi}{2}\, 4^{-b}$", + ) + cosine_bound_handle = Line2D( + [], + [], + color="#444444", + linestyle=(0, (4, 2, 1, 2)), + linewidth=1.2, + alpha=0.6, + label=( + r"paper Stage-2 bound: " + r"$\sqrt{D_{\mathrm{prod}}} \leq \frac{\pi\,3^{1/4}}{\sqrt{d}}\, 2^{-b}$" + ), + ) + + axes[0].legend( + handles=dataset_handles + [nmse_bound_handle_s1], + title="dataset / bound", + loc="upper right", + fontsize=10, + title_fontsize=10, + ) + axes[1].legend( + handles=stat_handles + [cosine_bound_handle], + title="statistic / bound", + loc="upper right", + fontsize=10, + title_fontsize=10, + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument( + "--dataset", + action="append", + required=True, + help=( + "Dataset to sweep (repeat to compare multiple). Optionally suffix " + "`:layout` to pin a specific train layout for that dataset, e.g. " + "`--dataset cohere-small-100k:single`. If omitted, the bench picks " + "the dataset's only layout, or errors if there are several." + ), + ) + parser.add_argument( + "--layout", + default=None, + help=("Default train layout applied to any `--dataset` entry that doesn't pin its own with `:layout`."), + ) + parser.add_argument("--samples", type=int, default=65536) + parser.add_argument( + "--bits", + type=int, + nargs="+", + default=[1, 2, 3, 4, 5, 6, 7, 8], + help="Bit widths to sweep (default: 1..=8).", + ) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--rounds", type=int, default=3) + parser.add_argument( + "--binary", + type=Path, + default=DEFAULT_BINARY, + help=f"Path to vector-search-bench (default: {DEFAULT_BINARY}).", + ) + parser.add_argument( + "--output", + type=Path, + default=None, + help="If set, save the chart to this path instead of opening a window.", + ) + args = parser.parse_args() + + print("building vector-search-bench (release) ...", file=sys.stderr) + subprocess.run( + ["cargo", "build", "-p", "vector-search-bench", "--release"], + cwd=REPO_ROOT, + check=True, + ) + + if not args.binary.exists(): + sys.exit(f"binary not found at {args.binary} after build") + + targets = [parse_dataset_arg(spec, args.layout) for spec in args.dataset] + + runs: list[Run] = [] + for target in targets: + layout_tag = f" (layout={target.layout})" if target.layout else "" + print( + f"sweeping {target.name}{layout_tag} over bits {args.bits} ...", + file=sys.stderr, + ) + for bits in args.bits: + runs.append( + run_one( + args.binary, + target, + bits, + args.samples, + args.seed, + args.rounds, + ) + ) + + print_summary(runs) + plot(runs, args) + + +def print_summary(runs: list[Run]) -> None: + print() + print("Summary (one row per (dataset, bits)):") + header = ["dataset", "dim", "bits"] + METRIC_NAMES + widths = [max(len(h), 14) for h in header] + print(" " + " ".join(h.ljust(w) for h, w in zip(header, widths))) + for r in runs: + cells = [r.dataset, str(r.dim), str(r.bits)] + [f"{r.values[m]:.3e}" for m in METRIC_NAMES] + print(" " + " ".join(c.ljust(w) for c, w in zip(cells, widths))) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/vector-search-bench/src/distortion.rs b/benchmarks/vector-search-bench/src/distortion.rs new file mode 100644 index 00000000000..664182bb2a4 --- /dev/null +++ b/benchmarks/vector-search-bench/src/distortion.rs @@ -0,0 +1,330 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant distortion measurement on real vector datasets. +//! +//! Reports per-vector normalized reconstruction error (`||x - x'||^2 / ||x||^2`) and pairwise +//! cosine-similarity error (`|cos(x_i, x_j) - cos(x'_i, x'_j)|`) after a full encode and decode +//! roundtrip through the [`vortex_tensor::encodings::turboquant`] scheme. This is the same +//! TurboQuant implementation the search subcommand stores on disk via +//! [`BtrBlocksCompressorBuilder::with_turboquant`](vortex_btrblocks::BtrBlocksCompressorBuilder). + +use std::io::Write; + +use anyhow::Context; +use anyhow::Result; +use anyhow::bail; +use rand::SeedableRng; +use rand::rngs::StdRng; +use rand::seq::SliceRandom; +use tabled::settings::Style; +use vortex::array::ArrayRef; +use vortex::array::ExecutionCtx; +use vortex::array::IntoArray; +use vortex::array::VortexSessionExecute; +use vortex::array::arrays::ExtensionArray; +use vortex::array::arrays::FixedSizeListArray; +use vortex::array::arrays::PrimitiveArray; +use vortex::array::arrays::ScalarFnArray; +use vortex::array::arrays::Struct; +use vortex::array::arrays::StructArray; +use vortex::array::arrays::extension::ExtensionArrayExt; +use vortex::array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex::array::arrays::struct_::StructArrayExt; +use vortex::array::validity::Validity; +use vortex::buffer::Buffer; +use vortex_bench::conversions::parquet_to_vortex_chunks; +use vortex_bench::vector_dataset; +use vortex_bench::vector_dataset::TrainLayout; +use vortex_bench::vector_dataset::VectorDataset; +use vortex_tensor::encodings::turboquant::TurboQuantConfig; +use vortex_tensor::encodings::turboquant::turboquant_encode; +use vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity; + +use crate::SESSION; +use crate::ingest::transform_chunk; + +/// Inputs to a distortion run. +#[derive(Debug, Clone)] +pub struct DistortionConfig { + /// Dataset to load vectors from. + pub dataset: VectorDataset, + /// Train-split layout (used to locate the local parquet shards). + pub layout: TrainLayout, + /// Bits per quantized coordinate. + pub bits: u8, + /// Seed for the SORF rotation. + pub seed: u64, + /// Number of sign-diagonal plus Walsh-Hadamard rounds in the SORF transform. + pub rounds: u8, + /// Number of base vectors to sample from the first train shard. + pub samples: usize, +} + +/// Mean, median, and max of a sample of distortion measurements. +#[derive(Debug, Clone)] +pub struct DistortionStats { + /// Arithmetic mean. + pub mean: f32, + /// Median (mid element after a partial sort). + pub median: f32, + /// Maximum. + pub max: f32, +} + +/// Per-dataset distortion report ready to render as markdown. +#[derive(Debug, Clone)] +pub struct DistortionReport { + /// Dataset the vectors came from. + pub dataset: VectorDataset, + /// Train-split layout used to locate the shard. + pub layout: TrainLayout, + /// Vector dimensionality. + pub dim: u32, + /// Bits per quantized coordinate. + pub bits: u8, + /// Seed for the SORF rotation. + pub seed: u64, + /// Number of SORF rounds. + pub rounds: u8, + /// Number of base vectors sampled. + pub samples: usize, + /// Per-vector normalized squared L2 reconstruction error. + pub reconstruction: DistortionStats, + /// Pairwise cosine-similarity error after decoding both sides. + pub decoded_cosine: DistortionStats, +} + +/// Compute reconstruction error and cosine-similarity error for a TurboQuant roundtrip. +pub async fn run_distortion(config: &DistortionConfig) -> Result { + let dataset = config.dataset; + let layout = config.layout; + + let paths = vector_dataset::download(dataset, layout) + .await + .with_context(|| format!("download {}", dataset.name()))?; + let train_path = paths + .train_files + .first() + .with_context(|| format!("dataset {} has no train shards", dataset.name()))? + .clone(); + + let mut ctx = SESSION.create_execution_ctx(); + + let chunked = parquet_to_vortex_chunks(train_path).await?; + let struct_array: StructArray = chunked.into_array().execute(&mut ctx)?; + let transformed = transform_chunk(struct_array.into_array(), &mut ctx)?; + let emb_full = transformed + .as_opt::() + .with_context(|| { + format!( + "transform_chunk did not return a Struct, got {}", + transformed.dtype() + ) + })? + .unmasked_field_by_name("emb") + .context("transformed chunk missing `emb` field")? + .clone(); + + let n = config.samples.min(emb_full.len()); + if n < 2 { + bail!( + "distortion: need at least 2 sampled vectors for cosine pairs, got {n} (dataset {})", + dataset.name(), + ); + } + let emb = emb_full.slice(0..n)?; + + let original = extract_flat_f32(&emb, &mut ctx)?; + let dim = pairs_per_row(&original, n)?; + + let tq_config = TurboQuantConfig { + bit_width: config.bits, + seed: config.seed, + num_rounds: config.rounds, + }; + let encoded = turboquant_encode(emb.clone(), &tq_config, &mut ctx)?; + let decoded_ext: ExtensionArray = encoded.execute(&mut ctx)?; + let decoded = decoded_ext.into_array(); + let decoded_flat = extract_flat_f32(&decoded, &mut ctx)?; + + let reconstruction = stats(&reconstruction_errors(&original, &decoded_flat, dim, n)); + + let half = n / 2; + let mut shuffled: Vec = (0..n).collect(); + shuffled.shuffle(&mut StdRng::seed_from_u64(config.seed)); + let lhs_indices = indices_to_array(&shuffled[..half]); + let rhs_indices = indices_to_array(&shuffled[half..2 * half]); + + let true_cosines = compute_cosines( + emb.take(lhs_indices.clone())?, + emb.take(rhs_indices.clone())?, + &mut ctx, + )?; + let decoded_cosines = compute_cosines( + decoded.take(lhs_indices)?, + decoded.take(rhs_indices)?, + &mut ctx, + )?; + let decoded_cosine = stats(&abs_diff(&true_cosines, &decoded_cosines)); + + Ok(DistortionReport { + dataset, + layout, + dim: u32::try_from(dim).context("dim must fit in u32")?, + bits: config.bits, + seed: config.seed, + rounds: config.rounds, + samples: n, + reconstruction, + decoded_cosine, + }) +} + +/// Extract a flat `f32` slice from a `Vector` extension array. +fn extract_flat_f32(array: &ArrayRef, ctx: &mut ExecutionCtx) -> Result> { + let ext: ExtensionArray = array.clone().execute(ctx)?; + let fsl: FixedSizeListArray = ext.storage_array().clone().execute(ctx)?; + let elements: PrimitiveArray = fsl.elements().clone().execute(ctx)?; + Ok(elements.as_slice::().to_vec()) +} + +/// Compute one cosine per row over two equal-length tensor-like arrays. +fn compute_cosines(lhs: ArrayRef, rhs: ArrayRef, ctx: &mut ExecutionCtx) -> Result> { + let len = lhs.len(); + let sfn: ScalarFnArray = CosineSimilarity::try_new_array(lhs, rhs, len)?; + let prim: PrimitiveArray = sfn.into_array().execute(ctx)?; + Ok(prim.as_slice::().to_vec()) +} + +/// Build a non-nullable `PrimitiveArray` of row indices for use with [`ArrayRef::take`]. +fn indices_to_array(indices: &[usize]) -> ArrayRef { + let buf: Buffer = indices.iter().map(|&i| i as u64).collect(); + PrimitiveArray::new::(buf, Validity::NonNullable).into_array() +} + +fn pairs_per_row(flat: &[f32], num_rows: usize) -> Result { + if num_rows == 0 { + bail!("distortion: cannot derive dim from zero rows"); + } + if !flat.len().is_multiple_of(num_rows) { + bail!( + "distortion: flat element count {} not divisible by row count {num_rows}", + flat.len(), + ); + } + Ok(flat.len() / num_rows) +} + +/// Per-vector normalized reconstruction squared error (NMSE). Rows whose original squared norm is +/// below `1e-10` are dropped because their normalized error is numerically undefined. +fn reconstruction_errors( + original: &[f32], + reconstructed: &[f32], + dim: usize, + num_rows: usize, +) -> Vec { + let mut out = Vec::with_capacity(num_rows); + for row in 0..num_rows { + let start = row * dim; + let end = start + dim; + let orig = &original[start..end]; + let recon = &reconstructed[start..end]; + let norm_sq: f32 = orig.iter().map(|&v| v * v).sum(); + if norm_sq < 1e-10 { + continue; + } + let err_sq: f32 = orig + .iter() + .zip(recon.iter()) + .map(|(&a, &b)| (a - b) * (a - b)) + .sum(); + out.push(err_sq / norm_sq); + } + out +} + +fn abs_diff(lhs: &[f32], rhs: &[f32]) -> Vec { + lhs.iter() + .zip(rhs.iter()) + .map(|(&a, &b)| (a - b).abs()) + .collect() +} + +fn stats(samples: &[f32]) -> DistortionStats { + if samples.is_empty() { + return DistortionStats { + mean: f32::NAN, + median: f32::NAN, + max: f32::NAN, + }; + } + + let sum: f64 = samples.iter().map(|&v| f64::from(v)).sum(); + #[expect( + clippy::cast_possible_truncation, + reason = "casting an f64 mean back to f32 is intentional and matches the input precision" + )] + let mean = (sum / samples.len() as f64) as f32; + + let mut sorted = samples.to_vec(); + sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + let mid = sorted.len() / 2; + let median = if sorted.len() % 2 == 1 { + sorted[mid] + } else { + 0.5 * (sorted[mid - 1] + sorted[mid]) + }; + + let max = samples.iter().copied().fold(f32::NEG_INFINITY, f32::max); + + DistortionStats { mean, median, max } +} + +impl DistortionReport { + /// Render the report as a markdown header line followed by a tabled table. + pub fn render(&self, writer: &mut dyn Write) -> Result<()> { + writeln!( + writer, + "## {} | dim={} | layout={} | bits={} | samples={} | seed={} | rounds={}", + self.dataset.name(), + self.dim, + self.layout.label(), + self.bits, + self.samples, + self.seed, + self.rounds, + )?; + + let rows: &[(&str, f32)] = &[ + ("reconstruction NMSE mean", self.reconstruction.mean), + ("reconstruction NMSE median", self.reconstruction.median), + ("reconstruction NMSE max", self.reconstruction.max), + ("decoded cosine err mean", self.decoded_cosine.mean), + ("decoded cosine err median", self.decoded_cosine.median), + ("decoded cosine err max", self.decoded_cosine.max), + ]; + + let mut builder = tabled::builder::Builder::new(); + builder.push_record(["metric", "value"]); + for &(metric, value) in rows { + builder.push_record([metric.to_owned(), format_metric(value)]); + } + let mut table = builder.build(); + table.with(Style::modern()); + writeln!(writer, "{table}")?; + Ok(()) + } +} + +fn format_metric(value: f32) -> String { + if value.is_nan() { + "nan".to_owned() + } else if value == 0.0 { + "0".to_owned() + } else if value.abs() < 1e-3 || value.abs() >= 1e4 { + format!("{value:.3e}") + } else { + format!("{value:.6}") + } +} diff --git a/benchmarks/vector-search-bench/src/lib.rs b/benchmarks/vector-search-bench/src/lib.rs index 643cbb5bd0d..76b24390d09 100644 --- a/benchmarks/vector-search-bench/src/lib.rs +++ b/benchmarks/vector-search-bench/src/lib.rs @@ -5,6 +5,7 @@ pub mod compression; pub mod display; +pub mod distortion; pub mod expression; pub mod ingest; pub mod prepare; @@ -13,9 +14,12 @@ pub mod scan; use std::sync::LazyLock; +use anyhow::Result; use vortex::VortexSessionDefault; use vortex::io::session::RuntimeSessionExt; use vortex::session::VortexSession; +use vortex_bench::vector_dataset::TrainLayout; +use vortex_bench::vector_dataset::VectorDataset; pub static SESSION: LazyLock = LazyLock::new(|| { // SAFETY: called from inside the LazyLock initializer, before any other access to @@ -26,3 +30,38 @@ pub static SESSION: LazyLock = LazyLock::new(|| { vortex_tensor::initialize(&session); session }); + +/// Resolve a dataset's [`TrainLayout`]. +/// +/// Every benchmark has different sets of possible dataset layouts available. The user **must** +/// provide one if there are multiple layouts. But if a dataset only has 1 layout, we can choose +/// that for them as the default. +pub fn resolve_layout( + dataset: VectorDataset, + requested: Option, +) -> Result { + let layouts = dataset.layouts(); + + match requested { + Some(layout) => { + dataset.validate_layout(layout)?; + Ok(layout) + } + None => { + if layouts.len() == 1 { + Ok(layouts[0].layout()) + } else { + let allowed = layouts + .iter() + .map(|s| s.layout().label()) + .collect::>() + .join(", "); + anyhow::bail!( + "dataset {} hosts multiple layouts ([{}]): pass --layout to pick one", + dataset.name(), + allowed, + ); + } + } + } +} diff --git a/benchmarks/vector-search-bench/src/main.rs b/benchmarks/vector-search-bench/src/main.rs index 440de142bef..307c050d221 100644 --- a/benchmarks/vector-search-bench/src/main.rs +++ b/benchmarks/vector-search-bench/src/main.rs @@ -1,15 +1,20 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! `vector-search-bench` — on-disk cosine-similarity scan benchmark. +//! `vector-search-bench` benchmarks for cosine-similarity scan and TurboQuant distortion. //! //! ```sh -//! cargo run -p vector-search-bench --release -- \ +//! cargo run -p vector-search-bench --release -- search \ //! --dataset cohere-large-10m \ //! --layout partitioned \ //! --flavors vortex-uncompressed,vortex-turboquant \ //! --iterations 3 \ //! --threshold 0.8 +//! +//! cargo run -p vector-search-bench --release -- distortion \ +//! --dataset sift-small-500k \ +//! --bits 4 \ +//! --samples 4096 //! ``` use std::path::PathBuf; @@ -17,13 +22,17 @@ use std::path::PathBuf; use anyhow::Context; use anyhow::Result; use clap::Parser; +use clap::Subcommand; use vector_search_bench::compression::ALL_VECTOR_FLAVORS; use vector_search_bench::compression::VectorFlavor; use vector_search_bench::display::DatasetReport; use vector_search_bench::display::render; +use vector_search_bench::distortion::DistortionConfig; +use vector_search_bench::distortion::run_distortion; use vector_search_bench::prepare::CompressedVortexDataset; use vector_search_bench::prepare::prepare_all; use vector_search_bench::query::get_random_query_vector; +use vector_search_bench::resolve_layout; use vector_search_bench::scan::ScanConfig; use vector_search_bench::scan::ScanTiming; use vector_search_bench::scan::run_search_scan; @@ -35,7 +44,21 @@ use vortex_bench::vector_dataset::VectorDataset; #[derive(Parser, Debug)] #[command(version, about, long_about = None)] -struct Args { +struct Cli { + #[command(subcommand)] + command: Command, +} + +#[derive(Subcommand, Debug)] +enum Command { + /// On-disk cosine-similarity scan latency benchmark. + Search(SearchArgs), + /// TurboQuant distortion measurement: reconstruction error and cosine error. + Distortion(DistortionArgs), +} + +#[derive(Parser, Debug)] +struct SearchArgs { /// Dataset to benchmark. Single dataset per CLI invocation by design — large datasets /// are intentionally babysat one at a time. #[arg(long, value_enum)] @@ -86,9 +109,55 @@ struct Args { tracing: bool, } +#[derive(Parser, Debug)] +struct DistortionArgs { + /// Dataset to load vectors from. One dataset per invocation. + #[arg(long, value_enum)] + dataset: VectorDataset, + + /// Train-split layout. Required when the dataset publishes more than one layout. + #[arg(long, value_enum)] + layout: Option, + + /// Bits per quantized coordinate. + #[arg(long, default_value_t = 4)] + bits: u8, + + /// Seed for the SORF rotation. + #[arg(long, default_value_t = 42)] + seed: u64, + + /// Number of sign-diagonal plus Walsh-Hadamard rounds in the SORF transform. + #[arg(long, default_value_t = 3)] + rounds: u8, + + /// Number of base vectors to sample from the first train shard (first N rows). + #[arg(long, default_value_t = 65536)] + samples: usize, + + /// Optional path to write the rendered table to instead of stdout. + #[arg(long)] + output_path: Option, + + /// Emit verbose tracing. + #[arg(short, long)] + verbose: bool, + + /// Enable perfetto tracing output. + #[arg(long)] + tracing: bool, +} + #[tokio::main] async fn main() -> Result<()> { - let args = Args::parse(); + let cli = Cli::parse(); + match cli.command { + Command::Search(args) => run_search(args).await, + Command::Distortion(args) => run_distortion_cmd(args).await, + } +} + +async fn run_search(args: SearchArgs) -> Result<()> { setup_logging_and_tracing(args.verbose, args.tracing)?; let dataset = args.dataset; @@ -105,12 +174,10 @@ async fn main() -> Result<()> { anyhow::bail!("no flavors selected, please pass at least one to --flavors"); } - // Load the source embeddings parquet files. let datasets_paths = vector_dataset::download(dataset, layout) .await .with_context(|| format!("download {}", dataset.name()))?; - // Load all vortex files needed, compressing new ones if needed. let prepared = prepare_all(dataset, layout, &datasets_paths, &args.flavors).await?; let query_vector = get_random_query_vector( @@ -131,14 +198,12 @@ async fn main() -> Result<()> { threshold: args.threshold, }; - // Run all scans and record how long each takes. let mut scan_timings: Vec = Vec::with_capacity(prepared.len()); for prep in &prepared { let timing = run_search_scan(prep, &query_vector.query, &scan_config).await?; scan_timings.push(timing); } - // Collect the benchmark results. let pairs: Vec<(VectorFlavor, &CompressedVortexDataset, &ScanTiming)> = prepared .iter() .zip(scan_timings.iter()) @@ -149,8 +214,6 @@ async fn main() -> Result<()> { vortex_results: &pairs, }; - // Emit v3 JSONL if requested. The records carry the per-scan dimensions that - // ScanTiming itself does not (dataset, layout, threshold). if let Some(path) = args.gh_json_v3.as_ref() { let records: Vec = scan_timings .iter() @@ -179,7 +242,6 @@ async fn main() -> Result<()> { v3::write_jsonl_to_path(path, &records)?; } - // Print the results. if let Some(path) = args.output_path { let mut file = std::fs::File::create(&path).with_context(|| format!("create {}", path.display()))?; @@ -193,32 +255,30 @@ async fn main() -> Result<()> { Ok(()) } -/// Every benchmark has different sets of possible dataset layouts available. The user **must** -/// provide one if there are multiple layouts. But if a dataset only has 1 layout, we can choose -/// that for them as the default. -fn resolve_layout(dataset: VectorDataset, requested: Option) -> Result { - let layouts = dataset.layouts(); - - match requested { - Some(layout) => { - dataset.validate_layout(layout)?; - Ok(layout) - } - None => { - if layouts.len() == 1 { - Ok(layouts[0].layout()) - } else { - let allowed = layouts - .iter() - .map(|s| s.layout().label()) - .collect::>() - .join(", "); - anyhow::bail!( - "dataset {} hosts multiple layouts ([{}]): pass --layout to pick one", - dataset.name(), - allowed, - ); - } - } +async fn run_distortion_cmd(args: DistortionArgs) -> Result<()> { + setup_logging_and_tracing(args.verbose, args.tracing)?; + + let layout = resolve_layout(args.dataset, args.layout)?; + let config = DistortionConfig { + dataset: args.dataset, + layout, + bits: args.bits, + seed: args.seed, + rounds: args.rounds, + samples: args.samples, + }; + + let report = run_distortion(&config).await?; + + if let Some(path) = args.output_path { + let mut file = + std::fs::File::create(&path).with_context(|| format!("create {}", path.display()))?; + report.render(&mut file)?; + } else { + let stdout = std::io::stdout(); + let mut handle = stdout.lock(); + report.render(&mut handle)?; } + + Ok(()) } From 11b5de1ef4879a5c99563b90751c42c8076db32e Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 27 May 2026 17:32:44 +0100 Subject: [PATCH 3/3] address comments Signed-off-by: Connor Tsui --- Cargo.lock | 1 + benchmarks/vector-search-bench/Cargo.toml | 1 + .../scripts/plot-turboquant-distortion.py | 151 ++++++++++----- .../vector-search-bench/src/distortion.rs | 177 +++++++++++------- 4 files changed, 211 insertions(+), 119 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 373a807cc20..6d13e840126 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9095,6 +9095,7 @@ dependencies = [ "indicatif", "parquet", "rand 0.10.1", + "rand_distr 0.6.0", "serde", "tabled", "tempfile", diff --git a/benchmarks/vector-search-bench/Cargo.toml b/benchmarks/vector-search-bench/Cargo.toml index 126b62d5e2d..02528526080 100644 --- a/benchmarks/vector-search-bench/Cargo.toml +++ b/benchmarks/vector-search-bench/Cargo.toml @@ -24,6 +24,7 @@ futures = { workspace = true } indicatif = { workspace = true } parquet = { workspace = true, features = ["async"] } rand = { workspace = true } +rand_distr = { workspace = true } serde = { workspace = true, features = ["derive"] } tabled = { workspace = true, features = ["std"] } tokio = { workspace = true, features = ["full"] } diff --git a/benchmarks/vector-search-bench/scripts/plot-turboquant-distortion.py b/benchmarks/vector-search-bench/scripts/plot-turboquant-distortion.py index 4d7fa20dc80..6d6b1f22b4f 100644 --- a/benchmarks/vector-search-bench/scripts/plot-turboquant-distortion.py +++ b/benchmarks/vector-search-bench/scripts/plot-turboquant-distortion.py @@ -11,7 +11,7 @@ """Sweep bits-vs-distortion for TurboQuant and plot the curves. Calls `vector-search-bench distortion` for each (dataset, bits) combination, parses the -table from stdout, and plots reconstruction NMSE and pairwise cosine-error curves with +table from stdout, and plots reconstruction NMSE and squared cosine-error curves with mean/median/max shown on a log-scaled y-axis. Each `--dataset` value may optionally pin a train layout with a colon, e.g. @@ -37,7 +37,7 @@ import matplotlib.pyplot as plt from matplotlib.lines import Line2D -from matplotlib.ticker import MaxNLocator +from matplotlib.ticker import MaxNLocator, NullLocator REPO_ROOT = Path(__file__).resolve().parents[3] DEFAULT_BINARY = REPO_ROOT / "target" / "release" / "vector-search-bench" @@ -46,9 +46,9 @@ "reconstruction NMSE mean", "reconstruction NMSE median", "reconstruction NMSE max", - "decoded cosine err mean", - "decoded cosine err median", - "decoded cosine err max", + "decoded cosine sqerr mean", + "decoded cosine sqerr median", + "decoded cosine sqerr max", ] @@ -143,17 +143,35 @@ def run_one( ) +# Refined small-b values from `main.tex` line 273-274 ("for b = 1, 2, 3, 4 we have +# D_mse approx 0.36, 0.117, 0.03, 0.009"). Tighter than the general sqrt(3)*pi/2 * 4^(-b) +# upper bound, which is what we fall back to for b >= 5. +_NMSE_UPPER_REFINED = {1: 0.36, 2: 0.117, 3: 0.03, 4: 0.009} + + def nmse_bound_stage1(bits: int) -> float: - """Paper's NMSE upper bound for TurboQuant_mse (Stage 1). + """Paper's Stage-1 unit-norm reconstruction upper bound for TurboQuant_mse. From the Stage 1 theorem (`main.tex`, line 272): for a unit-norm vector `x` quantized - to `b` bits per coordinate, `E[||x - x'||^2] <= (sqrt(3)*pi/2) / 4^b`. Because `x` is - unit-norm, `||x - x'||^2` equals the normalized squared error `||x - x'||^2 / ||x||^2`, - so the bound applies to the `reconstruction NMSE mean` curve directly. + to `b` bits per coordinate, `E[||x - x'||^2] <= (sqrt(3)*pi/2) * 4^(-b)`. TurboQuant + internally normalizes each input before quantizing, so the bound applies to per-row + NMSE = `||x - x'||^2 / ||x||^2 = ||unit(x) - unit(x')||^2` directly. For small `b` + (1..=4) the paper gives tighter refined values; we splice those in. """ + if bits in _NMSE_UPPER_REFINED: + return _NMSE_UPPER_REFINED[bits] return (math.sqrt(3.0) * math.pi / 2.0) / (4.0**bits) +def nmse_lower_bound(bits: int) -> float: + """Paper's Shannon lower bound on Stage-1 unit-norm reconstruction. + + From `main.tex` line 297: `D_mse(Q) >= 1/4^b` for any randomized `b`-bit quantizer. + Independent of dimension; applies to NMSE for the same reason as the upper bound. + """ + return 1.0 / (4.0**bits) + + def compression_ratio(bits: int, dim: int) -> float: """Theoretical TurboQuant compression ratio vs f32 storage. @@ -168,18 +186,13 @@ def compression_ratio(bits: int, dim: int) -> float: return original_bytes / per_vector_bytes -def cosine_bound(bits: int, dim: int) -> float: - """Paper's Stage-2 inner-product bound, rendered as an absolute-error envelope. +def cosine_sqerr_lower_bound(bits: int, dim: int) -> float: + """Paper's Shannon lower bound on Stage-2 squared inner-product distortion. - From the Stage 2 theorem (`main.tex`, line 288): for unit y and an `x` quantized via - TurboQuant_prod (Stage 2, MSE + QJL residual), `E[| - |^2] <= - sqrt(3)*pi^2/d * 4^(-b)`. Taking sqrt gives an upper envelope on the RMS error per - bit width, and by Jensen also on the mean abs error. - - Caveat: Vortex currently implements only Stage 1 (no QJL residual correction). The - Stage 1 inner-product error is biased and can sit *above* this Stage-2 envelope. + From `main.tex` line 298: `D_prod(Q) >= ||y||^2 / d * 1/4^b` for any randomized + `b`-bit quantizer. With unit probes (`||y||^2 = 1`) this is `1 / (d * 4^b)`. """ - return math.pi * (3.0**0.25) / math.sqrt(dim) / (2.0**bits) + return 1.0 / (dim * (4.0**bits)) DATASET_PALETTE = [ @@ -224,7 +237,17 @@ def plot(runs: list[Run], args: argparse.Namespace) -> None: } ) - fig, axes = plt.subplots(1, 3, figsize=(20, 6.5), constrained_layout=True) + # GridSpec with a dedicated bottom strip for the caption so the long text gets a real + # subplot rect: no clipping by `bbox_inches`, no overlap with axis labels, no reliance + # on matplotlib's `wrap=True` heuristic. Plot row gets the lion's share so the bottom + # caption strip doesn't dominate visually; legends are anchored above the axes via + # `bbox_to_anchor` (see `add_legends`), and constrained_layout reserves space for them + # inside the plot row. + fig = plt.figure(figsize=(22, 9.5), constrained_layout=True) + gs = fig.add_gridspec(2, 3, height_ratios=[12, 1]) + axes = [fig.add_subplot(gs[0, i]) for i in range(3)] + caption_ax = fig.add_subplot(gs[1, :]) + caption_ax.axis("off") fig.suptitle( f"TurboQuant distortion vs bits per coordinate" f" (samples={args.samples:,}, seed={args.seed}, rounds={args.rounds})", @@ -240,7 +263,7 @@ def plot(runs: list[Run], args: argparse.Namespace) -> None: by_dataset, dataset_colors, metric_prefix="reconstruction NMSE", - title="Reconstruction NMSE (per vector, normalized squared error)", + title=r"Reconstruction NMSE (per vector, $\|x - x^\prime\|^2 / \|x\|^2$)", ylabel=r"$\|x - x^\prime\|^2 / \|x\|^2$", ) bits_axis = sorted({r.bits for r in runs}) @@ -252,14 +275,22 @@ def plot(runs: list[Run], args: argparse.Namespace) -> None: linewidth=1.6, zorder=0, ) + axes[0].plot( + bits_axis, + [nmse_lower_bound(b) for b in bits_axis], + color="#222222", + linestyle=(0, (1, 2)), + linewidth=1.4, + zorder=0, + ) plot_panel( axes[1], by_dataset, dataset_colors, - metric_prefix="decoded cosine err", - title=r"Pairwise cosine error $|\cos(x_i, x_j) - \cos(x_i^\prime, x_j^\prime)|$", - ylabel="absolute error", + metric_prefix="decoded cosine sqerr", + title=r"Squared cosine error $(\cos(y_i, x_i) - \cos(y_i, x_i^\prime))^2$", + ylabel="squared error", ) for dataset, ds_runs in by_dataset.items(): color = dataset_colors[dataset] @@ -267,29 +298,34 @@ def plot(runs: list[Run], args: argparse.Namespace) -> None: bits = sorted({r.bits for r in ds_runs}) axes[1].plot( bits, - [cosine_bound(b, d) for b in bits], + [cosine_sqerr_lower_bound(b, d) for b in bits], color=color, - linestyle=(0, (4, 2, 1, 2)), - linewidth=1.2, - alpha=0.6, + linestyle=(0, (1, 2)), + linewidth=1.0, + alpha=0.5, zorder=0, ) plot_compression_panel(axes[2], by_dataset, dataset_colors) add_legends(fig, axes, dataset_colors, dataset_dims) - fig.text( + caption_ax.text( 0.5, - -0.015, - "Cosine bound is the paper's Stage-2 (TurboQuant_prod, MSE + QJL residual) " - "envelope; Vortex currently ships Stage 1 only, so empirical curves may sit " - "above it. Compression ratio is theoretical " + 1.0, + "NMSE upper bound uses the paper's refined small-b values for b<=4 and the " + "smooth sqrt(3)*pi/2 * 4^(-b) general formula for b>=5. Lower bounds are the " + "Shannon information-theoretic floor for any randomized b-bit quantizer. " + "Vortex ships TurboQuant Stage 1 only, so no Stage-2 inner-product upper " + "bound is drawn on the cosine panel. Probe vectors y_i are sampled iid " + "uniform on the unit sphere. Compression ratio is theoretical " "(padded_dim * bits / 8 + 4 bytes per vector), excludes per-shard centroid " "tables and file metadata.", ha="center", + va="top", fontsize=9, color="#555555", wrap=True, + transform=caption_ax.transAxes, ) if args.output: @@ -334,6 +370,10 @@ def plot_panel( ax.grid(True, which="major", linewidth=0.7, alpha=0.45) ax.grid(True, which="minor", linewidth=0.4, alpha=0.22) ax.minorticks_on() + # Only the integer bit-widths should get an x-axis line; suppress the in-between + # minor ticks that `minorticks_on()` adds (the y-axis minors stay - they're useful + # on the log scale). + ax.xaxis.set_minor_locator(NullLocator()) def plot_compression_panel( @@ -367,9 +407,12 @@ def plot_compression_panel( ax.grid(True, which="major", linewidth=0.7, alpha=0.45) ax.grid(True, which="minor", linewidth=0.4, alpha=0.22) ax.minorticks_on() + ax.xaxis.set_minor_locator(NullLocator()) ax.legend( title="dataset", - loc="upper right", + loc="lower center", + bbox_to_anchor=(0.5, 1.02), + ncol=2, fontsize=9, title_fontsize=10, ) @@ -412,38 +455,50 @@ def add_legends( ) for _, label, linestyle, linewidth, marker in STAT_STYLES ] - nmse_bound_handle_s1 = Line2D( + nmse_upper_handle = Line2D( [], [], color="#222222", linestyle=(0, (4, 2, 1, 2)), linewidth=1.6, - label=r"paper bound: $D_{\mathrm{mse}} \leq \frac{\sqrt{3}\,\pi}{2}\, 4^{-b}$", + label=( + r"upper bound: " + r"$D_{\mathrm{mse}} \leq \frac{\sqrt{3}\,\pi}{2}\, 4^{-b}$ (refined for $b\!\leq\!4$)" + ), + ) + nmse_lower_handle = Line2D( + [], + [], + color="#222222", + linestyle=(0, (1, 2)), + linewidth=1.4, + label=r"lower bound: $D_{\mathrm{mse}} \geq 4^{-b}$", ) - cosine_bound_handle = Line2D( + cosine_lower_handle = Line2D( [], [], color="#444444", - linestyle=(0, (4, 2, 1, 2)), - linewidth=1.2, - alpha=0.6, - label=( - r"paper Stage-2 bound: " - r"$\sqrt{D_{\mathrm{prod}}} \leq \frac{\pi\,3^{1/4}}{\sqrt{d}}\, 2^{-b}$" - ), + linestyle=(0, (1, 2)), + linewidth=1.0, + alpha=0.5, + label=r"lower bound: $D_{\mathrm{prod}} \geq \frac{1}{d}\, 4^{-b}$", ) axes[0].legend( - handles=dataset_handles + [nmse_bound_handle_s1], + handles=dataset_handles + [nmse_upper_handle, nmse_lower_handle], title="dataset / bound", - loc="upper right", + loc="lower center", + bbox_to_anchor=(0.5, 1.02), + ncol=2, fontsize=10, title_fontsize=10, ) axes[1].legend( - handles=stat_handles + [cosine_bound_handle], + handles=stat_handles + [cosine_lower_handle], title="statistic / bound", - loc="upper right", + loc="lower center", + bbox_to_anchor=(0.5, 1.02), + ncol=3, fontsize=10, title_fontsize=10, ) diff --git a/benchmarks/vector-search-bench/src/distortion.rs b/benchmarks/vector-search-bench/src/distortion.rs index 664182bb2a4..2f9a03ea275 100644 --- a/benchmarks/vector-search-bench/src/distortion.rs +++ b/benchmarks/vector-search-bench/src/distortion.rs @@ -3,11 +3,17 @@ //! TurboQuant distortion measurement on real vector datasets. //! -//! Reports per-vector normalized reconstruction error (`||x - x'||^2 / ||x||^2`) and pairwise -//! cosine-similarity error (`|cos(x_i, x_j) - cos(x'_i, x'_j)|`) after a full encode and decode -//! roundtrip through the [`vortex_tensor::encodings::turboquant`] scheme. This is the same -//! TurboQuant implementation the search subcommand stores on disk via +//! Reports per-vector NMSE (`||x - x'||^2 / ||x||^2 = ||unit(x) - unit(x')||^2`) and per- +//! vector squared cosine-similarity error (`(cos(y_i, x_i) - cos(y_i, x'_i))^2`) against a +//! set of independently sampled unit-norm probe vectors `y_i`, after a full encode and +//! decode roundtrip through the [`vortex_tensor::encodings::turboquant`] scheme. This is +//! the same TurboQuant implementation the search subcommand stores on disk via //! [`BtrBlocksCompressorBuilder::with_turboquant`](vortex_btrblocks::BtrBlocksCompressorBuilder). +//! +//! NMSE rather than raw SSE because TurboQuant internally normalizes each input to unit +//! norm before quantizing (storing `||x||` separately), so the paper's Stage-1 bound +//! `E[||unit(x) - unit(x')||^2] <= (sqrt(3) * pi / 2) * 4^(-b)` applies to NMSE directly; +//! raw `||x - x'||^2` sits at `||x||^2` times that bound and isn't comparable across rows. use std::io::Write; @@ -16,7 +22,8 @@ use anyhow::Result; use anyhow::bail; use rand::SeedableRng; use rand::rngs::StdRng; -use rand::seq::SliceRandom; +use rand_distr::Distribution; +use rand_distr::Normal; use tabled::settings::Style; use vortex::array::ArrayRef; use vortex::array::ExecutionCtx; @@ -25,21 +32,18 @@ use vortex::array::VortexSessionExecute; use vortex::array::arrays::ExtensionArray; use vortex::array::arrays::FixedSizeListArray; use vortex::array::arrays::PrimitiveArray; -use vortex::array::arrays::ScalarFnArray; use vortex::array::arrays::Struct; use vortex::array::arrays::StructArray; use vortex::array::arrays::extension::ExtensionArrayExt; use vortex::array::arrays::fixed_size_list::FixedSizeListArrayExt; use vortex::array::arrays::struct_::StructArrayExt; -use vortex::array::validity::Validity; -use vortex::buffer::Buffer; +use vortex::error::VortexExpect; use vortex_bench::conversions::parquet_to_vortex_chunks; use vortex_bench::vector_dataset; use vortex_bench::vector_dataset::TrainLayout; use vortex_bench::vector_dataset::VectorDataset; use vortex_tensor::encodings::turboquant::TurboQuantConfig; use vortex_tensor::encodings::turboquant::turboquant_encode; -use vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity; use crate::SESSION; use crate::ingest::transform_chunk; @@ -89,9 +93,10 @@ pub struct DistortionReport { pub rounds: u8, /// Number of base vectors sampled. pub samples: usize, - /// Per-vector normalized squared L2 reconstruction error. + /// Per-vector NMSE, `||x - x'||^2 / ||x||^2`, equal to `||unit(x) - unit(x')||^2`. pub reconstruction: DistortionStats, - /// Pairwise cosine-similarity error after decoding both sides. + /// Per-vector squared cosine-similarity error against a random unit-norm probe `y_i`, + /// `(cos(y_i, x_i) - cos(y_i, x'_i))^2`. pub decoded_cosine: DistortionStats, } @@ -127,9 +132,9 @@ pub async fn run_distortion(config: &DistortionConfig) -> Result Result = (0..n).collect(); - shuffled.shuffle(&mut StdRng::seed_from_u64(config.seed)); - let lhs_indices = indices_to_array(&shuffled[..half]); - let rhs_indices = indices_to_array(&shuffled[half..2 * half]); - - let true_cosines = compute_cosines( - emb.take(lhs_indices.clone())?, - emb.take(rhs_indices.clone())?, - &mut ctx, - )?; - let decoded_cosines = compute_cosines( - decoded.take(lhs_indices)?, - decoded.take(rhs_indices)?, - &mut ctx, - )?; - let decoded_cosine = stats(&abs_diff(&true_cosines, &decoded_cosines)); + let reconstruction = stats(&reconstruction_nmse(&original, &decoded_flat, dim, n)); + + // Sample independent unit-norm probe vectors `y_i` (one per row). The TurboQuant Stage-2 + // bound `E[( - )^2] <= sqrt(3) * pi^2 / d * 4^(-b)` holds for any fixed `y`, + // so drawing `y` from the unit sphere is a reasonable empirical sweep. + let probes = random_unit_vectors(n, dim, config.seed)?; + let decoded_cosine = stats(&squared_cosine_errors( + &original, + &decoded_flat, + &probes, + dim, + n, + )); Ok(DistortionReport { dataset, @@ -189,20 +188,6 @@ fn extract_flat_f32(array: &ArrayRef, ctx: &mut ExecutionCtx) -> Result Ok(elements.as_slice::().to_vec()) } -/// Compute one cosine per row over two equal-length tensor-like arrays. -fn compute_cosines(lhs: ArrayRef, rhs: ArrayRef, ctx: &mut ExecutionCtx) -> Result> { - let len = lhs.len(); - let sfn: ScalarFnArray = CosineSimilarity::try_new_array(lhs, rhs, len)?; - let prim: PrimitiveArray = sfn.into_array().execute(ctx)?; - Ok(prim.as_slice::().to_vec()) -} - -/// Build a non-nullable `PrimitiveArray` of row indices for use with [`ArrayRef::take`]. -fn indices_to_array(indices: &[usize]) -> ArrayRef { - let buf: Buffer = indices.iter().map(|&i| i as u64).collect(); - PrimitiveArray::new::(buf, Validity::NonNullable).into_array() -} - fn pairs_per_row(flat: &[f32], num_rows: usize) -> Result { if num_rows == 0 { bail!("distortion: cannot derive dim from zero rows"); @@ -216,38 +201,84 @@ fn pairs_per_row(flat: &[f32], num_rows: usize) -> Result { Ok(flat.len() / num_rows) } -/// Per-vector normalized reconstruction squared error (NMSE). Rows whose original squared norm is -/// below `1e-10` are dropped because their normalized error is numerically undefined. -fn reconstruction_errors( +/// Per-vector NMSE, `||x - x'||^2 / ||x||^2 = ||unit(x) - unit(x')||^2`. Zero-norm rows +/// report `0.0` (encoder maps zero in to zero out, so the unit-norm residual is `0`). +fn reconstruction_nmse( original: &[f32], reconstructed: &[f32], dim: usize, num_rows: usize, ) -> Vec { - let mut out = Vec::with_capacity(num_rows); + (0..num_rows) + .map(|row| { + let start = row * dim; + let end = start + dim; + let orig = &original[start..end]; + let recon = &reconstructed[start..end]; + let norm_sq: f32 = orig.iter().map(|&v| v * v).sum(); + if norm_sq == 0.0 { + return 0.0; + } + let err_sq: f32 = orig + .iter() + .zip(recon.iter()) + .map(|(&a, &b)| (a - b) * (a - b)) + .sum(); + err_sq / norm_sq + }) + .collect() +} + +/// Sample `num_rows` independent `dim`-D vectors with standard-normal entries and normalize each +/// row to unit L2 norm. Used as probe vectors `y_i` for the squared cosine-similarity error. +fn random_unit_vectors(num_rows: usize, dim: usize, seed: u64) -> Result> { + let mut rng = StdRng::seed_from_u64(seed); + let normal = Normal::new(0.0_f32, 1.0).context("constructing Normal(0, 1)")?; + let mut buf = vec![0.0_f32; num_rows * dim]; for row in 0..num_rows { let start = row * dim; let end = start + dim; - let orig = &original[start..end]; - let recon = &reconstructed[start..end]; - let norm_sq: f32 = orig.iter().map(|&v| v * v).sum(); - if norm_sq < 1e-10 { - continue; + for v in &mut buf[start..end] { + *v = normal.sample(&mut rng); + } + let norm = buf[start..end].iter().map(|&v| v * v).sum::().sqrt(); + if norm > 0.0 { + for v in &mut buf[start..end] { + *v /= norm; + } } - let err_sq: f32 = orig - .iter() - .zip(recon.iter()) - .map(|(&a, &b)| (a - b) * (a - b)) - .sum(); - out.push(err_sq / norm_sq); } - out + Ok(buf) +} + +/// Cosine similarity of two equal-length vectors, returning `0.0` if either has zero norm. +fn cosine(a: &[f32], b: &[f32]) -> f32 { + let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum(); + let norm_a: f32 = a.iter().map(|&v| v * v).sum::().sqrt(); + let norm_b: f32 = b.iter().map(|&v| v * v).sum::().sqrt(); + let denom = norm_a * norm_b; + if denom == 0.0 { 0.0 } else { dot / denom } } -fn abs_diff(lhs: &[f32], rhs: &[f32]) -> Vec { - lhs.iter() - .zip(rhs.iter()) - .map(|(&a, &b)| (a - b).abs()) +/// Per-row squared cosine-similarity error against probe `y_i`, +/// `(cos(y_i, x_i) - cos(y_i, x'_i))^2`. +fn squared_cosine_errors( + original: &[f32], + reconstructed: &[f32], + probes: &[f32], + dim: usize, + num_rows: usize, +) -> Vec { + (0..num_rows) + .map(|row| { + let start = row * dim; + let end = start + dim; + let xi = &original[start..end]; + let xi_dec = &reconstructed[start..end]; + let yi = &probes[start..end]; + let diff = cosine(yi, xi) - cosine(yi, xi_dec); + diff * diff + }) .collect() } @@ -276,7 +307,11 @@ fn stats(samples: &[f32]) -> DistortionStats { 0.5 * (sorted[mid - 1] + sorted[mid]) }; - let max = samples.iter().copied().fold(f32::NEG_INFINITY, f32::max); + let max = samples + .iter() + .copied() + .reduce(f32::max) + .vortex_expect("samples is non-empty per the early return above"); DistortionStats { mean, median, max } } @@ -300,9 +335,9 @@ impl DistortionReport { ("reconstruction NMSE mean", self.reconstruction.mean), ("reconstruction NMSE median", self.reconstruction.median), ("reconstruction NMSE max", self.reconstruction.max), - ("decoded cosine err mean", self.decoded_cosine.mean), - ("decoded cosine err median", self.decoded_cosine.median), - ("decoded cosine err max", self.decoded_cosine.max), + ("decoded cosine sqerr mean", self.decoded_cosine.mean), + ("decoded cosine sqerr median", self.decoded_cosine.median), + ("decoded cosine sqerr max", self.decoded_cosine.max), ]; let mut builder = tabled::builder::Builder::new();