diff --git a/.codecov.yml b/.codecov.yml index 2280801..5237740 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -3,12 +3,13 @@ coverage: status: project: default: - target: 60% + target: 98% removed_code_behavior: removals_only - threshold: 5% + threshold: 1% patch: default: - target: 50% + target: 90% + threshold: 2% ignore: - "src/lib.rs" diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 13b2be0..8239a0c 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -44,7 +44,7 @@ env: jobs: bench: runs-on: ubuntu-latest - timeout-minutes: 15 + timeout-minutes: 30 steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 diff --git a/AGENTS.md b/AGENTS.md index bc2dd63..f215352 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -63,6 +63,9 @@ invariant over the convenient edit. - `Result<_, LaError>` for all fallible operations. Panics are reserved for debug-only precondition violations (e.g. LDLT symmetry check) and documented on the method. +- Public APIs that return plain values must be genuinely infallible for all + representable inputs. If callers can observe failure, return `Result` or + `Option` instead of relying on `panic!`, `assert!`, `unwrap`, or `expect`. - Borrow by default (`&T`, `&[T]`); return borrowed views when possible. - Type and function names match textbook vocabulary (`Matrix`, `Vector`, `Lu`, `Ldlt`, `solve_vec`, `det`, `inf_norm`). Avoid Rust-ecosystem diff --git a/README.md b/README.md index ab1c57b..672c217 100644 --- a/README.md +++ b/README.md @@ -79,15 +79,15 @@ use la_stack::prelude::*; fn main() -> Result<(), LaError> { // This system requires pivoting (a[0][0] = 0), so it's a good LU demo. // A = J - I: zeros on diagonal, ones elsewhere. - let a = Matrix::<5>::from_rows([ + let a = Matrix::<5>::try_from_rows([ [0.0, 1.0, 1.0, 1.0, 1.0], [1.0, 0.0, 1.0, 1.0, 1.0], [1.0, 1.0, 0.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0], [1.0, 1.0, 1.0, 1.0, 0.0], - ]); + ])?; - let b = Vector::<5>::new([14.0, 13.0, 12.0, 11.0, 10.0]); + let b = Vector::<5>::try_new([14.0, 13.0, 12.0, 11.0, 10.0])?; let lu = a.lu(DEFAULT_PIVOT_TOL)?; let x = lu.solve_vec(b)?.into_array(); @@ -112,13 +112,13 @@ use la_stack::prelude::*; fn main() -> Result<(), LaError> { // This matrix is symmetric positive-definite (A = L*L^T) so LDLT works without pivoting. - let a = Matrix::<5>::from_rows([ + let a = Matrix::<5>::try_from_rows([ [1.0, 1.0, 0.0, 0.0, 0.0], [1.0, 2.0, 1.0, 0.0, 0.0], [0.0, 1.0, 2.0, 1.0, 0.0], [0.0, 0.0, 1.0, 2.0, 1.0], [0.0, 0.0, 0.0, 1.0, 2.0], - ]); + ])?; let det = a.ldlt(DEFAULT_SINGULAR_TOL)?.det()?; assert!((det - 1.0).abs() <= 1e-12); @@ -150,11 +150,14 @@ use la_stack::prelude::*; // Evaluated entirely at compile time — no runtime cost. const DET: Result, LaError> = { - let m = Matrix::<3>::from_rows([ + let m = match Matrix::<3>::try_from_rows([ [2.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 5.0], - ]); + ]) { + Ok(matrix) => matrix, + Err(_) => panic!("matrix entries must be finite"), + }; m.det_direct() }; assert_eq!(DET, Ok(Some(30.0))); @@ -196,19 +199,19 @@ use la_stack::prelude::*; fn main() -> Result<(), LaError> { // Exact determinant - let m = Matrix::<3>::from_rows([ + let m = Matrix::<3>::try_from_rows([ [1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], - ]); + ])?; assert_eq!(m.det_sign_exact()?, 0); // exactly singular let det = m.det_exact()?; assert_eq!(det, BigRational::from_integer(0.into())); // exact zero // Exact linear system solve - let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); - let b = Vector::<2>::new([5.0, 11.0]); + let a = Matrix::<2>::try_from_rows([[1.0, 2.0], [3.0, 4.0]])?; + let b = Vector::<2>::try_new([5.0, 11.0])?; let x = a.solve_exact_f64(b)?.into_array(); assert!((x[0] - 1.0).abs() <= f64::EPSILON); assert!((x[1] - 2.0).abs() <= f64::EPSILON); diff --git a/benches/exact.rs b/benches/exact.rs index e15879a..5348f51 100644 --- a/benches/exact.rs +++ b/benches/exact.rs @@ -1,7 +1,7 @@ //! Benchmarks for exact arithmetic operations. //! //! These benchmarks measure the performance of the `exact` feature's -//! arbitrary-precision methods. They are organised into two classes: +//! arbitrary-precision methods. They are organised into three classes: //! //! 1. **General-case benches** (`exact_d{2..5}`) — a single //! well-conditioned diagonally-dominant matrix per dimension. These @@ -15,12 +15,32 @@ //! the `f64_decompose → BigInt` path). These measure tail behaviour //! that fixed well-conditioned inputs miss and provide stronger //! empirical evidence for `docs/PERFORMANCE.md`. +//! 3. **Random percentile benches** (`exact_random_percentile_d{2..5}`) — +//! a fixed-seed corpus of diagonally-dominant random matrices per +//! dimension. Each operation is pre-timed across the corpus to select +//! p50/p95/p99 cumulative input subsets, then measured with Criterion. -use criterion::{BenchmarkGroup, Criterion, measurement::WallTime}; -use la_stack::{Matrix, Vector}; -use pastey::paste; -use std::fmt::Display; +use std::array; +use std::cell::Cell; +use std::fmt::{self, Display}; use std::hint::black_box; +use std::num::NonZeroUsize; +use std::time::Instant; + +use criterion::{BatchSize, BenchmarkGroup, Criterion, measurement::WallTime}; +use pastey::paste; + +use la_stack::{Matrix, Vector}; + +const RANDOM_INPUTS_PER_DIM: SampleCount = SampleCount::new_unchecked(50); +const RANDOM_INPUT_ARRAY_LEN: usize = RANDOM_INPUTS_PER_DIM.get(); +const RANDOM_TIMING_PASSES: SampleCount = SampleCount::new_unchecked(5); +const RANDOM_SEED: [u8; 32] = [0; 32]; +const RANDOM_PERCENTILES: [RandomPercentile; 3] = [ + RandomPercentile::P50, + RandomPercentile::P95, + RandomPercentile::P99, +]; /// Return a successful benchmark operation result or panic with the named operation. fn require_ok(result: Result, operation: &str) -> T { @@ -30,6 +50,106 @@ fn require_ok(result: Result, operation: &str) -> T { } } +/// Configuration errors for exact-arithmetic benchmark input generation. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum ExactBenchConfigError { + EmptyCorpus, + UnorderedRange { min: i16, max: i16 }, +} + +impl Display for ExactBenchConfigError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Self::EmptyCorpus => f.write_str("random input corpus must be nonempty"), + Self::UnorderedRange { min, max } => { + write!(f, "random integer range must be ordered: {min}..={max}") + } + } + } +} + +/// Non-zero sample count used when selecting percentile benchmark inputs. +#[derive(Clone, Copy)] +struct SampleCount { + len: NonZeroUsize, +} + +impl SampleCount { + /// Construct a sample count for compile-time constants with visible nonzero values. + const fn new_unchecked(len: usize) -> Self { + match NonZeroUsize::new(len) { + Some(len) => Self { len }, + None => panic!("random input corpus must be nonempty"), + } + } + + /// Validate a runtime sample count before percentile calculations use it. + const fn new(len: usize) -> Result { + if let Some(len) = NonZeroUsize::new(len) { + Ok(Self { len }) + } else { + Err(ExactBenchConfigError::EmptyCorpus) + } + } + + /// Return the proven nonzero sample count as a raw `usize`. + const fn get(self) -> usize { + self.len.get() + } +} + +/// Inclusive integer range used by the fixed-seed exact benchmark generator. +#[derive(Clone, Copy)] +struct I16Range { + min: i16, + width: u64, +} + +impl I16Range { + /// Validate an inclusive `i16` range and cache its sampling width. + fn new(min: i16, max: i16) -> Result { + if min > max { + return Err(ExactBenchConfigError::UnorderedRange { min, max }); + } + + let width = i32::from(max) - i32::from(min) + 1; + Ok(Self { + min, + width: u64::try_from(width) + .map_err(|_| ExactBenchConfigError::UnorderedRange { min, max })?, + }) + } +} + +/// Percentiles selected from a pre-timed random-input corpus. +#[derive(Clone, Copy)] +enum RandomPercentile { + P50, + P95, + P99, +} + +impl RandomPercentile { + /// Return the percentile value as an integer percentage. + const fn value(self) -> usize { + match self { + Self::P50 => 50, + Self::P95 => 95, + Self::P99 => 99, + } + } + + /// Return the benchmark-name suffix for this percentile. + const fn name(self) -> &'static str { + match self { + Self::P50 => "p50", + Self::P95 => "p95", + Self::P99 => "p99", + } + } +} + +/// Return a deterministic, strictly diagonally-dominant benchmark matrix entry. #[inline] #[allow(clippy::cast_precision_loss)] const fn matrix_entry(r: usize, c: usize) -> f64 { @@ -40,6 +160,7 @@ const fn matrix_entry(r: usize, c: usize) -> f64 { } } +/// Build the deterministic baseline matrix rows for dimension `D`. #[inline] const fn make_matrix_rows() -> [[f64; D]; D] { let mut rows = [[0.0; D]; D]; @@ -55,6 +176,7 @@ const fn make_matrix_rows() -> [[f64; D]; D] { rows } +/// Build the deterministic baseline right-hand-side vector for dimension `D`. #[inline] #[allow(clippy::cast_precision_loss)] fn make_vector_array() -> [f64; D] { @@ -67,6 +189,233 @@ fn make_vector_array() -> [f64; D] { data } +/// Matrix/RHS pair used by random percentile exact-arithmetic benchmarks. +#[derive(Clone, Copy)] +struct ExactRandomInput { + matrix: Matrix, + rhs: Vector, +} + +/// Exact operation timed when selecting representative random inputs. +#[derive(Clone, Copy)] +enum ExactRandomOperation { + DetSignExact, + DetExact, + SolveExact, + SolveExactF64, +} + +impl ExactRandomOperation { + /// Return the benchmark-name stem for this exact operation. + const fn name(self) -> &'static str { + match self { + Self::DetSignExact => "det_sign_exact", + Self::DetExact => "det_exact", + Self::SolveExact => "solve_exact", + Self::SolveExactF64 => "solve_exact_f64", + } + } +} + +/// Deterministic `SplitMix64` generator for reproducible benchmark corpora. +struct SplitMix64 { + state: u64, +} + +impl SplitMix64 { + /// Initialize the generator with a fixed state. + const fn new(state: u64) -> Self { + Self { state } + } + + /// Advance the generator and return the next 64 random bits. + const fn next_u64(&mut self) -> u64 { + self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15); + let mut z = self.state; + z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9); + z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB); + z ^ (z >> 31) + } + + #[allow(clippy::cast_possible_truncation)] + /// Draw a random `i16` inside a validated inclusive range. + fn next_i16(&mut self, range: I16Range) -> i16 { + let offset = (self.next_u64() % range.width) as i32; + let value = i32::from(range.min) + offset; + value as i16 + } +} + +/// Derive a stable per-dimension seed from the global random benchmark seed. +#[allow(clippy::cast_possible_truncation)] +fn random_seed_for_dim() -> u64 { + let mut seed = + 0xC0DE_CAFE_D15C_A11Au64 ^ require_ok(u64::try_from(D), "dimension seed conversion"); + for (i, byte) in RANDOM_SEED.iter().copied().enumerate() { + let shift = require_ok(u32::try_from((i % 8) * 8), "seed shift conversion"); + seed ^= u64::from(byte) << shift; + seed = seed.rotate_left(7) ^ require_ok(u64::try_from(i), "seed index conversion"); + } + seed +} + +/// Build a fixed random corpus of finite, strictly diagonally-dominant inputs. +fn make_random_input_corpus() -> [ExactRandomInput; RANDOM_INPUT_ARRAY_LEN] { + let mut rng = SplitMix64::new(random_seed_for_dim::()); + let entry_range = require_ok(I16Range::new(-10, 10), "random integer range"); + array::from_fn(|_| { + let mut rows = [[0.0; D]; D]; + let mut diag = [0_i16; D]; + + for (r, row) in rows.iter_mut().enumerate() { + for (c, entry) in row.iter_mut().enumerate() { + if r == c { + diag[r] = rng.next_i16(entry_range); + } else { + *entry = f64::from(rng.next_i16(entry_range)); + } + } + } + + let shift = + f64::from(require_ok(u8::try_from(D), "dimension shift conversion")).mul_add(10.0, 1.0); + for (i, row) in rows.iter_mut().enumerate() { + row[i] = if diag[i] >= 0 { + f64::from(diag[i]) + shift + } else { + f64::from(diag[i]) - shift + }; + } + + let rhs = array::from_fn(|_| f64::from(rng.next_i16(entry_range))); + + ExactRandomInput { + matrix: require_ok( + Matrix::::try_from_rows(rows), + "random matrix construction", + ), + rhs: require_ok(Vector::::try_new(rhs), "random RHS vector construction"), + } + }) +} + +/// Execute one exact operation on a random benchmark input. +fn run_random_operation( + operation: ExactRandomOperation, + input: ExactRandomInput, +) { + match operation { + ExactRandomOperation::DetSignExact => { + let sign = require_ok( + black_box(input.matrix).det_sign_exact(), + "exact determinant sign", + ); + black_box(sign); + } + ExactRandomOperation::DetExact => { + let det = require_ok(black_box(input.matrix).det_exact(), "exact determinant"); + black_box(det); + } + ExactRandomOperation::SolveExact => { + let x = require_ok( + black_box(input.matrix).solve_exact(black_box(input.rhs)), + "exact linear solve", + ); + let _ = black_box(x); + } + ExactRandomOperation::SolveExactF64 => { + let x = require_ok( + black_box(input.matrix).solve_exact_f64(black_box(input.rhs)), + "exact linear solve converted to f64", + ); + let _ = black_box(x); + } + } +} + +/// Time one exact operation on one random input in nanoseconds. +fn time_random_operation( + operation: ExactRandomOperation, + input: ExactRandomInput, +) -> u128 { + let start = Instant::now(); + run_random_operation(operation, input); + start.elapsed().as_nanos() +} + +/// Time one exact operation repeatedly on one random input. +fn time_random_operation_repeated( + operation: ExactRandomOperation, + input: ExactRandomInput, +) -> u128 { + let mut elapsed = 0; + for _ in 0..RANDOM_TIMING_PASSES.get() { + elapsed += time_random_operation(operation, input); + } + elapsed +} + +/// Convert a percentile request into an index in a sorted timing corpus. +const fn percentile_index(count: SampleCount, percentile: RandomPercentile) -> usize { + ((count.get() - 1) * percentile.value() + 50) / 100 +} + +/// Select cumulative corpus index sets by pre-timing every input for one operation. +fn percentile_input_indices( + corpus: &[ExactRandomInput; RANDOM_INPUT_ARRAY_LEN], + operation: ExactRandomOperation, +) -> [Vec; RANDOM_PERCENTILES.len()] { + let input_count = require_ok(SampleCount::new(corpus.len()), "random input corpus size"); + let mut timings = [(0_u128, 0_usize); RANDOM_INPUT_ARRAY_LEN]; + for (i, input) in corpus.iter().copied().enumerate() { + timings[i] = (time_random_operation_repeated(operation, input), i); + } + timings.sort_unstable(); + + RANDOM_PERCENTILES.map(|percentile| { + let timing_idx = percentile_index(input_count, percentile); + let threshold = timings[timing_idx].0; + let mut indices = Vec::new(); + for &(elapsed, input_idx) in &timings { + if elapsed <= threshold { + indices.push(input_idx); + } + } + indices + }) +} + +/// Add p50/p95/p99 Criterion benches over percentile input sets. +fn bench_random_percentile_operation( + group: &mut BenchmarkGroup<'_, WallTime>, + corpus: &[ExactRandomInput; RANDOM_INPUT_ARRAY_LEN], + operation: ExactRandomOperation, +) { + let index_sets = percentile_input_indices(corpus, operation); + + for (percentile, input_indices) in RANDOM_PERCENTILES.into_iter().zip(index_sets) { + let input_count = require_ok( + SampleCount::new(input_indices.len()), + "percentile input set size", + ); + let cursor = Cell::new(0); + group.bench_function( + format!("{}_{}", operation.name(), percentile.name()), + move |bencher| { + bencher.iter_batched( + || { + let cursor_pos = cursor.get(); + cursor.set((cursor_pos + 1) % input_count.get()); + corpus[input_indices[cursor_pos]] + }, + |sample| run_random_operation(operation, sample), + BatchSize::SmallInput, + ); + }, + ); + } +} + /// Near-singular matrix: base singular matrix + tiny perturbation. /// /// The base `[[1,2,3],[4,5,6],[7,8,9]]` is exactly singular; adding @@ -78,11 +427,14 @@ fn make_vector_array() -> [f64; D] { #[inline] fn near_singular_3x3() -> Matrix<3> { let perturbation = f64::from_bits(0x3CD0_0000_0000_0000); // 2^-50 - Matrix::<3>::from_rows([ - [1.0 + perturbation, 2.0, 3.0], - [4.0, 5.0, 6.0], - [7.0, 8.0, 9.0], - ]) + require_ok( + Matrix::<3>::try_from_rows([ + [1.0 + perturbation, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + ]), + "near-singular matrix construction", + ) } /// Large-entry 3×3: strictly diagonally-dominant matrix with diagonal @@ -99,7 +451,10 @@ fn near_singular_3x3() -> Matrix<3> { #[inline] fn large_entries_3x3() -> Matrix<3> { let big = f64::MAX / 2.0; - Matrix::<3>::from_rows([[big, 1.0, 1.0], [1.0, big, 1.0], [1.0, 1.0, big]]) + require_ok( + Matrix::<3>::try_from_rows([[big, 1.0, 1.0], [1.0, big, 1.0], [1.0, 1.0, big]]), + "large-entry matrix construction", + ) } /// Hilbert matrix `H[i][j] = 1 / (i + j + 1)`. @@ -127,7 +482,10 @@ fn hilbert() -> Matrix { } r += 1; } - Matrix::::from_rows(rows) + require_ok( + Matrix::::try_from_rows(rows), + "Hilbert matrix construction", + ) } /// Populate a Criterion group with the four headline exact-arithmetic @@ -179,8 +537,14 @@ fn bench_extreme_group( macro_rules! gen_exact_benches_for_dim { ($c:expr, $d:literal) => { paste! {{ - let a = Matrix::<$d>::from_rows(make_matrix_rows::<$d>()); - let rhs = Vector::<$d>::new(make_vector_array::<$d>()); + let a = require_ok( + Matrix::<$d>::try_from_rows(make_matrix_rows::<$d>()), + "benchmark matrix construction", + ); + let rhs = require_ok( + Vector::<$d>::try_new(make_vector_array::<$d>()), + "benchmark RHS vector construction", + ); let mut [] = ($c).benchmark_group(concat!("exact_d", stringify!($d))); @@ -253,6 +617,39 @@ macro_rules! gen_exact_benches_for_dim { }; } +macro_rules! gen_random_percentile_benches_for_dim { + ($c:expr, $d:literal) => { + paste! {{ + let corpus = make_random_input_corpus::<$d>(); + let mut [] = + ($c).benchmark_group(concat!("exact_random_percentile_d", stringify!($d))); + + bench_random_percentile_operation( + &mut [], + &corpus, + ExactRandomOperation::DetSignExact, + ); + bench_random_percentile_operation( + &mut [], + &corpus, + ExactRandomOperation::DetExact, + ); + bench_random_percentile_operation( + &mut [], + &corpus, + ExactRandomOperation::SolveExact, + ); + bench_random_percentile_operation( + &mut [], + &corpus, + ExactRandomOperation::SolveExactF64, + ); + + [].finish(); + }}; + }; +} + fn main() { let mut c = Criterion::default().configure_from_args(); @@ -264,6 +661,20 @@ fn main() { gen_exact_benches_for_dim!(&mut c, 5); } + // === Random percentile groups === + // + // Each dimension uses a fixed-seed corpus of strictly + // diagonally-dominant integer matrices. For each operation, the corpus + // is pre-timed repeatedly to select cumulative p50/p95/p99 input sets, + // then Criterion cycles through each set with normal sampling. + #[allow(unused_must_use)] + { + gen_random_percentile_benches_for_dim!(&mut c, 2); + gen_random_percentile_benches_for_dim!(&mut c, 3); + gen_random_percentile_benches_for_dim!(&mut c, 4); + gen_random_percentile_benches_for_dim!(&mut c, 5); + } + // === Adversarial / extreme-input groups === // // Each group runs the same four exact-arithmetic benches @@ -279,7 +690,10 @@ fn main() { bench_extreme_group( &mut group, near_singular_3x3(), - Vector::<3>::new([1.0, 2.0, 3.0]), + require_ok( + Vector::<3>::try_new([1.0, 2.0, 3.0]), + "near-singular RHS vector construction", + ), ); group.finish(); } @@ -291,7 +705,10 @@ fn main() { bench_extreme_group( &mut group, large_entries_3x3(), - Vector::<3>::new([1.0, 1.0, 1.0]), + require_ok( + Vector::<3>::try_new([1.0, 1.0, 1.0]), + "large-entry RHS vector construction", + ), ); group.finish(); } @@ -301,13 +718,27 @@ fn main() { // space, exercising the f64 → BigInt scaling path. { let mut group = c.benchmark_group("exact_hilbert_4x4"); - bench_extreme_group(&mut group, hilbert::<4>(), Vector::<4>::new([1.0; 4])); + bench_extreme_group( + &mut group, + hilbert::<4>(), + require_ok( + Vector::<4>::try_new([1.0; 4]), + "Hilbert RHS vector construction", + ), + ); group.finish(); } { let mut group = c.benchmark_group("exact_hilbert_5x5"); - bench_extreme_group(&mut group, hilbert::<5>(), Vector::<5>::new([1.0; 5])); + bench_extreme_group( + &mut group, + hilbert::<5>(), + require_ok( + Vector::<5>::try_new([1.0; 5]), + "Hilbert RHS vector construction", + ), + ); group.finish(); } diff --git a/benches/vs_linalg.rs b/benches/vs_linalg.rs index fb41a6d..7ed10ea 100644 --- a/benches/vs_linalg.rs +++ b/benches/vs_linalg.rs @@ -7,13 +7,15 @@ //! - Determinant is benchmarked via LU on all sides (nalgebra uses closed-forms for 1×1/2×2/3×3). //! - Matrix infinity norm is the maximum absolute row sum on all sides. +use std::fmt::Display; +use std::hint::black_box; + use criterion::Criterion; use faer::linalg::solvers::{PartialPivLu, Solve}; use faer::perm::PermRef; -use la_stack::{DEFAULT_PIVOT_TOL, Matrix, Vector}; use pastey::paste; -use std::fmt::Display; -use std::hint::black_box; + +use la_stack::{DEFAULT_PIVOT_TOL, Matrix, Vector}; /// Return a successful benchmark operation result or panic with the named operation. fn require_ok(result: Result, operation: &str) -> T { @@ -28,10 +30,11 @@ fn require_some(value: Option, operation: &str) -> T { value.unwrap_or_else(|| panic!("{operation} returned no result")) } +/// Return `det(P)` for faer's permutation representation. +/// +/// Sign(det(P)) is +1 for even permutations and -1 for odd. Parity is computed +/// from the number of cycles: `sign = (-1)^(n - cycles)`. fn faer_perm_sign(p: PermRef<'_, usize>) -> f64 { - // Sign(det(P)) for a permutation matrix P is +1 for even permutations, -1 for odd. - // Parity can be computed from the number of cycles: - // sign = (-1)^(n - cycles) let (forward, _inverse) = p.arrays(); let n = forward.len(); @@ -58,6 +61,7 @@ fn faer_perm_sign(p: PermRef<'_, usize>) -> f64 { } } +/// Compute a determinant from a faer partial-pivot LU factorization. fn faer_det_from_partial_piv_lu(lu: &PartialPivLu) -> f64 { // For PA = LU with unit-lower L, det(A) = det(P) * det(U). let u = lu.U(); @@ -68,6 +72,7 @@ fn faer_det_from_partial_piv_lu(lu: &PartialPivLu) -> f64 { det * faer_perm_sign(lu.P()) } +/// Return a deterministic, strictly diagonally-dominant benchmark matrix entry. #[inline] #[allow(clippy::cast_precision_loss)] // D, r, c are small integers, precision loss is not an issue. fn matrix_entry(r: usize, c: usize) -> f64 { @@ -80,6 +85,7 @@ fn matrix_entry(r: usize, c: usize) -> f64 { } } +/// Build the shared matrix rows used by all crates for a dimension. #[inline] fn make_matrix_rows() -> [[f64; D]; D] { let mut rows = [[0.0; D]; D]; @@ -97,12 +103,14 @@ fn make_matrix_rows() -> [[f64; D]; D] { rows } +/// Return a deterministic benchmark vector entry. #[inline] #[allow(clippy::cast_precision_loss)] // i is a small integer, precision loss is not an issue. fn vector_entry(i: usize, offset: f64) -> f64 { (i as f64) + 1.0 + offset } +/// Build the shared vector input used by all crates for a dimension. #[inline] fn make_vector_array(offset: f64) -> [f64; D] { let mut data = [0.0; D]; @@ -116,6 +124,7 @@ fn make_vector_array(offset: f64) -> [f64; D] { data } +/// Compute nalgebra's matrix infinity norm using la-stack's row-sum convention. #[inline] fn nalgebra_inf_norm(m: &nalgebra::SMatrix) -> f64 { // Infinity norm = max absolute row sum. @@ -143,10 +152,22 @@ macro_rules! gen_vs_linalg_benches_for_dim { paste! {{ // Isolate each dimension's inputs to keep types and captures clean. { - let a = Matrix::<$d>::from_rows(make_matrix_rows::<$d>()); - let rhs = Vector::<$d>::new(make_vector_array::<$d>(0.0)); - let v1 = Vector::<$d>::new(make_vector_array::<$d>(0.0)); - let v2 = Vector::<$d>::new(make_vector_array::<$d>(1.0)); + let a = require_ok( + Matrix::<$d>::try_from_rows(make_matrix_rows::<$d>()), + "la_stack matrix construction", + ); + let rhs = require_ok( + Vector::<$d>::try_new(make_vector_array::<$d>(0.0)), + "la_stack RHS vector construction", + ); + let v1 = require_ok( + Vector::<$d>::try_new(make_vector_array::<$d>(0.0)), + "la_stack vector construction", + ); + let v2 = require_ok( + Vector::<$d>::try_new(make_vector_array::<$d>(1.0)), + "la_stack vector construction", + ); let na = nalgebra::SMatrix::::from_fn(|r, c| matrix_entry::<$d>(r, c)); let nrhs = nalgebra::SVector::::from_fn(|i, _| vector_entry(i, 0.0)); diff --git a/docs/BENCHMARKING.md b/docs/BENCHMARKING.md index 30969bd..1c268a0 100644 --- a/docs/BENCHMARKING.md +++ b/docs/BENCHMARKING.md @@ -14,9 +14,14 @@ la-stack has two Criterion benchmark suites: (`det_exact`, `solve_exact`, `det_sign_exact`, etc.) alongside f64 baselines (`det`, `det_direct`) across D=2–5. Use this to understand the cost of exact arithmetic and track optimization progress. - In addition to the per-dimension groups (`exact_d{2..5}`), the suite - includes four adversarial-input groups designed to stress specific - corners of the pipeline: + In addition to the fixed per-dimension groups (`exact_d{2..5}`), the + suite includes random percentile and adversarial-input groups designed + to capture variance and stress specific corners of the pipeline: + + - `exact_random_percentile_d{2..5}` — fixed-seed corpora of 50 + strictly diagonally-dominant random matrices per dimension. Each + operation is pre-timed across the corpus to select representative + p50/p95/p99 inputs, then Criterion measures those inputs normally. - `exact_near_singular_3x3` — a 2^-50 perturbation of a singular base matrix; forces the Bareiss fallback in `det_sign_exact` and exercises the largest intermediate `BigInt` values in `solve_exact`. @@ -25,9 +30,11 @@ la-stack has two Criterion benchmark suites: - `exact_hilbert_4x4` / `exact_hilbert_5x5` — classically ill-conditioned matrices whose non-terminating-in-binary entries stress the `f64_decompose → BigInt` scaling path. - Each adversarial group runs the same four benches (`det_sign_exact`, - `det_exact`, `solve_exact`, `solve_exact_f64`) so the resulting tables - are directly comparable across input classes. + + Each random percentile and adversarial group runs the same four + exact-arithmetic benches (`det_sign_exact`, `det_exact`, `solve_exact`, + `solve_exact_f64`) so the resulting tables are directly comparable + across input classes. ## Quick reference diff --git a/examples/const_det_4x4.rs b/examples/const_det_4x4.rs index f7b1120..7f75983 100644 --- a/examples/const_det_4x4.rs +++ b/examples/const_det_4x4.rs @@ -6,12 +6,15 @@ use la_stack::prelude::*; /// An example 4×4 matrix with small integer entries. -const MAT: Matrix<4> = Matrix::<4>::from_rows([ +const MAT: Matrix<4> = match Matrix::<4>::try_from_rows([ [1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [2.0, 6.0, 1.0, 5.0], [3.0, 8.0, 2.0, 9.0], -]); +]) { + Ok(matrix) => matrix, + Err(_) => panic!("matrix entries must be finite"), +}; /// Determinant computed at compile time. const DET: f64 = match MAT.det_direct() { diff --git a/examples/det_5x5.rs b/examples/det_5x5.rs index 5b82b78..5829d67 100644 --- a/examples/det_5x5.rs +++ b/examples/det_5x5.rs @@ -5,13 +5,13 @@ use la_stack::prelude::*; fn main() -> Result<(), LaError> { // 5×5 matrix with zeros on diagonal and ones elsewhere (J - I). // det(J - I) = (D - 1) * (-1)^(D-1) = 4 for D=5. - let a = Matrix::<5>::from_rows([ + let a = Matrix::<5>::try_from_rows([ [0.0, 1.0, 1.0, 1.0, 1.0], [1.0, 0.0, 1.0, 1.0, 1.0], [1.0, 1.0, 0.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0], [1.0, 1.0, 1.0, 1.0, 0.0], - ]); + ])?; // Compute via explicit LU factorization. let lu = a.lu(DEFAULT_PIVOT_TOL)?; diff --git a/examples/exact_det_3x3.rs b/examples/exact_det_3x3.rs index 6ecadc3..ae1bb96 100644 --- a/examples/exact_det_3x3.rs +++ b/examples/exact_det_3x3.rs @@ -18,11 +18,11 @@ fn main() -> Result<(), LaError> { // Perturb entry (0,0) by 2^-50 ≈ 8.9e-16. // Exact det = 2^-50 × cofactor(0,0) = 2^-50 × (5×9 − 6×8) = −3 × 2^-50. let perturbation = f64::from_bits(0x3CD0_0000_0000_0000); // 2^-50 - let m = Matrix::<3>::from_rows([ + let m = Matrix::<3>::try_from_rows([ [1.0 + perturbation, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], - ]); + ])?; let Some(det_f64_approx) = m.det_direct()? else { unreachable!("D=3 is supported by det_direct"); diff --git a/examples/exact_sign_3x3.rs b/examples/exact_sign_3x3.rs index efe1973..5bac14b 100644 --- a/examples/exact_sign_3x3.rs +++ b/examples/exact_sign_3x3.rs @@ -17,11 +17,11 @@ fn main() -> Result<(), LaError> { // Perturb entry (0,0) by 2^-50 ≈ 8.9e-16. // Exact det = 2^-50 × cofactor(0,0) = 2^-50 × (5×9 − 6×8) = −3 × 2^-50 < 0. let perturbation = f64::from_bits(0x3CD0_0000_0000_0000); // 2^-50 - let m = Matrix::<3>::from_rows([ + let m = Matrix::<3>::try_from_rows([ [1.0 + perturbation, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], - ]); + ])?; let sign = m.det_sign_exact()?; let det_f64 = m.det()?; diff --git a/examples/exact_solve_3x3.rs b/examples/exact_solve_3x3.rs index 671226d..3ac505a 100644 --- a/examples/exact_solve_3x3.rs +++ b/examples/exact_solve_3x3.rs @@ -16,13 +16,13 @@ fn main() -> Result<(), LaError> { // arithmetic progression). Perturbing entry (0,0) by 2^-50 ≈ 8.9e-16 // makes it invertible but extremely ill-conditioned. let perturbation = f64::from_bits(0x3CD0_0000_0000_0000); // 2^-50 - let a = Matrix::<3>::from_rows([ + let a = Matrix::<3>::try_from_rows([ [1.0 + perturbation, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], - ]); + ])?; - let b = Vector::<3>::new([1.0, 2.0, 3.0]); + let b = Vector::<3>::try_new([1.0, 2.0, 3.0])?; // f64 LU solve (using zero pivot tolerance since the matrix is nearly singular // and would be rejected by DEFAULT_PIVOT_TOL). diff --git a/examples/ldlt_solve_3x3.rs b/examples/ldlt_solve_3x3.rs index d365486..2a3b4c7 100644 --- a/examples/ldlt_solve_3x3.rs +++ b/examples/ldlt_solve_3x3.rs @@ -10,10 +10,10 @@ use la_stack::prelude::*; fn main() -> Result<(), LaError> { // Symmetric positive definite 3×3 matrix (classic SPD tridiagonal). - let a = Matrix::<3>::from_rows([[4.0, -1.0, 0.0], [-1.0, 4.0, -1.0], [0.0, -1.0, 4.0]]); + let a = Matrix::<3>::try_from_rows([[4.0, -1.0, 0.0], [-1.0, 4.0, -1.0], [0.0, -1.0, 4.0]])?; // Choose x = [1, 2, 3]. Then b = A x = [2, 4, 10]. - let b = Vector::<3>::new([2.0, 4.0, 10.0]); + let b = Vector::<3>::try_new([2.0, 4.0, 10.0])?; let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL)?; let x = ldlt.solve_vec(b)?.into_array(); diff --git a/examples/solve_5x5.rs b/examples/solve_5x5.rs index 8515146..29870f4 100644 --- a/examples/solve_5x5.rs +++ b/examples/solve_5x5.rs @@ -5,16 +5,16 @@ use la_stack::prelude::*; fn main() -> Result<(), LaError> { // This system requires pivoting (a[0][0] = 0), so it's a good LU demo. // A = J - I: zeros on diagonal, ones elsewhere. - let a = Matrix::<5>::from_rows([ + let a = Matrix::<5>::try_from_rows([ [0.0, 1.0, 1.0, 1.0, 1.0], [1.0, 0.0, 1.0, 1.0, 1.0], [1.0, 1.0, 0.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0], [1.0, 1.0, 1.0, 1.0, 0.0], - ]); + ])?; // Choose x = [1, 2, 3, 4, 5]. Then b = A x = [14, 13, 12, 11, 10]. - let b = Vector::<5>::new([14.0, 13.0, 12.0, 11.0, 10.0]); + let b = Vector::<5>::try_new([14.0, 13.0, 12.0, 11.0, 10.0])?; let lu = a.lu(DEFAULT_PIVOT_TOL)?; let x = lu.solve_vec(b)?.into_array(); diff --git a/scripts/bench_compare.py b/scripts/bench_compare.py index e1e6275..1fd4a28 100644 --- a/scripts/bench_compare.py +++ b/scripts/bench_compare.py @@ -41,16 +41,21 @@ # Groups and the benchmarks within each group that we track. # # Mirrors the structure of `benches/exact.rs`: general-case per-dimension -# groups (`exact_d{2..5}`) plus adversarial/extreme-input groups that -# share a fixed four-bench layout (`det_sign_exact`, `det_exact`, -# `solve_exact`, `solve_exact_f64`). +# groups (`exact_d{2..5}`), fixed-seed random percentile groups, plus +# adversarial/extreme-input groups that share a fixed four-bench layout +# (`det_sign_exact`, `det_exact`, `solve_exact`, `solve_exact_f64`). _EXTREME_BENCHES: list[str] = ["det_sign_exact", "det_exact", "solve_exact", "solve_exact_f64"] +_RANDOM_PERCENTILE_BENCHES: list[str] = [f"{operation}_{percentile}" for operation in _EXTREME_BENCHES for percentile in ("p50", "p95", "p99")] EXACT_GROUPS: dict[str, list[str]] = { "exact_d2": ["det", "det_direct", "det_exact", "det_exact_f64", "det_sign_exact", "solve_exact", "solve_exact_f64"], "exact_d3": ["det", "det_direct", "det_exact", "det_exact_f64", "det_sign_exact", "solve_exact", "solve_exact_f64"], "exact_d4": ["det", "det_direct", "det_exact", "det_exact_f64", "det_sign_exact", "solve_exact", "solve_exact_f64"], "exact_d5": ["det", "det_direct", "det_exact", "det_exact_f64", "det_sign_exact", "solve_exact", "solve_exact_f64"], + "exact_random_percentile_d2": _RANDOM_PERCENTILE_BENCHES, + "exact_random_percentile_d3": _RANDOM_PERCENTILE_BENCHES, + "exact_random_percentile_d4": _RANDOM_PERCENTILE_BENCHES, + "exact_random_percentile_d5": _RANDOM_PERCENTILE_BENCHES, "exact_near_singular_3x3": _EXTREME_BENCHES, "exact_large_entries_3x3": _EXTREME_BENCHES, "exact_hilbert_4x4": _EXTREME_BENCHES, @@ -210,8 +215,11 @@ def _group_by_group[T: _GroupedItem](items: list[T]) -> dict[str, list[T]]: def _group_heading(group: str) -> str: """Turn a Criterion group name into a readable heading.""" - # exact_d3 -> "D=3", exact_near_singular_3x3 -> "Near-singular 3x3", - # exact_hilbert_4x4 -> "Hilbert 4x4", etc. + # exact_d3 -> "D=3", exact_random_percentile_d3 -> + # "Random percentile D=3", exact_near_singular_3x3 -> + # "Near-singular 3x3", exact_hilbert_4x4 -> "Hilbert 4x4", etc. + if group.startswith("exact_random_percentile_d"): + return f"Random percentile D={group.removeprefix('exact_random_percentile_d')}" if group.startswith("exact_d"): return f"D={group.removeprefix('exact_d')}" if group == "exact_near_singular_3x3": diff --git a/scripts/tests/test_bench_compare.py b/scripts/tests/test_bench_compare.py index f4a9944..eb638b7 100644 --- a/scripts/tests/test_bench_compare.py +++ b/scripts/tests/test_bench_compare.py @@ -37,6 +37,9 @@ def _build_criterion_tree(criterion_dir: Path, stat: str = "median") -> None: ns_group = criterion_dir / "exact_near_singular_3x3" _write_estimates(ns_group / "det_sign_exact" / "new" / "estimates.json", stat, 12000.0) + random_group = criterion_dir / "exact_random_percentile_d3" + _write_estimates(random_group / "det_exact_p95" / "new" / "estimates.json", stat, 33000.0) + # --------------------------------------------------------------------------- # Unit tests for formatting helpers @@ -77,6 +80,9 @@ def test_dimension_group(self) -> None: def test_near_singular(self) -> None: assert bench_compare._group_heading("exact_near_singular_3x3") == "Near-singular 3x3" + def test_random_percentile(self) -> None: + assert bench_compare._group_heading("exact_random_percentile_d4") == "Random percentile D=4" + def test_large_entries(self) -> None: assert bench_compare._group_heading("exact_large_entries_3x3") == "Large entries 3x3" @@ -133,10 +139,11 @@ def test_read_estimate_missing_stat(tmp_path: Path) -> None: def test_collect_results(tmp_path: Path) -> None: _build_criterion_tree(tmp_path) results = bench_compare._collect_results(tmp_path, "new", "median") - assert len(results) == 5 # 2 benches x 2 dims + 1 near-singular + assert len(results) == 6 # 2 benches x 2 dims + 1 near-singular + 1 random percentile groups = {r.group for r in results} assert "exact_d2" in groups assert "exact_d3" in groups + assert "exact_random_percentile_d3" in groups assert "exact_near_singular_3x3" in groups @@ -191,6 +198,7 @@ def test_snapshot_tables_per_dimension(tmp_path: Path) -> None: tables = bench_compare._snapshot_tables(results, "median") assert "### D=2" in tables assert "### D=3" in tables + assert "### Random percentile D=3" in tables assert "### Near-singular 3x3" in tables assert "| Benchmark | Median | 95% CI |" in tables @@ -233,6 +241,7 @@ def test_main_snapshot_writes_output(tmp_path: Path) -> None: text = output.read_text(encoding="utf-8") assert "### D=2" in text + assert "### Random percentile D=3" in text assert "### Near-singular 3x3" in text assert "just bench-compare" in text diff --git a/semgrep.yaml b/semgrep.yaml index ad26f0c..26de90d 100644 --- a/semgrep.yaml +++ b/semgrep.yaml @@ -58,6 +58,42 @@ rules: - pattern: $VALUE.unwrap_or_else(|| f64::INFINITY) - pattern: $VALUE.unwrap_or_else(|| f64::NEG_INFINITY) + - id: la-stack.rust.no-public-infallible-raw-f64-constructors + languages: + - rust + severity: WARNING + message: "Raw f64 Matrix/Vector constructors must be fallible public APIs; keep infallible literal helpers crate-private." + metadata: + category: correctness + rationale: >- + Matrix and Vector store only finite values. Public raw constructors must + return Result so callers receive LaError::NonFinite instead of a panic; + infallible construction is reserved for crate-private validated/literal + paths. + paths: + include: + - "/src/**/*.rs" + - "/tests/semgrep/src/project_rules/raw_f64_constructors.rs" + pattern-regex: '(?m)^\s*pub\s+(?:const\s+)?fn\s+(?:new|from_rows)\s*\([^)]*(?:\[\s*f64\s*;\s*D\s*\]|\[\s*\[\s*f64\s*;\s*D\s*\]\s*;\s*D\s*\])[^)]*\)\s*->\s*(?:Self|(?:Matrix|Vector)\s*<)' + + - id: la-stack.rust.no-public-api-panic-paths + languages: + - regex + severity: WARNING + message: "Public APIs should expose fallibility with Result/Option instead of panic/assert/unwrap paths." + metadata: + category: correctness + rationale: >- + Public functions returning plain values should be genuinely infallible + for all representable inputs. Caller-visible failure belongs in + Result/Option; panic-only paths make recoverable conditions look + infallible. + paths: + include: + - "/src/**/*.rs" + - "/tests/semgrep/src/project_rules/public_api_panic_paths.rs" + pattern-regex: '(?ms)^\s*pub\s+(?:const\s+|async\s+|unsafe\s+)*fn\s+[A-Za-z_][A-Za-z0-9_]*[^;{]*\{(?:(?!^\s*\}).|\n){0,1000}(?:panic!|assert!|debug_assert!|unreachable!|\.unwrap\s*\(|\.expect\s*\()' + - id: la-stack.rust.public-error-enums-non-exhaustive languages: - rust diff --git a/src/exact.rs b/src/exact.rs index 8feaa85..a1f9d87 100644 --- a/src/exact.rs +++ b/src/exact.rs @@ -85,20 +85,22 @@ use crate::vector::{FiniteVector, Vector}; /// `(-1)^is_negative × mantissa × 2^exponent` and `mantissa` is odd (trailing /// zeros stripped). See `REFERENCES.md` \[9-10\]. /// -/// # Panics -/// Panics if `x` is NaN or infinite. -fn f64_decompose(x: f64) -> Option<(NonZeroU64, i32, bool)> { +/// # Errors +/// Returns [`LaError::NonFinite`] if `x` is NaN or infinite. +const fn f64_decompose(x: f64) -> Result, LaError> { let bits = x.to_bits(); let biased_exp = ((bits >> 52) & 0x7FF) as i32; let fraction = bits & 0x000F_FFFF_FFFF_FFFF; // ±0.0 if biased_exp == 0 && fraction == 0 { - return None; + return Ok(None); } - // NaN / Inf — callers must validate finiteness before reaching here. - assert!(biased_exp != 0x7FF, "non-finite f64 in exact conversion"); + if biased_exp == 0x7FF { + cold_path(); + return Err(LaError::non_finite_at(0)); + } let (mantissa, raw_exp) = if biased_exp == 0 { // Subnormal: (-1)^s × 0.fraction × 2^(-1022) @@ -113,11 +115,14 @@ fn f64_decompose(x: f64) -> Option<(NonZeroU64, i32, bool)> { // Strip trailing zeros so the mantissa is odd. let tz = mantissa.trailing_zeros(); let mantissa = mantissa >> tz; - let mantissa = NonZeroU64::new(mantissa)?; + let Some(mantissa) = NonZeroU64::new(mantissa) else { + cold_path(); + return Ok(None); + }; let exponent = raw_exp + tz.cast_signed(); let is_negative = bits >> 63 != 0; - Some((mantissa, exponent, is_negative)) + Ok(Some((mantissa, exponent, is_negative))) } /// Convert a `BigInt × 2^exp` pair to a reduced `BigRational`. @@ -134,15 +139,22 @@ fn bigint_exp_to_bigrational(mut value: BigInt, mut exp: i32) -> BigRational { if exp < 0 && let Some(tz) = value.trailing_zeros() { - let reduce = tz.min(u64::from((-exp).cast_unsigned())); + let exp_abs = exp.unsigned_abs(); + let reduce = tz.min(u64::from(exp_abs)); value >>= reduce; - exp += i32::try_from(reduce).expect("reduce ≤ -exp which fits in i32"); + let reduce = u32::try_from(reduce).unwrap_or(u32::MAX); + let remaining_abs = exp_abs - reduce; + exp = match remaining_abs { + 0 => 0, + 2_147_483_648 => i32::MIN, + value => -value.cast_signed(), + }; } if exp >= 0 { BigRational::new_raw(value << exp.cast_unsigned(), BigInt::from(1u32)) } else { - BigRational::new_raw(value, BigInt::from(1u32) << (-exp).cast_unsigned()) + BigRational::new_raw(value, BigInt::from(1u32) << exp.unsigned_abs()) } } @@ -161,32 +173,37 @@ fn bigint_exp_to_bigrational(mut value: BigInt, mut exp: i32) -> BigRational { /// Decomposed finite f64 in the form `(-1)^is_negative · mantissa · 2^exponent`. /// -/// Zero entries have `mantissa == None`; the other fields are unused in that -/// case. `Default` yields such a zero component, which is what the per-entry -/// initialiser in `decompose_finite_matrix` / `decompose_finite_vec` produces -/// for ±0.0 cells. Non-zero entries carry a [`NonZeroU64`] mantissa, so the -/// exact-arithmetic paths cannot accidentally store a raw zero sentinel after -/// decomposition. +/// `Zero` represents ±0.0. Non-zero entries carry a [`NonZeroU64`] mantissa, so +/// the exact-arithmetic paths cannot accidentally combine an absent mantissa +/// with active exponent/sign fields after decomposition. #[derive(Clone, Copy, Default)] -struct Component { - mantissa: Option, - exponent: i32, - is_negative: bool, +enum Component { + #[default] + Zero, + NonZero { + mantissa: NonZeroU64, + exponent: i32, + is_negative: bool, + }, } /// Decompose every entry of a finite `D×D` matrix via `f64_decompose`. /// /// Returns the per-entry components and the minimum exponent across non-zero /// entries. If every entry is zero, the exponent is `i32::MAX`. -fn decompose_finite_matrix(m: &FiniteMatrix) -> ([[Component; D]; D], i32) { +fn decompose_finite_matrix( + m: &FiniteMatrix, +) -> Result<([[Component; D]; D], i32), LaError> { let mut components = [[Component::default(); D]; D]; let mut e_min = i32::MAX; let matrix = m.as_matrix(); for (r, row) in matrix.rows.iter().enumerate() { for (c, &entry) in row.iter().enumerate() { - if let Some((mantissa, exponent, is_negative)) = f64_decompose(entry) { - components[r][c] = Component { - mantissa: Some(mantissa), + if let Some((mantissa, exponent, is_negative)) = + f64_decompose(entry).map_err(|_| LaError::non_finite_cell(r, c))? + { + components[r][c] = Component::NonZero { + mantissa, exponent, is_negative, }; @@ -194,43 +211,49 @@ fn decompose_finite_matrix(m: &FiniteMatrix) -> ([[Component; } } } - (components, e_min) + Ok((components, e_min)) } /// Decompose every entry of a finite length-`D` vector via `f64_decompose`. /// /// Returns the per-entry components and the minimum exponent across non-zero /// entries. If every entry is zero, the exponent is `i32::MAX`. -fn decompose_finite_vec(v: &FiniteVector) -> ([Component; D], i32) { +fn decompose_finite_vec( + v: &FiniteVector, +) -> Result<([Component; D], i32), LaError> { let mut components = [Component::default(); D]; let mut e_min = i32::MAX; let data = v.as_array(); for (i, &entry) in data.iter().enumerate() { - if let Some((mantissa, exponent, is_negative)) = f64_decompose(entry) { - components[i] = Component { - mantissa: Some(mantissa), + if let Some((mantissa, exponent, is_negative)) = + f64_decompose(entry).map_err(|_| LaError::non_finite_at(i))? + { + components[i] = Component::NonZero { + mantissa, exponent, is_negative, }; e_min = e_min.min(exponent); } } - (components, e_min) + Ok((components, e_min)) } /// Convert a single decomposed component to its scaled `BigInt` -/// representation: `(±mantissa) << (exp − e_min)`. Zero components map -/// to `BigInt::from(0)` through the `None` case; non-zero components reuse -/// their carried [`NonZeroU64`] proof without revalidating the mantissa. +/// representation: `(±mantissa) << (exp − e_min)`. #[inline] fn component_to_bigint(c: Component, e_min: i32) -> BigInt { - c.mantissa.map_or_else( - || BigInt::from(0), - |mantissa| { - let v = BigInt::from(mantissa.get()) << (c.exponent - e_min).cast_unsigned(); - if c.is_negative { -v } else { v } - }, - ) + match c { + Component::Zero => BigInt::from(0), + Component::NonZero { + mantissa, + exponent, + is_negative, + } => { + let v = BigInt::from(mantissa.get()) << (exponent - e_min).cast_unsigned(); + if is_negative { -v } else { v } + } + } } /// Build a `D×D` integer matrix from a component table, scaled to the @@ -341,18 +364,18 @@ fn bareiss_forward_eliminate( /// scaled to a common base `2^e_min` so every entry becomes an integer. /// The Bareiss inner-loop division is exact (guaranteed by the algorithm). /// -fn bareiss_det_int_finite(m: &FiniteMatrix) -> (BigInt, i32) { +fn bareiss_det_int_finite(m: &FiniteMatrix) -> Result<(BigInt, i32), LaError> { // D == 0 has no `a[D-1][D-1]` to read; shortcut to the empty-product // determinant. if D == 0 { - return (BigInt::from(1), 0); + return Ok((BigInt::from(1), 0)); } - let (components, e_min) = decompose_finite_matrix(m); + let (components, e_min) = decompose_finite_matrix(m)?; // All entries are zero → singular (det = 0). if e_min == i32::MAX { - return (BigInt::from(0), 0); + return Ok((BigInt::from(0), 0)); } let mut a = build_bigint_matrix(&components, e_min); @@ -360,7 +383,7 @@ fn bareiss_det_int_finite(m: &FiniteMatrix) -> (BigInt, i32) BareissResult::Upper { sign } => sign, BareissResult::Singular { .. } => { cold_path(); - return (BigInt::from(0), 0); + return Ok((BigInt::from(0), 0)); } }; @@ -371,19 +394,23 @@ fn bareiss_det_int_finite(m: &FiniteMatrix) -> (BigInt, i32) }; // det(original) = det_int × 2^(D × e_min) - let d_i32 = i32::try_from(D).expect("dimension exceeds i32"); - let total_exp = e_min - .checked_mul(d_i32) - .expect("exponent overflow in bareiss_det_int"); + let Ok(d_i32) = i32::try_from(D) else { + cold_path(); + return Err(LaError::unsupported_dimension(D, i32::MAX as usize)); + }; + let Some(total_exp) = e_min.checked_mul(d_i32) else { + cold_path(); + return Err(LaError::Overflow { index: None }); + }; - (det_int, total_exp) + Ok((det_int, total_exp)) } /// Compute the exact determinant of a `D×D` matrix using integer-only Bareiss /// elimination and return the result as a `BigRational`. -fn bareiss_det_finite(m: &FiniteMatrix) -> BigRational { - let (det_int, total_exp) = bareiss_det_int_finite(m); - bigint_exp_to_bigrational(det_int, total_exp) +fn bareiss_det_finite(m: &FiniteMatrix) -> Result { + let (det_int, total_exp) = bareiss_det_int_finite(m)?; + Ok(bigint_exp_to_bigrational(det_int, total_exp)) } /// Solve `A x = b` exactly after matrix and RHS finiteness has been proven. @@ -398,8 +425,8 @@ fn gauss_solve_finite( m: &FiniteMatrix, b: &FiniteVector, ) -> Result<[BigRational; D], LaError> { - let (m_components, m_e_min) = decompose_finite_matrix(m); - let (b_components, b_e_min) = decompose_finite_vec(b); + let (m_components, m_e_min) = decompose_finite_matrix(m)?; + let (b_components, b_e_min) = decompose_finite_vec(b)?; gauss_solve_components(m_components, m_e_min, b_components, b_e_min) } @@ -455,7 +482,7 @@ fn gauss_solve_components( impl FiniteMatrix { /// Exact determinant for an already finite matrix. #[inline] - fn det_exact(&self) -> BigRational { + fn det_exact(&self) -> Result { bareiss_det_finite(self) } @@ -466,7 +493,7 @@ impl FiniteMatrix { /// represented as a finite `f64`. #[inline] fn det_exact_f64(&self) -> Result { - let exact = self.det_exact(); + let exact = self.det_exact()?; let Some(val) = exact.to_f64() else { cold_path(); return Err(LaError::Overflow { index: None }); @@ -509,7 +536,7 @@ impl FiniteMatrix { } result[i] = f; } - Ok(FiniteVector::new_unchecked(Vector::new(result))) + Ok(FiniteVector::new_unchecked(Vector::new_unchecked(result))) } /// Exact determinant sign for an already finite matrix. @@ -539,7 +566,7 @@ impl FiniteMatrix { } cold_path(); - let (det_int, _) = bareiss_det_int_finite(self); + let (det_int, _) = bareiss_det_int_finite(self)?; Ok(match det_int.sign() { Sign::Plus => 1, Sign::Minus => -1, @@ -569,7 +596,7 @@ impl Matrix { /// use la_stack::prelude::*; /// /// # fn main() -> Result<(), LaError> { - /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); + /// let m = Matrix::<2>::try_from_rows([[1.0, 2.0], [3.0, 4.0]])?; /// let det = m.det_exact()?; /// // det = 1*4 - 2*3 = -2 (exact) /// assert_eq!(det, BigRational::from_integer((-2).into())); @@ -578,10 +605,14 @@ impl Matrix { /// ``` /// /// # Errors - /// Returns [`LaError::NonFinite`] if any matrix entry is NaN or infinite. + /// Returns [`LaError::Overflow`] if determinant scaling overflows the internal + /// exponent representation. + /// + /// Returns [`LaError::UnsupportedDimension`] if `D` cannot be represented in + /// the internal determinant exponent calculation. #[inline] pub fn det_exact(&self) -> Result { - Ok(FiniteMatrix::try_new(*self)?.det_exact()) + FiniteMatrix::new(*self).det_exact() } /// Exact determinant converted to `f64`. @@ -598,7 +629,7 @@ impl Matrix { /// use la_stack::prelude::*; /// /// # fn main() -> Result<(), LaError> { - /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); + /// let m = Matrix::<2>::try_from_rows([[1.0, 2.0], [3.0, 4.0]])?; /// let det = m.det_exact_f64()?; /// assert!((det - (-2.0)).abs() <= f64::EPSILON); /// # Ok(()) @@ -606,12 +637,12 @@ impl Matrix { /// ``` /// /// # Errors - /// Returns [`LaError::NonFinite`] if any matrix entry is NaN or infinite. - /// Returns [`LaError::Overflow`] if the exact determinant is too large to + /// Returns [`LaError::Overflow`] if determinant scaling overflows the internal + /// exponent representation or if the exact determinant is too large to /// represent as a finite `f64`. #[inline] pub fn det_exact_f64(&self) -> Result { - FiniteMatrix::try_new(*self)?.det_exact_f64() + FiniteMatrix::new(*self).det_exact_f64() } /// Exact linear system solve using hybrid integer/rational arithmetic. @@ -647,8 +678,8 @@ impl Matrix { /// /// # fn main() -> Result<(), LaError> { /// // A x = b where A = [[1,2],[3,4]], b = [5, 11] → x = [1, 2] - /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); - /// let b = Vector::<2>::new([5.0, 11.0]); + /// let a = Matrix::<2>::try_from_rows([[1.0, 2.0], [3.0, 4.0]])?; + /// let b = Vector::<2>::try_new([5.0, 11.0])?; /// let x = a.solve_exact(b)?; /// assert_eq!(x[0], BigRational::from_integer(1.into())); /// assert_eq!(x[1], BigRational::from_integer(2.into())); @@ -657,13 +688,11 @@ impl Matrix { /// ``` /// /// # Errors - /// Returns [`LaError::NonFinite`] if any matrix or vector entry is NaN or - /// infinite. /// Returns [`LaError::Singular`] if the matrix is exactly singular. #[inline] pub fn solve_exact(&self, b: Vector) -> Result<[BigRational; D], LaError> { - let finite_m = FiniteMatrix::try_new(*self)?; - let finite_b = FiniteVector::try_new(b)?; + let finite_m = FiniteMatrix::new(*self); + let finite_b = FiniteVector::new(b); finite_m.solve_exact(finite_b) } @@ -681,8 +710,8 @@ impl Matrix { /// use la_stack::prelude::*; /// /// # fn main() -> Result<(), LaError> { - /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); - /// let b = Vector::<2>::new([5.0, 11.0]); + /// let a = Matrix::<2>::try_from_rows([[1.0, 2.0], [3.0, 4.0]])?; + /// let b = Vector::<2>::try_new([5.0, 11.0])?; /// let x = a.solve_exact_f64(b)?.into_array(); /// assert!((x[0] - 1.0).abs() <= f64::EPSILON); /// assert!((x[1] - 2.0).abs() <= f64::EPSILON); @@ -691,15 +720,13 @@ impl Matrix { /// ``` /// /// # Errors - /// Returns [`LaError::NonFinite`] if any matrix or vector entry is NaN or - /// infinite. /// Returns [`LaError::Singular`] if the matrix is exactly singular. /// Returns [`LaError::Overflow`] if any component of the exact solution is /// too large to represent as a finite `f64`. #[inline] pub fn solve_exact_f64(&self, b: Vector) -> Result, LaError> { - let finite_m = FiniteMatrix::try_new(*self)?; - let finite_b = FiniteVector::try_new(b)?; + let finite_m = FiniteMatrix::new(*self); + let finite_b = FiniteVector::new(b); Ok(finite_m.solve_exact_f64(finite_b)?.into_vector()) } @@ -727,11 +754,11 @@ impl Matrix { /// ``` /// use la_stack::prelude::*; /// - /// let m = Matrix::<3>::from_rows([ + /// let m = Matrix::<3>::try_from_rows([ /// [1.0, 2.0, 3.0], /// [4.0, 5.0, 6.0], /// [7.0, 8.0, 9.0], - /// ]); + /// ])?; /// // This matrix is singular (row 3 = row 1 + row 2 in exact arithmetic). /// assert_eq!(m.det_sign_exact()?, 0); /// @@ -740,10 +767,11 @@ impl Matrix { /// ``` /// /// # Errors - /// Returns [`LaError::NonFinite`] if any matrix entry is NaN or infinite. + /// This exact sign path has no additional runtime errors for finite + /// matrices. #[inline] pub fn det_sign_exact(&self) -> Result { - FiniteMatrix::try_new(*self)?.det_sign_exact() + FiniteMatrix::new(*self).det_sign_exact() } } @@ -777,7 +805,9 @@ mod tests { /// # Panics /// Panics if `x` is NaN or infinite. fn f64_to_bigrational(x: f64) -> BigRational { - let Some((mantissa, exponent, is_negative)) = f64_decompose(x) else { + let Some((mantissa, exponent, is_negative)) = + f64_decompose(x).expect("test helper requires finite f64 input") + else { return BigRational::from_integer(BigInt::from(0)); }; @@ -803,10 +833,13 @@ mod tests { paste! { #[test] fn []() { - let a = FiniteMatrix::<$d>::try_new(Matrix::<$d>::identity()).unwrap(); - let b = FiniteVector::<$d>::try_new(Vector::<$d>::new([1.0; $d])).unwrap(); + let a = FiniteMatrix::<$d>::new(Matrix::<$d>::identity()); + let b = FiniteVector::<$d>::new(Vector::<$d>::new([1.0; $d])); - assert_eq!(a.det_exact(), BigRational::from_integer(BigInt::from(1))); + assert_eq!( + a.det_exact().unwrap(), + BigRational::from_integer(BigInt::from(1)) + ); assert!((a.det_exact_f64().unwrap() - 1.0).abs() <= f64::EPSILON); assert_eq!(a.det_sign_exact().unwrap(), 1); @@ -841,15 +874,13 @@ mod tests { #[test] fn []() { let mut m = Matrix::<$d>::identity(); - assert_eq!(m.set(0, 0, f64::NAN), Some(())); - assert_eq!(m.det_exact(), Err(LaError::NonFinite { row: Some(0), col: 0 })); + assert_eq!(m.set(0, 0, f64::NAN), Err(LaError::NonFinite { row: Some(0), col: 0 })); } #[test] fn []() { let mut m = Matrix::<$d>::identity(); - assert_eq!(m.set(0, 0, f64::INFINITY), Some(())); - assert_eq!(m.det_exact(), Err(LaError::NonFinite { row: Some(0), col: 0 })); + assert_eq!(m.set(0, 0, f64::INFINITY), Err(LaError::NonFinite { row: Some(0), col: 0 })); } } }; @@ -872,8 +903,7 @@ mod tests { #[test] fn []() { let mut m = Matrix::<$d>::identity(); - assert_eq!(m.set(0, 0, f64::NAN), Some(())); - assert_eq!(m.det_exact_f64(), Err(LaError::NonFinite { row: Some(0), col: 0 })); + assert_eq!(m.set(0, 0, f64::NAN), Err(LaError::NonFinite { row: Some(0), col: 0 })); } } }; @@ -1057,9 +1087,8 @@ mod tests { #[test] fn det_sign_exact_returns_err_on_nan() { - let m = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]); assert_eq!( - m.det_sign_exact(), + Matrix::<2>::try_from_rows([[f64::NAN, 0.0], [0.0, 1.0]]), Err(LaError::NonFinite { row: Some(0), col: 0 @@ -1069,9 +1098,8 @@ mod tests { #[test] fn det_sign_exact_returns_err_on_infinity() { - let m = Matrix::<2>::from_rows([[f64::INFINITY, 0.0], [0.0, 1.0]]); assert_eq!( - m.det_sign_exact(), + Matrix::<2>::try_from_rows([[f64::INFINITY, 0.0], [0.0, 1.0]]), Err(LaError::NonFinite { row: Some(0), col: 0 @@ -1083,9 +1111,8 @@ mod tests { fn det_sign_exact_returns_err_on_nan_5x5() { // D ≥ 5 bypasses the fast filter, exercising the bareiss_det path. let mut m = Matrix::<5>::identity(); - assert_eq!(m.set(2, 3, f64::NAN), Some(())); assert_eq!( - m.det_sign_exact(), + m.set(2, 3, f64::NAN), Err(LaError::NonFinite { row: Some(2), col: 3 @@ -1096,9 +1123,8 @@ mod tests { #[test] fn det_sign_exact_returns_err_on_infinity_5x5() { let mut m = Matrix::<5>::identity(); - assert_eq!(m.set(0, 0, f64::INFINITY), Some(())); assert_eq!( - m.det_sign_exact(), + m.set(0, 0, f64::INFINITY), Err(LaError::NonFinite { row: Some(0), col: 0 @@ -1206,13 +1232,13 @@ mod tests { #[test] fn f64_decompose_zero() { - assert!(f64_decompose(0.0).is_none()); - assert!(f64_decompose(-0.0).is_none()); + assert!(f64_decompose(0.0).unwrap().is_none()); + assert!(f64_decompose(-0.0).unwrap().is_none()); } #[test] fn f64_decompose_one() { - let (mant, exp, neg) = f64_decompose(1.0).unwrap(); + let (mant, exp, neg) = f64_decompose(1.0).unwrap().unwrap(); assert_eq!(mant.get(), 1); assert_eq!(exp, 0); assert!(!neg); @@ -1220,7 +1246,7 @@ mod tests { #[test] fn f64_decompose_negative() { - let (mant, exp, neg) = f64_decompose(-3.5).unwrap(); + let (mant, exp, neg) = f64_decompose(-3.5).unwrap().unwrap(); // -3.5 = -7 × 2^(-1), mantissa is 7 (odd after stripping) assert_eq!(mant.get(), 7); assert_eq!(exp, -1); @@ -1231,7 +1257,7 @@ mod tests { fn f64_decompose_subnormal() { let tiny = 5e-324_f64; assert!(tiny.is_subnormal()); - let (mant, exp, neg) = f64_decompose(tiny).unwrap(); + let (mant, exp, neg) = f64_decompose(tiny).unwrap().unwrap(); assert_eq!(mant.get(), 1); assert_eq!(exp, -1074); assert!(!neg); @@ -1239,16 +1265,18 @@ mod tests { #[test] fn f64_decompose_power_of_two() { - let (mant, exp, neg) = f64_decompose(1024.0).unwrap(); + let (mant, exp, neg) = f64_decompose(1024.0).unwrap().unwrap(); assert_eq!(mant.get(), 1); assert_eq!(exp, 10); // 1024 = 2^10 assert!(!neg); } #[test] - #[should_panic(expected = "non-finite f64 in exact conversion")] - fn f64_decompose_panics_on_nan() { - f64_decompose(f64::NAN); + fn f64_decompose_rejects_nan() { + assert_eq!( + f64_decompose(f64::NAN), + Err(LaError::NonFinite { row: None, col: 0 }) + ); } #[test] @@ -1258,15 +1286,15 @@ mod tests { BigInt::from(0) ); - let positive = Component { - mantissa: NonZeroU64::new(3), + let positive = Component::NonZero { + mantissa: NonZeroU64::new(3).unwrap(), exponent: 4, is_negative: false, }; assert_eq!(component_to_bigint(positive, 1), BigInt::from(24)); - let negative = Component { - mantissa: NonZeroU64::new(5), + let negative = Component::NonZero { + mantissa: NonZeroU64::new(5).unwrap(), exponent: 3, is_negative: true, }; @@ -1279,8 +1307,8 @@ mod tests { #[test] fn bareiss_det_int_d0() { - let m = FiniteMatrix::try_new(Matrix::<0>::zero()).unwrap(); - let (det, exp) = bareiss_det_int_finite(&m); + let m = FiniteMatrix::new(Matrix::<0>::zero()); + let (det, exp) = bareiss_det_int_finite(&m).unwrap(); assert_eq!(det, BigInt::from(1)); assert_eq!(exp, 0); } @@ -1300,8 +1328,8 @@ mod tests { (0.5, 1, -1), // 0.5 = 1 × 2^(-1) ]; for &(input, expected_det_int, expected_exp) in cases { - let m = FiniteMatrix::try_new(Matrix::<1>::from_rows([[input]])).unwrap(); - let (det, exp) = bareiss_det_int_finite(&m); + let m = FiniteMatrix::new(Matrix::<1>::from_rows([[input]])); + let (det, exp) = bareiss_det_int_finite(&m).unwrap(); assert_eq!( det, BigInt::from(expected_det_int), @@ -1314,8 +1342,8 @@ mod tests { #[test] fn bareiss_det_int_d2_known() { // det([[1,2],[3,4]]) = -2 - let m = FiniteMatrix::try_new(Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]])).unwrap(); - let (det_int, total_exp) = bareiss_det_int_finite(&m); + let m = FiniteMatrix::new(Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]])); + let (det_int, total_exp) = bareiss_det_int_finite(&m).unwrap(); // Reconstruct and verify. let det = bigint_exp_to_bigrational(det_int, total_exp); assert_eq!(det, BigRational::from_integer(BigInt::from(-2))); @@ -1323,21 +1351,20 @@ mod tests { #[test] fn bareiss_det_int_all_zeros() { - let m = FiniteMatrix::try_new(Matrix::<3>::zero()).unwrap(); - let (det, _) = bareiss_det_int_finite(&m); + let m = FiniteMatrix::new(Matrix::<3>::zero()); + let (det, _) = bareiss_det_int_finite(&m).unwrap(); assert_eq!(det, BigInt::from(0)); } #[test] fn bareiss_det_int_sign_matches_det_sign_exact() { // The sign of det_int should match det_sign_exact for various matrices. - let m = FiniteMatrix::try_new(Matrix::<3>::from_rows([ + let m = FiniteMatrix::new(Matrix::<3>::from_rows([ [0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0], - ])) - .unwrap(); - let (det_int, _) = bareiss_det_int_finite(&m); + ])); + let (det_int, _) = bareiss_det_int_finite(&m).unwrap(); assert_eq!(det_int.sign(), Sign::Minus); // det = -1 } @@ -1345,8 +1372,8 @@ mod tests { fn bareiss_det_int_fractional_entries() { // Entries with negative exponents: 0.5 = 1×2^(-1), 0.25 = 1×2^(-2). // det([[0.5, 0.25], [1.0, 1.0]]) = 0.5×1.0 − 0.25×1.0 = 0.25 - let m = FiniteMatrix::try_new(Matrix::<2>::from_rows([[0.5, 0.25], [1.0, 1.0]])).unwrap(); - let (det_int, total_exp) = bareiss_det_int_finite(&m); + let m = FiniteMatrix::new(Matrix::<2>::from_rows([[0.5, 0.25], [1.0, 1.0]])); + let (det_int, total_exp) = bareiss_det_int_finite(&m).unwrap(); let det = bigint_exp_to_bigrational(det_int, total_exp); assert_eq!(det, BigRational::new(BigInt::from(1), BigInt::from(4))); } @@ -1354,13 +1381,12 @@ mod tests { #[test] fn bareiss_det_int_d3_with_pivoting() { // Zero on diagonal → exercises pivot swap inside bareiss_det_int. - let m = FiniteMatrix::try_new(Matrix::<3>::from_rows([ + let m = FiniteMatrix::new(Matrix::<3>::from_rows([ [0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0], - ])) - .unwrap(); - let (det_int, total_exp) = bareiss_det_int_finite(&m); + ])); + let (det_int, total_exp) = bareiss_det_int_finite(&m).unwrap(); let det = bigint_exp_to_bigrational(det_int, total_exp); assert_eq!(det, BigRational::from_integer(BigInt::from(-1))); } @@ -1370,9 +1396,8 @@ mod tests { #[test] fn bareiss_det_int_errs_on_nan() { let mut m = Matrix::<3>::identity(); - assert_eq!(m.set(1, 2, f64::NAN), Some(())); assert_eq!( - m.det_exact(), + m.set(1, 2, f64::NAN), Err(LaError::NonFinite { row: Some(1), col: 2 @@ -1383,9 +1408,8 @@ mod tests { #[test] fn bareiss_det_int_errs_on_inf() { let mut m = Matrix::<2>::identity(); - assert_eq!(m.set(0, 0, f64::INFINITY), Some(())); assert_eq!( - m.det_exact(), + m.set(0, 0, f64::INFINITY), Err(LaError::NonFinite { row: Some(0), col: 0 @@ -1399,8 +1423,8 @@ mod tests { paste! { #[test] fn []() { - let m = FiniteMatrix::try_new(Matrix::<$d>::identity()).unwrap(); - let (det_int, total_exp) = bareiss_det_int_finite(&m); + let m = FiniteMatrix::new(Matrix::<$d>::identity()); + let (det_int, total_exp) = bareiss_det_int_finite(&m).unwrap(); let det = bigint_exp_to_bigrational(det_int, total_exp); assert_eq!(det, BigRational::from_integer(BigInt::from(1))); } @@ -1438,6 +1462,13 @@ mod tests { assert_eq!(*r.denom(), BigInt::from(2)); } + #[test] + fn bigint_exp_to_bigrational_negative_exp_reduces_to_integer() { + // 8 × 2^(-3) = 1 after stripping every denominator factor. + let r = bigint_exp_to_bigrational(BigInt::from(8), -3); + assert_eq!(r, BigRational::from_integer(BigInt::from(1))); + } + #[test] fn bigint_exp_to_bigrational_negative_exp_already_odd() { // 3 × 2^(-2) = 3/4 (already in lowest terms since 3 is odd) @@ -1611,35 +1642,27 @@ mod tests { #[test] fn []() { let mut a = Matrix::<$d>::identity(); - assert_eq!(a.set(0, 0, f64::NAN), Some(())); - let b = arbitrary_rhs::<$d>(); - assert_eq!(a.solve_exact(b), Err(LaError::NonFinite { row: Some(0), col: 0 })); + assert_eq!(a.set(0, 0, f64::NAN), Err(LaError::NonFinite { row: Some(0), col: 0 })); } #[test] fn []() { let mut a = Matrix::<$d>::identity(); - assert_eq!(a.set(0, 0, f64::INFINITY), Some(())); - let b = arbitrary_rhs::<$d>(); - assert_eq!(a.solve_exact(b), Err(LaError::NonFinite { row: Some(0), col: 0 })); + assert_eq!(a.set(0, 0, f64::INFINITY), Err(LaError::NonFinite { row: Some(0), col: 0 })); } #[test] fn []() { - let a = Matrix::<$d>::identity(); let mut b_arr = [1.0f64; $d]; b_arr[0] = f64::NAN; - let b = Vector::<$d>::new(b_arr); - assert_eq!(a.solve_exact(b), Err(LaError::NonFinite { row: None, col: 0 })); + assert_eq!(Vector::<$d>::try_new(b_arr), Err(LaError::NonFinite { row: None, col: 0 })); } #[test] fn []() { - let a = Matrix::<$d>::identity(); let mut b_arr = [1.0f64; $d]; b_arr[$d - 1] = f64::INFINITY; - let b = Vector::<$d>::new(b_arr); - assert_eq!(a.solve_exact(b), Err(LaError::NonFinite { row: None, col: $d - 1 })); + assert_eq!(Vector::<$d>::try_new(b_arr), Err(LaError::NonFinite { row: None, col: $d - 1 })); } #[test] @@ -1674,9 +1697,7 @@ mod tests { #[test] fn []() { let mut a = Matrix::<$d>::identity(); - assert_eq!(a.set(0, 0, f64::NAN), Some(())); - let b = arbitrary_rhs::<$d>(); - assert_eq!(a.solve_exact_f64(b), Err(LaError::NonFinite { row: Some(0), col: 0 })); + assert_eq!(a.set(0, 0, f64::NAN), Err(LaError::NonFinite { row: Some(0), col: 0 })); } } }; @@ -2450,20 +2471,12 @@ mod tests { } #[test] - #[should_panic(expected = "non-finite f64 in exact conversion")] - fn f64_to_bigrational_panics_on_nan() { - f64_to_bigrational(f64::NAN); - } - - #[test] - #[should_panic(expected = "non-finite f64 in exact conversion")] - fn f64_to_bigrational_panics_on_inf() { - f64_to_bigrational(f64::INFINITY); - } - - #[test] - #[should_panic(expected = "non-finite f64 in exact conversion")] - fn f64_to_bigrational_panics_on_neg_inf() { - f64_to_bigrational(f64::NEG_INFINITY); + fn f64_decompose_rejects_nonfinite_inputs() { + for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] { + assert_eq!( + f64_decompose(value), + Err(LaError::NonFinite { row: None, col: 0 }) + ); + } } } diff --git a/src/ldlt.rs b/src/ldlt.rs index e8c73e1..b815591 100644 --- a/src/ldlt.rs +++ b/src/ldlt.rs @@ -36,8 +36,45 @@ use crate::{LaError, Tolerance}; #[must_use] #[derive(Clone, Copy, Debug, PartialEq)] pub struct Ldlt { - factors: Matrix, - tol: Tolerance, + factors: LdltFactors, +} + +/// In-place LDLT factor storage whose diagonal entries are finite and usable. +/// +/// Construction through [`Ldlt::factor_symmetric`] proves every stored entry is +/// finite and every diagonal satisfies the factorization tolerance. +#[derive(Clone, Copy, Debug, PartialEq)] +struct LdltFactors { + storage: Matrix, +} + +impl LdltFactors { + /// Construct factors after LDLT factorization has proven the storage invariant. + #[inline] + const fn new_unchecked(storage: Matrix) -> Self { + Self { storage } + } + + /// Return a copied factor row. + #[inline] + #[must_use] + const fn row(&self, index: usize) -> [f64; D] { + self.storage.rows[index] + } + + /// Return a factor entry. + #[inline] + #[must_use] + const fn entry(&self, row: usize, col: usize) -> f64 { + self.storage.rows[row][col] + } + + /// Return a diagonal entry of `D`. + #[inline] + #[must_use] + const fn diag(&self, index: usize) -> f64 { + self.storage.rows[index][index] + } } impl Ldlt { @@ -89,7 +126,9 @@ impl Ldlt { } } - Ok(Self { factors: f, tol }) + Ok(Self { + factors: LdltFactors::new_unchecked(f), + }) } /// Determinant of the original matrix. @@ -102,7 +141,7 @@ impl Ldlt { /// /// # fn main() -> Result<(), LaError> { /// // Symmetric SPD matrix. - /// let a = Matrix::<2>::from_rows([[4.0, 2.0], [2.0, 3.0]]); + /// let a = Matrix::<2>::try_from_rows([[4.0, 2.0], [2.0, 3.0]])?; /// let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL)?; /// /// assert!((ldlt.det()? - 8.0).abs() <= 1e-12); @@ -118,7 +157,7 @@ impl Ldlt { let mut det = 1.0; let mut i = 0; while i < D { - det *= self.factors.rows[i][i]; + det *= self.factors.diag(i); if !det.is_finite() { cold_path(); return Err(LaError::non_finite_at(i)); @@ -135,10 +174,10 @@ impl Ldlt { /// use la_stack::prelude::*; /// /// # fn main() -> Result<(), LaError> { - /// let a = Matrix::<2>::from_rows([[4.0, 2.0], [2.0, 3.0]]); + /// let a = Matrix::<2>::try_from_rows([[4.0, 2.0], [2.0, 3.0]])?; /// let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL)?; /// - /// let b = Vector::<2>::new([1.0, 2.0]); + /// let b = Vector::<2>::try_new([1.0, 2.0])?; /// let x = ldlt.solve_vec(b)?.into_array(); /// /// assert!((x[0] - (-0.125)).abs() <= 1e-12); @@ -148,51 +187,26 @@ impl Ldlt { /// ``` /// /// # Errors - /// Returns [`LaError::NotPositiveSemidefinite`] if a diagonal entry - /// `d = D[i,i]` is negative. - /// - /// Returns [`LaError::Singular`] if a diagonal entry `d = D[i,i]` satisfies - /// `0 <= d <= tol`, where `tol` is the tolerance that was used during factorization. - /// - /// Returns [`LaError::NonFinite`] if NaN/∞ is detected. If - /// `FiniteVector::try_new(b)` rejects a non-finite RHS entry, the error uses - /// `row: None` and `col: i`, where `i` is the offending vector index. The - /// `row`/`col` coordinates follow the convention documented on - /// [`LaError::NonFinite`]: - /// - /// - `row: Some(i), col: i` — the stored `D` diagonal at `(i, i)` is non-finite - /// (only reachable via direct `Ldlt` construction; [`Matrix::ldlt`](crate::Matrix::ldlt) - /// rejects such factorizations). - /// - `row: None, col: i` — either RHS entry `b[i]` is non-finite, or a - /// computed intermediate (forward/back-substitution accumulator or the - /// quotient `x[i] / diag`) overflowed to NaN/∞ at step `i`. + /// Returns [`LaError::NonFinite`] if a computed substitution intermediate + /// overflows to NaN or infinity. Raw non-finite right-hand sides are rejected + /// by [`Vector::try_new`](crate::Vector::try_new) before a [`Vector`] can be + /// passed to this method. #[inline] pub const fn solve_vec(&self, b: Vector) -> Result, LaError> { - match FiniteVector::try_new(b) { - Ok(finite) => match self.solve_finite_vec(finite) { - Ok(x) => Ok(x.into_vector()), - Err(err) => Err(err), - }, + match self.solve_finite_vec(FiniteVector::new(b)) { + Ok(x) => Ok(x.into_vector()), Err(err) => Err(err), } } /// Solve `A x = b` using this LDLT factorization and a finite right-hand side. /// - /// The right-hand side entries are known finite, so this path only checks - /// factorization defensive invariants and computed substitution overflows. + /// The right-hand side entries and stored factors are known finite, so this + /// path only checks computed substitution overflows. /// /// # Errors - /// Returns [`LaError::NotPositiveSemidefinite`] if a diagonal entry - /// `d = D[i,i]` is negative. - /// - /// Returns [`LaError::Singular`] if a diagonal entry `d = D[i,i]` satisfies - /// `0 <= d <= tol`, where `tol` is the tolerance that was used during - /// factorization. - /// - /// Returns [`LaError::NonFinite`] if a stored factorization diagonal is - /// corrupt or if a computed substitution intermediate overflows to NaN or - /// infinity. + /// Returns [`LaError::NonFinite`] if a computed substitution intermediate + /// overflows to NaN or infinity. #[inline] pub(crate) const fn solve_finite_vec( &self, @@ -204,7 +218,7 @@ impl Ldlt { let mut i = 0; while i < D { let mut sum = x[i]; - let row = self.factors.rows[i]; + let row = self.factors.row(i); let mut j = 0; while j < i { sum = (-row[j]).mul_add(x[j], sum); @@ -221,23 +235,7 @@ impl Ldlt { // Diagonal solve: D z = y. let mut i = 0; while i < D { - let diag = self.factors.rows[i][i]; - // A corrupt stored diagonal is a specific matrix cell (i, i), - // distinct from a computed overflow — report it with - // `row: Some(i)` per the `LaError::NonFinite` convention used by - // `Matrix::det`, `Lu::factor`, and `Ldlt::factor`. - if !diag.is_finite() { - cold_path(); - return Err(LaError::non_finite_cell(i, i)); - } - if diag < 0.0 { - cold_path(); - return Err(LaError::not_positive_semidefinite(i, diag)); - } - if diag <= self.tol.get() { - cold_path(); - return Err(LaError::Singular { pivot_col: i }); - } + let diag = self.factors.diag(i); let quotient = x[i] / diag; if !quotient.is_finite() { @@ -255,7 +253,7 @@ impl Ldlt { let mut sum = x[i]; let mut j = i + 1; while j < D { - sum = (-self.factors.rows[j][i]).mul_add(x[j], sum); + sum = (-self.factors.entry(j, i)).mul_add(x[j], sum); j += 1; } if !sum.is_finite() { @@ -266,7 +264,7 @@ impl Ldlt { ii += 1; } - Ok(FiniteVector::new_unchecked(Vector::new(x))) + Ok(FiniteVector::new_unchecked(Vector::new_unchecked(x))) } } @@ -376,10 +374,7 @@ mod tests { #[test] fn solve_2x2_known_spd() { let a = Matrix::<2>::from_rows(black_box([[4.0, 2.0], [2.0, 3.0]])); - let ldlt = FiniteMatrix::try_new(a) - .unwrap() - .ldlt(DEFAULT_SINGULAR_TOL) - .unwrap(); + let ldlt = FiniteMatrix::new(a).ldlt(DEFAULT_SINGULAR_TOL).unwrap(); let b = Vector::<2>::new(black_box([1.0, 2.0])); let x = ldlt.solve_vec(b).unwrap().into_array(); @@ -442,9 +437,8 @@ mod tests { } #[test] - fn nonfinite_detected() { - let a = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]); - let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err(); + fn matrix_constructor_rejects_nonfinite_diagonal() { + let err = Matrix::<2>::try_from_rows([[f64::NAN, 0.0], [0.0, 1.0]]).unwrap_err(); assert_eq!( err, LaError::NonFinite { @@ -455,9 +449,8 @@ mod tests { } #[test] - fn nonfinite_offdiagonal_detected_before_asymmetry() { - let a = Matrix::<2>::from_rows([[1.0, f64::NAN], [0.0, 1.0]]); - let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err(); + fn matrix_constructor_rejects_nonfinite_offdiagonal_before_asymmetry() { + let err = Matrix::<2>::try_from_rows([[1.0, f64::NAN], [0.0, 1.0]]).unwrap_err(); assert_eq!( err, LaError::NonFinite { @@ -587,109 +580,32 @@ mod tests { ); } - // ----------------------------------------------------------------------- - // Defensive-path coverage for `solve_vec`. - // - // `Ldlt::factor` guarantees that every stored diagonal is finite and - // strictly greater than the recorded `tol`. `solve_vec` still re-checks - // both invariants as a safety net (see the `!diag.is_finite()` and - // `diag <= self.tol` guards in the diagonal solve). Those branches are - // unreachable through the public API, so the only way to exercise them - // is to construct `Ldlt` directly with corrupt factors. The tests below - // document and verify that the safety nets return the documented error - // variants. - // ----------------------------------------------------------------------- - - macro_rules! gen_solve_vec_defensive_tests { + macro_rules! gen_solve_vec_boundary_tests { ($d:literal) => { paste! { - /// `solve_vec` must surface `NonFinite` when a stored - /// diagonal is NaN, even though `factor` cannot produce - /// such a factorization. The error must pinpoint the - /// corrupt cell at `(D-1, D-1)` per the - /// [`LaError::NonFinite`] convention. - #[test] - fn []() { - let mut factors = Matrix::<$d>::identity(); - factors.rows[$d - 1][$d - 1] = f64::NAN; - let ldlt = Ldlt::<$d> { - factors, - tol: DEFAULT_SINGULAR_TOL, - }; - let b = Vector::<$d>::new([1.0; $d]); - let err = ldlt.solve_vec(b).unwrap_err(); - assert_eq!( - err, - LaError::NonFinite { - row: Some($d - 1), - col: $d - 1, - } - ); - } - - /// `solve_vec` must surface `NotPositiveSemidefinite` when a - /// stored diagonal is negative, even though `ldlt` cannot - /// produce such a factorization. + /// Raw non-finite right-hand sides are rejected before a + /// `Vector` can be passed into `solve_vec`. #[test] - fn []() { - let mut factors = Matrix::<$d>::identity(); - factors.rows[$d - 1][$d - 1] = -1.0; - let ldlt = Ldlt::<$d> { - factors, - tol: DEFAULT_SINGULAR_TOL, - }; - let b = Vector::<$d>::new([1.0; $d]); - let err = ldlt.solve_vec(b).unwrap_err(); - assert_eq!( - err, - LaError::NotPositiveSemidefinite { - pivot_col: $d - 1, - value: -1.0, - } - ); - } - - /// `solve_vec` must surface `Singular` when a stored - /// non-negative diagonal is at or below the recorded tolerance, - /// even though `ldlt` cannot produce such a factorization. - #[test] - fn []() { - let mut factors = Matrix::<$d>::identity(); - factors.rows[$d - 1][$d - 1] = 0.0; - let ldlt = Ldlt::<$d> { - factors, - tol: DEFAULT_SINGULAR_TOL, - }; - let b = Vector::<$d>::new([1.0; $d]); - let err = ldlt.solve_vec(b).unwrap_err(); - assert_eq!(err, LaError::Singular { pivot_col: $d - 1 }); - } - - /// `solve_vec` rejects raw non-finite right-hand sides before - /// entering the finite-RHS solve path. - #[test] - fn []() { - let ldlt = Matrix::<$d>::identity().ldlt(DEFAULT_SINGULAR_TOL).unwrap(); + fn []() { let mut rhs = [1.0; $d]; rhs[$d - 1] = f64::NAN; - let err = ldlt.solve_vec(Vector::<$d>::new(rhs)).unwrap_err(); assert_eq!( - err, - LaError::NonFinite { + Vector::<$d>::try_new(rhs), + Err(LaError::NonFinite { row: None, col: $d - 1, - } + }) ); } } }; } - gen_solve_vec_defensive_tests!(2); - gen_solve_vec_defensive_tests!(3); - gen_solve_vec_defensive_tests!(4); - gen_solve_vec_defensive_tests!(5); + gen_solve_vec_boundary_tests!(2); + gen_solve_vec_boundary_tests!(3); + gen_solve_vec_boundary_tests!(4); + gen_solve_vec_boundary_tests!(5); // ----------------------------------------------------------------------- // Const-evaluability tests. @@ -714,8 +630,7 @@ mod tests { let mut factors = Matrix::<$d>::identity(); factors.rows[0][0] = 2.0; let ldlt = Ldlt::<$d> { - factors, - tol: DEFAULT_SINGULAR_TOL, + factors: LdltFactors::new_unchecked(factors), }; ldlt.det() }; @@ -729,11 +644,10 @@ mod tests { #[test] fn []() { #[allow(clippy::cast_precision_loss)] - const X: [f64; $d] = { - let ldlt = Ldlt::<$d> { - factors: Matrix::<$d>::identity(), - tol: DEFAULT_SINGULAR_TOL, - }; + const X: [f64; $d] = { + let ldlt = Ldlt::<$d> { + factors: LdltFactors::new_unchecked(Matrix::<$d>::identity()), + }; let mut b_arr = [0.0f64; $d]; let mut i = 0; while i < $d { diff --git a/src/lib.rs b/src/lib.rs index 9d3c71a..4621a2e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,15 +10,15 @@ mod readme_doctests { /// /// # fn main() -> Result<(), LaError> { /// // This system requires pivoting (a[0][0] = 0), so it's a good LU demo. - /// let a = Matrix::<5>::from_rows([ + /// let a = Matrix::<5>::try_from_rows([ /// [0.0, 1.0, 1.0, 1.0, 1.0], /// [1.0, 0.0, 1.0, 1.0, 1.0], /// [1.0, 1.0, 0.0, 1.0, 1.0], /// [1.0, 1.0, 1.0, 0.0, 1.0], /// [1.0, 1.0, 1.0, 1.0, 0.0], - /// ]); + /// ])?; /// - /// let b = Vector::<5>::new([14.0, 13.0, 12.0, 11.0, 10.0]); + /// let b = Vector::<5>::try_new([14.0, 13.0, 12.0, 11.0, 10.0])?; /// /// let lu = a.lu(DEFAULT_PIVOT_TOL)?; /// let x = lu.solve_vec(b)?.into_array(); @@ -38,13 +38,13 @@ mod readme_doctests { /// /// # fn main() -> Result<(), LaError> { /// // This matrix is symmetric positive-definite (A = L*L^T) so LDLT works without pivoting. - /// let a = Matrix::<5>::from_rows([ + /// let a = Matrix::<5>::try_from_rows([ /// [1.0, 1.0, 0.0, 0.0, 0.0], /// [1.0, 2.0, 1.0, 0.0, 0.0], /// [0.0, 1.0, 2.0, 1.0, 0.0], /// [0.0, 0.0, 1.0, 2.0, 1.0], /// [0.0, 0.0, 0.0, 1.0, 2.0], - /// ]); + /// ])?; /// /// let det = a.ldlt(DEFAULT_SINGULAR_TOL)?.det()?; /// assert!((det - 1.0).abs() <= 1e-12); @@ -135,7 +135,7 @@ const EPS: f64 = f64::EPSILON; // 2^-52 /// use la_stack::prelude::*; /// /// # fn main() -> Result<(), LaError> { -/// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); +/// let m = Matrix::<2>::try_from_rows([[1.0, 2.0], [3.0, 4.0]])?; /// let Some(det) = m.det_direct()? else { /// return Ok(()); /// }; diff --git a/src/lu.rs b/src/lu.rs index 4dc82d3..86fa617 100644 --- a/src/lu.rs +++ b/src/lu.rs @@ -10,10 +10,40 @@ use crate::{LaError, Tolerance}; #[must_use] #[derive(Clone, Copy, Debug, PartialEq)] pub struct Lu { - factors: Matrix, + factors: LuFactors, piv: [usize; D], piv_sign: f64, - tol: Tolerance, +} + +/// In-place LU factor storage whose `U` diagonal is finite and usable. +/// +/// Construction through [`Lu::factor_finite`] proves every stored entry is +/// finite and every `U[i,i]` satisfies the factorization tolerance. +#[derive(Clone, Copy, Debug, PartialEq)] +struct LuFactors { + storage: Matrix, +} + +impl LuFactors { + /// Construct factors after LU factorization has proven the storage invariant. + #[inline] + const fn new_unchecked(storage: Matrix) -> Self { + Self { storage } + } + + /// Return a copied factor row. + #[inline] + #[must_use] + const fn row(&self, index: usize) -> [f64; D] { + self.storage.rows[index] + } + + /// Return a diagonal entry of `U`. + #[inline] + #[must_use] + const fn diag(&self, index: usize) -> f64 { + self.storage.rows[index][index] + } } impl Lu { @@ -83,10 +113,9 @@ impl Lu { } Ok(Self { - factors: lu, + factors: LuFactors::new_unchecked(lu), piv, piv_sign, - tol, }) } @@ -97,10 +126,10 @@ impl Lu { /// use la_stack::prelude::*; /// /// # fn main() -> Result<(), LaError> { - /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); + /// let a = Matrix::<2>::try_from_rows([[1.0, 2.0], [3.0, 4.0]])?; /// let lu = a.lu(DEFAULT_PIVOT_TOL)?; /// - /// let b = Vector::<2>::new([5.0, 11.0]); + /// let b = Vector::<2>::try_new([5.0, 11.0])?; /// let x = lu.solve_vec(b)?.into_array(); /// /// assert!((x[0] - 1.0).abs() <= 1e-12); @@ -110,45 +139,26 @@ impl Lu { /// ``` /// /// # Errors - /// Returns [`LaError::Singular`] if a diagonal entry of `U` satisfies `|u_ii| <= tol`, where - /// `tol` is the tolerance that was used during factorization. - /// - /// Returns [`LaError::NonFinite`] if NaN/∞ is detected. If - /// `FiniteVector::try_new(b)` rejects a non-finite RHS entry, the error uses - /// `row: None` and `col: i`, where `i` is the offending vector index. The - /// `row`/`col` coordinates follow the convention documented on - /// [`LaError::NonFinite`]: - /// - /// - `row: Some(i), col: i` — the stored `U` diagonal at `(i, i)` is non-finite - /// (only reachable via direct `Lu` construction; [`Matrix::lu`](crate::Matrix::lu) - /// rejects such factorizations). - /// - `row: None, col: i` — either RHS entry `b[i]` is non-finite, or a - /// computed intermediate (forward/back-substitution accumulator or the - /// quotient `sum / diag`) overflowed to NaN/∞ at step `i`. + /// Returns [`LaError::NonFinite`] if a computed substitution intermediate + /// overflows to NaN or infinity. Raw non-finite right-hand sides are rejected + /// by [`Vector::try_new`](crate::Vector::try_new) before a [`Vector`] can be + /// passed to this method. #[inline] pub const fn solve_vec(&self, b: Vector) -> Result, LaError> { - match FiniteVector::try_new(b) { - Ok(finite) => match self.solve_finite_vec(finite) { - Ok(x) => Ok(x.into_vector()), - Err(err) => Err(err), - }, + match self.solve_finite_vec(FiniteVector::new(b)) { + Ok(x) => Ok(x.into_vector()), Err(err) => Err(err), } } /// Solve `A x = b` using this LU factorization and a finite right-hand side. /// - /// The right-hand side entries are known finite, so this path only checks - /// factorization defensive invariants and computed substitution overflows. + /// The right-hand side entries and stored factors are known finite, so this + /// path only checks computed substitution overflows. /// /// # Errors - /// Returns [`LaError::Singular`] if a diagonal entry of `U` satisfies - /// `|u_ii| <= tol`, where `tol` is the tolerance that was used during - /// factorization. - /// - /// Returns [`LaError::NonFinite`] if a stored factorization diagonal is - /// corrupt or if a computed substitution intermediate overflows to NaN or - /// infinity. + /// Returns [`LaError::NonFinite`] if a computed substitution intermediate + /// overflows to NaN or infinity. #[inline] pub(crate) const fn solve_finite_vec( &self, @@ -166,7 +176,7 @@ impl Lu { let mut i = 0; while i < D { let mut sum = x[i]; - let row = self.factors.rows[i]; + let row = self.factors.row(i); let mut j = 0; while j < i { sum = (-row[j]).mul_add(x[j], sum); @@ -185,7 +195,7 @@ impl Lu { while ii < D { let i = D - 1 - ii; let mut sum = x[i]; - let row = self.factors.rows[i]; + let row = self.factors.row(i); let mut j = i + 1; while j < D { sum = (-row[j]).mul_add(x[j], sum); @@ -193,21 +203,10 @@ impl Lu { } let diag = row[i]; - // Distinguish a corrupt stored pivot (row: Some(i), col: i) from - // a computed intermediate overflow (row: None, col: i) so callers - // can diagnose the failure source without inspecting internals. - if !diag.is_finite() { - cold_path(); - return Err(LaError::non_finite_cell(i, i)); - } if !sum.is_finite() { cold_path(); return Err(LaError::non_finite_at(i)); } - if diag.abs() <= self.tol.get() { - cold_path(); - return Err(LaError::Singular { pivot_col: i }); - } let quotient = sum / diag; if !quotient.is_finite() { @@ -218,7 +217,7 @@ impl Lu { ii += 1; } - Ok(FiniteVector::new_unchecked(Vector::new(x))) + Ok(FiniteVector::new_unchecked(Vector::new_unchecked(x))) } /// Determinant of the original matrix. @@ -228,7 +227,7 @@ impl Lu { /// use la_stack::prelude::*; /// /// # fn main() -> Result<(), LaError> { - /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); + /// let a = Matrix::<2>::try_from_rows([[1.0, 2.0], [3.0, 4.0]])?; /// let lu = a.lu(DEFAULT_PIVOT_TOL)?; /// /// let det = lu.det()?; @@ -245,7 +244,7 @@ impl Lu { let mut det = self.piv_sign; let mut i = 0; while i < D { - det *= self.factors.rows[i][i]; + det *= self.factors.diag(i); if !det.is_finite() { cold_path(); return Err(LaError::non_finite_at(i)); @@ -422,10 +421,7 @@ mod tests { #[test] fn solve_1x1() { let a = Matrix::<1>::from_rows(black_box([[2.0]])); - let lu = FiniteMatrix::try_new(a) - .unwrap() - .lu(DEFAULT_PIVOT_TOL) - .unwrap(); + let lu = FiniteMatrix::new(a).lu(DEFAULT_PIVOT_TOL).unwrap(); let b = Vector::<1>::new(black_box([6.0])); let solve_fn: fn(&Lu<1>, Vector<1>) -> Result, LaError> = @@ -440,10 +436,7 @@ mod tests { #[test] fn solve_2x2_basic() { let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]])); - let lu = FiniteMatrix::try_new(a) - .unwrap() - .lu(DEFAULT_PIVOT_TOL) - .unwrap(); + let lu = FiniteMatrix::new(a).lu(DEFAULT_PIVOT_TOL).unwrap(); let b = Vector::<2>::new(black_box([5.0, 11.0])); let solve_fn: fn(&Lu<2>, Vector<2>) -> Result, LaError> = @@ -522,9 +515,8 @@ mod tests { } #[test] - fn nonfinite_detected_on_pivot_entry() { - let a = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]); - let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err(); + fn matrix_constructor_rejects_nonfinite_pivot_entry() { + let err = Matrix::<2>::try_from_rows([[f64::NAN, 0.0], [0.0, 1.0]]).unwrap_err(); assert_eq!( err, LaError::NonFinite { @@ -535,9 +527,8 @@ mod tests { } #[test] - fn nonfinite_detected_in_pivot_column_scan() { - let a = Matrix::<2>::from_rows([[1.0, 0.0], [f64::INFINITY, 1.0]]); - let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err(); + fn matrix_constructor_rejects_nonfinite_pivot_column_entry() { + let err = Matrix::<2>::try_from_rows([[1.0, 0.0], [f64::INFINITY, 1.0]]).unwrap_err(); assert_eq!( err, LaError::NonFinite { @@ -613,103 +604,32 @@ mod tests { assert_eq!(lu.det(), Err(LaError::NonFinite { row: None, col: 3 })); } - // ----------------------------------------------------------------------- - // Defensive-path coverage for `solve_vec`. - // - // `Lu::factor` guarantees that every stored U diagonal is finite and - // satisfies `|U[i,i]| > tol`. `solve_vec` still re-checks during - // back-substitution as a safety net (see the `!diag.is_finite()` and - // `diag.abs() <= self.tol` guards). Those branches are unreachable - // through the public API, so the only way to exercise them is to - // construct `Lu` directly with a corrupt U. The tests below document - // and verify that the safety nets return the documented error variants - // with coordinates that locate the offending stored cell. - // ----------------------------------------------------------------------- - - macro_rules! gen_solve_vec_defensive_tests { + macro_rules! gen_solve_vec_boundary_tests { ($d:literal) => { paste! { - /// `solve_vec` must surface `Singular` when a stored U - /// diagonal is at or below the recorded tolerance, even - /// though `factor` cannot produce such a factorization. - #[test] - fn []() { - let mut factors = Matrix::<$d>::identity(); - factors.rows[$d - 1][$d - 1] = 0.0; - - let mut piv = [0usize; $d]; - for (i, p) in piv.iter_mut().enumerate() { - *p = i; - } - - let lu = Lu::<$d> { - factors, - piv, - piv_sign: 1.0, - tol: DEFAULT_PIVOT_TOL, - }; - let b = Vector::<$d>::new([0.0; $d]); - let err = lu.solve_vec(b).unwrap_err(); - assert_eq!(err, LaError::Singular { pivot_col: $d - 1 }); - } - - /// `solve_vec` must surface `NonFinite` with the corrupt - /// cell's coordinates when a stored U diagonal is NaN, - /// even though `factor` cannot produce such a - /// factorization. The error must pinpoint `(D-1, D-1)` - /// per the [`LaError::NonFinite`] convention. - #[test] - fn []() { - let mut factors = Matrix::<$d>::identity(); - factors.rows[$d - 1][$d - 1] = f64::NAN; - - let mut piv = [0usize; $d]; - for (i, p) in piv.iter_mut().enumerate() { - *p = i; - } - - let lu = Lu::<$d> { - factors, - piv, - piv_sign: 1.0, - tol: DEFAULT_PIVOT_TOL, - }; - let b = Vector::<$d>::new([1.0; $d]); - let err = lu.solve_vec(b).unwrap_err(); - assert_eq!( - err, - LaError::NonFinite { - row: Some($d - 1), - col: $d - 1, - } - ); - } - - /// `solve_vec` rejects raw non-finite right-hand sides before - /// entering the finite-RHS solve path. + /// Raw non-finite right-hand sides are rejected before a + /// `Vector` can be passed into `solve_vec`. #[test] - fn []() { - let lu = Matrix::<$d>::identity().lu(DEFAULT_PIVOT_TOL).unwrap(); + fn []() { let mut rhs = [1.0; $d]; rhs[$d - 1] = f64::NAN; - let err = lu.solve_vec(Vector::<$d>::new(rhs)).unwrap_err(); assert_eq!( - err, - LaError::NonFinite { + Vector::<$d>::try_new(rhs), + Err(LaError::NonFinite { row: None, col: $d - 1, - } + }) ); } } }; } - gen_solve_vec_defensive_tests!(2); - gen_solve_vec_defensive_tests!(3); - gen_solve_vec_defensive_tests!(4); - gen_solve_vec_defensive_tests!(5); + gen_solve_vec_boundary_tests!(2); + gen_solve_vec_boundary_tests!(3); + gen_solve_vec_boundary_tests!(4); + gen_solve_vec_boundary_tests!(5); // ----------------------------------------------------------------------- // Const-evaluability tests. @@ -728,10 +648,9 @@ mod tests { factors.rows[0][0] = 2.0; factors.rows[1][1] = 3.0; let lu = Lu::<2> { - factors, + factors: LuFactors::new_unchecked(factors), piv: [0, 1], piv_sign: 1.0, - tol: DEFAULT_PIVOT_TOL, }; lu.det() }; @@ -744,10 +663,9 @@ mod tests { // Identity factors but `piv_sign = -1.0` encoding a single row swap; // the determinant magnitude is 1 but the sign flips. let lu = Lu::<3> { - factors: Matrix::<3>::identity(), + factors: LuFactors::new_unchecked(Matrix::<3>::identity()), piv: [1, 0, 2], piv_sign: -1.0, - tol: DEFAULT_PIVOT_TOL, }; lu.det() }; @@ -759,10 +677,9 @@ mod tests { // Identity LU ⇒ solve_vec returns the permuted RHS untouched. const X: [f64; 2] = { let lu = Lu::<2> { - factors: Matrix::<2>::identity(), + factors: LuFactors::new_unchecked(Matrix::<2>::identity()), piv: [0, 1], piv_sign: 1.0, - tol: DEFAULT_PIVOT_TOL, }; let b = Vector::<2>::new([1.0, 2.0]); match lu.solve_vec(b) { diff --git a/src/matrix.rs b/src/matrix.rs index c9245c9..bb2bd22 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -15,9 +15,9 @@ pub struct Matrix { /// Fixed-size square matrix whose stored entries are all finite. /// -/// This proof-carrying wrapper lets callers validate raw [`Matrix`] values once -/// at an input boundary, then pass the finite invariant into numerical -/// algorithms without repeatedly checking stored entries for NaN or infinity. +/// This proof-carrying wrapper makes the finite invariant explicit for internal +/// algorithms that should not repeatedly check stored entries for NaN or +/// infinity. #[must_use] #[derive(Clone, Copy, Debug, PartialEq)] #[allow(clippy::redundant_pub_crate)] @@ -28,25 +28,19 @@ pub(crate) struct FiniteMatrix { impl FiniteMatrix { /// Construct a finite matrix without checking the invariant. /// - /// This is crate-internal so public callers must use [`try_new`](Self::try_new) - /// or [`from_rows`](Self::from_rows), which preserve diagnostics for rejected - /// raw entries. + /// This is crate-internal so raw storage still goes through + /// [`Matrix::try_from_rows`], which preserves diagnostics for rejected + /// entries. #[inline] pub(crate) const fn new_unchecked(matrix: Matrix) -> Self { Self { matrix } } - /// Validate that every stored matrix entry is finite. - /// # Errors - /// Returns [`LaError::NonFinite`] with matrix coordinates for the first - /// offending entry in row-major order when `matrix` contains NaN or infinity. + /// Wrap an already-finite matrix for algorithms that carry the invariant + /// explicitly. #[inline] - pub const fn try_new(matrix: Matrix) -> Result { - if let Some((row, col)) = matrix.first_non_finite_cell() { - Err(LaError::non_finite_cell(row, col)) - } else { - Ok(Self::new_unchecked(matrix)) - } + pub const fn new(matrix: Matrix) -> Self { + Self::new_unchecked(matrix) } /// Validate raw row-major storage and construct a finite matrix. @@ -55,7 +49,10 @@ impl FiniteMatrix { /// offending entry in row-major order when `rows` contains NaN or infinity. #[inline] pub const fn from_rows(rows: [[f64; D]; D]) -> Result { - Self::try_new(Matrix::from_rows(rows)) + match Matrix::try_from_rows(rows) { + Ok(matrix) => Ok(Self::new_unchecked(matrix)), + Err(err) => Err(err), + } } /// All-zeros finite matrix. @@ -351,7 +348,11 @@ impl TryFrom> for FiniteMatrix { #[inline] fn try_from(value: Matrix) -> Result { - Self::try_new(value) + if let Some((row, col)) = Matrix::::first_non_finite_cell_in(&value.rows) { + Err(LaError::non_finite_cell(row, col)) + } else { + Ok(Self::new(value)) + } } } @@ -413,17 +414,51 @@ impl SymmetricMatrix { } impl Matrix { - /// Construct from row-major storage. + /// Test-only infallible constructor for finite literal fixtures. + #[cfg(test)] + #[inline] + pub(crate) const fn from_rows(rows: [[f64; D]; D]) -> Self { + match Self::try_from_rows(rows) { + Ok(matrix) => matrix, + Err(_) => panic!("Matrix::from_rows requires finite entries"), + } + } + + /// Try to create a finite matrix from row-major storage. + /// + /// This is the public raw-storage boundary for matrices. Once construction + /// succeeds, methods such as [`lu`](Self::lu), [`ldlt`](Self::ldlt), and + /// [`det`](Self::det) can rely on finite stored entries. /// /// # Examples /// ``` /// use la_stack::prelude::*; /// - /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); + /// # fn main() -> Result<(), LaError> { + /// let m = Matrix::<2>::try_from_rows([[1.0, 2.0], [3.0, 4.0]])?; /// assert_eq!(m.get(0, 1), Some(2.0)); + /// # Ok(()) + /// # } /// ``` + /// + /// # Errors + /// Returns [`LaError::NonFinite`] with matrix coordinates for the first + /// offending entry in row-major order when `rows` contains NaN or infinity. #[inline] - pub const fn from_rows(rows: [[f64; D]; D]) -> Self { + pub const fn try_from_rows(rows: [[f64; D]; D]) -> Result { + if let Some((row, col)) = Self::first_non_finite_cell_in(&rows) { + Err(LaError::non_finite_cell(row, col)) + } else { + Ok(Self::from_rows_unchecked(rows)) + } + } + + /// Construct a matrix without checking that entries are finite. + /// + /// This crate-internal escape hatch is reserved for literals and algorithm + /// outputs whose finite invariant is visible at the call site. + #[inline] + pub(crate) const fn from_rows_unchecked(rows: [[f64; D]; D]) -> Self { Self { rows } } @@ -438,9 +473,7 @@ impl Matrix { /// ``` #[inline] pub const fn zero() -> Self { - Self { - rows: [[0.0; D]; D], - } + Self::from_rows_unchecked([[0.0; D]; D]) } /// Identity matrix. @@ -473,9 +506,12 @@ impl Matrix { /// ``` /// use la_stack::prelude::*; /// - /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); + /// # fn main() -> Result<(), LaError> { + /// let m = Matrix::<2>::try_from_rows([[1.0, 2.0], [3.0, 4.0]])?; /// assert_eq!(m.get(1, 0), Some(3.0)); /// assert_eq!(m.get(2, 0), None); + /// # Ok(()) + /// # } /// ``` #[inline] #[must_use] @@ -498,7 +534,7 @@ impl Matrix { /// use la_stack::prelude::*; /// /// # fn main() -> Result<(), LaError> { - /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); + /// let m = Matrix::<2>::try_from_rows([[1.0, 2.0], [3.0, 4.0]])?; /// assert_eq!(m.get_checked(1, 0)?, 3.0); /// assert_eq!( /// m.get_checked(2, 0), @@ -523,36 +559,40 @@ impl Matrix { } } - /// Set an element with bounds checking. - /// - /// Returns `Some(())` if the index was in bounds, or `None` otherwise. + /// Set a finite element with bounds checking. /// /// # Examples /// ``` /// use la_stack::prelude::*; /// + /// # fn main() -> Result<(), LaError> { /// let mut m = Matrix::<2>::zero(); - /// assert_eq!(m.set(0, 1, 2.5), Some(())); + /// assert_eq!(m.set(0, 1, 2.5), Ok(())); /// assert_eq!(m.get(0, 1), Some(2.5)); - /// assert_eq!(m.set(10, 0, 1.0), None); + /// assert_eq!( + /// m.set(10, 0, 1.0), + /// Err(LaError::IndexOutOfBounds { + /// row: 10, + /// col: 0, + /// dim: 2, + /// }) + /// ); + /// # Ok(()) + /// # } /// ``` + /// + /// # Errors + /// Returns [`LaError::IndexOutOfBounds`] when either index is not `< D`. + /// Returns [`LaError::NonFinite`] when `value` is NaN or infinity. #[inline] - #[must_use] - pub const fn set(&mut self, r: usize, c: usize, value: f64) -> Option<()> { - if r < D && c < D { - self.rows[r][c] = value; - Some(()) - } else { - None - } + pub const fn set(&mut self, row: usize, col: usize, value: f64) -> Result<(), LaError> { + self.set_checked(row, col, value) } /// Set an element, preserving index context on failure. /// - /// The matrix is mutated only when `(row, col)` is in bounds. Prefer - /// [`set`](Self::set) for const or hot paths that only need `Option`-style absence; - /// use this method at public runtime boundaries where failed mutation - /// should return a typed, contextual error. + /// The matrix is mutated only when `(row, col)` is in bounds and `value` is + /// finite. /// /// # Examples /// ``` @@ -577,38 +617,40 @@ impl Matrix { /// /// # Errors /// Returns [`LaError::IndexOutOfBounds`] when either index is not `< D`. + /// Returns [`LaError::NonFinite`] when `value` is NaN or infinity. #[inline] pub const fn set_checked(&mut self, row: usize, col: usize, value: f64) -> Result<(), LaError> { - if row < D && col < D { - self.rows[row][col] = value; - Ok(()) - } else { - Err(LaError::index_out_of_bounds(row, col, D)) + if row >= D || col >= D { + return Err(LaError::index_out_of_bounds(row, col, D)); } + if !value.is_finite() { + return Err(LaError::non_finite_cell(row, col)); + } + self.rows[row][col] = value; + Ok(()) } /// Infinity norm (maximum absolute row sum). /// /// # Non-finite handling - /// Non-finite entries are rejected with source coordinates instead of - /// silently propagating NaN or infinity through the norm. + /// Raw non-finite entries are rejected by [`try_from_rows`](Self::try_from_rows) + /// and [`set`](Self::set) before they can be stored. /// /// Row sums are accumulated in `f64` with ordinary addition. This method - /// checks for non-finite inputs and overflowed accumulators, but it does not - /// provide a certified absolute rounding bound for the returned norm. + /// checks for overflowed accumulators, but it does not provide a certified + /// absolute rounding bound for the returned norm. /// /// # Examples /// ``` /// use la_stack::prelude::*; /// /// # fn main() -> Result<(), LaError> { - /// let m = Matrix::<2>::from_rows([[1.0, -2.0], [3.0, 4.0]]); + /// let m = Matrix::<2>::try_from_rows([[1.0, -2.0], [3.0, 4.0]])?; /// assert!((m.inf_norm()? - 7.0).abs() <= 1e-12); /// - /// // NaN entries are rejected with coordinates. - /// let nan = Matrix::<2>::from_rows([[f64::NAN, 1.0], [2.0, 3.0]]); + /// // Raw NaN entries are rejected with coordinates. /// assert_eq!( - /// nan.inf_norm(), + /// Matrix::<2>::try_from_rows([[f64::NAN, 1.0], [2.0, 3.0]]), /// Err(LaError::NonFinite { /// row: Some(0), /// col: 0, @@ -619,14 +661,11 @@ impl Matrix { /// ``` /// /// # Errors - /// Returns [`LaError::NonFinite`] when any entry is NaN or infinity, or when - /// a row sum overflows to NaN or infinity. + /// Returns [`LaError::NonFinite`] when a row sum overflows to NaN or + /// infinity. #[inline] pub const fn inf_norm(&self) -> Result { - match FiniteMatrix::try_new(*self) { - Ok(finite) => finite.inf_norm(), - Err(err) => Err(err), - } + FiniteMatrix::new(*self).inf_norm() } /// Returns `true` if the matrix is symmetric within a relative tolerance. @@ -646,35 +685,33 @@ impl Matrix { /// raw `f64`; negative, NaN, and infinite tolerances return /// [`LaError::InvalidTolerance`]. /// - /// # NaN / infinity handling - /// Stored NaN or ±∞ entries return [`LaError::NonFinite`] with the - /// offending matrix coordinates. A finite matrix can still return - /// [`LaError::NonFinite`] if computing the scaled symmetry tolerance - /// overflows to NaN or infinity. If both stored entries are finite but - /// their difference overflows to ±∞, the pair is reported as asymmetric. + /// # Overflow handling + /// A finite matrix can return [`LaError::NonFinite`] if computing the scaled + /// symmetry tolerance overflows to NaN or infinity. If both stored entries + /// are finite but their difference overflows to ±∞, the pair is reported as + /// asymmetric. /// /// # Examples /// ``` /// use la_stack::prelude::*; /// /// # fn main() -> Result<(), LaError> { - /// let a = Matrix::<2>::from_rows([[4.0, 2.0], [2.0, 3.0]]); + /// let a = Matrix::<2>::try_from_rows([[4.0, 2.0], [2.0, 3.0]])?; /// let tol = Tolerance::new(1e-12)?; /// assert!(a.is_symmetric(tol)?); /// - /// let b = Matrix::<2>::from_rows([[4.0, 2.0], [3.0, 3.0]]); + /// let b = Matrix::<2>::try_from_rows([[4.0, 2.0], [3.0, 3.0]])?; /// assert!(!b.is_symmetric(tol)?); /// # Ok(()) /// # } /// ``` /// /// # Errors - /// Returns [`LaError::NonFinite`] when any matrix entry is NaN or infinite, - /// or when computing the scaled symmetry tolerance overflows to NaN or - /// infinity. + /// Returns [`LaError::NonFinite`] when computing the scaled symmetry + /// tolerance overflows to NaN or infinity. #[inline] pub fn is_symmetric(&self, rel_tol: Tolerance) -> Result { - FiniteMatrix::try_new(*self)?.is_symmetric(rel_tol) + FiniteMatrix::new(*self).is_symmetric(rel_tol) } /// Returns the indices `(r, c)` (with `r < c`) of the first off-diagonal @@ -686,11 +723,10 @@ impl Matrix { /// predicate is the same as [`is_symmetric`](Self::is_symmetric): /// `|self[r][c] - self[c][r]| <= rel_tol * max(1.0, inf_norm(self))`. /// - /// Stored NaN or ±∞ entries return [`LaError::NonFinite`] with the - /// offending matrix coordinates. A finite matrix can still return - /// [`LaError::NonFinite`] if computing the scaled symmetry tolerance - /// overflows to NaN or infinity. If both stored entries are finite but - /// their difference overflows to ±∞, the pair is reported as asymmetric. + /// A finite matrix can return [`LaError::NonFinite`] if computing the scaled + /// symmetry tolerance overflows to NaN or infinity. If both stored entries + /// are finite but their difference overflows to ±∞, the pair is reported as + /// asymmetric. /// /// The `rel_tol` argument is a [`Tolerance`], so raw caller input must be /// finite and non-negative before it can reach this predicate. Use @@ -703,11 +739,11 @@ impl Matrix { /// use la_stack::prelude::*; /// /// # fn main() -> Result<(), LaError> { - /// let a = Matrix::<3>::from_rows([ + /// let a = Matrix::<3>::try_from_rows([ /// [1.0, 2.0, 0.0], /// [2.0, 4.0, 5.0], /// [0.0, 6.0, 9.0], // 6.0 breaks symmetry with a[1][2] = 5.0 - /// ]); + /// ])?; /// let tol = Tolerance::new(1e-12)?; /// assert_eq!(a.first_asymmetry(tol)?, Some((1, 2))); /// assert_eq!(Matrix::<3>::identity().first_asymmetry(tol)?, None); @@ -716,12 +752,11 @@ impl Matrix { /// ``` /// /// # Errors - /// Returns [`LaError::NonFinite`] when any matrix entry is NaN or infinite, - /// or when computing the scaled symmetry tolerance overflows to NaN or - /// infinity. + /// Returns [`LaError::NonFinite`] when computing the scaled symmetry + /// tolerance overflows to NaN or infinity. #[inline] pub fn first_asymmetry(&self, rel_tol: Tolerance) -> Result, LaError> { - FiniteMatrix::try_new(*self)?.first_asymmetry(rel_tol) + FiniteMatrix::new(*self).first_asymmetry(rel_tol) } /// Compute an LU decomposition with partial pivoting. @@ -731,10 +766,10 @@ impl Matrix { /// use la_stack::prelude::*; /// /// # fn main() -> Result<(), LaError> { - /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); + /// let a = Matrix::<2>::try_from_rows([[1.0, 2.0], [3.0, 4.0]])?; /// let lu = a.lu(DEFAULT_PIVOT_TOL)?; /// - /// let b = Vector::<2>::new([5.0, 11.0]); + /// let b = Vector::<2>::try_new([5.0, 11.0])?; /// let x = lu.solve_vec(b)?.into_array(); /// /// assert!((x[0] - 1.0).abs() <= 1e-12); @@ -752,12 +787,11 @@ impl Matrix { /// # Errors /// Returns [`LaError::Singular`] if, for some column `k`, the largest-magnitude candidate pivot /// in that column satisfies `|pivot| <= tol` (so no numerically usable pivot exists). - /// Returns [`LaError::NonFinite`] if a stored matrix entry is NaN/∞, or if - /// an elimination intermediate overflows to NaN/∞ before it can be stored in - /// the returned [`Lu`]. + /// Returns [`LaError::NonFinite`] if an elimination intermediate overflows + /// to NaN/∞ before it can be stored in the returned [`Lu`]. #[inline] pub fn lu(self, tol: Tolerance) -> Result, LaError> { - FiniteMatrix::try_new(self)?.lu(tol) + FiniteMatrix::new(self).lu(tol) } /// Compute an LDLT factorization (`A = L D Lᵀ`) without pivoting. @@ -786,14 +820,14 @@ impl Matrix { /// /// # fn main() -> Result<(), LaError> { /// // Note the symmetric layout: a[0][1] == a[1][0] == 2.0. - /// let a = Matrix::<2>::from_rows([[4.0, 2.0], [2.0, 3.0]]); + /// let a = Matrix::<2>::try_from_rows([[4.0, 2.0], [2.0, 3.0]])?; /// let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL)?; /// /// // det(A) = 8 /// assert!((ldlt.det()? - 8.0).abs() <= 1e-12); /// /// // Solve A x = b - /// let b = Vector::<2>::new([1.0, 2.0]); + /// let b = Vector::<2>::try_new([1.0, 2.0])?; /// let x = ldlt.solve_vec(b)?.into_array(); /// assert!((x[0] - (-0.125)).abs() <= 1e-12); /// assert!((x[1] - 0.75).abs() <= 1e-12); @@ -806,20 +840,21 @@ impl Matrix { /// diagonal entry `d = D[k,k]` is negative. /// Returns [`LaError::Singular`] if `0 <= d <= tol`, treating PSD degeneracy /// as singular/degenerate. - /// Returns [`LaError::NonFinite`] if NaN/∞ is detected during factorization. + /// Returns [`LaError::NonFinite`] if factorization computes a non-finite + /// intermediate. /// Returns [`LaError::Asymmetric`] if the input matrix is not symmetric. #[inline] pub fn ldlt(self, tol: Tolerance) -> Result, LaError> { - FiniteMatrix::try_new(self)?.ldlt(tol) + FiniteMatrix::new(self).ldlt(tol) } /// Return the first non-finite stored cell in row-major order. - const fn first_non_finite_cell(&self) -> Option<(usize, usize)> { + const fn first_non_finite_cell_in(rows: &[[f64; D]; D]) -> Option<(usize, usize)> { let mut r = 0; while r < D { let mut c = 0; while c < D { - if !self.rows[r][c].is_finite() { + if !rows[r][c].is_finite() { return Some((r, c)); } c += 1; @@ -844,7 +879,7 @@ impl Matrix { /// use la_stack::prelude::*; /// /// # fn main() -> Result<(), LaError> { - /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); + /// let m = Matrix::<2>::try_from_rows([[1.0, 2.0], [3.0, 4.0]])?; /// assert_eq!(m.det_direct()?, Some(-2.0)); /// /// // D = 0 is the empty product. @@ -857,14 +892,11 @@ impl Matrix { /// ``` /// /// # Errors - /// Returns [`LaError::NonFinite`] when any entry is NaN or infinity, or when - /// the closed-form determinant overflows to NaN or infinity. + /// Returns [`LaError::NonFinite`] when the closed-form determinant overflows + /// to NaN or infinity. #[inline] pub const fn det_direct(&self) -> Result, LaError> { - match FiniteMatrix::try_new(*self) { - Ok(finite) => finite.det_direct(), - Err(err) => Err(err), - } + FiniteMatrix::new(*self).det_direct() } /// Floating-point determinant, using closed-form formulas for D ≤ 4 and @@ -895,12 +927,12 @@ impl Matrix { /// ``` /// /// # Errors - /// Returns [`LaError::NonFinite`] if an entry is non-finite, if the LU - /// fallback computes a non-finite factorization cell, or if the determinant - /// product overflows to NaN or infinity. + /// Returns [`LaError::NonFinite`] if the LU fallback computes a non-finite + /// factorization cell, or if the determinant product overflows to NaN or + /// infinity. #[inline] pub fn det(self) -> Result { - FiniteMatrix::try_new(self)?.det() + FiniteMatrix::new(self).det() } /// Conservative absolute error bound for `det_direct()`. @@ -928,11 +960,11 @@ impl Matrix { /// use la_stack::prelude::*; /// /// # fn main() -> Result<(), LaError> { - /// let m = Matrix::<3>::from_rows([ + /// let m = Matrix::<3>::try_from_rows([ /// [1.0, 2.0, 3.0], /// [4.0, 5.0, 6.0], /// [7.0, 8.0, 9.0], - /// ]); + /// ])?; /// if let (Some(bound), Some(det_approx)) = (m.det_errbound()?, m.det_direct()?) { /// // If |det_approx| > bound, the sign is guaranteed correct. /// let sign_is_certified = det_approx.abs() > bound; @@ -964,14 +996,11 @@ impl Matrix { /// ``` /// /// # Errors - /// Returns [`LaError::NonFinite`] when any entry is NaN or infinity, or when - /// the bound computation overflows to NaN or infinity. + /// Returns [`LaError::NonFinite`] when the bound computation overflows to + /// NaN or infinity. #[inline] pub const fn det_errbound(&self) -> Result, LaError> { - match FiniteMatrix::try_new(*self) { - Ok(finite) => finite.det_errbound(), - Err(err) => Err(err), - } + FiniteMatrix::new(*self).det_errbound() } } @@ -1040,7 +1069,14 @@ mod tests { // Out-of-bounds set fails. let before_failed_set = m; - assert_eq!(m.set($d, 0, 3.0), None); + assert_eq!( + m.set($d, 0, 3.0), + Err(LaError::IndexOutOfBounds { + row: $d, + col: 0, + dim: $d, + }) + ); assert_eq!(m, before_failed_set); assert_eq!( m.set_checked($d, 0, 3.0), @@ -1063,7 +1099,7 @@ mod tests { assert_eq!(m.get(0, 0), Some(1.0)); // In-bounds set works. - assert_eq!(m.set(0, $d - 1, 3.0), Some(())); + assert_eq!(m.set(0, $d - 1, 3.0), Ok(())); assert_eq!(m.get(0, $d - 1), Some(3.0)); assert_eq!(m.set_checked($d - 1, 0, 4.0), Ok(())); assert_eq!(m.get_checked($d - 1, 0), Ok(4.0)); @@ -1181,6 +1217,21 @@ mod tests { ); } + #[test] + fn []() { + let mut rows = [[1.0f64; $d]; $d]; + rows[$d - 1][$d - 1] = f64::INFINITY; + let raw = Matrix::<$d>::from_rows_unchecked(rows); + + assert_eq!( + FiniteMatrix::<$d>::try_from(raw), + Err(LaError::NonFinite { + row: Some($d - 1), + col: $d - 1, + }) + ); + } + #[test] fn []() { let mut rows = [[0.0f64; $d]; $d]; @@ -1208,7 +1259,7 @@ mod tests { #[test] fn []() { - let finite = FiniteMatrix::<$d>::try_new(Matrix::<$d>::identity()).unwrap(); + let finite = FiniteMatrix::<$d>::new(Matrix::<$d>::identity()); let rhs = { let mut arr = [0.0f64; $d]; let values = [1.0f64, 2.0, 3.0, 4.0, 5.0]; @@ -1309,9 +1360,8 @@ mod tests { #[test] fn det_direct_d5_rejects_nonfinite_before_returning_none() { let mut m = Matrix::<5>::identity(); - assert_eq!(m.set(3, 4, f64::NAN), Some(())); assert_eq!( - m.det_direct(), + m.set(3, 4, f64::NAN), Err(LaError::NonFinite { row: Some(3), col: 4, @@ -1326,9 +1376,8 @@ mod tests { #[test] fn det_direct_rejects_nonfinite_entry_with_coordinates() { - let m = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [0.0, f64::NAN, 0.0], [0.0, 0.0, 1.0]]); assert_eq!( - m.det_direct(), + Matrix::<3>::try_from_rows([[1.0, 0.0, 0.0], [0.0, f64::NAN, 0.0], [0.0, 0.0, 1.0]]), Err(LaError::NonFinite { row: Some(1), col: 1, @@ -1493,9 +1542,8 @@ mod tests { #[test] fn det_returns_nonfinite_error_for_nan_d2() { - let m = Matrix::<2>::from_rows([[f64::NAN, 1.0], [1.0, 1.0]]); assert_eq!( - m.det(), + Matrix::<2>::try_from_rows([[f64::NAN, 1.0], [1.0, 1.0]]), Err(LaError::NonFinite { row: Some(0), col: 0 @@ -1505,10 +1553,12 @@ mod tests { #[test] fn det_returns_nonfinite_error_for_inf_d3() { - let m = - Matrix::<3>::from_rows([[f64::INFINITY, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]); assert_eq!( - m.det(), + Matrix::<3>::try_from_rows([ + [f64::INFINITY, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0] + ]), Err(LaError::NonFinite { row: Some(0), col: 0 @@ -1599,9 +1649,8 @@ mod tests { #[test] fn det_errbound_d1_rejects_nonfinite_even_with_zero_bound() { - let m = Matrix::<1>::from_rows([[f64::INFINITY]]); assert_eq!( - m.det_errbound(), + Matrix::<1>::try_from_rows([[f64::INFINITY]]), Err(LaError::NonFinite { row: Some(0), col: 0, @@ -1612,9 +1661,8 @@ mod tests { #[test] fn det_errbound_d5_rejects_nonfinite_before_returning_none() { let mut m = Matrix::<5>::identity(); - assert_eq!(m.set(4, 1, f64::NAN), Some(())); assert_eq!( - m.det_errbound(), + m.set(4, 1, f64::NAN), Err(LaError::NonFinite { row: Some(4), col: 1, @@ -1624,9 +1672,8 @@ mod tests { #[test] fn det_errbound_rejects_nonfinite_entry_with_coordinates() { - let m = Matrix::<2>::from_rows([[1.0, f64::INFINITY], [0.0, 1.0]]); assert_eq!( - m.det_errbound(), + Matrix::<2>::try_from_rows([[1.0, f64::INFINITY], [0.0, 1.0]]), Err(LaError::NonFinite { row: Some(0), col: 1, @@ -1703,9 +1750,8 @@ mod tests { fn []() { // Before the fix, `NaN > max_row_sum` was always false, so a // matrix full of NaN silently produced inf_norm == 0.0. - let m = Matrix::<$d>::from_rows([[f64::NAN; $d]; $d]); assert_eq!( - m.inf_norm(), + Matrix::<$d>::try_from_rows([[f64::NAN; $d]; $d]), Err(LaError::NonFinite { row: Some(0), col: 0, @@ -1719,9 +1765,8 @@ mod tests { let mut rows = [[0.0f64; $d]; $d]; rows[0][0] = f64::NAN; rows[$d - 1][$d - 1] = 1.0; - let m = Matrix::<$d>::from_rows(rows); assert_eq!( - m.inf_norm(), + Matrix::<$d>::try_from_rows(rows), Err(LaError::NonFinite { row: Some(0), col: 0, @@ -1734,9 +1779,8 @@ mod tests { // Infinity entries should be rejected with their source coordinates. let mut rows = [[0.0f64; $d]; $d]; rows[0][0] = f64::INFINITY; - let m = Matrix::<$d>::from_rows(rows); assert_eq!( - m.inf_norm(), + Matrix::<$d>::try_from_rows(rows), Err(LaError::NonFinite { row: Some(0), col: 0, @@ -1821,16 +1865,8 @@ mod tests { } rows[0][1] = f64::NAN; rows[1][0] = f64::NAN; - let a = Matrix::<$d>::from_rows(rows); assert_eq!( - a.is_symmetric(Tolerance::new(1e-12).unwrap()), - Err(LaError::NonFinite { - row: Some(0), - col: 1, - }) - ); - assert_eq!( - a.first_asymmetry(Tolerance::new(1e-12).unwrap()), + Matrix::<$d>::try_from_rows(rows), Err(LaError::NonFinite { row: Some(0), col: 1, @@ -1865,7 +1901,9 @@ mod tests { rows[$d - 1][0] = -1.0; assert_eq!( - FiniteMatrix::from_rows(rows).and_then(SymmetricMatrix::try_new), + Matrix::<$d>::try_from_rows(rows) + .map(FiniteMatrix::new) + .and_then(SymmetricMatrix::try_new), Err(LaError::Asymmetric { row: 0, col: $d - 1, @@ -1884,10 +1922,8 @@ mod tests { #[test] fn matrix_ldlt_rejects_nonfinite_before_asymmetry() { - let a = Matrix::<2>::from_rows([[1.0, f64::NAN], [0.0, 1.0]]); - assert_eq!( - a.ldlt(DEFAULT_PIVOT_TOL), + Matrix::<2>::try_from_rows([[1.0, f64::NAN], [0.0, 1.0]]), Err(LaError::NonFinite { row: Some(0), col: 1, @@ -1898,9 +1934,7 @@ mod tests { #[test] fn symmetric_matrix_into_matrix_roundtrips_storage_internally() { let a = Matrix::<2>::from_rows([[2.0, 1.0], [1.0, 3.0]]); - let symmetric = FiniteMatrix::try_new(a) - .and_then(SymmetricMatrix::try_new) - .unwrap(); + let symmetric = SymmetricMatrix::try_new(FiniteMatrix::new(a)).unwrap(); assert_eq!(symmetric.into_matrix(), a); } @@ -1927,16 +1961,8 @@ mod tests { #[test] fn first_asymmetry_rejects_infinite_offdiagonal() { - let a = Matrix::<2>::from_rows([[1.0, f64::INFINITY], [0.0, 1.0]]); - assert_eq!( - a.first_asymmetry(Tolerance::new(1e-12).unwrap()), - Err(LaError::NonFinite { - row: Some(0), - col: 1, - }) - ); assert_eq!( - a.is_symmetric(Tolerance::new(1e-12).unwrap()), + Matrix::<2>::try_from_rows([[1.0, f64::INFINITY], [0.0, 1.0]]), Err(LaError::NonFinite { row: Some(0), col: 1, @@ -1946,16 +1972,8 @@ mod tests { #[test] fn first_asymmetry_rejects_nan_diagonal() { - let a = Matrix::<2>::from_rows([[f64::NAN, 1.0], [1.0, 1.0]]); - assert_eq!( - a.first_asymmetry(Tolerance::new(1e-12).unwrap()), - Err(LaError::NonFinite { - row: Some(0), - col: 0, - }) - ); assert_eq!( - a.is_symmetric(Tolerance::new(1e-12).unwrap()), + Matrix::<2>::try_from_rows([[f64::NAN, 1.0], [1.0, 1.0]]), Err(LaError::NonFinite { row: Some(0), col: 0, diff --git a/src/vector.rs b/src/vector.rs index c83c3e4..13234eb 100644 --- a/src/vector.rs +++ b/src/vector.rs @@ -11,9 +11,9 @@ pub struct Vector { /// Fixed-size vector whose stored entries are all finite. /// -/// This proof-carrying wrapper lets callers validate raw [`Vector`] values once -/// at an input boundary, then pass the finite invariant into numerical -/// algorithms without rediscovering that no stored entry is NaN or infinite. +/// This proof-carrying wrapper makes the finite invariant explicit for internal +/// algorithms that should not rediscover that no stored entry is NaN or +/// infinite. #[must_use] #[derive(Clone, Copy, Debug, PartialEq)] #[allow(clippy::redundant_pub_crate)] @@ -24,28 +24,18 @@ pub(crate) struct FiniteVector { impl FiniteVector { /// Construct a finite vector without checking the invariant. /// - /// This is crate-internal so public callers must use [`try_new`](Self::try_new) - /// or [`from_array`](Self::from_array), which preserve diagnostics for - /// rejected raw entries. + /// This is crate-internal so raw storage still goes through + /// [`Vector::try_new`], which preserves diagnostics for rejected entries. #[inline] pub(crate) const fn new_unchecked(vector: Vector) -> Self { Self { vector } } - /// Validate that every stored vector entry is finite. - /// # Errors - /// Returns [`LaError::NonFinite`] with `row: None` and the first offending - /// entry index when `vector` contains NaN or infinity. + /// Wrap an already-finite vector for algorithms that carry the invariant + /// explicitly. #[inline] - pub const fn try_new(vector: Vector) -> Result { - let mut i = 0; - while i < D { - if !vector.data[i].is_finite() { - return Err(LaError::non_finite_at(i)); - } - i += 1; - } - Ok(Self::new_unchecked(vector)) + pub const fn new(vector: Vector) -> Self { + Self::new_unchecked(vector) } /// Validate raw vector storage and construct a finite vector. @@ -54,7 +44,10 @@ impl FiniteVector { /// entry index when `data` contains NaN or infinity. #[inline] pub const fn from_array(data: [f64; D]) -> Result { - Self::try_new(Vector::new(data)) + match Vector::try_new(data) { + Ok(vector) => Ok(Self::new_unchecked(vector)), + Err(err) => Err(err), + } } /// All-zeros finite vector. @@ -103,7 +96,14 @@ impl TryFrom> for FiniteVector { #[inline] fn try_from(value: Vector) -> Result { - Self::try_new(value) + let mut i = 0; + while i < D { + if !value.data[i].is_finite() { + return Err(LaError::non_finite_at(i)); + } + i += 1; + } + Ok(Self::new(value)) } } @@ -117,17 +117,54 @@ impl TryFrom<[f64; D]> for FiniteVector { } impl Vector { - /// Create a vector from a backing array. + /// Test-only infallible constructor for finite literal fixtures. + #[cfg(test)] + #[inline] + pub(crate) const fn new(data: [f64; D]) -> Self { + match Self::try_new(data) { + Ok(vector) => vector, + Err(_) => panic!("Vector::new requires finite entries"), + } + } + + /// Try to create a finite vector from a backing array. + /// + /// This is the public raw-storage boundary for vectors. Once construction + /// succeeds, methods such as [`dot`](Self::dot) and [`norm2_sq`](Self::norm2_sq) + /// do not need to rediscover that stored entries are finite. /// /// # Examples /// ``` /// use la_stack::prelude::*; /// - /// let v = Vector::<3>::new([1.0, 2.0, 3.0]); + /// # fn main() -> Result<(), LaError> { + /// let v = Vector::<3>::try_new([1.0, 2.0, 3.0])?; /// assert_eq!(v.into_array(), [1.0, 2.0, 3.0]); + /// # Ok(()) + /// # } /// ``` + /// + /// # Errors + /// Returns [`LaError::NonFinite`] with the first offending entry index when + /// `data` contains NaN or infinity. + #[inline] + pub const fn try_new(data: [f64; D]) -> Result { + let mut i = 0; + while i < D { + if !data[i].is_finite() { + return Err(LaError::non_finite_at(i)); + } + i += 1; + } + Ok(Self::new_unchecked(data)) + } + + /// Construct a vector without checking that entries are finite. + /// + /// This crate-internal escape hatch is reserved for literals and algorithm + /// outputs whose finite invariant is visible at the call site. #[inline] - pub const fn new(data: [f64; D]) -> Self { + pub(crate) const fn new_unchecked(data: [f64; D]) -> Self { Self { data } } @@ -142,7 +179,7 @@ impl Vector { /// ``` #[inline] pub const fn zero() -> Self { - Self { data: [0.0; D] } + Self::new_unchecked([0.0; D]) } /// Borrow the backing array. @@ -151,8 +188,11 @@ impl Vector { /// ``` /// use la_stack::prelude::*; /// - /// let v = Vector::<2>::new([1.0, -2.0]); + /// # fn main() -> Result<(), LaError> { + /// let v = Vector::<2>::try_new([1.0, -2.0])?; /// assert_eq!(v.as_array(), &[1.0, -2.0]); + /// # Ok(()) + /// # } /// ``` #[inline] #[must_use] @@ -166,9 +206,12 @@ impl Vector { /// ``` /// use la_stack::prelude::*; /// - /// let v = Vector::<2>::new([1.0, 2.0]); + /// # fn main() -> Result<(), LaError> { + /// let v = Vector::<2>::try_new([1.0, 2.0])?; /// let a = v.into_array(); /// assert_eq!(a, [1.0, 2.0]); + /// # Ok(()) + /// # } /// ``` #[inline] #[must_use] @@ -181,35 +224,29 @@ impl Vector { /// Terms are accumulated in `f64` using [`f64::mul_add`] at each index. /// Intermediate rounding occurs, and this method does not provide a /// certified absolute rounding bound for the returned dot product. The - /// returned [`Result`] is still checked for non-finite inputs and for - /// non-finite accumulation. + /// stored entries are finite by construction, so the returned [`Result`] + /// only reports non-finite accumulation. /// /// # Examples /// ``` /// use la_stack::prelude::*; /// /// # fn main() -> Result<(), LaError> { - /// let a = Vector::<3>::new([1.0, 2.0, 3.0]); - /// let b = Vector::<3>::new([-2.0, 0.5, 4.0]); + /// let a = Vector::<3>::try_new([1.0, 2.0, 3.0])?; + /// let b = Vector::<3>::try_new([-2.0, 0.5, 4.0])?; /// assert!((a.dot(b)? - 11.0).abs() <= 1e-12); /// # Ok(()) /// # } /// ``` /// /// # Errors - /// Returns [`LaError::NonFinite`] when either input contains NaN or infinity, - /// or when the accumulated dot product overflows to NaN or infinity. + /// Returns [`LaError::NonFinite`] when the accumulated dot product overflows + /// to NaN or infinity. #[inline] pub const fn dot(self, other: Self) -> Result { let mut acc = 0.0; let mut i = 0; while i < D { - if !self.data[i].is_finite() { - return Err(LaError::non_finite_at(i)); - } - if !other.data[i].is_finite() { - return Err(LaError::non_finite_at(i)); - } acc = self.data[i].mul_add(other.data[i], acc); if !acc.is_finite() { return Err(LaError::non_finite_at(i)); @@ -225,31 +262,28 @@ impl Vector { /// `f64` [`mul_add`](f64::mul_add) accumulation behavior as [`dot`](Self::dot). /// Intermediate rounding occurs, and this method does not provide a /// certified absolute rounding bound for the returned squared norm. The - /// returned [`Result`] is still checked for non-finite inputs and for - /// non-finite accumulation. + /// stored entries are finite by construction, so the returned [`Result`] + /// only reports non-finite accumulation. /// /// # Examples /// ``` /// use la_stack::prelude::*; /// /// # fn main() -> Result<(), LaError> { - /// let v = Vector::<3>::new([1.0, 2.0, 3.0]); + /// let v = Vector::<3>::try_new([1.0, 2.0, 3.0])?; /// assert!((v.norm2_sq()? - 14.0).abs() <= 1e-12); /// # Ok(()) /// # } /// ``` /// /// # Errors - /// Returns [`LaError::NonFinite`] when the input contains NaN or infinity, - /// or when the accumulated norm overflows to NaN or infinity. + /// Returns [`LaError::NonFinite`] when the accumulated norm overflows to NaN + /// or infinity. #[inline] pub const fn norm2_sq(self) -> Result { let mut acc = 0.0; let mut i = 0; while i < D { - if !self.data[i].is_finite() { - return Err(LaError::non_finite_at(i)); - } acc = self.data[i].mul_add(self.data[i], acc); if !acc.is_finite() { return Err(LaError::non_finite_at(i)); @@ -388,21 +422,12 @@ mod tests { } #[test] - fn []() { + fn []() { let mut a_arr = [1.0f64; $d]; a_arr[$d - 1] = f64::NAN; - let a = Vector::<$d>::new(a_arr); - let b = Vector::<$d>::new([1.0; $d]); assert_eq!( - a.dot(b), - Err(LaError::NonFinite { - row: None, - col: $d - 1, - }) - ); - assert_eq!( - a.norm2_sq(), + Vector::<$d>::try_new(a_arr), Err(LaError::NonFinite { row: None, col: $d - 1, @@ -411,13 +436,14 @@ mod tests { } #[test] - fn []() { - let a = Vector::<$d>::new([1.0; $d]); + fn []() { let mut b_arr = [1.0f64; $d]; b_arr[0] = f64::INFINITY; - let b = Vector::<$d>::new(b_arr); - assert_eq!(a.dot(b), Err(LaError::NonFinite { row: None, col: 0 })); + assert_eq!( + Vector::<$d>::try_new(b_arr), + Err(LaError::NonFinite { row: None, col: 0 }) + ); } #[test] @@ -472,6 +498,21 @@ mod tests { }) ); } + + #[test] + fn []() { + let mut arr = [1.0f64; $d]; + arr[$d - 1] = f64::NAN; + let raw = Vector::<$d>::new_unchecked(arr); + + assert_eq!( + FiniteVector::<$d>::try_from(raw), + Err(LaError::NonFinite { + row: None, + col: $d - 1, + }) + ); + } } }; } diff --git a/tests/proptest_exact.rs b/tests/proptest_exact.rs index 230abc1..c67afda 100644 --- a/tests/proptest_exact.rs +++ b/tests/proptest_exact.rs @@ -97,7 +97,7 @@ macro_rules! gen_det_sign_exact_proptests { for i in 0..$d { rows[i][i] = diag[i]; } - let m = Matrix::<$d>::from_rows(rows); + let m = Matrix::<$d>::try_from_rows(rows).unwrap(); let exact_sign = m.det_sign_exact().unwrap(); @@ -118,7 +118,7 @@ macro_rules! gen_det_sign_exact_proptests { for i in 0..$d { rows[i][i] = diag[i]; } - let m = Matrix::<$d>::from_rows(rows); + let m = Matrix::<$d>::try_from_rows(rows).unwrap(); let exact_sign = m.det_sign_exact().unwrap(); let fp_det = m.det().unwrap(); @@ -164,7 +164,7 @@ macro_rules! gen_solve_exact_roundtrip_proptests { x0 in proptest::array::[](small_int_f64()), ) { let rows = make_diagonally_dominant::<$d>(offdiag, diag); - let a = Matrix::<$d>::from_rows(rows); + let a = Matrix::<$d>::try_from_rows(rows).unwrap(); // b = A · x0, computed in f64. Small integers keep // every partial sum exact. @@ -176,7 +176,7 @@ macro_rules! gen_solve_exact_roundtrip_proptests { } b_arr[i] = sum; } - let b = Vector::<$d>::new(b_arr); + let b = Vector::<$d>::try_new(b_arr).unwrap(); let x = a.solve_exact(b).expect("diagonally-dominant A is non-singular"); let expected: [BigRational; $d] = from_fn(|i| { @@ -217,8 +217,8 @@ macro_rules! gen_solve_exact_residual_proptests { b_arr in proptest::array::[](small_int_f64()), ) { let rows = make_diagonally_dominant::<$d>(offdiag, diag); - let a = Matrix::<$d>::from_rows(rows); - let b = Vector::<$d>::new(b_arr); + let a = Matrix::<$d>::try_from_rows(rows).unwrap(); + let b = Vector::<$d>::try_new(b_arr).unwrap(); let x = a.solve_exact(b).expect("diagonally-dominant A is non-singular"); let ax = bigrational_matvec::<$d>(&rows, &x); @@ -256,7 +256,7 @@ macro_rules! gen_det_sign_agrees_with_det_exact_proptests { proptest::array::[](small_int_f64()), ), ) { - let m = Matrix::<$d>::from_rows(entries); + let m = Matrix::<$d>::try_from_rows(entries).unwrap(); let sign = m.det_sign_exact().unwrap(); let det = m.det_exact().unwrap(); let expected: i8 = if det.is_positive() { @@ -299,7 +299,7 @@ macro_rules! gen_det_sign_fast_filter_boundary_proptests { proptest::array::[](small_int_f64()), ), ) { - let m = Matrix::<$d>::from_rows(entries); + let m = Matrix::<$d>::try_from_rows(entries).unwrap(); let det = m .det_direct() .unwrap() diff --git a/tests/proptest_factorizations.rs b/tests/proptest_factorizations.rs index c4ef6a0..fb8498e 100644 --- a/tests/proptest_factorizations.rs +++ b/tests/proptest_factorizations.rs @@ -86,12 +86,12 @@ macro_rules! gen_factorization_proptests { b_arr[i] = sum; } - let a = Matrix::<$d>::from_rows(a_rows); + let a = Matrix::<$d>::try_from_rows(a_rows).unwrap(); let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap(); assert_abs_diff_eq!(ldlt.det().unwrap(), expected_det, epsilon = 1e-8); - let b = Vector::<$d>::new(b_arr); + let b = Vector::<$d>::try_new(b_arr).unwrap(); let x = ldlt.solve_vec(b).unwrap().into_array(); for i in 0..$d { assert_abs_diff_eq!(x[i], x_true[i], epsilon = 1e-8); @@ -166,12 +166,12 @@ macro_rules! gen_factorization_proptests { b_arr[i] = sum; } - let a = Matrix::<$d>::from_rows(a_rows); + let a = Matrix::<$d>::try_from_rows(a_rows).unwrap(); let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap(); assert_abs_diff_eq!(lu.det().unwrap(), expected_det, epsilon = 1e-8); - let b = Vector::<$d>::new(b_arr); + let b = Vector::<$d>::try_new(b_arr).unwrap(); let x = lu.solve_vec(b).unwrap().into_array(); for i in 0..$d { assert_abs_diff_eq!(x[i], x_true[i], epsilon = 1e-8); @@ -252,12 +252,12 @@ macro_rules! gen_factorization_proptests { b_arr[i] = sum; } - let a = Matrix::<$d>::from_rows(a_rows); + let a = Matrix::<$d>::try_from_rows(a_rows).unwrap(); let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap(); assert_abs_diff_eq!(lu.det().unwrap(), expected_det, epsilon = 1e-8); - let b = Vector::<$d>::new(b_arr); + let b = Vector::<$d>::try_new(b_arr).unwrap(); let x = lu.solve_vec(b).unwrap().into_array(); for i in 0..$d { assert_abs_diff_eq!(x[i], x_true[i], epsilon = 1e-8); diff --git a/tests/proptest_matrix.rs b/tests/proptest_matrix.rs index 2b4618c..dd685bb 100644 --- a/tests/proptest_matrix.rs +++ b/tests/proptest_matrix.rs @@ -31,12 +31,12 @@ macro_rules! gen_public_api_matrix_proptests { #![proptest_config(ProptestConfig::with_cases(64))] #[test] - fn []( + fn []( rows in proptest::array::[]( proptest::array::[](small_f64()), ), ) { - let m = Matrix::<$d>::from_rows(rows); + let m = Matrix::<$d>::try_from_rows(rows).unwrap(); for r in 0..$d { for c in 0..$d { @@ -73,7 +73,7 @@ macro_rules! gen_public_api_matrix_proptests { v in small_f64(), ) { let mut m = Matrix::<$d>::zero(); - prop_assert_eq!(m.set(r, c, v), Some(())); + prop_assert_eq!(m.set(r, c, v), Ok(())); assert_abs_diff_eq!(m.get(r, c).unwrap(), v, epsilon = 0.0); prop_assert_eq!(m.set_checked(r, c, -v), Ok(())); assert_abs_diff_eq!(m.get_checked(r, c).unwrap(), -v, epsilon = 0.0); @@ -86,10 +86,17 @@ macro_rules! gen_public_api_matrix_proptests { ), v in small_f64(), ) { - let mut m = Matrix::<$d>::from_rows(rows); + let mut m = Matrix::<$d>::try_from_rows(rows).unwrap(); let original = m; - prop_assert_eq!(m.set($d, 0, v), None); + prop_assert_eq!( + m.set($d, 0, v), + Err(LaError::IndexOutOfBounds { + row: $d, + col: 0, + dim: $d, + }) + ); prop_assert_eq!(m, original); prop_assert_eq!( m.set_checked($d, 0, v), @@ -117,7 +124,7 @@ macro_rules! gen_public_api_matrix_proptests { proptest::array::[](small_f64()), ), ) { - let m = Matrix::<$d>::from_rows(rows); + let m = Matrix::<$d>::try_from_rows(rows).unwrap(); let expected = rows .iter() @@ -138,7 +145,7 @@ macro_rules! gen_public_api_matrix_proptests { for i in 0..$d { rows[i][i] = diag[i]; } - let a = Matrix::<$d>::from_rows(rows); + let a = Matrix::<$d>::try_from_rows(rows).unwrap(); let det = a.det().unwrap(); let expected_det = { @@ -155,7 +162,7 @@ macro_rules! gen_public_api_matrix_proptests { assert_abs_diff_eq!(det, expected_det, epsilon = eps); let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap(); - let b = Vector::<$d>::new(b_arr); + let b = Vector::<$d>::try_new(b_arr).unwrap(); let x = lu.solve_vec(b).unwrap().into_array(); for i in 0..$d { @@ -208,14 +215,14 @@ macro_rules! gen_public_api_matrix_proptests { b_arr[i] = sum; } - let a = Matrix::<$d>::from_rows(a_rows); + let a = Matrix::<$d>::try_from_rows(a_rows).unwrap(); let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap(); let det_ldlt = ldlt.det().unwrap(); let det_lu = a.lu(DEFAULT_PIVOT_TOL).unwrap().det().unwrap(); assert_abs_diff_eq!(det_ldlt, det_lu, epsilon = 1e-8); - let b = Vector::<$d>::new(b_arr); + let b = Vector::<$d>::try_new(b_arr).unwrap(); let x = ldlt.solve_vec(b).unwrap().into_array(); for i in 0..$d { diff --git a/tests/proptest_vector.rs b/tests/proptest_vector.rs index 563a12b..da171cf 100644 --- a/tests/proptest_vector.rs +++ b/tests/proptest_vector.rs @@ -17,10 +17,10 @@ macro_rules! gen_public_api_vector_proptests { #![proptest_config(ProptestConfig::with_cases(64))] #[test] - fn []( + fn []( arr in proptest::array::[](small_f64()), ) { - let v = Vector::<$d>::new(arr); + let v = Vector::<$d>::try_new(arr).unwrap(); for i in 0..$d { assert_abs_diff_eq!(v.as_array()[i], arr[i], epsilon = 0.0); @@ -37,8 +37,8 @@ macro_rules! gen_public_api_vector_proptests { a_arr in proptest::array::[](small_f64()), b_arr in proptest::array::[](small_f64()), ) { - let a = Vector::<$d>::new(a_arr); - let b = Vector::<$d>::new(b_arr); + let a = Vector::<$d>::try_new(a_arr).unwrap(); + let b = Vector::<$d>::try_new(b_arr).unwrap(); let dot_ab = a.dot(b).unwrap(); let dot_reversed = b.dot(a).unwrap(); diff --git a/tests/semgrep/src/project_rules/public_api_panic_paths.rs b/tests/semgrep/src/project_rules/public_api_panic_paths.rs new file mode 100644 index 0000000..1e75762 --- /dev/null +++ b/tests/semgrep/src/project_rules/public_api_panic_paths.rs @@ -0,0 +1,63 @@ +#![forbid(unsafe_code)] +#![allow(dead_code)] + +pub enum Error { + Invalid, +} + +// ruleid: la-stack.rust.no-public-api-panic-paths +pub fn panics_on_input(value: usize) -> usize { + if value == 0 { + panic!("zero is invalid"); + } + value +} + +// ruleid: la-stack.rust.no-public-api-panic-paths +pub const fn asserts_on_input(value: usize) -> usize { + assert!(value > 0); + value +} + +// ruleid: la-stack.rust.no-public-api-panic-paths +pub fn unwraps_on_input(value: Option) -> usize { + value.unwrap() +} + +// ruleid: la-stack.rust.no-public-api-panic-paths +pub async fn expects_on_input(value: Option) -> usize { + value.expect("value is required") +} + +// ok: la-stack.rust.no-public-api-panic-paths +pub fn fallible_result(value: usize) -> Result { + if value == 0 { + Err(Error::Invalid) + } else { + Ok(value) + } +} + +// ok: la-stack.rust.no-public-api-panic-paths +pub fn fallible_option(value: usize) -> Option { + if value == 0 { None } else { Some(value) } +} + +// ok: la-stack.rust.no-public-api-panic-paths +pub fn total(value: usize) -> usize { + value + 1 +} + +// ok: la-stack.rust.no-public-api-panic-paths +pub(crate) fn crate_private_literal_helper(value: usize) -> usize { + assert!(value > 0); + value +} + +#[cfg(test)] +mod tests { + // ok: la-stack.rust.no-public-api-panic-paths + pub(crate) fn test_only_helper(value: usize) -> usize { + value.expect("test fixture has a value") + } +} diff --git a/tests/semgrep/src/project_rules/raw_f64_constructors.rs b/tests/semgrep/src/project_rules/raw_f64_constructors.rs new file mode 100644 index 0000000..da2c7ee --- /dev/null +++ b/tests/semgrep/src/project_rules/raw_f64_constructors.rs @@ -0,0 +1,51 @@ +#![forbid(unsafe_code)] +#![allow(dead_code)] + +#[must_use] +pub struct Matrix { + rows: [[f64; D]; D], +} + +#[must_use] +pub struct Vector { + data: [f64; D], +} + +#[non_exhaustive] +pub enum LaError { + NonFinite, +} + +impl Matrix { + // ruleid: la-stack.rust.no-public-infallible-raw-f64-constructors + pub const fn from_rows(rows: [[f64; D]; D]) -> Self { + Self { rows } + } + + // ok: la-stack.rust.no-public-infallible-raw-f64-constructors + pub const fn try_from_rows(rows: [[f64; D]; D]) -> Result { + Ok(Self { rows }) + } + + // ok: la-stack.rust.no-public-infallible-raw-f64-constructors + pub(crate) const fn from_rows_literal(rows: [[f64; D]; D]) -> Self { + Self { rows } + } +} + +impl Vector { + // ruleid: la-stack.rust.no-public-infallible-raw-f64-constructors + pub const fn new(data: [f64; D]) -> Self { + Self { data } + } + + // ok: la-stack.rust.no-public-infallible-raw-f64-constructors + pub const fn try_new(data: [f64; D]) -> Result { + Ok(Self { data }) + } + + // ok: la-stack.rust.no-public-infallible-raw-f64-constructors + pub(crate) const fn new_literal(data: [f64; D]) -> Self { + Self { data } + } +}