Skip to content

feat(ops): add LinearBitNet — ternary weight GEMV with zero-skip#477

Open
eriirfos-eng wants to merge 1 commit into
ruvnet:mainfrom
eriirfos-eng:feat/bitnet-ternary-gemv
Open

feat(ops): add LinearBitNet — ternary weight GEMV with zero-skip#477
eriirfos-eng wants to merge 1 commit into
ruvnet:mainfrom
eriirfos-eng:feat/bitnet-ternary-gemv

Conversation

@eriirfos-eng
Copy link
Copy Markdown

What this adds

A LinearBitNet struct in crates/ruvector-sparse-inference/src/ops.rs, alongside the existing Linear.

The difference: weights are stored as i8 in {−1, 0, +1} rather than f32. The forward pass skips every multiply-accumulate where the weight is zero — no approximation, no special hardware. At the sparsity levels typical of BitNet b1.58 and similar ternary-quantized models (50–70% zeros), this halves the active MACs in the linear layer.

Why it belongs here

ruvector-sparse-inference already has a SparseFfn that exploits activation sparsity (skipping neurons that activate to zero). LinearBitNet is the complementary case: weight sparsity, where the zeros are baked into the model at quantization time rather than determined at runtime.

API

// Quantize an f32 weight matrix at load time
let layer = LinearBitNet::from_f32(out_features, in_features, &weights_f32, threshold, bias);

// Sparse forward — zero weights are skipped in the inner loop
let output = layer.forward(&input);

// Fraction of weights that are zero (useful for benchmarking)
let s = layer.sparsity();

threshold is the absolute value below which a weight becomes 0. The BitNet b1.58 paper uses the mean absolute value of the weight matrix as the threshold, which works well in practice.

Changes

  • crates/ruvector-sparse-inference/src/ops.rs: LinearBitNet struct with from_f32, forward, sparsity
  • Three tests: forward correctness (hand-verifiable), sparsity reporting, sign-consistency with dense at zero threshold

What it does not change

Existing Linear, Embedding, RMSNorm, LayerNorm and all existing tests are untouched. All 91 existing tests pass.

Adds LinearBitNet alongside the existing Linear struct in ops.rs.

Weights are stored as i8 in {-1, 0, +1} and quantized from f32 at load
time using an absolute threshold. The forward pass skips any multiply-
accumulate where the weight is zero — exact, not approximate. At typical
ternary sparsity levels (50-70% zeros in BitNet b1.58 and similar schemes)
this cuts active MACs by roughly half with no loss in output fidelity.

- from_f32(): quantize an f32 matrix at a given threshold
- forward(): sparse GEMV, zero-weight skipping in inner loop
- sparsity(): reports fraction of zero weights (useful for benchmarking)

Three tests added alongside the existing ops tests.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant