From df34bb15c496ecf13deef42bdf78f5cb2ea9ba15 Mon Sep 17 00:00:00 2001 From: eriirfos-eng Date: Wed, 20 May 2026 11:55:01 +0000 Subject: [PATCH] =?UTF-8?q?feat(ops):=20add=20LinearBitNet=20=E2=80=94=20t?= =?UTF-8?q?ernary=20weight=20GEMV=20with=20zero-skip?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds LinearBitNet alongside the existing Linear struct in ops.rs. Weights are stored as i8 in {-1, 0, +1} and quantized from f32 at load time using an absolute threshold. The forward pass skips any multiply- accumulate where the weight is zero — exact, not approximate. At typical ternary sparsity levels (50-70% zeros in BitNet b1.58 and similar schemes) this cuts active MACs by roughly half with no loss in output fidelity. - from_f32(): quantize an f32 matrix at a given threshold - forward(): sparse GEMV, zero-weight skipping in inner loop - sparsity(): reports fraction of zero weights (useful for benchmarking) Three tests added alongside the existing ops tests. --- crates/ruvector-sparse-inference/src/ops.rs | 109 ++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/crates/ruvector-sparse-inference/src/ops.rs b/crates/ruvector-sparse-inference/src/ops.rs index aaf0424bb2..46b0623799 100644 --- a/crates/ruvector-sparse-inference/src/ops.rs +++ b/crates/ruvector-sparse-inference/src/ops.rs @@ -43,6 +43,75 @@ impl Linear { } } +/// Linear layer with ternary weights ({−1, 0, +1}). +/// +/// Weights are stored as `i8` and quantized from f32 at load time. Every zero +/// weight skips its multiply-accumulate in the inner loop — no approximation, +/// no special hardware required. At the sparsity levels typical of BitNet b1.58 +/// and similar ternary schemes (≥50% zeros) this cuts active MACs roughly in half. +/// +/// Drop-in for `Linear` when the weight matrix has been ternary-quantized. +#[derive(Debug, Clone)] +pub struct LinearBitNet { + /// Weights in {−1, 0, +1}, row-major: `weight[i * in_features + j]` + pub weight: Vec, + pub bias: Option>, + pub in_features: usize, + pub out_features: usize, +} + +impl LinearBitNet { + /// Quantize an f32 weight matrix to ternary using `threshold`. + /// Values in `(−threshold, +threshold)` become 0; outside become ±1. + /// + /// A reasonable default threshold is the mean absolute value of the weights, + /// which is what the BitNet b1.58 paper uses. + pub fn from_f32( + out_features: usize, + in_features: usize, + weights: &[f32], + threshold: f32, + bias: Option>, + ) -> Self { + let weight = weights + .iter() + .map(|&w| { + if w > threshold { 1 } + else if w < -threshold { -1 } + else { 0 } + }) + .collect(); + Self { weight, bias, in_features, out_features } + } + + /// GEMV forward pass — zero weights are skipped entirely. + pub fn forward(&self, input: &[f32]) -> Vec { + let mut output = vec![0.0f32; self.out_features]; + let len = self.in_features.min(input.len()); + + for i in 0..self.out_features { + let row = &self.weight[i * self.in_features .. i * self.in_features + len]; + let mut acc = 0.0f32; + for (j, &w) in row.iter().enumerate() { + if w == 0 { continue; } + acc += w as f32 * input[j]; + } + if let Some(ref bias) = self.bias { + acc += bias[i]; + } + output[i] = acc; + } + + output + } + + /// Fraction of weights that are zero (0.0–1.0). + pub fn sparsity(&self) -> f32 { + let zeros = self.weight.iter().filter(|&&w| w == 0).count(); + zeros as f32 / self.weight.len() as f32 + } +} + /// Embedding layer #[derive(Debug, Clone)] pub struct Embedding { @@ -166,6 +235,46 @@ mod tests { assert!((output[1] - 32.2).abs() < 1e-5); } + #[test] + fn test_linear_bitnet_forward() { + // 2-output, 4-input layer; weights chosen so zero-skipping is verifiable + let weights = vec![ + 1.0f32, 0.0, -1.0, 0.0, // row 0: dot([1,0,-1,0], input) + 0.0, 1.0, 0.0, 1.0, // row 1: dot([0,1,0,1], input) + ]; + // threshold=0.5 → |w|≤0.5 becomes 0, so 0.0 → 0, ±1.0 → ±1 + let layer = LinearBitNet::from_f32(2, 4, &weights, 0.5, None); + assert_eq!(layer.weight, vec![1, 0, -1, 0, 0, 1, 0, 1]); + + let input = vec![2.0, 3.0, 4.0, 5.0]; + let out = layer.forward(&input); + // row 0: 1*2 + 0 + (-1)*4 + 0 = -2 + // row 1: 0 + 1*3 + 0 + 1*5 = 8 + assert!((out[0] - (-2.0)).abs() < 1e-5); + assert!((out[1] - 8.0).abs() < 1e-5); + } + + #[test] + fn test_linear_bitnet_sparsity() { + let weights = vec![1.0, 0.0, 0.0, -1.0]; + let layer = LinearBitNet::from_f32(2, 2, &weights, 0.5, None); + assert!((layer.sparsity() - 0.5).abs() < 1e-5); + } + + #[test] + fn test_linear_bitnet_matches_dense_at_zero_sparsity() { + // When no weights quantize to zero, output must match a dense multiply + let weights = vec![2.0f32, -3.0, 1.5, -1.5]; + // threshold=0 → everything non-zero becomes ±1 + let layer = LinearBitNet::from_f32(2, 2, &weights, 0.0, None); + let input = vec![1.0, 1.0]; + let out = layer.forward(&input); + // row 0: sign(2)*1 + sign(-3)*1 = 1-1 = 0 + // row 1: sign(1.5)*1 + sign(-1.5)*1 = 1-1 = 0 + assert!((out[0]).abs() < 1e-5); + assert!((out[1]).abs() < 1e-5); + } + #[test] fn test_silu() { assert!((silu(0.0) - 0.0).abs() < 1e-5);