feat: add cosine_distance scalar function#21542
Conversation
Add cosine_distance (and list_cosine_distance alias) to compute cosine distance between two numeric arrays. Includes shared vector math primitives in vector_math.rs for reuse by follow-on functions. Part of apache#21536. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Addresses review comments on apache#21542: - Iterate list offsets/values directly instead of per-row ArrayRef downcast - Remove nested-list unwrap loop (function does not support nested lists) - Drop convert_to_f64_array wrapper (coerce_types already guarantees Float64) - Remove duplicate Rust unit tests now covered by SLT - More descriptive error message for mismatched list lengths - Delete now-unused vector_math module; inline math into sole caller Adds SLT coverage for NULL-element-in-list behavior previously tested only in Rust unit tests. Part of apache#21536. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Thanks for the detailed review, @Jefffrey. Rework pushed in fc3ee90. Walking through each comment: 1. Iterate via offsets/values, not per-row ArrayRef downcast ( 2. Nested-list unwrap loop ( 3. Redundant null/Float64 check ( 4. Ambiguous length-mismatch error wording ( 5. Duplicate Rust unit tests ( 6. 7. Inline the math instead of a separate module ( Full validation matrix ( |
Addresses round-2 review comments on apache#21542: - Widen container variant in coerce_types when inputs mix List and LargeList (or FixedSizeList), so mixed-type calls like `cosine_distance([1.0, 0.0], arrow_cast([0.0, 1.0], 'LargeList(Float64)'))` succeed. Follows the pattern from PR apache#21704 (ArrayConcat). - Coerce bare NULL inputs to a matching list variant so `cosine_distance(NULL, [1.0, 2.0])` returns NULL instead of erroring. - Drop the `list_cosine_distance` alias — the base name is not `array_cosine_distance`, so the `array_X` -> `list_X` convention does not apply. - Expand SLT coverage: mixed-type variants, FixedSizeList inputs, Float32 and Int64 inner types, bare NULL in each position, NULL row in a multi-row VALUES, and an unsupported-type plan error case. Dispatch fallthrough in cosine_distance_inner is now unreachable after the coerce_types widening, changed from exec_err! to internal_err!. Part of apache#21536. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Thanks @Jefffrey. Round-3 pushed in ce312cc: 1. Mixed-type inputs ( 2. 3. Bare NULL input ( 4. Multi-row NULL coverage ( Additional SLT coverage added proactively:
Full validation run clean (fmt, clippy, full + sqlite-extended SLT, CLI, doctests, feature-flag checks, |
|
Thanks @crm26 |
## Which issue does this PR close? Part of apache#21536 — split of apache#21371 into one-function-per-PR. ## Rationale for this change Adds `inner_product(array1, array2)` — the dot product of two equal-length numeric arrays, returning `Float64`. Computed as `sum(array1[i] * array2[i])`. ## What changes are included in this PR? Mirrors the structural pattern of merged apache#21542 (`cosine_distance`): - Same `coerce_types` for `List`/`LargeList`/`FixedSizeList` of any numeric inner type, with widening to `LargeList` when any input is `LargeList` (per the apache#21704 pattern) - Same NULL semantics: bare `NULL` → `NULL`, NULL row → NULL, NULL element in list → NULL - Same Arrow-idiomatic implementation: single `as_float64_array(list_array.values())` downcast, slice by `value_offsets()`, iterate via `ScalarBuffer<f64>` - No alias, no shared module — standalone, inline math The arithmetic is the only semantic divergence from `cosine_distance`: - `dot += a*b` (no magnitude or normalization) - Empty arrays return `0.0` (sum of empty set), not `NULL` - No zero-magnitude special case (`inner_product([0,0], [1,2])` returns `0`, which is well-defined for inner product) ## Are these changes tested? Yes. SLT covers: - Orthogonal, identical, opposite, general non-trivial vectors - Single zero vector, both zero vectors - Bare `NULL` in either or both positions - NULL element inside a list (returns NULL for that row) - Mismatched lengths (error) - `LargeList` inputs - Mixed `(List, LargeList)` in both orders - `(FixedSizeList, FixedSizeList)` and `(FixedSizeList, LargeList)` - `Float32` and `Int64` inner type coercion - Multi-row query with NULL row propagation - Empty arrays (returns `0`) - No-args error - Return-type assertion (`Float64`) ## Are there any user-facing changes? New scalar function `inner_product`, documented in `docs/source/user-guide/sql/scalar_functions.md`. --------- Co-authored-by: Jeffrey Vo <jeffrey.vo.australia@gmail.com> Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
## Which issue does this PR close? Part of apache#21536 — split of apache#21371 into one-function-per-PR. Third in the series after apache#21542 (cosine_distance) and apache#21861 (inner_product). ## Rationale for this change Adds `array_normalize(array)` — the L2-normalized version of a numeric input vector. Computed as `array[i] / sqrt(sum(array[i]^2))` per element. Returns the same shape as the input (`List<Float64>` or `LargeList<Float64>`). Aliased as `list_normalize` to match the `array_X`/`list_X` convention used across the crate. ## What changes are included in this PR? Coercion shell mirrors the merged cosine_distance/inner_product pattern: - `coerce_types` accepts `List`/`LargeList`/`FixedSizeList` of any numeric inner type, plus bare `NULL`. After coercion the inner function only sees `List(Float64)` or `LargeList(Float64)`. - Per-row L2 norm computed inline (no shared module), using a single `as_float64_array(list_array.values())` downcast plus `value_offsets()` slicing — no per-row downcasts. - Manual list builder: `Vec<f64>` for values, `Vec<O>` for offsets, `NullBuffer` for row validity. Per-row semantics: - NULL row → NULL output - NULL element in list → NULL row - Empty list → empty list (no division-by-zero hazard) - Zero magnitude → NULL row (consistent with cosine_distance's zero-magnitude → NULL) - Otherwise → divide each element by `sqrt(sum-of-squares)` ## Are these changes tested? Yes. SLT covers: - 3-4-5 right triangle, 3D vector, already-unit-axis, single non-zero component, negative components - Bare `NULL` input, NULL element in list, zero vector, empty array - `LargeList`, `FixedSizeList` (via coercion), `Float32` and `Int64` inner types, integer literals - Multi-row query mixing normal / NULL row / zero-vector row / null-element row - Plan error for non-list input - No-args error - Return-type assertion (`List(Float64)`) - `list_normalize` alias coverage (constant + multi-row with NULL) ## Are there any user-facing changes? New scalar function `array_normalize` (alias `list_normalize`), documented in `docs/source/user-guide/sql/scalar_functions.md`.
Adds `array_add(array1, array2)` returning the element-wise sum of two numeric arrays. Aliased as `list_add`. Follows the per-function split pattern established by cosine_distance (apache#21542), inner_product (apache#21861), and array_normalize (apache#22013) per tracking issue apache#21536. Semantics: - NULL row in either input -> NULL row out - NULL element at position i in either input -> NULL element at i out (per-element propagation, divergent from inner_product which nulls the whole row; chosen because output is a list, not a scalar) - Length mismatch between rows -> exec_err - Empty arrays -> empty array Supports List, LargeList, and FixedSizeList inputs; numeric element types are coerced to Float64. If any input is LargeList, both sides are widened to LargeList for homogeneous runtime dispatch. Uses OffsetBufferBuilder + NullBufferBuilder per the pattern adopted in array_normalize round 1.
Adds `array_add(array1, array2)` returning the element-wise sum of two numeric arrays. Aliased as `list_add`. Follows the per-function split pattern established by cosine_distance (apache#21542), inner_product (apache#21861), and array_normalize (apache#22013) per tracking issue apache#21536. Semantics: - NULL row in either input -> NULL row out - NULL element at position i in either input -> NULL element at i out (per-element propagation, divergent from inner_product which nulls the whole row; chosen because output is a list, not a scalar) - Length mismatch between rows -> exec_err - Empty arrays -> empty array Supports List, LargeList, and FixedSizeList inputs; numeric element types are coerced to Float64. If any input is LargeList, both sides are widened to LargeList for homogeneous runtime dispatch. Uses OffsetBufferBuilder + NullBufferBuilder per the pattern adopted in array_normalize round 1.
Adds `array_add(array1, array2)` returning the element-wise sum of two numeric arrays. Aliased as `list_add`. Follows the per-function split pattern established by cosine_distance (apache#21542), inner_product (apache#21861), and array_normalize (apache#22013) per tracking issue apache#21536. Semantics: - NULL row in either input -> NULL row out - NULL element at position i in either input -> NULL element at i out (per-element propagation, divergent from inner_product which nulls the whole row; chosen because output is a list, not a scalar) - Length mismatch between rows -> exec_err - Empty arrays -> empty array Supports List, LargeList, and FixedSizeList inputs; numeric element types are coerced to Float64. If any input is LargeList, both sides are widened to LargeList for homogeneous runtime dispatch. Uses OffsetBufferBuilder + NullBufferBuilder per the pattern adopted in array_normalize round 1.
## Which issue does this PR close? Partial of apache#21536 — `array_scale` (the list+scalar arithmetic function in the vector math series). ## Rationale for this change Continues the per-function split requested by @alamb on apache#21536. Three sibling PRs already merged: `cosine_distance` (apache#21542), `inner_product` (apache#21861), `array_normalize` (apache#22013). `array_add` is in flight as apache#22459 by @SubhamSinghal. Adds element-wise scalar multiplication for numeric arrays, returning a list of the same shape. Aliased as `list_scale` to match the `array_X` / `list_X` precedent in this crate. ## What changes are included in this PR? - New scalar UDF `array_scale(array, scalar)` in `datafusion/functions-nested/src/array_scale.rs` - Module wire-up + registration in `datafusion/functions-nested/src/lib.rs` - SLT tests at `datafusion/sqllogictest/test_files/array_scale.slt` - Auto-generated function docs entry in `docs/source/user-guide/sql/scalar_functions.md` **Signature:** first arg `List/LargeList/FixedSizeList<numeric>`, second arg numeric scalar. Both coerce to `Float64`. Same list-widening rules as the binary-op siblings. **NULL semantics:** - NULL row in array → NULL row out - NULL scalar → NULL row out (whole-row, because the scalar applies uniformly) - NULL element at position \`i\` → NULL element at \`i\` out (per-element propagation) - Empty array → empty array **Builders:** uses \`OffsetBufferBuilder\` + \`NullBufferBuilder\` per the pattern adopted in the round-1 review of apache#22013. ## Are these changes tested? Yes. \`array_scale.slt\` covers: - Happy paths (positive, negative, zero, fractional, single-element) - NULL propagation at all three levels (NULL row, NULL scalar, NULL element) - All list type variants (\`List\`, \`LargeList\`, \`FixedSizeList\`) - Numeric inner type coercion (Float32, Int64, integer literals) - Multi-row queries with both constant-scalar broadcast and per-row column scalar - Error paths (non-numeric scalar, non-list first arg, wrong arity) - Empty array - \`list_scale\` alias ## Are there any user-facing changes? Yes — new SQL scalar function \`array_scale(array, scalar)\` and its alias \`list_scale\`. Documented in \`docs/source/user-guide/sql/scalar_functions.md\`.
Adds `array_sum(array)` returning the sum of elements in a numeric array. Aliased as `list_sum`. Part of the per-function split sequence on tracking issue apache#21536, following the pattern of the already-merged PRs in this series (cosine_distance apache#21542, inner_product apache#21861, array_normalize apache#22013, array_scale apache#22466). Semantics: - NULL row in array -> NULL row out - NULL elements are skipped (SQL aggregate convention; matches PostgreSQL array_sum, DuckDB list_sum, Spark aggregate). A row whose every element is NULL yields NULL. - Empty array -> 0.0 (additive identity, matches SQL SUM over no rows conceptually, and DuckDB list_sum([]) = 0) Input is List/LargeList/FixedSizeList of any numeric type; elements are coerced to Float64. Output is Float64.
Adds `array_sum(array)` returning the sum of elements in a numeric array. Aliased as `list_sum`. Part of the per-function split sequence on tracking issue apache#21536, following the pattern of the already-merged PRs in this series (cosine_distance apache#21542, inner_product apache#21861, array_normalize apache#22013, array_scale apache#22466). Semantics: - NULL row in array -> NULL row out - NULL elements are skipped (SQL aggregate convention; matches PostgreSQL array_sum, DuckDB list_sum, Spark aggregate). A row whose every element is NULL yields NULL. - Empty array -> 0.0 (additive identity, matches SQL SUM over no rows conceptually, and DuckDB list_sum([]) = 0) Input is List/LargeList/FixedSizeList of any numeric type; elements are coerced to Float64. Output is Float64.
Adds `array_sum(array)` returning the sum of elements in a numeric array. Aliased as `list_sum`. Part of the per-function split sequence on tracking issue apache#21536, following the pattern of the already-merged PRs in this series (cosine_distance apache#21542, inner_product apache#21861, array_normalize apache#22013, array_scale apache#22466). Semantics: - NULL row in array -> NULL row out - NULL elements are skipped (SQL aggregate convention; matches PostgreSQL array_sum, DuckDB list_sum, Spark aggregate). A row whose every element is NULL yields NULL. - Empty array -> 0.0 (additive identity, matches SQL SUM over no rows conceptually, and DuckDB list_sum([]) = 0) Input is List/LargeList/FixedSizeList of any numeric type; elements are coerced to Float64. Output is Float64.
Adds `array_sum(array)` returning the sum of elements in a numeric array. Aliased as `list_sum`. Part of the per-function split sequence on tracking issue apache#21536, following the pattern of the already-merged PRs in this series (cosine_distance apache#21542, inner_product apache#21861, array_normalize apache#22013, array_scale apache#22466). Semantics: - NULL row in array -> NULL row out - NULL elements are skipped (SQL aggregate convention; matches PostgreSQL array_sum, DuckDB list_sum, Spark aggregate). A row whose every element is NULL yields NULL. - Empty array -> 0.0 (additive identity, matches SQL SUM over no rows conceptually, and DuckDB list_sum([]) = 0) Input is List/LargeList/FixedSizeList of any numeric type; elements are coerced to Float64. Output is Float64.
Summary
cosine_distance(array1, array2)/list_cosine_distance— computes cosine distance (1 - cosine similarity) between two numeric arraysvector_math.rsprimitives (dot_product_f64,magnitude_f64,convert_to_f64_array) for reuse by follow-on vector functionsPart of #21536 — first in a series of split PRs (replacing #21371).
Test plan
cosine_distance.sltcovering all edge cases including empty arrays, LargeList, integer coercion, alias, return typecargo clippy,cargo fmt,taplo,prettier,cargo machete— all clean🤖 Generated with Claude Code