Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions examples/sort-axis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,10 @@ pub struct Permutation
impl Permutation
{
/// Checks if the permutation is correct
pub fn from_indices(v: Vec<usize>) -> Result<Self, ()>
pub fn from_indices(v: Vec<usize>) -> Option<Self>
Comment thread
RPG-Alex marked this conversation as resolved.
Outdated
{
let perm = Permutation { indices: v };
if perm.correct() {
Ok(perm)
} else {
Err(())
}
perm.correct().then_some(perm)
}

fn correct(&self) -> bool
Expand Down
16 changes: 8 additions & 8 deletions src/doc/crate_feature_flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@
//! ## `serde`
//! - Enables serialization support for serde 1.x
//!
//! ## `rayon`
//! - Enables parallel iterators, parallelized methods, the [`parallel`] module and [`par_azip!`].
//! - Implies std
#![cfg_attr(
not(feature = "rayon"),
doc = "//! ## `rayon`\n//! - Enables parallel iterators, parallelized methods, and the `par_azip!` macro.\n//! - Implies std\n"
)]
#![cfg_attr(
feature = "rayon",
doc = "//! ## `rayon`\n//! - Enables parallel iterators, parallelized methods, the [`crate::parallel`] module and [`crate::parallel::par_azip`].\n//! - Implies std\n"
)]
//!
//! ## `approx`
//! - Enables implementations of traits of the [`approx`] crate.
Expand All @@ -28,8 +33,3 @@
//!
//! ## `matrixmultiply-threading`
//! - Enable the ``threading`` feature in the matrixmultiply package
//!
//! [`parallel`]: crate::parallel

#[cfg(doc)]
use crate::parallel::par_azip;
13 changes: 10 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,18 @@
//! ## Crate Feature Flags
//!
//! The following crate feature flags are available. They are configured in your
//! `Cargo.toml`. See [`doc::crate_feature_flags`] for more information.
//! `Cargo.toml`. See [`crate::doc::crate_feature_flags`] for more information.
//!
//! - `std`: Rust standard library-using functionality (enabled by default)
//! - `serde`: serialization support for serde 1.x
//! - `rayon`: Parallel iterators, parallelized methods, the [`parallel`] module and [`par_azip!`].
#![cfg_attr(
not(feature = "rayon"),
doc = "//! - `rayon`: Parallel iterators, parallelized methods, and the `par_azip!` macro."
)]
#![cfg_attr(
feature = "rayon",
doc = "//! - `rayon`: Parallel iterators, parallelized methods, the [`parallel`] module and [`par_azip!`]."
)]
//! - `approx` Implementations of traits from the [`approx`] crate.
//! - `blas`: transparent BLAS support for matrix multiplication, needs configuration.
//! - `matrixmultiply-threading`: Use threading from `matrixmultiply`.
Expand Down Expand Up @@ -129,7 +136,7 @@ extern crate std;
#[cfg(feature = "blas")]
extern crate cblas_sys;

#[cfg(docsrs)]
#[cfg(any(doc, docsrs))]
pub mod doc;

use alloc::fmt::Debug;
Expand Down
114 changes: 57 additions & 57 deletions src/linalg/impl_linalg.rs
Comment thread
akern40 marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,63 @@ where
is_blas_2d(a._dim(), a._strides(), BlasOrder::F)
}

/// Dot product for dynamic-dimensional arrays (`ArrayD`).
///
/// For one-dimensional arrays, computes the vector dot product, which is the sum
/// of the elementwise products (no conjugation of complex operands).
/// Both arrays must have the same length.
///
/// For two-dimensional arrays, performs matrix multiplication. The array shapes
/// must be compatible in the following ways:
/// - If `self` is *M* × *N*, then `rhs` must be *N* × *K* for matrix-matrix multiplication
/// - If `self` is *M* × *N* and `rhs` is *N*, returns a vector of length *M*
/// - If `self` is *M* and `rhs` is *M* × *N*, returns a vector of length *N*
/// - If both arrays are one-dimensional of length *N*, returns a scalar
///
/// **Panics** if:
/// - The arrays have dimensions other than 1 or 2
/// - The array shapes are incompatible for the operation
/// - For vector dot product: the vectors have different lengths
impl<A> Dot<ArrayRef<A, IxDyn>> for ArrayRef<A, IxDyn>
where A: LinalgScalar
{
type Output = Array<A, IxDyn>;

fn dot(&self, rhs: &ArrayRef<A, IxDyn>) -> Self::Output
{
match (self.ndim(), rhs.ndim()) {
(1, 1) => {
let a = self.view().into_dimensionality::<Ix1>().unwrap();
let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
let result = a.dot(&b);
ArrayD::from_elem(vec![], result)
}
(2, 2) => {
// Matrix-matrix multiplication
let a = self.view().into_dimensionality::<Ix2>().unwrap();
let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
(2, 1) => {
// Matrix-vector multiplication
let a = self.view().into_dimensionality::<Ix2>().unwrap();
let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
(1, 2) => {
// Vector-matrix multiplication
let a = self.view().into_dimensionality::<Ix1>().unwrap();
let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
_ => panic!("Dot product for ArrayD is only supported for 1D and 2D arrays"),
}
}
}

#[cfg(test)]
#[cfg(feature = "blas")]
mod blas_tests
Expand Down Expand Up @@ -1083,60 +1140,3 @@ mod blas_tests
}
}
}

/// Dot product for dynamic-dimensional arrays (`ArrayD`).
///
/// For one-dimensional arrays, computes the vector dot product, which is the sum
/// of the elementwise products (no conjugation of complex operands).
/// Both arrays must have the same length.
///
/// For two-dimensional arrays, performs matrix multiplication. The array shapes
/// must be compatible in the following ways:
/// - If `self` is *M* × *N*, then `rhs` must be *N* × *K* for matrix-matrix multiplication
/// - If `self` is *M* × *N* and `rhs` is *N*, returns a vector of length *M*
/// - If `self` is *M* and `rhs` is *M* × *N*, returns a vector of length *N*
/// - If both arrays are one-dimensional of length *N*, returns a scalar
///
/// **Panics** if:
/// - The arrays have dimensions other than 1 or 2
/// - The array shapes are incompatible for the operation
/// - For vector dot product: the vectors have different lengths
impl<A> Dot<ArrayRef<A, IxDyn>> for ArrayRef<A, IxDyn>
where A: LinalgScalar
{
type Output = Array<A, IxDyn>;

fn dot(&self, rhs: &ArrayRef<A, IxDyn>) -> Self::Output
{
match (self.ndim(), rhs.ndim()) {
(1, 1) => {
let a = self.view().into_dimensionality::<Ix1>().unwrap();
let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
let result = a.dot(&b);
ArrayD::from_elem(vec![], result)
}
(2, 2) => {
// Matrix-matrix multiplication
let a = self.view().into_dimensionality::<Ix2>().unwrap();
let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
(2, 1) => {
// Matrix-vector multiplication
let a = self.view().into_dimensionality::<Ix2>().unwrap();
let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
(1, 2) => {
// Vector-matrix multiplication
let a = self.view().into_dimensionality::<Ix1>().unwrap();
let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
_ => panic!("Dot product for ArrayD is only supported for 1D and 2D arrays"),
}
}
}
1 change: 1 addition & 0 deletions src/zip/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ where D: Dimension
}

#[cfg(feature = "rayon")]
#[allow(dead_code)]
pub(crate) fn uninitialized_for_current_layout<T>(&self) -> Array<MaybeUninit<T>, D>
{
let is_f = self.prefer_f();
Expand Down
5 changes: 3 additions & 2 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
)]

use approx::assert_relative_eq;
use core::panic;
use defmac::defmac;
#[allow(deprecated)]
use itertools::{zip, Itertools};
Expand Down Expand Up @@ -1005,7 +1006,7 @@ fn iter_size_hint()
fn zero_axes()
{
let mut a = arr1::<f32>(&[]);
if let Some(_) = a.iter().next() {
if a.iter().next().is_some() {
panic!();
}
a.map(|_| panic!());
Expand Down Expand Up @@ -2080,7 +2081,7 @@ fn test_contiguous()
assert!(c.as_slice_memory_order().is_some());
let v = c.slice(s![.., 0..1, ..]);
assert!(!v.is_standard_layout());
assert!(!v.as_slice_memory_order().is_some());
assert!(v.as_slice_memory_order().is_none());

let v = c.slice(s![1..2, .., ..]);
assert!(v.is_standard_layout());
Expand Down
6 changes: 4 additions & 2 deletions tests/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ fn inner_iter_corner_cases()
assert_equal(a0.rows(), vec![aview1(&[0])]);

let a2 = ArcArray::<i32, _>::zeros((0, 3));
assert_equal(a2.rows(), vec![aview1(&[]); 0]);
assert_equal(a2.rows(), Vec::<ArrayView1<'_, i32>>::new());

let a2 = ArcArray::<i32, _>::zeros((3, 0));
assert_equal(a2.rows(), vec![aview1(&[]); 3]);
Expand Down Expand Up @@ -359,11 +359,13 @@ fn axis_iter_zip_partially_consumed_discontiguous()
}
}

use ndarray::ArrayView1;

#[test]
fn outer_iter_corner_cases()
{
let a2 = ArcArray::<i32, _>::zeros((0, 3));
assert_equal(a2.outer_iter(), vec![aview1(&[]); 0]);
assert_equal(a2.outer_iter(), Vec::<ArrayView1<'_, i32>>::new());

let a2 = ArcArray::<i32, _>::zeros((3, 0));
assert_equal(a2.outer_iter(), vec![aview1(&[]); 3]);
Expand Down
2 changes: 1 addition & 1 deletion tests/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ fn var_too_large_ddof()
fn var_nan_ddof()
{
let a = Array2::<f64>::zeros((2, 3));
let v = a.var(std::f64::NAN);
let v = a.var(f64::NAN);
assert!(v.is_nan());
}

Expand Down
8 changes: 4 additions & 4 deletions tests/oper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,20 @@ fn test_oper(op: &str, a: &[f32], b: &[f32], c: &[f32])
let aa = CowArray::from(arr1(a));
let bb = CowArray::from(arr1(b));
let cc = CowArray::from(arr1(c));
test_oper_arr::<f32, _>(op, aa.clone(), bb.clone(), cc.clone());
test_oper_arr(op, aa.clone(), bb.clone(), cc.clone());
let dim = (2, 2);
let aa = aa.to_shape(dim).unwrap();
let bb = bb.to_shape(dim).unwrap();
let cc = cc.to_shape(dim).unwrap();
test_oper_arr::<f32, _>(op, aa.clone(), bb.clone(), cc.clone());
test_oper_arr(op, aa.clone(), bb.clone(), cc.clone());
let dim = (1, 2, 1, 2);
let aa = aa.to_shape(dim).unwrap();
let bb = bb.to_shape(dim).unwrap();
let cc = cc.to_shape(dim).unwrap();
test_oper_arr::<f32, _>(op, aa.clone(), bb.clone(), cc.clone());
test_oper_arr(op, aa.clone(), bb.clone(), cc.clone());
}

fn test_oper_arr<A, D>(op: &str, mut aa: CowArray<f32, D>, bb: CowArray<f32, D>, cc: CowArray<f32, D>)
fn test_oper_arr<D>(op: &str, mut aa: CowArray<f32, D>, bb: CowArray<f32, D>, cc: CowArray<f32, D>)
where D: Dimension
{
match op {
Expand Down