diff --git a/encodings/alp/public-api.lock b/encodings/alp/public-api.lock index 1671433c077..ee64904d6cf 100644 --- a/encodings/alp/public-api.lock +++ b/encodings/alp/public-api.lock @@ -12,14 +12,10 @@ pub fn vortex_alp::ALPArray::exponents(&self) -> vortex_alp::Exponents pub fn vortex_alp::ALPArray::into_parts(self) -> (vortex_array::array::ArrayRef, vortex_alp::Exponents, core::option::Option, vortex_array::dtype::DType) -pub fn vortex_alp::ALPArray::new(encoded: vortex_array::array::ArrayRef, exponents: vortex_alp::Exponents, patches: core::option::Option) -> Self - pub fn vortex_alp::ALPArray::patches(&self) -> core::option::Option<&vortex_array::patches::Patches> pub fn vortex_alp::ALPArray::ptype(&self) -> vortex_array::dtype::ptype::PType -pub fn vortex_alp::ALPArray::try_new(encoded: vortex_array::array::ArrayRef, exponents: vortex_alp::Exponents, patches: core::option::Option) -> vortex_error::VortexResult - impl vortex_alp::ALPArray pub fn vortex_alp::ALPArray::to_array(&self) -> vortex_array::array::ArrayRef @@ -88,8 +84,6 @@ pub fn vortex_alp::ALPRDArray::right_bit_width(&self) -> u8 pub fn vortex_alp::ALPRDArray::right_parts(&self) -> &vortex_array::array::ArrayRef -pub fn vortex_alp::ALPRDArray::try_new(dtype: vortex_array::dtype::DType, left_parts: vortex_array::array::ArrayRef, left_parts_dictionary: vortex_buffer::buffer::Buffer, right_parts: vortex_array::array::ArrayRef, right_bit_width: u8, left_parts_patches: core::option::Option) -> vortex_error::VortexResult - impl vortex_alp::ALPRDArray pub fn vortex_alp::ALPRDArray::to_array(&self) -> vortex_array::array::ArrayRef @@ -152,6 +146,8 @@ impl vortex_alp::ALPRDVTable pub const vortex_alp::ALPRDVTable::ID: vortex_array::vtable::dyn_::ArrayId +pub fn vortex_alp::ALPRDVTable::try_new(dtype: vortex_array::dtype::DType, left_parts: vortex_array::array::ArrayRef, left_parts_dictionary: vortex_buffer::buffer::Buffer, right_parts: vortex_array::array::ArrayRef, right_bit_width: u8, left_parts_patches: core::option::Option) -> vortex_error::VortexResult + impl core::fmt::Debug for vortex_alp::ALPRDVTable pub fn vortex_alp::ALPRDVTable::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result @@ -240,6 +236,10 @@ impl vortex_alp::ALPVTable pub const vortex_alp::ALPVTable::ID: vortex_array::vtable::dyn_::ArrayId +pub fn vortex_alp::ALPVTable::new(encoded: vortex_array::array::ArrayRef, exponents: vortex_alp::Exponents, patches: core::option::Option) -> vortex_alp::ALPArray + +pub fn vortex_alp::ALPVTable::try_new(encoded: vortex_array::array::ArrayRef, exponents: vortex_alp::Exponents, patches: core::option::Option) -> vortex_error::VortexResult + impl core::fmt::Debug for vortex_alp::ALPVTable pub fn vortex_alp::ALPVTable::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result diff --git a/encodings/alp/src/alp/array.rs b/encodings/alp/src/alp/array.rs index 6cdfb0241bc..8cc27002983 100644 --- a/encodings/alp/src/alp/array.rs +++ b/encodings/alp/src/alp/array.rs @@ -179,7 +179,7 @@ impl VTable for ALPVTable { }) .transpose()?; - ALPArray::try_new( + Self::try_new( encoded, Exponents { e: u8::try_from(metadata.exp_e)?, @@ -268,98 +268,15 @@ pub struct ALPArray { #[derive(Debug)] pub struct ALPVTable; +#[allow(clippy::new_ret_no_self)] impl ALPVTable { pub const ID: ArrayId = ArrayId::new_ref("vortex.alp"); -} - -#[derive(Clone, prost::Message)] -pub struct ALPMetadata { - #[prost(uint32, tag = "1")] - pub(crate) exp_e: u32, - #[prost(uint32, tag = "2")] - pub(crate) exp_f: u32, - #[prost(message, optional, tag = "3")] - pub(crate) patches: Option, -} - -impl ALPArray { - fn validate( - encoded: &ArrayRef, - exponents: Exponents, - patches: Option<&Patches>, - ) -> VortexResult<()> { - vortex_ensure!( - matches!( - encoded.dtype(), - DType::Primitive(PType::I32 | PType::I64, _) - ), - "ALP encoded ints have invalid DType {}", - encoded.dtype(), - ); - - // Validate exponents are in-bounds for the float, and that patches have the proper - // length and type. - let Exponents { e, f } = exponents; - match encoded.dtype().as_ptype() { - PType::I32 => { - vortex_ensure!(exponents.e <= f32::MAX_EXPONENT, "e out of bounds: {e}"); - vortex_ensure!(exponents.f <= f32::MAX_EXPONENT, "f out of bounds: {f}"); - if let Some(patches) = patches { - Self::validate_patches::(patches, encoded)?; - } - } - PType::I64 => { - vortex_ensure!(e <= f64::MAX_EXPONENT, "e out of bounds: {e}"); - vortex_ensure!(f <= f64::MAX_EXPONENT, "f out of bounds: {f}"); - if let Some(patches) = patches { - Self::validate_patches::(patches, encoded)?; - } - } - _ => unreachable!(), - } - - // Validate patches - if let Some(patches) = patches { - vortex_ensure!( - patches.array_len() == encoded.len(), - "patches array_len != encoded len: {} != {}", - patches.array_len(), - encoded.len() - ); - - // Verify that the patches DType are of the proper DType. - } - - Ok(()) - } - - /// Validate that any patches provided are valid for the ALPArray. - fn validate_patches(patches: &Patches, encoded: &ArrayRef) -> VortexResult<()> { - vortex_ensure!( - patches.array_len() == encoded.len(), - "patches array_len != encoded len: {} != {}", - patches.array_len(), - encoded.len() - ); - - let expected_type = DType::Primitive(T::PTYPE, encoded.dtype().nullability()); - vortex_ensure!( - patches.dtype() == &expected_type, - "Expected patches type {expected_type}, actual {}", - patches.dtype(), - ); - - Ok(()) - } -} - -impl ALPArray { /// Build a new `ALPArray` from components, panicking on validation failure. /// - /// See [`ALPArray::try_new`] for reference on preconditions that must pass before - /// calling this method. - pub fn new(encoded: ArrayRef, exponents: Exponents, patches: Option) -> Self { + /// See [`ALPVTable::try_new`] for reference on preconditions that must + /// pass before calling this method. + pub fn new(encoded: ArrayRef, exponents: Exponents, patches: Option) -> ALPArray { Self::try_new(encoded, exponents, patches).vortex_expect("ALPArray new") } @@ -385,12 +302,12 @@ impl ALPArray { /// # Examples /// /// ``` - /// # use vortex_alp::{ALPArray, Exponents}; + /// # use vortex_alp::{ALPArray, ALPVTable, Exponents}; /// # use vortex_array::IntoArray; /// # use vortex_buffer::buffer; /// /// // Returns error because buffer has wrong PType. - /// let result = ALPArray::try_new( + /// let result = ALPVTable::try_new( /// buffer![1i8].into_array(), /// Exponents { e: 1, f: 1 }, /// None @@ -398,7 +315,7 @@ impl ALPArray { /// assert!(result.is_err()); /// /// // Returns error because Exponents are out of bounds for f32 - /// let result = ALPArray::try_new( + /// let result = ALPVTable::try_new( /// buffer![1i32, 2i32].into_array(), /// Exponents { e: 100, f: 100 }, /// None @@ -406,7 +323,7 @@ impl ALPArray { /// assert!(result.is_err()); /// /// // Success! - /// let value = ALPArray::try_new( + /// let value = ALPVTable::try_new( /// buffer![0i32].into_array(), /// Exponents { e: 1, f: 1 }, /// None @@ -418,8 +335,8 @@ impl ALPArray { encoded: ArrayRef, exponents: Exponents, patches: Option, - ) -> VortexResult { - Self::validate(&encoded, exponents, patches.as_ref())?; + ) -> VortexResult { + validate_alp(&encoded, exponents, patches.as_ref())?; let dtype = match encoded.dtype() { DType::Primitive(PType::I32, nullability) => DType::Primitive(PType::F32, *nullability), @@ -428,53 +345,130 @@ impl ALPArray { }; let len = encoded.len(); - Ok(Self { + Ok(ALPArray { common: ArrayCommon::new(len, dtype), encoded, exponents, patches, }) } +} - /// Build a new `ALPArray` from components without validation. - /// - /// See [`ALPArray::try_new`] for information about the preconditions that should be checked - /// **before** calling this method. - pub(crate) unsafe fn new_unchecked( - encoded: ArrayRef, - exponents: Exponents, - patches: Option, - dtype: DType, - ) -> Self { - let len = encoded.len(); - Self { - common: ArrayCommon::new(len, dtype), - encoded, - exponents, - patches, +#[derive(Clone, prost::Message)] +pub struct ALPMetadata { + #[prost(uint32, tag = "1")] + pub(crate) exp_e: u32, + #[prost(uint32, tag = "2")] + pub(crate) exp_f: u32, + #[prost(message, optional, tag = "3")] + pub(crate) patches: Option, +} + +fn validate_alp( + encoded: &ArrayRef, + exponents: Exponents, + patches: Option<&Patches>, +) -> VortexResult<()> { + vortex_ensure!( + matches!( + encoded.dtype(), + DType::Primitive(PType::I32 | PType::I64, _) + ), + "ALP encoded ints have invalid DType {}", + encoded.dtype(), + ); + + // Validate exponents are in-bounds for the float, and that patches have the proper + // length and type. + let Exponents { e, f } = exponents; + match encoded.dtype().as_ptype() { + PType::I32 => { + vortex_ensure!(exponents.e <= f32::MAX_EXPONENT, "e out of bounds: {e}"); + vortex_ensure!(exponents.f <= f32::MAX_EXPONENT, "f out of bounds: {f}"); + if let Some(patches) = patches { + validate_alp_patches::(patches, encoded)?; + } + } + PType::I64 => { + vortex_ensure!(e <= f64::MAX_EXPONENT, "e out of bounds: {e}"); + vortex_ensure!(f <= f64::MAX_EXPONENT, "f out of bounds: {f}"); + + if let Some(patches) = patches { + validate_alp_patches::(patches, encoded)?; + } } + _ => unreachable!(), } - pub fn ptype(&self) -> PType { + // Validate patches + if let Some(patches) = patches { + vortex_ensure!( + patches.array_len() == encoded.len(), + "patches array_len != encoded len: {} != {}", + patches.array_len(), + encoded.len() + ); + } + + Ok(()) +} + +fn validate_alp_patches(patches: &Patches, encoded: &ArrayRef) -> VortexResult<()> { + vortex_ensure!( + patches.array_len() == encoded.len(), + "patches array_len != encoded len: {} != {}", + patches.array_len(), + encoded.len() + ); + + let expected_type = DType::Primitive(T::PTYPE, encoded.dtype().nullability()); + vortex_ensure!( + patches.dtype() == &expected_type, + "Expected patches type {expected_type}, actual {}", + patches.dtype(), + ); + + Ok(()) +} + +/// Extension trait for [`ALPArray`] methods. +pub trait ALPArrayExt: Sized { + /// Returns the primitive type of the array. + fn ptype(&self) -> PType; + + /// Returns a reference to the encoded child array. + fn encoded(&self) -> &ArrayRef; + + /// Returns the ALP exponents. + fn exponents(&self) -> Exponents; + + /// Returns a reference to the patches, if any. + fn patches(&self) -> Option<&Patches>; + + /// Consumes the array and returns its parts. + fn into_parts(self) -> (ArrayRef, Exponents, Option, DType); +} + +impl ALPArrayExt for ALPArray { + fn ptype(&self) -> PType { self.common.dtype().as_ptype() } - pub fn encoded(&self) -> &ArrayRef { + fn encoded(&self) -> &ArrayRef { &self.encoded } #[inline] - pub fn exponents(&self) -> Exponents { + fn exponents(&self) -> Exponents { self.exponents } - pub fn patches(&self) -> Option<&Patches> { + fn patches(&self) -> Option<&Patches> { self.patches.as_ref() } - /// Consumes the array and returns its parts. #[inline] - pub fn into_parts(self) -> (ArrayRef, Exponents, Option, DType) { + fn into_parts(self) -> (ArrayRef, Exponents, Option, DType) { ( self.encoded, self.exponents, @@ -484,6 +478,27 @@ impl ALPArray { } } +impl ALPArray { + /// Build a new `ALPArray` from components without validation. + /// + /// See [`ALPVTable::try_new`] for information about the preconditions + /// that should be checked **before** calling this method. + pub(crate) unsafe fn new_unchecked( + encoded: ArrayRef, + exponents: Exponents, + patches: Option, + dtype: DType, + ) -> Self { + let len = encoded.len(); + Self { + common: ArrayCommon::new(len, dtype), + encoded, + exponents, + patches, + } + } +} + impl ValidityChild for ALPVTable { fn validity_child(array: &ALPArray) -> &ArrayRef { array.encoded() @@ -805,7 +820,7 @@ mod tests { .unwrap(); // Build a new ALPArray with the same encoded data but patches without chunk_offsets. - let alp_without_chunk_offsets = ALPArray::new( + let alp_without_chunk_offsets = ALPVTable::new( normally_encoded.encoded().clone(), normally_encoded.exponents(), Some(patches_without_chunk_offsets), diff --git a/encodings/alp/src/alp/compress.rs b/encodings/alp/src/alp/compress.rs index 3759f354a22..960e5edc956 100644 --- a/encodings/alp/src/alp/compress.rs +++ b/encodings/alp/src/alp/compress.rs @@ -141,6 +141,7 @@ mod tests { use vortex_buffer::buffer; use super::*; + use crate::alp::ALPArrayExt; use crate::decompress_into_array; #[test] diff --git a/encodings/alp/src/alp/compute/between.rs b/encodings/alp/src/alp/compute/between.rs index d6fe760c192..0787f411b11 100644 --- a/encodings/alp/src/alp/compute/between.rs +++ b/encodings/alp/src/alp/compute/between.rs @@ -17,6 +17,7 @@ use vortex_array::scalar_fn::fns::between::StrictComparison; use vortex_error::VortexResult; use crate::ALPArray; +use crate::ALPArrayExt; use crate::ALPFloat; use crate::ALPVTable; use crate::match_each_alp_float_ptype; @@ -101,6 +102,7 @@ mod tests { use vortex_array::scalar_fn::fns::between::StrictComparison; use crate::ALPArray; + use crate::ALPArrayExt; use crate::alp::compute::between::between_impl; use crate::alp_encode; diff --git a/encodings/alp/src/alp/compute/cast.rs b/encodings/alp/src/alp/compute/cast.rs index fff47f8b479..fb1b9fd0ab7 100644 --- a/encodings/alp/src/alp/compute/cast.rs +++ b/encodings/alp/src/alp/compute/cast.rs @@ -10,6 +10,7 @@ use vortex_array::scalar_fn::fns::cast::CastReduce; use vortex_error::VortexResult; use crate::alp::ALPArray; +use crate::alp::ALPArrayExt; use crate::alp::ALPVTable; impl CastReduce for ALPVTable { @@ -76,6 +77,7 @@ mod tests { use vortex_error::VortexExpect; use vortex_error::VortexResult; + use crate::ALPArrayExt; use crate::alp_encode; #[test] diff --git a/encodings/alp/src/alp/compute/compare.rs b/encodings/alp/src/alp/compute/compare.rs index 57cd881a970..0971cee37ed 100644 --- a/encodings/alp/src/alp/compute/compare.rs +++ b/encodings/alp/src/alp/compute/compare.rs @@ -19,6 +19,7 @@ use vortex_error::vortex_bail; use vortex_error::vortex_err; use crate::ALPArray; +use crate::ALPArrayExt; use crate::ALPFloat; use crate::ALPVTable; use crate::match_each_alp_float_ptype; diff --git a/encodings/alp/src/alp/compute/filter.rs b/encodings/alp/src/alp/compute/filter.rs index 119134e9578..c5f7f1a2658 100644 --- a/encodings/alp/src/alp/compute/filter.rs +++ b/encodings/alp/src/alp/compute/filter.rs @@ -9,6 +9,7 @@ use vortex_error::VortexResult; use vortex_mask::Mask; use crate::ALPArray; +use crate::ALPArrayExt; use crate::ALPVTable; impl FilterKernel for ALPVTable { diff --git a/encodings/alp/src/alp/compute/mask.rs b/encodings/alp/src/alp/compute/mask.rs index e99c805f6ba..b526df99c8f 100644 --- a/encodings/alp/src/alp/compute/mask.rs +++ b/encodings/alp/src/alp/compute/mask.rs @@ -11,6 +11,7 @@ use vortex_array::validity::Validity; use vortex_error::VortexResult; use crate::ALPArray; +use crate::ALPArrayExt; use crate::ALPVTable; impl MaskReduce for ALPVTable { @@ -21,7 +22,7 @@ impl MaskReduce for ALPVTable { } let masked_encoded = array.encoded().clone().mask(mask.clone())?; Ok(Some( - ALPArray::new(masked_encoded, array.exponents(), None).into_array(), + ALPVTable::new(masked_encoded, array.exponents(), None).into_array(), )) } } @@ -40,7 +41,7 @@ impl MaskKernel for ALPVTable { .transpose()? .flatten(); Ok(Some( - ALPArray::new(masked_encoded, array.exponents(), masked_patches).into_array(), + ALPVTable::new(masked_encoded, array.exponents(), masked_patches).into_array(), )) } } @@ -54,6 +55,7 @@ mod test { use vortex_array::compute::conformance::mask::test_mask_conformance; use vortex_buffer::buffer; + use crate::ALPArrayExt; use crate::alp_encode; #[rstest] diff --git a/encodings/alp/src/alp/compute/nan_count.rs b/encodings/alp/src/alp/compute/nan_count.rs index 74d8e27ebb9..fe62e0323fe 100644 --- a/encodings/alp/src/alp/compute/nan_count.rs +++ b/encodings/alp/src/alp/compute/nan_count.rs @@ -8,6 +8,7 @@ use vortex_array::register_kernel; use vortex_error::VortexResult; use crate::ALPArray; +use crate::ALPArrayExt; use crate::ALPVTable; impl NaNCountKernel for ALPVTable { diff --git a/encodings/alp/src/alp/compute/slice.rs b/encodings/alp/src/alp/compute/slice.rs index c9e656df550..afeec7d5b93 100644 --- a/encodings/alp/src/alp/compute/slice.rs +++ b/encodings/alp/src/alp/compute/slice.rs @@ -9,7 +9,7 @@ use vortex_array::IntoArray; use vortex_array::arrays::slice::SliceKernel; use vortex_error::VortexResult; -use crate::ALPArray; +use crate::ALPArrayExt; use crate::ALPVTable; impl SliceKernel for ALPVTable { @@ -18,7 +18,7 @@ impl SliceKernel for ALPVTable { range: Range, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - let sliced_alp = ALPArray::new( + let sliced_alp = ALPVTable::new( array.encoded().slice(range.clone())?, array.exponents(), array diff --git a/encodings/alp/src/alp/compute/take.rs b/encodings/alp/src/alp/compute/take.rs index d91162cef8f..98ad505769f 100644 --- a/encodings/alp/src/alp/compute/take.rs +++ b/encodings/alp/src/alp/compute/take.rs @@ -9,6 +9,7 @@ use vortex_array::arrays::dict::TakeExecute; use vortex_error::VortexResult; use crate::ALPArray; +use crate::ALPArrayExt; use crate::ALPVTable; impl TakeExecute for ALPVTable { @@ -32,7 +33,7 @@ impl TakeExecute for ALPVTable { }) .transpose()?; Ok(Some( - ALPArray::new(taken_encoded, array.exponents(), taken_patches).into_array(), + ALPVTable::new(taken_encoded, array.exponents(), taken_patches).into_array(), )) } } diff --git a/encodings/alp/src/alp/decompress.rs b/encodings/alp/src/alp/decompress.rs index c18ecf0b2d9..a2da6634150 100644 --- a/encodings/alp/src/alp/decompress.rs +++ b/encodings/alp/src/alp/decompress.rs @@ -16,6 +16,7 @@ use vortex_buffer::BufferMut; use vortex_error::VortexResult; use crate::ALPArray; +use crate::ALPArrayExt; use crate::ALPFloat; use crate::Exponents; use crate::match_each_alp_float_ptype; diff --git a/encodings/alp/src/alp/ops.rs b/encodings/alp/src/alp/ops.rs index a0e6388a4dc..40b133c35ad 100644 --- a/encodings/alp/src/alp/ops.rs +++ b/encodings/alp/src/alp/ops.rs @@ -7,6 +7,7 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use crate::ALPArray; +use crate::ALPArrayExt; use crate::ALPFloat; use crate::ALPVTable; use crate::match_each_alp_float_ptype; diff --git a/encodings/alp/src/alp_rd/array.rs b/encodings/alp/src/alp_rd/array.rs index 7242edec0f7..080e093320d 100644 --- a/encodings/alp/src/alp_rd/array.rs +++ b/encodings/alp/src/alp_rd/array.rs @@ -239,7 +239,7 @@ impl VTable for ALPRDVTable { }) .transpose()?; - ALPRDArray::try_new( + Self::try_new( dtype.clone(), left_parts, left_parts_dictionary, @@ -370,9 +370,7 @@ pub struct ALPRDVTable; impl ALPRDVTable { pub const ID: ArrayId = ArrayId::new_ref("vortex.alprd"); -} -impl ALPRDArray { /// Build a new `ALPRDArray` from components. pub fn try_new( dtype: DType, @@ -381,7 +379,7 @@ impl ALPRDArray { right_parts: ArrayRef, right_bit_width: u8, left_parts_patches: Option, - ) -> VortexResult { + ) -> VortexResult { if !dtype.is_float() { vortex_bail!("ALPRDArray given invalid DType ({dtype})"); } @@ -428,7 +426,7 @@ impl ALPRDArray { .transpose()?; let len = left_parts.len(); - Ok(Self { + Ok(ALPRDArray { common: ArrayCommon::new(len, dtype), left_parts, left_parts_dictionary, @@ -437,70 +435,88 @@ impl ALPRDArray { left_parts_patches, }) } +} - /// Build a new `ALPRDArray` from components. This does not perform any validation, and instead - /// it constructs it from parts. - pub(crate) unsafe fn new_unchecked( - dtype: DType, - left_parts: ArrayRef, - left_parts_dictionary: Buffer, - right_parts: ArrayRef, - right_bit_width: u8, - left_parts_patches: Option, - ) -> Self { - let len = left_parts.len(); - Self { - common: ArrayCommon::new(len, dtype), - left_parts, - left_parts_patches, - left_parts_dictionary, - right_parts, - right_bit_width, - } - } - +/// Extension trait for [`ALPRDArray`] methods. +pub trait ALPRDArrayExt { /// Returns true if logical type of the array values is f32. - /// - /// Returns false if the logical type of the array values is f64. + fn is_f32(&self) -> bool; + + /// The leftmost (most significant) bits of the floating point values stored in the array. + fn left_parts(&self) -> &ArrayRef; + + /// The rightmost (least significant) bits of the floating point values stored in the array. + fn right_parts(&self) -> &ArrayRef; + + /// Returns the right bit width. + fn right_bit_width(&self) -> u8; + + /// Patches of left-most bits. + fn left_parts_patches(&self) -> Option<&Patches>; + + /// The dictionary that maps the codes in `left_parts` into bit patterns. + fn left_parts_dictionary(&self) -> &Buffer; + + /// Replace the left parts patches. + fn replace_left_parts_patches(&mut self, patches: Option); +} + +impl ALPRDArrayExt for ALPRDArray { #[inline] - pub fn is_f32(&self) -> bool { + fn is_f32(&self) -> bool { matches!(self.common.dtype(), DType::Primitive(PType::F32, _)) } - /// The leftmost (most significant) bits of the floating point values stored in the array. - /// - /// These are bit-packed and dictionary encoded, and cannot directly be interpreted without - /// the metadata of this array. - pub fn left_parts(&self) -> &ArrayRef { + fn left_parts(&self) -> &ArrayRef { &self.left_parts } - /// The rightmost (least significant) bits of the floating point values stored in the array. - pub fn right_parts(&self) -> &ArrayRef { + fn right_parts(&self) -> &ArrayRef { &self.right_parts } #[inline] - pub fn right_bit_width(&self) -> u8 { + fn right_bit_width(&self) -> u8 { self.right_bit_width } - /// Patches of left-most bits. - pub fn left_parts_patches(&self) -> Option<&Patches> { + fn left_parts_patches(&self) -> Option<&Patches> { self.left_parts_patches.as_ref() } - /// The dictionary that maps the codes in `left_parts` into bit patterns. #[inline] - pub fn left_parts_dictionary(&self) -> &Buffer { + fn left_parts_dictionary(&self) -> &Buffer { &self.left_parts_dictionary } - pub fn replace_left_parts_patches(&mut self, patches: Option) { + fn replace_left_parts_patches(&mut self, patches: Option) { self.left_parts_patches = patches; } } +impl ALPRDArray { + /// Build a new `ALPRDArray` from components. This does not perform any validation, and instead + /// it constructs it from parts. + pub(crate) unsafe fn new_unchecked( + dtype: DType, + left_parts: ArrayRef, + left_parts_dictionary: Buffer, + right_parts: ArrayRef, + right_bit_width: u8, + left_parts_patches: Option, + ) -> Self { + let len = left_parts.len(); + Self { + common: ArrayCommon::new(len, dtype), + left_parts, + left_parts_patches, + left_parts_dictionary, + right_parts, + right_bit_width, + } + } +} + impl ValidityChild for ALPRDVTable { fn validity_child(array: &ALPRDArray) -> &ArrayRef { array.left_parts() diff --git a/encodings/alp/src/alp_rd/compute/cast.rs b/encodings/alp/src/alp_rd/compute/cast.rs index 3ef1dc8be26..5472dacec61 100644 --- a/encodings/alp/src/alp_rd/compute/cast.rs +++ b/encodings/alp/src/alp_rd/compute/cast.rs @@ -9,6 +9,7 @@ use vortex_array::scalar_fn::fns::cast::CastReduce; use vortex_error::VortexResult; use crate::alp_rd::ALPRDArray; +use crate::alp_rd::ALPRDArrayExt; use crate::alp_rd::ALPRDVTable; impl CastReduce for ALPRDVTable { @@ -28,7 +29,7 @@ impl CastReduce for ALPRDVTable { )?; return Ok(Some( - ALPRDArray::try_new( + ALPRDVTable::try_new( dtype.clone(), new_left_parts, array.left_parts_dictionary().clone(), diff --git a/encodings/alp/src/alp_rd/compute/filter.rs b/encodings/alp/src/alp_rd/compute/filter.rs index d21622f49d0..64b37f7dc47 100644 --- a/encodings/alp/src/alp_rd/compute/filter.rs +++ b/encodings/alp/src/alp_rd/compute/filter.rs @@ -9,6 +9,7 @@ use vortex_error::VortexResult; use vortex_mask::Mask; use crate::ALPRDArray; +use crate::ALPRDArrayExt; use crate::ALPRDVTable; impl FilterKernel for ALPRDVTable { @@ -24,7 +25,7 @@ impl FilterKernel for ALPRDVTable { .flatten(); Ok(Some( - ALPRDArray::try_new( + ALPRDVTable::try_new( array.dtype().clone(), array.left_parts().filter(mask.clone())?, array.left_parts_dictionary().clone(), @@ -48,6 +49,7 @@ mod test { use vortex_buffer::buffer; use vortex_mask::Mask; + use crate::ALPRDArrayExt; use crate::ALPRDFloat; use crate::RDEncoder; diff --git a/encodings/alp/src/alp_rd/compute/mask.rs b/encodings/alp/src/alp_rd/compute/mask.rs index 8064cf1bc29..85784d74b0f 100644 --- a/encodings/alp/src/alp_rd/compute/mask.rs +++ b/encodings/alp/src/alp_rd/compute/mask.rs @@ -10,6 +10,7 @@ use vortex_array::scalar_fn::fns::mask::MaskReduce; use vortex_error::VortexResult; use crate::ALPRDArray; +use crate::ALPRDArrayExt; use crate::ALPRDVTable; impl MaskReduce for ALPRDVTable { @@ -20,7 +21,7 @@ impl MaskReduce for ALPRDVTable { [array.left_parts().clone(), mask.clone()], )?; Ok(Some( - ALPRDArray::try_new( + ALPRDVTable::try_new( array.dtype().as_nullable(), masked_left_parts, array.left_parts_dictionary().clone(), diff --git a/encodings/alp/src/alp_rd/compute/take.rs b/encodings/alp/src/alp_rd/compute/take.rs index d2363fc058a..4795e31da11 100644 --- a/encodings/alp/src/alp_rd/compute/take.rs +++ b/encodings/alp/src/alp_rd/compute/take.rs @@ -11,6 +11,7 @@ use vortex_array::scalar::Scalar; use vortex_error::VortexResult; use crate::ALPRDArray; +use crate::ALPRDArrayExt; use crate::ALPRDVTable; impl TakeExecute for ALPRDVTable { @@ -39,7 +40,7 @@ impl TakeExecute for ALPRDVTable { .fill_null(Scalar::zero_value(array.right_parts().dtype()))?; Ok(Some( - ALPRDArray::try_new( + ALPRDVTable::try_new( array .dtype() .with_nullability(taken_left_parts.dtype().nullability()), @@ -63,6 +64,7 @@ mod test { use vortex_array::assert_arrays_eq; use vortex_array::compute::conformance::take::test_take_conformance; + use crate::ALPRDArrayExt; use crate::ALPRDFloat; use crate::RDEncoder; diff --git a/encodings/alp/src/alp_rd/mod.rs b/encodings/alp/src/alp_rd/mod.rs index 58188f7ab9f..f631bca472c 100644 --- a/encodings/alp/src/alp_rd/mod.rs +++ b/encodings/alp/src/alp_rd/mod.rs @@ -273,7 +273,7 @@ impl RDEncoder { .vortex_expect("Patches construction in encode") }); - ALPRDArray::try_new( + ALPRDVTable::try_new( DType::Primitive(T::PTYPE, packed_left.dtype().nullability()), packed_left, Buffer::::copy_from(&self.codes), diff --git a/encodings/alp/src/alp_rd/ops.rs b/encodings/alp/src/alp_rd/ops.rs index 8ae65b15179..4c0123d31ea 100644 --- a/encodings/alp/src/alp_rd/ops.rs +++ b/encodings/alp/src/alp_rd/ops.rs @@ -8,6 +8,7 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use crate::ALPRDArray; +use crate::ALPRDArrayExt; use crate::ALPRDVTable; impl OperationsVTable for ALPRDVTable { @@ -64,6 +65,7 @@ mod test { use vortex_array::assert_arrays_eq; use vortex_array::scalar::Scalar; + use crate::ALPRDArrayExt; use crate::ALPRDFloat; use crate::RDEncoder; diff --git a/encodings/alp/src/alp_rd/slice.rs b/encodings/alp/src/alp_rd/slice.rs index 583d06960fa..a7c9891f503 100644 --- a/encodings/alp/src/alp_rd/slice.rs +++ b/encodings/alp/src/alp_rd/slice.rs @@ -10,6 +10,7 @@ use vortex_array::arrays::slice::SliceKernel; use vortex_error::VortexResult; use crate::alp_rd::ALPRDArray; +use crate::alp_rd::ALPRDArrayExt; use crate::alp_rd::ALPRDVTable; impl SliceKernel for ALPRDVTable { diff --git a/encodings/bytebool/src/array.rs b/encodings/bytebool/src/array.rs index 097dbd5e66c..5b58165528c 100644 --- a/encodings/bytebool/src/array.rs +++ b/encodings/bytebool/src/array.rs @@ -155,7 +155,7 @@ impl VTable for ByteBoolVTable { } let buffer = buffers[0].clone(); - Ok(ByteBoolArray::new(buffer, validity)) + Ok(Self::new(buffer, validity)) } fn with_children(array: &mut Self::Array, children: Vec) -> VortexResult<()> { @@ -210,10 +210,10 @@ pub struct ByteBoolVTable; impl ByteBoolVTable { pub const ID: ArrayId = ArrayId::new_ref("vortex.bytebool"); -} -impl ByteBoolArray { - pub fn new(buffer: BufferHandle, validity: Validity) -> Self { + /// Build a new `ByteBoolArray` from a buffer and validity. + #[allow(clippy::new_ret_no_self)] + pub fn new(buffer: BufferHandle, validity: Validity) -> ByteBoolArray { let length = buffer.len(); if let Some(vlen) = validity.maybe_len() && length != vlen @@ -224,7 +224,7 @@ impl ByteBoolArray { vlen ); } - Self { + ByteBoolArray { common: ArrayCommon::new(length, DType::Bool(validity.nullability())), buffer, validity, @@ -232,18 +232,30 @@ impl ByteBoolArray { } // TODO(ngates): deprecate construction from vec - pub fn from_vec>(data: Vec, validity: V) -> Self { + /// Build a new `ByteBoolArray` from a vec of bools and validity. + pub fn from_vec>(data: Vec, validity: V) -> ByteBoolArray { let validity = validity.into(); // SAFETY: we are transmuting a Vec into a Vec let data: Vec = unsafe { std::mem::transmute(data) }; Self::new(BufferHandle::new_host(ByteBuffer::from(data)), validity) } +} + +/// Extension trait for [`ByteBoolArray`] methods. +pub trait ByteBoolArrayExt { + /// Returns a reference to the underlying buffer. + fn buffer(&self) -> &BufferHandle; + + /// Returns the buffer contents as a slice of bools. + fn as_slice(&self) -> &[bool]; +} - pub fn buffer(&self) -> &BufferHandle { +impl ByteBoolArrayExt for ByteBoolArray { + fn buffer(&self) -> &BufferHandle { &self.buffer } - pub fn as_slice(&self) -> &[bool] { + fn as_slice(&self) -> &[bool] { // Safety: The internal buffer contains byte-sized bools unsafe { std::mem::transmute(self.buffer().as_host().as_slice()) } } @@ -266,7 +278,7 @@ impl OperationsVTable for ByteBoolVTable { impl From> for ByteBoolArray { fn from(value: Vec) -> Self { - Self::from_vec(value, Validity::AllValid) + ByteBoolVTable::from_vec(value, Validity::AllValid) } } @@ -277,7 +289,7 @@ impl From>> for ByteBoolArray { // This doesn't reallocate, and the compiler even vectorizes it let data = value.into_iter().map(Option::unwrap_or_default).collect(); - Self::from_vec(data, validity) + ByteBoolVTable::from_vec(data, validity) } } diff --git a/encodings/bytebool/src/compute.rs b/encodings/bytebool/src/compute.rs index 9b743051510..d80f4d1c973 100644 --- a/encodings/bytebool/src/compute.rs +++ b/encodings/bytebool/src/compute.rs @@ -16,6 +16,7 @@ use vortex_array::vtable::ValidityHelper; use vortex_error::VortexResult; use super::ByteBoolArray; +use super::ByteBoolArrayExt; use super::ByteBoolVTable; impl CastReduce for ByteBoolVTable { @@ -32,7 +33,7 @@ impl CastReduce for ByteBoolVTable { .cast_nullability(dtype.nullability(), array.len())?; return Ok(Some( - ByteBoolArray::new(array.buffer().clone(), new_validity).into_array(), + ByteBoolVTable::new(array.buffer().clone(), new_validity).into_array(), )); } @@ -44,7 +45,7 @@ impl CastReduce for ByteBoolVTable { impl MaskReduce for ByteBoolVTable { fn mask(array: &ByteBoolArray, mask: &ArrayRef) -> VortexResult> { Ok(Some( - ByteBoolArray::new( + ByteBoolVTable::new( array.buffer().clone(), array .validity() @@ -80,7 +81,7 @@ impl TakeExecute for ByteBoolVTable { }); Ok(Some( - ByteBoolArray::from_vec(taken_bools, validity).into_array(), + ByteBoolVTable::from_vec(taken_bools, validity).into_array(), )) } } diff --git a/encodings/bytebool/src/slice.rs b/encodings/bytebool/src/slice.rs index b1d93f063ec..b585f014cbf 100644 --- a/encodings/bytebool/src/slice.rs +++ b/encodings/bytebool/src/slice.rs @@ -10,12 +10,13 @@ use vortex_array::vtable::ValidityHelper; use vortex_error::VortexResult; use crate::ByteBoolArray; +use crate::ByteBoolArrayExt; use crate::ByteBoolVTable; impl SliceReduce for ByteBoolVTable { fn slice(array: &ByteBoolArray, range: Range) -> VortexResult> { Ok(Some( - ByteBoolArray::new( + ByteBoolVTable::new( array.buffer().slice(range.clone()), array.validity().slice(range)?, ) diff --git a/encodings/datetime-parts/src/array.rs b/encodings/datetime-parts/src/array.rs index c1ab6444d26..bd3e5189498 100644 --- a/encodings/datetime-parts/src/array.rs +++ b/encodings/datetime-parts/src/array.rs @@ -203,7 +203,7 @@ impl VTable for DateTimePartsVTable { len, )?; - DateTimePartsArray::try_new(dtype.clone(), days, seconds, subseconds) + Self::try_new(dtype.clone(), days, seconds, subseconds) } fn with_children(array: &mut Self::Array, children: Vec) -> VortexResult<()> { @@ -264,15 +264,14 @@ pub struct DateTimePartsVTable; impl DateTimePartsVTable { pub const ID: ArrayId = ArrayId::new_ref("vortex.datetimeparts"); -} -impl DateTimePartsArray { + /// Build a new `DateTimePartsArray` from components. pub fn try_new( dtype: DType, days: ArrayRef, seconds: ArrayRef, subseconds: ArrayRef, - ) -> VortexResult { + ) -> VortexResult { if !days.dtype().is_int() || (dtype.is_nullable() != days.dtype().is_nullable()) { vortex_bail!( "Expected integer with nullability {}, got {}", @@ -297,30 +296,32 @@ impl DateTimePartsArray { ); } - Ok(Self { + Ok(DateTimePartsArray { common: ArrayCommon::new(length, dtype), days, seconds, subseconds, }) } +} - pub(crate) unsafe fn new_unchecked( - dtype: DType, - days: ArrayRef, - seconds: ArrayRef, - subseconds: ArrayRef, - ) -> Self { - let len = days.len(); - Self { - common: ArrayCommon::new(len, dtype), - days, - seconds, - subseconds, - } - } +/// Extension trait for [`DateTimePartsArray`] methods. +pub trait DateTimePartsArrayExt: Sized { + /// Consumes the array and returns its parts. + fn into_parts(self) -> DateTimePartsArrayParts; + + /// Returns a reference to the days child array. + fn days(&self) -> &ArrayRef; + + /// Returns a reference to the seconds child array. + fn seconds(&self) -> &ArrayRef; - pub fn into_parts(self) -> DateTimePartsArrayParts { + /// Returns a reference to the subseconds child array. + fn subseconds(&self) -> &ArrayRef; +} + +impl DateTimePartsArrayExt for DateTimePartsArray { + fn into_parts(self) -> DateTimePartsArrayParts { DateTimePartsArrayParts { dtype: self.common.into_dtype(), days: self.days, @@ -329,19 +330,36 @@ impl DateTimePartsArray { } } - pub fn days(&self) -> &ArrayRef { + fn days(&self) -> &ArrayRef { &self.days } - pub fn seconds(&self) -> &ArrayRef { + fn seconds(&self) -> &ArrayRef { &self.seconds } - pub fn subseconds(&self) -> &ArrayRef { + fn subseconds(&self) -> &ArrayRef { &self.subseconds } } +impl DateTimePartsArray { + pub(crate) unsafe fn new_unchecked( + dtype: DType, + days: ArrayRef, + seconds: ArrayRef, + subseconds: ArrayRef, + ) -> Self { + let len = days.len(); + Self { + common: ArrayCommon::new(len, dtype), + days, + seconds, + subseconds, + } + } +} + impl ValidityChild for DateTimePartsVTable { fn validity_child(array: &DateTimePartsArray) -> &ArrayRef { array.days() diff --git a/encodings/datetime-parts/src/canonical.rs b/encodings/datetime-parts/src/canonical.rs index 36ae6f01444..240e2ad93dc 100644 --- a/encodings/datetime-parts/src/canonical.rs +++ b/encodings/datetime-parts/src/canonical.rs @@ -19,6 +19,7 @@ use vortex_error::VortexResult; use vortex_error::vortex_panic; use crate::DateTimePartsArray; +use crate::DateTimePartsArrayExt; /// Decode an [Array] into a [TemporalArray]. /// diff --git a/encodings/datetime-parts/src/compress.rs b/encodings/datetime-parts/src/compress.rs index 65bbea1c116..c25e685d978 100644 --- a/encodings/datetime-parts/src/compress.rs +++ b/encodings/datetime-parts/src/compress.rs @@ -15,6 +15,7 @@ use vortex_error::VortexError; use vortex_error::VortexResult; use crate::DateTimePartsArray; +use crate::DateTimePartsVTable; use crate::timestamp; pub struct TemporalParts { @@ -69,7 +70,7 @@ impl TryFrom for DateTimePartsArray { seconds, subseconds, } = split_temporal(array)?; - DateTimePartsArray::try_new(DType::Extension(ext_dtype), days, seconds, subseconds) + DateTimePartsVTable::try_new(DType::Extension(ext_dtype), days, seconds, subseconds) } } diff --git a/encodings/datetime-parts/src/compute/cast.rs b/encodings/datetime-parts/src/compute/cast.rs index d507ac7ec59..0543fa55d7a 100644 --- a/encodings/datetime-parts/src/compute/cast.rs +++ b/encodings/datetime-parts/src/compute/cast.rs @@ -10,6 +10,7 @@ use vortex_array::scalar_fn::fns::cast::CastReduce; use vortex_error::VortexResult; use crate::DateTimePartsArray; +use crate::DateTimePartsArrayExt; use crate::DateTimePartsVTable; impl CastReduce for DateTimePartsVTable { @@ -19,7 +20,7 @@ impl CastReduce for DateTimePartsVTable { }; Ok(Some( - DateTimePartsArray::try_new( + DateTimePartsVTable::try_new( dtype.clone(), array .days() diff --git a/encodings/datetime-parts/src/compute/compare.rs b/encodings/datetime-parts/src/compute/compare.rs index 9ef6a1b5893..9537792e696 100644 --- a/encodings/datetime-parts/src/compute/compare.rs +++ b/encodings/datetime-parts/src/compute/compare.rs @@ -17,6 +17,7 @@ use vortex_array::scalar_fn::fns::operators::Operator; use vortex_error::VortexResult; use crate::array::DateTimePartsArray; +use crate::array::DateTimePartsArrayExt; use crate::array::DateTimePartsVTable; use crate::timestamp; @@ -312,7 +313,7 @@ mod test { Some("UTC".into()), ); - let lhs = DateTimePartsArray::try_new( + let lhs = DateTimePartsVTable::try_new( DType::Extension(temporal_array.ext_dtype()), PrimitiveArray::new(buffer![0i32], lhs_validity).into_array(), PrimitiveArray::new(buffer![0u32], Validity::NonNullable).into_array(), diff --git a/encodings/datetime-parts/src/compute/filter.rs b/encodings/datetime-parts/src/compute/filter.rs index 2769f31bee2..c360def8a73 100644 --- a/encodings/datetime-parts/src/compute/filter.rs +++ b/encodings/datetime-parts/src/compute/filter.rs @@ -8,12 +8,13 @@ use vortex_error::VortexResult; use vortex_mask::Mask; use crate::DateTimePartsArray; +use crate::DateTimePartsArrayExt; use crate::DateTimePartsVTable; impl FilterReduce for DateTimePartsVTable { fn filter(array: &DateTimePartsArray, mask: &Mask) -> VortexResult> { Ok(Some( - DateTimePartsArray::try_new( + DateTimePartsVTable::try_new( array.dtype().clone(), array.days().filter(mask.clone())?, array.seconds().filter(mask.clone())?, diff --git a/encodings/datetime-parts/src/compute/is_constant.rs b/encodings/datetime-parts/src/compute/is_constant.rs index 62216a09046..0ed2bfe0bd6 100644 --- a/encodings/datetime-parts/src/compute/is_constant.rs +++ b/encodings/datetime-parts/src/compute/is_constant.rs @@ -9,6 +9,7 @@ use vortex_array::register_kernel; use vortex_error::VortexResult; use crate::DateTimePartsArray; +use crate::DateTimePartsArrayExt; use crate::DateTimePartsVTable; impl IsConstantKernel for DateTimePartsVTable { diff --git a/encodings/datetime-parts/src/compute/mask.rs b/encodings/datetime-parts/src/compute/mask.rs index fb62aaed78a..2b85ce15aa5 100644 --- a/encodings/datetime-parts/src/compute/mask.rs +++ b/encodings/datetime-parts/src/compute/mask.rs @@ -8,6 +8,7 @@ use vortex_array::scalar_fn::fns::mask::MaskReduce; use vortex_error::VortexResult; use crate::DateTimePartsArray; +use crate::DateTimePartsArrayExt; use crate::DateTimePartsArrayParts; use crate::DateTimePartsVTable; @@ -21,7 +22,7 @@ impl MaskReduce for DateTimePartsVTable { } = array.clone().into_parts(); let masked_days = days.mask(mask.clone())?; Ok(Some( - DateTimePartsArray::try_new(dtype.as_nullable(), masked_days, seconds, subseconds)? + DateTimePartsVTable::try_new(dtype.as_nullable(), masked_days, seconds, subseconds)? .into_array(), )) } diff --git a/encodings/datetime-parts/src/compute/rules.rs b/encodings/datetime-parts/src/compute/rules.rs index e7da7d275bf..d964e07007c 100644 --- a/encodings/datetime-parts/src/compute/rules.rs +++ b/encodings/datetime-parts/src/compute/rules.rs @@ -26,6 +26,7 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use crate::DateTimePartsArray; +use crate::DateTimePartsArrayExt; use crate::DateTimePartsVTable; use crate::timestamp; @@ -58,7 +59,7 @@ impl ArrayParentReduceRule for DTPFilterPushDownRule { return Ok(None); } - DateTimePartsArray::try_new( + DateTimePartsVTable::try_new( child.dtype().clone(), child.days().clone().filter(parent.filter_mask().clone())?, ConstantArray::new( diff --git a/encodings/datetime-parts/src/compute/slice.rs b/encodings/datetime-parts/src/compute/slice.rs index 80293c2cac0..e23f81e383d 100644 --- a/encodings/datetime-parts/src/compute/slice.rs +++ b/encodings/datetime-parts/src/compute/slice.rs @@ -9,6 +9,7 @@ use vortex_array::arrays::slice::SliceReduce; use vortex_error::VortexResult; use crate::DateTimePartsArray; +use crate::DateTimePartsArrayExt; use crate::DateTimePartsVTable; impl SliceReduce for DateTimePartsVTable { diff --git a/encodings/datetime-parts/src/compute/take.rs b/encodings/datetime-parts/src/compute/take.rs index cef10682df3..96a8aca9b39 100644 --- a/encodings/datetime-parts/src/compute/take.rs +++ b/encodings/datetime-parts/src/compute/take.rs @@ -16,6 +16,7 @@ use vortex_error::VortexResult; use vortex_error::vortex_panic; use crate::DateTimePartsArray; +use crate::DateTimePartsArrayExt; use crate::DateTimePartsVTable; fn take_datetime_parts(array: &DateTimePartsArray, indices: &ArrayRef) -> VortexResult { @@ -36,7 +37,7 @@ fn take_datetime_parts(array: &DateTimePartsArray, indices: &ArrayRef) -> Vortex }; if !taken_seconds.dtype().is_nullable() && !taken_subseconds.dtype().is_nullable() { - return Ok(DateTimePartsArray::try_new( + return Ok(DateTimePartsVTable::try_new( dtype, taken_days, taken_seconds, @@ -80,7 +81,7 @@ fn take_datetime_parts(array: &DateTimePartsArray, indices: &ArrayRef) -> Vortex let taken_subseconds = taken_subseconds.fill_null(subseconds_fill)?; Ok( - DateTimePartsArray::try_new(dtype, taken_days, taken_seconds, taken_subseconds)? + DateTimePartsVTable::try_new(dtype, taken_days, taken_seconds, taken_subseconds)? .into_array(), ) } diff --git a/encodings/datetime-parts/src/ops.rs b/encodings/datetime-parts/src/ops.rs index f6226298e2f..09e97bfd853 100644 --- a/encodings/datetime-parts/src/ops.rs +++ b/encodings/datetime-parts/src/ops.rs @@ -11,6 +11,7 @@ use vortex_error::VortexResult; use vortex_error::vortex_panic; use crate::DateTimePartsArray; +use crate::DateTimePartsArrayExt; use crate::DateTimePartsVTable; use crate::timestamp; use crate::timestamp::TimestampParts; diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/cast.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/cast.rs index 55c08c7138f..b8934a607cf 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/cast.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/cast.rs @@ -10,6 +10,7 @@ use vortex_array::scalar_fn::fns::cast::CastReduce; use vortex_error::VortexResult; use crate::DecimalBytePartsArray; +use crate::DecimalBytePartsArrayExt; use crate::DecimalBytePartsVTable; impl CastReduce for DecimalBytePartsVTable { @@ -30,7 +31,7 @@ impl CastReduce for DecimalBytePartsVTable { .cast(array.msp().dtype().with_nullability(*target_nullability))?; return Ok(Some( - DecimalBytePartsArray::try_new(new_msp, *target_decimal)?.into_array(), + DecimalBytePartsVTable::try_new(new_msp, *target_decimal)?.into_array(), )); } @@ -53,11 +54,12 @@ mod tests { use vortex_buffer::buffer; use crate::DecimalBytePartsArray; + use crate::DecimalBytePartsVTable; #[test] fn test_cast_decimal_byte_parts_nullability() { let decimal_dtype = DecimalDType::new(10, 2); - let array = DecimalBytePartsArray::try_new( + let array = DecimalBytePartsVTable::try_new( buffer![100i32, 200, 300, 400].into_array(), decimal_dtype, ) @@ -81,7 +83,7 @@ mod tests { #[test] fn test_cast_decimal_byte_parts_nullable_to_non_nullable() { let decimal_dtype = DecimalDType::new(10, 2); - let array = DecimalBytePartsArray::try_new( + let array = DecimalBytePartsVTable::try_new( PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(), decimal_dtype, ) @@ -96,24 +98,24 @@ mod tests { } #[rstest] - #[case::i32(DecimalBytePartsArray::try_new( + #[case::i32(DecimalBytePartsVTable::try_new( buffer![100i32, 200, 300, 400, 500].into_array(), DecimalDType::new(10, 2), ).unwrap())] - #[case::i64(DecimalBytePartsArray::try_new( + #[case::i64(DecimalBytePartsVTable::try_new( buffer![1000i64, 2000, 3000, 4000].into_array(), DecimalDType::new(19, 4), ).unwrap())] - #[case::nullable(DecimalBytePartsArray::try_new( + #[case::nullable(DecimalBytePartsVTable::try_new( PrimitiveArray::from_option_iter([Some(100i32), None, Some(300), Some(400), None]) .into_array(), DecimalDType::new(10, 2), ).unwrap())] - #[case::single(DecimalBytePartsArray::try_new( + #[case::single(DecimalBytePartsVTable::try_new( buffer![42i32].into_array(), DecimalDType::new(5, 1), ).unwrap())] - #[case::negative(DecimalBytePartsArray::try_new( + #[case::negative(DecimalBytePartsVTable::try_new( buffer![-100i32, -200, 300, -400, 500].into_array(), DecimalDType::new(10, 2), ).unwrap())] diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs index 2babae469d0..54d91dd16a8 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs @@ -158,13 +158,13 @@ mod tests { use vortex_buffer::buffer; use vortex_error::VortexResult; - use crate::DecimalBytePartsArray; + use crate::DecimalBytePartsVTable; #[test] fn compare_decimal_const() { let decimal_dtype = DecimalDType::new(8, 2); let dtype = DType::Decimal(decimal_dtype, Nullability::Nullable); - let lhs = DecimalBytePartsArray::try_new( + let lhs = DecimalBytePartsVTable::try_new( PrimitiveArray::new(buffer![100i32, 200i32, 400i32], Validity::AllValid).into_array(), decimal_dtype, ) @@ -184,7 +184,7 @@ mod tests { #[test] fn test_byteparts_compare_nullable() -> VortexResult<()> { let decimal_type = DecimalDType::new(19, -11); - let lhs = DecimalBytePartsArray::try_new( + let lhs = DecimalBytePartsVTable::try_new( PrimitiveArray::new( buffer![1i64, 2i64, 3i64, 4i64], Validity::Array(BoolArray::from_iter([false, true, true, true]).into_array()), @@ -215,7 +215,7 @@ mod tests { fn compare_decimal_const_unconvertible_comparison() { let decimal_dtype = DecimalDType::new(40, 2); let dtype = DType::Decimal(decimal_dtype, Nullability::Nullable); - let lhs = DecimalBytePartsArray::try_new( + let lhs = DecimalBytePartsVTable::try_new( PrimitiveArray::new(buffer![100i32, 200i32, 400i32], Validity::AllValid).into_array(), decimal_dtype, ) diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/filter.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/filter.rs index be10cf3f882..832331b1848 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/filter.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/filter.rs @@ -8,11 +8,12 @@ use vortex_error::VortexResult; use vortex_mask::Mask; use crate::DecimalBytePartsArray; +use crate::DecimalBytePartsArrayExt; use crate::DecimalBytePartsVTable; impl FilterReduce for DecimalBytePartsVTable { fn filter(array: &DecimalBytePartsArray, mask: &Mask) -> VortexResult> { - DecimalBytePartsArray::try_new(array.msp.filter(mask.clone())?, *array.decimal_dtype()) + DecimalBytePartsVTable::try_new(array.msp.filter(mask.clone())?, *array.decimal_dtype()) .map(|d| Some(d.into_array())) } } @@ -25,7 +26,7 @@ mod test { use vortex_array::dtype::DecimalDType; use vortex_buffer::buffer; - use crate::DecimalBytePartsArray; + use crate::DecimalBytePartsVTable; #[test] fn test_filter_decimal_byte_parts() { @@ -33,7 +34,7 @@ mod test { let msp = buffer![100i32, 200, 300, 400, 500].into_array(); let decimal_dtype = DecimalDType::new(8, 2); - let array = DecimalBytePartsArray::try_new(msp, decimal_dtype).unwrap(); + let array = DecimalBytePartsVTable::try_new(msp, decimal_dtype).unwrap(); test_filter_conformance(&array.into_array()); // Test with nullable values @@ -41,7 +42,7 @@ mod test { .into_array(); let decimal_dtype = DecimalDType::new(18, 4); - let array = DecimalBytePartsArray::try_new(msp, decimal_dtype).unwrap(); + let array = DecimalBytePartsVTable::try_new(msp, decimal_dtype).unwrap(); test_filter_conformance(&array.into_array()); } } diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mask.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mask.rs index 8fd3557499d..d23be82e5bf 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mask.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mask.rs @@ -10,6 +10,7 @@ use vortex_array::scalar_fn::fns::mask::MaskReduce; use vortex_error::VortexResult; use crate::DecimalBytePartsArray; +use crate::DecimalBytePartsArrayExt; use crate::DecimalBytePartsVTable; impl MaskReduce for DecimalBytePartsVTable { @@ -20,7 +21,7 @@ impl MaskReduce for DecimalBytePartsVTable { [array.msp.clone(), mask.clone()], )?; Ok(Some( - DecimalBytePartsArray::try_new(masked_msp, *array.decimal_dtype())?.into_array(), + DecimalBytePartsVTable::try_new(masked_msp, *array.decimal_dtype())?.into_array(), )) } } diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mod.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mod.rs index 7b0cf2cffcb..58772522256 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mod.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mod.rs @@ -19,50 +19,51 @@ mod tests { use vortex_buffer::buffer; use crate::DecimalBytePartsArray; + use crate::DecimalBytePartsVTable; #[rstest] // Basic decimal byte parts arrays - #[case::decimal_i32(DecimalBytePartsArray::try_new( + #[case::decimal_i32(DecimalBytePartsVTable::try_new( buffer![100i32, 200, 300, 400, 500].into_array(), DecimalDType::new(10, 2) ).unwrap())] - #[case::decimal_i64(DecimalBytePartsArray::try_new( + #[case::decimal_i64(DecimalBytePartsVTable::try_new( buffer![1000i64, 2000, 3000, 4000, 5000].into_array(), DecimalDType::new(19, 4) ).unwrap())] // Nullable arrays - #[case::decimal_nullable_i32(DecimalBytePartsArray::try_new( + #[case::decimal_nullable_i32(DecimalBytePartsVTable::try_new( PrimitiveArray::from_option_iter([Some(100i32), None, Some(300), Some(400), None]).into_array(), DecimalDType::new(10, 2) ).unwrap())] - #[case::decimal_nullable_i64(DecimalBytePartsArray::try_new( + #[case::decimal_nullable_i64(DecimalBytePartsVTable::try_new( PrimitiveArray::from_option_iter([Some(1000i64), None, Some(3000), Some(4000), None]).into_array(), DecimalDType::new(19, 4) ).unwrap())] // Different precision/scale combinations - #[case::decimal_high_precision(DecimalBytePartsArray::try_new( + #[case::decimal_high_precision(DecimalBytePartsVTable::try_new( buffer![123456789i32, 987654321, -123456789].into_array(), DecimalDType::new(38, 10) ).unwrap())] - #[case::decimal_zero_scale(DecimalBytePartsArray::try_new( + #[case::decimal_zero_scale(DecimalBytePartsVTable::try_new( buffer![100i32, 200, 300].into_array(), DecimalDType::new(10, 0) ).unwrap())] // Edge cases - #[case::decimal_single(DecimalBytePartsArray::try_new( + #[case::decimal_single(DecimalBytePartsVTable::try_new( buffer![42i32].into_array(), DecimalDType::new(5, 1) ).unwrap())] - #[case::decimal_negative(DecimalBytePartsArray::try_new( + #[case::decimal_negative(DecimalBytePartsVTable::try_new( buffer![-100i32, -200, 300, -400, 500].into_array(), DecimalDType::new(10, 2) ).unwrap())] // Large arrays - #[case::decimal_large(DecimalBytePartsArray::try_new( + #[case::decimal_large(DecimalBytePartsVTable::try_new( PrimitiveArray::from_iter((0..1500).map(|i| i * 100)).into_array(), DecimalDType::new(10, 2) ).unwrap())] - #[case::decimal_large_i64(DecimalBytePartsArray::try_new( + #[case::decimal_large_i64(DecimalBytePartsVTable::try_new( PrimitiveArray::from_iter((0..2000i64).map(|i| i * 1000000)).into_array(), DecimalDType::new(19, 6) ).unwrap())] diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs index bf75bbb4768..af14ce7f493 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs @@ -9,6 +9,7 @@ use vortex_array::arrays::dict::TakeExecute; use vortex_error::VortexResult; use crate::DecimalBytePartsArray; +use crate::DecimalBytePartsArrayExt; use crate::DecimalBytePartsVTable; impl TakeExecute for DecimalBytePartsVTable { @@ -17,7 +18,7 @@ impl TakeExecute for DecimalBytePartsVTable { indices: &ArrayRef, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - DecimalBytePartsArray::try_new(array.msp.take(indices.to_array())?, *array.decimal_dtype()) + DecimalBytePartsVTable::try_new(array.msp.take(indices.to_array())?, *array.decimal_dtype()) .map(|a| Some(a.into_array())) } } diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs index 2d28829608b..b76815650ee 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs @@ -168,7 +168,7 @@ impl VTable for DecimalBytePartsVTable { "lower_part_count > 0 not currently supported" ); - DecimalBytePartsArray::try_new(msp, *decimal_dtype) + Self::try_new(msp, *decimal_dtype) } fn with_children(array: &mut Self::Array, children: Vec) -> VortexResult<()> { @@ -224,45 +224,41 @@ pub struct DecimalBytePartsArrayParts { pub dtype: DType, } -impl DecimalBytePartsArray { - pub fn try_new(msp: ArrayRef, decimal_dtype: DecimalDType) -> VortexResult { - if !msp.dtype().is_signed_int() { - vortex_bail!("decimal bytes parts, first part must be a signed array") - } - - let nullable = msp.dtype().nullability(); - let len = msp.len(); - Ok(Self { - msp, - _lower_parts: Vec::new(), - common: ArrayCommon::new(len, DType::Decimal(decimal_dtype, nullable)), - }) - } +/// Extension trait for [`DecimalBytePartsArray`] methods. +pub trait DecimalBytePartsArrayExt: Sized { + /// If `_lower_parts` is supported check all calls use this correctly. + fn into_parts(self) -> DecimalBytePartsArrayParts; - pub(crate) unsafe fn new_unchecked(msp: ArrayRef, decimal_dtype: DecimalDType) -> Self { - let nullable = msp.dtype().nullability(); - let len = msp.len(); - Self { - msp, - _lower_parts: Vec::new(), - common: ArrayCommon::new(len, DType::Decimal(decimal_dtype, nullable)), - } - } + /// Returns the decimal dtype of this array. + fn decimal_dtype(&self) -> &DecimalDType; +} - /// If `_lower_parts` is supported check all calls use this correctly. - pub fn into_parts(self) -> DecimalBytePartsArrayParts { +impl DecimalBytePartsArrayExt for DecimalBytePartsArray { + fn into_parts(self) -> DecimalBytePartsArrayParts { DecimalBytePartsArrayParts { msp: self.msp, dtype: self.common.into_dtype(), } } - pub fn decimal_dtype(&self) -> &DecimalDType { + fn decimal_dtype(&self) -> &DecimalDType { self.common .dtype() .as_decimal_opt() .vortex_expect("must be a decimal dtype") } +} + +impl DecimalBytePartsArray { + pub(crate) unsafe fn new_unchecked(msp: ArrayRef, decimal_dtype: DecimalDType) -> Self { + let nullable = msp.dtype().nullability(); + let len = msp.len(); + Self { + msp, + _lower_parts: Vec::new(), + common: ArrayCommon::new(len, DType::Decimal(decimal_dtype, nullable)), + } + } pub(crate) fn msp(&self) -> &ArrayRef { &self.msp @@ -274,6 +270,24 @@ pub struct DecimalBytePartsVTable; impl DecimalBytePartsVTable { pub const ID: ArrayId = ArrayId::new_ref("vortex.decimal_byte_parts"); + + /// Build a new `DecimalBytePartsArray` from components. + pub fn try_new( + msp: ArrayRef, + decimal_dtype: DecimalDType, + ) -> VortexResult { + if !msp.dtype().is_signed_int() { + vortex_bail!("decimal bytes parts, first part must be a signed array") + } + + let nullable = msp.dtype().nullability(); + let len = msp.len(); + Ok(DecimalBytePartsArray { + msp, + _lower_parts: Vec::new(), + common: ArrayCommon::new(len, DType::Decimal(decimal_dtype, nullable)), + }) + } } /// Converts a DecimalBytePartsArray to its canonical DecimalArray representation. @@ -338,13 +352,13 @@ mod tests { use vortex_array::validity::Validity; use vortex_buffer::buffer; - use crate::DecimalBytePartsArray; + use crate::DecimalBytePartsVTable; #[test] fn test_scalar_at_decimal_parts() { let decimal_dtype = DecimalDType::new(8, 2); let dtype = DType::Decimal(decimal_dtype, Nullability::Nullable); - let array = DecimalBytePartsArray::try_new( + let array = DecimalBytePartsVTable::try_new( PrimitiveArray::new( buffer![100i32, 200i32, 400i32], Validity::Array(BoolArray::from_iter(vec![false, true, true]).into_array()), diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/rules.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/rules.rs index 14bf0a076d9..1157fb36ef6 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/rules.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/rules.rs @@ -15,6 +15,7 @@ use vortex_array::scalar_fn::fns::mask::MaskReduceAdaptor; use vortex_error::VortexResult; use crate::DecimalBytePartsArray; +use crate::DecimalBytePartsArrayExt; use crate::DecimalBytePartsVTable; pub(super) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ @@ -45,7 +46,7 @@ impl ArrayParentReduceRule for DecimalBytePartsFilterPus let new_msp = child.msp.filter(parent.filter_mask().clone())?; let new_child = - DecimalBytePartsArray::try_new(new_msp, *child.decimal_dtype())?.into_array(); + DecimalBytePartsVTable::try_new(new_msp, *child.decimal_dtype())?.into_array(); Ok(Some(new_child)) } } diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/slice.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/slice.rs index e445d08a4ef..a119a4a7771 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/slice.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/slice.rs @@ -9,6 +9,7 @@ use vortex_array::arrays::slice::SliceReduce; use vortex_error::VortexResult; use crate::DecimalBytePartsArray; +use crate::DecimalBytePartsArrayExt; use crate::DecimalBytePartsVTable; impl SliceReduce for DecimalBytePartsVTable { diff --git a/encodings/fastlanes/benches/bitpacking_take.rs b/encodings/fastlanes/benches/bitpacking_take.rs index 9bf7ea4db79..d27cd80c45b 100644 --- a/encodings/fastlanes/benches/bitpacking_take.rs +++ b/encodings/fastlanes/benches/bitpacking_take.rs @@ -19,6 +19,7 @@ use vortex_array::compute::warm_up_vtables; use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_buffer::buffer; +use vortex_fastlanes::BitPackedArrayExt; use vortex_fastlanes::bitpack_compress::bitpack_to_best_bit_width; fn main() { diff --git a/encodings/fastlanes/benches/compute_between.rs b/encodings/fastlanes/benches/compute_between.rs index d3bc7cc6259..13887e84ff6 100644 --- a/encodings/fastlanes/benches/compute_between.rs +++ b/encodings/fastlanes/benches/compute_between.rs @@ -5,7 +5,8 @@ use num_traits::NumCast; use rand::Rng; use rand::rngs::StdRng; -use vortex_alp::ALPArray; +use vortex_alp::ALPArrayExt; +use vortex_alp::ALPVTable; use vortex_alp::alp_encode; use vortex_array::ArrayRef; use vortex_array::IntoArray; @@ -56,7 +57,7 @@ fn generate_alp_bit_pack_primitive_array( let bp = bitpack_to_best_bit_width(&encoded) .vortex_expect("") .into_array(); - ALPArray::new(bp, alp.exponents(), None).into_array() + ALPVTable::new(bp, alp.exponents(), None).into_array() } const BENCH_ARGS: &[usize] = &[2 << 10, 2 << 13, 2 << 14]; diff --git a/encodings/fastlanes/src/bitpacking/array/bitpack_compress.rs b/encodings/fastlanes/src/bitpacking/array/bitpack_compress.rs index c0f3e90f4fb..80a06735326 100644 --- a/encodings/fastlanes/src/bitpacking/array/bitpack_compress.rs +++ b/encodings/fastlanes/src/bitpacking/array/bitpack_compress.rs @@ -439,6 +439,8 @@ mod test { use vortex_session::VortexSession; use super::*; + use crate::BitPackedArrayExt; + use crate::BitPackedVTable; use crate::bitpack_compress::test_harness::make_array; static SESSION: LazyLock = @@ -463,7 +465,7 @@ mod test { Validity::from_iter(valid_values), ); assert!(values.ptype().is_unsigned_int()); - let compressed = BitPackedArray::encode(&values.into_array(), 4).unwrap(); + let compressed = BitPackedVTable::encode(&values.into_array(), 4).unwrap(); assert!(compressed.patches().is_none()); assert_eq!( (0..(1 << 4)).collect::>(), @@ -482,7 +484,7 @@ mod test { let array = PrimitiveArray::new(values, Validity::AllValid); assert!(array.ptype().is_signed_int()); - let err = BitPackedArray::encode(&array.into_array(), 1024u32.ilog2() as u8).unwrap_err(); + let err = BitPackedVTable::encode(&array.into_array(), 1024u32.ilog2() as u8).unwrap_err(); assert!(matches!(err, VortexError::InvalidArgument(_, _))); } diff --git a/encodings/fastlanes/src/bitpacking/array/bitpack_decompress.rs b/encodings/fastlanes/src/bitpacking/array/bitpack_decompress.rs index 392c61bcf4c..d196ffb99c9 100644 --- a/encodings/fastlanes/src/bitpacking/array/bitpack_decompress.rs +++ b/encodings/fastlanes/src/bitpacking/array/bitpack_decompress.rs @@ -25,6 +25,7 @@ use vortex_error::vortex_panic; use vortex_mask::Mask; use crate::BitPackedArray; +use crate::BitPackedArrayExt; use crate::unpack_iter::BitPacked; /// Unpacks a bit-packed array into a primitive array. @@ -255,6 +256,7 @@ mod tests { use vortex_session::VortexSession; use super::*; + use crate::BitPackedVTable; use crate::bitpack_compress::bitpack_encode; static SESSION: LazyLock = @@ -262,7 +264,7 @@ mod tests { fn compression_roundtrip(n: usize) { let values = PrimitiveArray::from_iter((0..n).map(|i| (i % 2047) as u16)); - let compressed = BitPackedArray::encode(&values.clone().into_array(), 11).unwrap(); + let compressed = BitPackedVTable::encode(&values.clone().into_array(), 11).unwrap(); assert_arrays_eq!(compressed, values); values diff --git a/encodings/fastlanes/src/bitpacking/array/mod.rs b/encodings/fastlanes/src/bitpacking/array/mod.rs index 0f9c9f1fdf4..85aa8b48694 100644 --- a/encodings/fastlanes/src/bitpacking/array/mod.rs +++ b/encodings/fastlanes/src/bitpacking/array/mod.rs @@ -3,8 +3,6 @@ use fastlanes::BitPacking; use vortex_array::ArrayCommon; -use vortex_array::ArrayRef; -use vortex_array::arrays::PrimitiveVTable; use vortex_array::buffer::BufferHandle; use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; @@ -12,14 +10,12 @@ use vortex_array::dtype::PType; use vortex_array::patches::Patches; use vortex_array::validity::Validity; use vortex_error::VortexResult; -use vortex_error::vortex_bail; use vortex_error::vortex_ensure; pub mod bitpack_compress; pub mod bitpack_decompress; pub mod unpack_iter; -use crate::bitpack_compress::bitpack_encode; use crate::unpack_iter::BitPacked; use crate::unpack_iter::BitUnpackedChunks; @@ -63,7 +59,7 @@ impl BitPackedArray { /// This invariant is upheld by the compressor, but callers must ensure this if they wish to /// construct a new `BitPackedArray` from parts. /// - /// See also the [`encode`][Self::encode] method on this type for a safe path to create a new + /// See also the [`encode`][BitPackedArrayExt::encode] method on this type for a safe path to create a new /// bit-packed array. pub(crate) unsafe fn new_unchecked( packed: BufferHandle, @@ -83,131 +79,119 @@ impl BitPackedArray { validity, } } +} - /// A safe constructor for a `BitPackedArray` from its components: - /// - /// * `packed` is ByteBuffer holding the compressed data that was packed with FastLanes - /// bit-packing to a `bit_width` bits per value. `length` is the length of the original - /// vector. Note that the packed is padded with zeros to the next multiple of 1024 elements - /// if `length` is not divisible by 1024. - /// * `ptype` of the original data - /// * `validity` to track any nulls - /// * `patches` optionally provided for values that did not pack - /// - /// Any failure in validation will result in an error. - /// - /// # Validation - /// - /// * The `ptype` must be an integer - /// * `validity` must have `length` len - /// * Any patches must have any `array_len` equal to `length` - /// * The `packed` buffer must be exactly sized to hold `length` values of `bit_width` rounded - /// up to the next multiple of 1024. - /// - /// Any violation of these preconditions will result in an error. - pub fn try_new( - packed: BufferHandle, - ptype: PType, - validity: Validity, - patches: Option, - bit_width: u8, - length: usize, - offset: u16, - ) -> VortexResult { - Self::validate( - &packed, - ptype, - &validity, - patches.as_ref(), - bit_width, - length, - offset, - )?; +pub(crate) fn validate( + packed: &BufferHandle, + ptype: PType, + validity: &Validity, + patches: Option<&Patches>, + bit_width: u8, + length: usize, + offset: u16, +) -> VortexResult<()> { + vortex_ensure!(ptype.is_int(), MismatchedTypes: "integer", ptype); + vortex_ensure!(bit_width <= 64, "Unsupported bit width {bit_width}"); + + if let Some(validity_len) = validity.maybe_len() { + vortex_ensure!( + validity_len == length, + "BitPackedArray validity length {validity_len} != array length {length}", + ); + } - let dtype = DType::Primitive(ptype, validity.nullability()); + // Validate offset for sliced arrays + vortex_ensure!( + offset < 1024, + "Offset must be less than the full block i.e., 1024, got {offset}" + ); - // SAFETY: all components validated above - unsafe { - Ok(Self::new_unchecked( - packed, dtype, validity, patches, bit_width, length, offset, - )) - } + // Validate patches + if let Some(patches) = patches { + validate_patches(patches, ptype, length)?; } - fn validate( - packed: &BufferHandle, - ptype: PType, - validity: &Validity, - patches: Option<&Patches>, - bit_width: u8, - length: usize, - offset: u16, - ) -> VortexResult<()> { - vortex_ensure!(ptype.is_int(), MismatchedTypes: "integer", ptype); - vortex_ensure!(bit_width <= 64, "Unsupported bit width {bit_width}"); - - if let Some(validity_len) = validity.maybe_len() { - vortex_ensure!( - validity_len == length, - "BitPackedArray validity length {validity_len} != array length {length}", - ); - } + // Validate packed buffer + let expected_packed_len = + (length + offset as usize).div_ceil(1024) * (128 * bit_width as usize); + vortex_ensure!( + packed.len() == expected_packed_len, + "Expected {} packed bytes, got {}", + expected_packed_len, + packed.len() + ); + + Ok(()) +} - // Validate offset for sliced arrays - vortex_ensure!( - offset < 1024, - "Offset must be less than the full block i.e., 1024, got {offset}" - ); +fn validate_patches(patches: &Patches, ptype: PType, len: usize) -> VortexResult<()> { + // Ensure that array and patches have same ptype + vortex_ensure!( + patches.dtype().eq_ignore_nullability(ptype.into()), + "Patches DType {} does not match BitPackedArray dtype {}", + patches.dtype().as_nonnullable(), + ptype + ); + + vortex_ensure!( + patches.array_len() == len, + "BitPackedArray patches length {} != expected {len}", + patches.array_len(), + ); + + Ok(()) +} - // Validate patches - if let Some(patches) = patches { - Self::validate_patches(patches, ptype, length)?; - } +/// Extension trait for [`BitPackedArray`] methods. +pub trait BitPackedArrayExt: Sized { + /// Returns the primitive type of the array. + fn ptype(&self) -> PType; - // Validate packed buffer - let expected_packed_len = - (length + offset as usize).div_ceil(1024) * (128 * bit_width as usize); - vortex_ensure!( - packed.len() == expected_packed_len, - "Expected {} packed bytes, got {}", - expected_packed_len, - packed.len() - ); + /// Underlying bit packed values as byte array + fn packed(&self) -> &BufferHandle; - Ok(()) - } + /// Access the slice of packed values as an array of `T` + fn packed_slice(&self) -> &[T]; - fn validate_patches(patches: &Patches, ptype: PType, len: usize) -> VortexResult<()> { - // Ensure that array and patches have same ptype - vortex_ensure!( - patches.dtype().eq_ignore_nullability(ptype.into()), - "Patches DType {} does not match BitPackedArray dtype {}", - patches.dtype().as_nonnullable(), - ptype - ); + /// Accessor for bit unpacked chunks + fn unpacked_chunks(&self) -> BitUnpackedChunks; - vortex_ensure!( - patches.array_len() == len, - "BitPackedArray patches length {} != expected {len}", - patches.array_len(), - ); + /// Bit-width of the packed values + fn bit_width(&self) -> u8; - Ok(()) - } + /// Access the patches array. + /// + /// If present, patches MUST be a `SparseArray` with equal-length to this array, and whose + /// indices indicate the locations of patches. The indices must have non-zero length. + fn patches(&self) -> Option<&Patches>; + + /// Replace the patches on this array. + fn replace_patches(&mut self, patches: Option); + + /// Returns the offset within the first block. + fn offset(&self) -> u16; - pub fn ptype(&self) -> PType { + /// Calculate the maximum value that **can** be contained by this array, given its bit-width. + /// + /// Note that this value need not actually be present in the array. + fn max_packed_value(&self) -> usize; + + /// Decompose this array into its constituent parts. + fn into_parts(self) -> BitPackedArrayParts; +} + +impl BitPackedArrayExt for BitPackedArray { + fn ptype(&self) -> PType { self.common.dtype().as_ptype() } - /// Underlying bit packed values as byte array #[inline] - pub fn packed(&self) -> &BufferHandle { + fn packed(&self) -> &BufferHandle { &self.packed } - /// Access the slice of packed values as an array of `T` #[inline] - pub fn packed_slice(&self) -> &[T] { + fn packed_slice(&self) -> &[T] { let packed_bytes = self.packed().as_host(); let packed_ptr: *const T = packed_bytes.as_ptr().cast(); // Return number of elements of type `T` packed in the buffer @@ -219,8 +203,7 @@ impl BitPackedArray { unsafe { std::slice::from_raw_parts(packed_ptr, packed_len) } } - /// Accessor for bit unpacked chunks - pub fn unpacked_chunks(&self) -> BitUnpackedChunks { + fn unpacked_chunks(&self) -> BitUnpackedChunks { assert_eq!( T::PTYPE, self.ptype(), @@ -229,59 +212,31 @@ impl BitPackedArray { BitUnpackedChunks::new(self) } - /// Bit-width of the packed values #[inline] - pub fn bit_width(&self) -> u8 { + fn bit_width(&self) -> u8 { self.bit_width } - /// Access the patches array. - /// - /// If present, patches MUST be a `SparseArray` with equal-length to this array, and whose - /// indices indicate the locations of patches. The indices must have non-zero length. #[inline] - pub fn patches(&self) -> Option<&Patches> { + fn patches(&self) -> Option<&Patches> { self.patches.as_ref() } - pub fn replace_patches(&mut self, patches: Option) { + fn replace_patches(&mut self, patches: Option) { self.patches = patches; } #[inline] - pub fn offset(&self) -> u16 { + fn offset(&self) -> u16 { self.offset } - /// Bit-pack an array of primitive integers down to the target bit-width using the FastLanes - /// SIMD-accelerated packing kernels. - /// - /// # Errors - /// - /// If the provided array is not an integer type, an error will be returned. - /// - /// If the provided array contains negative values, an error will be returned. - /// - /// If the requested bit-width for packing is larger than the array's native width, an - /// error will be returned. - // FIXME(ngates): take a PrimitiveArray - pub fn encode(array: &ArrayRef, bit_width: u8) -> VortexResult { - if let Some(parray) = array.as_opt::() { - bitpack_encode(parray, bit_width, None) - } else { - vortex_bail!(InvalidArgument: "Bitpacking can only encode primitive arrays"); - } - } - - /// Calculate the maximum value that **can** be contained by this array, given its bit-width. - /// - /// Note that this value need not actually be present in the array. #[inline] - pub fn max_packed_value(&self) -> usize { + fn max_packed_value(&self) -> usize { (1 << self.bit_width()) - 1 } - pub fn into_parts(self) -> BitPackedArrayParts { + fn into_parts(self) -> BitPackedArrayParts { BitPackedArrayParts { offset: self.offset, bit_width: self.bit_width, @@ -301,7 +256,8 @@ mod test { use vortex_array::assert_arrays_eq; use vortex_buffer::Buffer; - use crate::BitPackedArray; + use crate::BitPackedArrayExt; + use crate::BitPackedVTable; #[test] fn test_encode() { @@ -315,7 +271,7 @@ mod test { Some(u64::MAX), ]; let uncompressed = PrimitiveArray::from_option_iter(values); - let packed = BitPackedArray::encode(&uncompressed.into_array(), 1).unwrap(); + let packed = BitPackedVTable::encode(&uncompressed.into_array(), 1).unwrap(); let expected = PrimitiveArray::from_option_iter(values); assert_arrays_eq!(packed.to_primitive(), expected); } @@ -324,9 +280,9 @@ mod test { fn test_encode_too_wide() { let values = [Some(1u8), None, Some(1), None, Some(1), None]; let uncompressed = PrimitiveArray::from_option_iter(values); - let _packed = BitPackedArray::encode(&uncompressed.clone().into_array(), 8) + let _packed = BitPackedVTable::encode(&uncompressed.clone().into_array(), 8) .expect_err("Cannot pack value into the same width"); - let _packed = BitPackedArray::encode(&uncompressed.into_array(), 9) + let _packed = BitPackedVTable::encode(&uncompressed.into_array(), 9) .expect_err("Cannot pack value into larger width"); } @@ -335,7 +291,7 @@ mod test { let values: Buffer = (0i32..=512).collect(); let parray = values.clone().into_array(); - let packed_with_patches = BitPackedArray::encode(&parray, 9).unwrap(); + let packed_with_patches = BitPackedVTable::encode(&parray, 9).unwrap(); assert!(packed_with_patches.patches().is_some()); assert_arrays_eq!( packed_with_patches.to_primitive(), diff --git a/encodings/fastlanes/src/bitpacking/array/unpack_iter.rs b/encodings/fastlanes/src/bitpacking/array/unpack_iter.rs index 93c44e59172..3b353eba408 100644 --- a/encodings/fastlanes/src/bitpacking/array/unpack_iter.rs +++ b/encodings/fastlanes/src/bitpacking/array/unpack_iter.rs @@ -13,6 +13,7 @@ use vortex_array::dtype::PhysicalPType; use vortex_buffer::ByteBuffer; use crate::BitPackedArray; +use crate::BitPackedArrayExt; const CHUNK_SIZE: usize = 1024; @@ -54,10 +55,11 @@ impl> UnpackStrategy for BitPackingStr /// use lending_iterator::prelude::LendingIterator; /// use vortex_array::IntoArray; /// use vortex_buffer::buffer; -/// use vortex_fastlanes::BitPackedArray; +/// use vortex_fastlanes::BitPackedArrayExt; +/// use vortex_fastlanes::BitPackedVTable; /// use vortex_fastlanes::unpack_iter::BitUnpackedChunks; /// -/// let array = BitPackedArray::encode(&buffer![2, 3, 4, 5].into_array(), 2).unwrap(); +/// let array = BitPackedVTable::encode(&buffer![2, 3, 4, 5].into_array(), 2).unwrap(); /// let mut unpacked_chunks: BitUnpackedChunks = array.unpacked_chunks(); /// /// if let Some(header) = unpacked_chunks.initial() { diff --git a/encodings/fastlanes/src/bitpacking/compute/cast.rs b/encodings/fastlanes/src/bitpacking/compute/cast.rs index 42cd28b63b4..b808532ccbf 100644 --- a/encodings/fastlanes/src/bitpacking/compute/cast.rs +++ b/encodings/fastlanes/src/bitpacking/compute/cast.rs @@ -11,6 +11,7 @@ use vortex_array::vtable::ValidityHelper; use vortex_error::VortexResult; use crate::bitpacking::BitPackedArray; +use crate::bitpacking::BitPackedArrayExt; use crate::bitpacking::BitPackedVTable; impl CastReduce for BitPackedVTable { @@ -21,7 +22,7 @@ impl CastReduce for BitPackedVTable { .clone() .cast_nullability(dtype.nullability(), array.len())?; return Ok(Some( - BitPackedArray::try_new( + BitPackedVTable::try_new( array.packed().clone(), dtype.as_ptype(), new_validity, @@ -64,11 +65,12 @@ mod tests { use vortex_buffer::buffer; use crate::BitPackedArray; + use crate::BitPackedVTable; #[test] fn test_cast_bitpacked_u8_to_u32() { let packed = - BitPackedArray::encode(&buffer![10u8, 20, 30, 40, 50, 60].into_array(), 6).unwrap(); + BitPackedVTable::encode(&buffer![10u8, 20, 30, 40, 50, 60].into_array(), 6).unwrap(); let casted = packed .into_array() @@ -88,7 +90,7 @@ mod tests { #[test] fn test_cast_bitpacked_nullable() { let values = PrimitiveArray::from_option_iter([Some(5u16), None, Some(10), Some(15), None]); - let packed = BitPackedArray::encode(&values.into_array(), 4).unwrap(); + let packed = BitPackedVTable::encode(&values.into_array(), 4).unwrap(); let casted = packed .into_array() @@ -101,10 +103,10 @@ mod tests { } #[rstest] - #[case(BitPackedArray::encode(&buffer![0u8, 10, 20, 30, 40, 50, 60, 63].into_array(), 6).unwrap())] - #[case(BitPackedArray::encode(&buffer![0u16, 100, 200, 300, 400, 500].into_array(), 9).unwrap())] - #[case(BitPackedArray::encode(&buffer![0u32, 1000, 2000, 3000, 4000].into_array(), 12).unwrap())] - #[case(BitPackedArray::encode(&PrimitiveArray::from_option_iter([Some(1u32), None, Some(7), Some(15), None]).into_array(), 4).unwrap())] + #[case(BitPackedVTable::encode(&buffer![0u8, 10, 20, 30, 40, 50, 60, 63].into_array(), 6).unwrap())] + #[case(BitPackedVTable::encode(&buffer![0u16, 100, 200, 300, 400, 500].into_array(), 9).unwrap())] + #[case(BitPackedVTable::encode(&buffer![0u32, 1000, 2000, 3000, 4000].into_array(), 12).unwrap())] + #[case(BitPackedVTable::encode(&PrimitiveArray::from_option_iter([Some(1u32), None, Some(7), Some(15), None]).into_array(), 4).unwrap())] fn test_cast_bitpacked_conformance(#[case] array: BitPackedArray) { test_cast_conformance(&array.into_array()); } diff --git a/encodings/fastlanes/src/bitpacking/compute/filter.rs b/encodings/fastlanes/src/bitpacking/compute/filter.rs index b2445c2de2c..5a34e78c16f 100644 --- a/encodings/fastlanes/src/bitpacking/compute/filter.rs +++ b/encodings/fastlanes/src/bitpacking/compute/filter.rs @@ -24,6 +24,7 @@ use vortex_mask::MaskValues; use super::chunked_indices; use super::take::UNPACK_CHUNK_THRESHOLD; use crate::BitPackedArray; +use crate::BitPackedArrayExt; use crate::BitPackedVTable; /// The threshold over which it is faster to fully unpack the entire [`BitPackedArray`] and then @@ -172,13 +173,14 @@ mod test { use vortex_buffer::buffer; use vortex_mask::Mask; - use crate::BitPackedArray; + use crate::BitPackedArrayExt; + use crate::BitPackedVTable; #[test] fn take_indices() { // Create a u8 array modulo 63. let unpacked = PrimitiveArray::from_iter((0..4096).map(|i| (i % 63) as u8)); - let bitpacked = BitPackedArray::encode(&unpacked.into_array(), 6).unwrap(); + let bitpacked = BitPackedVTable::encode(&unpacked.into_array(), 6).unwrap(); let mask = Mask::from_indices(bitpacked.len(), vec![0, 125, 2047, 2049, 2151, 2790]); @@ -193,7 +195,7 @@ mod test { fn take_sliced_indices() { // Create a u8 array modulo 63. let unpacked = PrimitiveArray::from_iter((0..4096).map(|i| (i % 63) as u8)); - let bitpacked = BitPackedArray::encode(&unpacked.into_array(), 6).unwrap(); + let bitpacked = BitPackedVTable::encode(&unpacked.into_array(), 6).unwrap(); let sliced = bitpacked.slice(128..2050).unwrap(); let mask = Mask::from_indices(sliced.len(), vec![1919, 1921]); @@ -205,7 +207,7 @@ mod test { #[test] fn filter_bitpacked() { let unpacked = PrimitiveArray::from_iter((0..4096).map(|i| (i % 63) as u8)); - let bitpacked = BitPackedArray::encode(&unpacked.into_array(), 6).unwrap(); + let bitpacked = BitPackedVTable::encode(&unpacked.into_array(), 6).unwrap(); let filtered = bitpacked .filter(Mask::from_indices(4096, (0..1024).collect())) .unwrap(); @@ -219,7 +221,7 @@ mod test { fn filter_bitpacked_signed() { let values: Buffer = (0..500).collect(); let unpacked = PrimitiveArray::new(values.clone(), Validity::NonNullable); - let bitpacked = BitPackedArray::encode(&unpacked.into_array(), 9).unwrap(); + let bitpacked = BitPackedVTable::encode(&unpacked.into_array(), 9).unwrap(); let filtered = bitpacked .filter(Mask::from_indices(values.len(), (0..250).collect())) .unwrap() @@ -235,17 +237,17 @@ mod test { fn test_filter_bitpacked_conformance() { // Test with u8 values let unpacked = buffer![1u8, 2, 3, 4, 5].into_array(); - let bitpacked = BitPackedArray::encode(&unpacked, 3).unwrap(); + let bitpacked = BitPackedVTable::encode(&unpacked, 3).unwrap(); test_filter_conformance(&bitpacked.into_array()); // Test with u32 values let unpacked = buffer![100u32, 200, 300, 400, 500].into_array(); - let bitpacked = BitPackedArray::encode(&unpacked, 9).unwrap(); + let bitpacked = BitPackedVTable::encode(&unpacked, 9).unwrap(); test_filter_conformance(&bitpacked.into_array()); // Test with nullable values let unpacked = PrimitiveArray::from_option_iter([Some(1u16), None, Some(3), Some(4), None]); - let bitpacked = BitPackedArray::encode(&unpacked.into_array(), 3).unwrap(); + let bitpacked = BitPackedVTable::encode(&unpacked.into_array(), 3).unwrap(); test_filter_conformance(&bitpacked.into_array()); } @@ -260,7 +262,7 @@ mod test { // Values 0-127 fit in 7 bits, but 1000 and 2000 do not. let values: Vec = vec![0, 10, 1000, 20, 30, 2000, 40, 50, 60, 70]; let unpacked = PrimitiveArray::from_iter(values.clone()); - let bitpacked = BitPackedArray::encode(&unpacked.into_array(), 7).unwrap(); + let bitpacked = BitPackedVTable::encode(&unpacked.into_array(), 7).unwrap(); assert!( bitpacked.patches().is_some(), "Expected patches for values exceeding bit width" @@ -292,7 +294,7 @@ mod test { }) .collect(); let unpacked = PrimitiveArray::from_iter(values.clone()); - let bitpacked = BitPackedArray::encode(&unpacked.into_array(), 7).unwrap(); + let bitpacked = BitPackedVTable::encode(&unpacked.into_array(), 7).unwrap(); assert!( bitpacked.patches().is_some(), "Expected patches for values exceeding bit width" diff --git a/encodings/fastlanes/src/bitpacking/compute/is_constant.rs b/encodings/fastlanes/src/bitpacking/compute/is_constant.rs index a5098773283..40705d40da7 100644 --- a/encodings/fastlanes/src/bitpacking/compute/is_constant.rs +++ b/encodings/fastlanes/src/bitpacking/compute/is_constant.rs @@ -19,6 +19,7 @@ use vortex_array::register_kernel; use vortex_error::VortexResult; use crate::BitPackedArray; +use crate::BitPackedArrayExt; use crate::BitPackedVTable; use crate::unpack_iter::BitPacked; @@ -173,11 +174,11 @@ mod tests { use vortex_array::compute::is_constant; use vortex_buffer::buffer; - use crate::BitPackedArray; + use crate::BitPackedVTable; #[test] fn is_constant_with_patches() { - let array = BitPackedArray::encode(&buffer![4; 1025].into_array(), 2).unwrap(); + let array = BitPackedVTable::encode(&buffer![4; 1025].into_array(), 2).unwrap(); assert!(is_constant(&array.into_array()).unwrap().unwrap()); } } diff --git a/encodings/fastlanes/src/bitpacking/compute/slice.rs b/encodings/fastlanes/src/bitpacking/compute/slice.rs index 4ec64d05514..4012381b005 100644 --- a/encodings/fastlanes/src/bitpacking/compute/slice.rs +++ b/encodings/fastlanes/src/bitpacking/compute/slice.rs @@ -11,6 +11,7 @@ use vortex_array::arrays::slice::SliceKernel; use vortex_error::VortexResult; use crate::BitPackedArray; +use crate::BitPackedArrayExt; use crate::BitPackedVTable; impl SliceKernel for BitPackedVTable { @@ -63,6 +64,7 @@ mod tests { use vortex_error::VortexResult; use vortex_session::VortexSession; + use crate::BitPackedArrayExt; use crate::BitPackedVTable; use crate::bitpack_compress::bitpack_encode; diff --git a/encodings/fastlanes/src/bitpacking/compute/take.rs b/encodings/fastlanes/src/bitpacking/compute/take.rs index 4b71dad9a87..3ea7fd3a704 100644 --- a/encodings/fastlanes/src/bitpacking/compute/take.rs +++ b/encodings/fastlanes/src/bitpacking/compute/take.rs @@ -26,6 +26,7 @@ use vortex_error::VortexResult; use super::chunked_indices; use crate::BitPackedArray; +use crate::BitPackedArrayExt; use crate::BitPackedVTable; use crate::bitpack_decompress; @@ -163,6 +164,8 @@ mod test { use vortex_buffer::buffer; use crate::BitPackedArray; + use crate::BitPackedArrayExt; + use crate::BitPackedVTable; use crate::bitpacking::compute::take::take_primitive; #[test] @@ -171,7 +174,7 @@ mod test { // Create a u8 array modulo 63. let unpacked = PrimitiveArray::from_iter((0..4096).map(|i| (i % 63) as u8)); - let bitpacked = BitPackedArray::encode(&unpacked.into_array(), 6).unwrap(); + let bitpacked = BitPackedVTable::encode(&unpacked.into_array(), 6).unwrap(); let primitive_result = bitpacked.take(indices.to_array()).unwrap(); assert_arrays_eq!( @@ -183,7 +186,7 @@ mod test { #[test] fn take_with_patches() { let unpacked = Buffer::from_iter(0u32..1024).into_array(); - let bitpacked = BitPackedArray::encode(&unpacked, 2).unwrap(); + let bitpacked = BitPackedVTable::encode(&unpacked, 2).unwrap(); let indices = buffer![0, 2, 4, 6].into_array(); @@ -197,7 +200,7 @@ mod test { // Create a u8 array modulo 63. let unpacked = PrimitiveArray::from_iter((0..4096).map(|i| (i % 63) as u8)); - let bitpacked = BitPackedArray::encode(&unpacked.into_array(), 6).unwrap(); + let bitpacked = BitPackedVTable::encode(&unpacked.into_array(), 6).unwrap(); let sliced = bitpacked.slice(128..2050).unwrap(); let primitive_result = sliced.take(indices.to_array()).unwrap(); @@ -210,7 +213,7 @@ mod test { let num_patches: usize = 128; let values = (0..u16::MAX as u32 + num_patches as u32).collect::>(); let uncompressed = PrimitiveArray::new(values.clone(), Validity::NonNullable); - let packed = BitPackedArray::encode(&uncompressed.into_array(), 16).unwrap(); + let packed = BitPackedVTable::encode(&uncompressed.into_array(), 16).unwrap(); assert!(packed.patches().is_some()); let rng = rng(); @@ -240,7 +243,7 @@ mod test { #[cfg_attr(miri, ignore)] fn take_signed_with_patches() { let start = - BitPackedArray::encode(&buffer![1i32, 2i32, 3i32, 4i32].into_array(), 1).unwrap(); + BitPackedVTable::encode(&buffer![1i32, 2i32, 3i32, 4i32].into_array(), 1).unwrap(); let taken_primitive = take_primitive::( &start, @@ -255,7 +258,7 @@ mod test { #[test] fn take_nullable_with_nullables() { let start = - BitPackedArray::encode(&buffer![1i32, 2i32, 3i32, 4i32].into_array(), 1).unwrap(); + BitPackedVTable::encode(&buffer![1i32, 2i32, 3i32, 4i32].into_array(), 1).unwrap(); let taken_primitive = start .take( @@ -270,15 +273,15 @@ mod test { } #[rstest] - #[case(BitPackedArray::encode(&PrimitiveArray::from_iter((0..100).map(|i| (i % 63) as u8)).into_array(), 6).unwrap())] - #[case(BitPackedArray::encode(&PrimitiveArray::from_iter((0..256).map(|i| i as u32)).into_array(), 8).unwrap())] - #[case(BitPackedArray::encode(&buffer![1i32, 2, 3, 4, 5, 6, 7, 8].into_array(), 3).unwrap())] - #[case(BitPackedArray::encode( + #[case(BitPackedVTable::encode(&PrimitiveArray::from_iter((0..100).map(|i| (i % 63) as u8)).into_array(), 6).unwrap())] + #[case(BitPackedVTable::encode(&PrimitiveArray::from_iter((0..256).map(|i| i as u32)).into_array(), 8).unwrap())] + #[case(BitPackedVTable::encode(&buffer![1i32, 2, 3, 4, 5, 6, 7, 8].into_array(), 3).unwrap())] + #[case(BitPackedVTable::encode( &PrimitiveArray::from_option_iter([Some(10u16), None, Some(20), Some(30), None]).into_array(), 5 ).unwrap())] - #[case(BitPackedArray::encode(&buffer![42u32].into_array(), 6).unwrap())] - #[case(BitPackedArray::encode(&PrimitiveArray::from_iter((0..1024).map(|i| i as u32)).into_array(), 8).unwrap())] + #[case(BitPackedVTable::encode(&buffer![42u32].into_array(), 6).unwrap())] + #[case(BitPackedVTable::encode(&PrimitiveArray::from_iter((0..1024).map(|i| i as u32)).into_array(), 8).unwrap())] fn test_take_bitpacked_conformance(#[case] bitpacked: BitPackedArray) { use vortex_array::compute::conformance::take::test_take_conformance; test_take_conformance(&bitpacked.into_array()); diff --git a/encodings/fastlanes/src/bitpacking/mod.rs b/encodings/fastlanes/src/bitpacking/mod.rs index d2b6c38f6e2..58c58f9eb4e 100644 --- a/encodings/fastlanes/src/bitpacking/mod.rs +++ b/encodings/fastlanes/src/bitpacking/mod.rs @@ -3,6 +3,7 @@ mod array; pub use array::BitPackedArray; +pub use array::BitPackedArrayExt; pub use array::BitPackedArrayParts; pub use array::bitpack_compress; pub use array::bitpack_decompress; diff --git a/encodings/fastlanes/src/bitpacking/vtable/mod.rs b/encodings/fastlanes/src/bitpacking/vtable/mod.rs index 77cf102e440..49750e8eb27 100644 --- a/encodings/fastlanes/src/bitpacking/vtable/mod.rs +++ b/encodings/fastlanes/src/bitpacking/vtable/mod.rs @@ -40,6 +40,7 @@ use vortex_error::vortex_panic; use vortex_session::VortexSession; use crate::BitPackedArray; +use crate::BitPackedArrayExt; use crate::bitpack_decompress::unpack_array; use crate::bitpack_decompress::unpack_into_primitive_builder; use crate::bitpacking::vtable::kernels::PARENT_KERNELS; @@ -315,7 +316,7 @@ impl VTable for BitPackedVTable { }) .transpose()?; - BitPackedArray::try_new( + Self::try_new( packed, PType::try_from(dtype)?, validity, @@ -372,4 +373,80 @@ pub struct BitPackedVTable; impl BitPackedVTable { pub const ID: ArrayId = ArrayId::new_ref("fastlanes.bitpacked"); + + /// A safe constructor for a `BitPackedArray` from its components: + /// + /// * `packed` is ByteBuffer holding the compressed data that was packed with FastLanes + /// bit-packing to a `bit_width` bits per value. `length` is the length of the original + /// vector. Note that the packed is padded with zeros to the next multiple of 1024 elements + /// if `length` is not divisible by 1024. + /// * `ptype` of the original data + /// * `validity` to track any nulls + /// * `patches` optionally provided for values that did not pack + /// + /// Any failure in validation will result in an error. + /// + /// # Validation + /// + /// * The `ptype` must be an integer + /// * `validity` must have `length` len + /// * Any patches must have any `array_len` equal to `length` + /// * The `packed` buffer must be exactly sized to hold `length` values of `bit_width` rounded + /// up to the next multiple of 1024. + /// + /// Any violation of these preconditions will result in an error. + pub fn try_new( + packed: BufferHandle, + ptype: PType, + validity: Validity, + patches: Option, + bit_width: u8, + length: usize, + offset: u16, + ) -> VortexResult { + use crate::bitpacking::array::validate; + + validate( + &packed, + ptype, + &validity, + patches.as_ref(), + bit_width, + length, + offset, + )?; + + let dtype = DType::Primitive(ptype, validity.nullability()); + + // SAFETY: all components validated above + unsafe { + Ok(BitPackedArray::new_unchecked( + packed, dtype, validity, patches, bit_width, length, offset, + )) + } + } + + /// Bit-pack an array of primitive integers down to the target bit-width using the FastLanes + /// SIMD-accelerated packing kernels. + /// + /// # Errors + /// + /// If the provided array is not an integer type, an error will be returned. + /// + /// If the provided array contains negative values, an error will be returned. + /// + /// If the requested bit-width for packing is larger than the array's native width, an + /// error will be returned. + // FIXME(ngates): take a PrimitiveArray + pub fn encode(array: &ArrayRef, bit_width: u8) -> VortexResult { + use vortex_array::arrays::PrimitiveVTable; + + use crate::bitpack_compress::bitpack_encode; + + if let Some(parray) = array.as_opt::() { + bitpack_encode(parray, bit_width, None) + } else { + vortex_bail!(InvalidArgument: "Bitpacking can only encode primitive arrays"); + } + } } diff --git a/encodings/fastlanes/src/bitpacking/vtable/operations.rs b/encodings/fastlanes/src/bitpacking/vtable/operations.rs index 86cad42c433..cd4e6de0850 100644 --- a/encodings/fastlanes/src/bitpacking/vtable/operations.rs +++ b/encodings/fastlanes/src/bitpacking/vtable/operations.rs @@ -6,6 +6,7 @@ use vortex_array::vtable::OperationsVTable; use vortex_error::VortexResult; use crate::BitPackedArray; +use crate::BitPackedArrayExt; use crate::BitPackedVTable; use crate::bitpack_decompress; @@ -50,6 +51,7 @@ mod test { use vortex_buffer::buffer; use crate::BitPackedArray; + use crate::BitPackedArrayExt; use crate::BitPackedVTable; static SESSION: LazyLock = @@ -71,7 +73,7 @@ mod test { #[test] pub fn slice_block() { - let arr = BitPackedArray::encode( + let arr = BitPackedVTable::encode( &PrimitiveArray::from_iter((0u32..2048).map(|v| v % 64)).into_array(), 6, ) @@ -85,7 +87,7 @@ mod test { #[test] pub fn slice_within_block() { - let arr = BitPackedArray::encode( + let arr = BitPackedVTable::encode( &PrimitiveArray::from_iter((0u32..2048).map(|v| v % 64)).into_array(), 6, ) @@ -99,7 +101,7 @@ mod test { #[test] fn slice_within_block_u8s() { - let packed = BitPackedArray::encode( + let packed = BitPackedVTable::encode( &PrimitiveArray::from_iter((0..10_000).map(|i| (i % 63) as u8)).into_array(), 7, ) @@ -112,7 +114,7 @@ mod test { #[test] fn slice_block_boundary_u8s() { - let packed = BitPackedArray::encode( + let packed = BitPackedVTable::encode( &PrimitiveArray::from_iter((0..10_000).map(|i| (i % 63) as u8)).into_array(), 7, ) @@ -125,7 +127,7 @@ mod test { #[test] fn double_slice_within_block() { - let arr = BitPackedArray::encode( + let arr = BitPackedVTable::encode( &PrimitiveArray::from_iter((0u32..2048).map(|v| v % 64)).into_array(), 6, ) @@ -145,7 +147,7 @@ mod test { #[test] fn slice_empty_patches() { // We create an array that has 1 element that does not fit in the 6-bit range. - let array = BitPackedArray::encode(&buffer![0u32..=64].into_array(), 6).unwrap(); + let array = BitPackedVTable::encode(&buffer![0u32..=64].into_array(), 6).unwrap(); assert!(array.patches().is_some()); @@ -161,7 +163,7 @@ mod test { fn take_after_slice() { // Check that our take implementation respects the offsets applied after slicing. - let array = BitPackedArray::encode( + let array = BitPackedVTable::encode( &PrimitiveArray::from_iter((63u32..).take(3072)).into_array(), 6, ) @@ -218,7 +220,7 @@ mod test { fn scalar_at() { let values = (0u32..257).collect::>(); let uncompressed = values.clone().into_array(); - let packed = BitPackedArray::encode(&uncompressed, 8).unwrap(); + let packed = BitPackedVTable::encode(&uncompressed, 8).unwrap(); assert!(packed.patches().is_some()); let patches = packed.patches().unwrap().indices().clone(); diff --git a/encodings/fastlanes/src/delta/array/delta_compress.rs b/encodings/fastlanes/src/delta/array/delta_compress.rs index f15c3ad8d58..9d9c6578059 100644 --- a/encodings/fastlanes/src/delta/array/delta_compress.rs +++ b/encodings/fastlanes/src/delta/array/delta_compress.rs @@ -101,7 +101,7 @@ mod tests { use vortex_error::VortexResult; use vortex_session::VortexSession; - use crate::DeltaArray; + use crate::DeltaVTable; use crate::delta::array::delta_decompress::delta_decompress; static SESSION: LazyLock = @@ -125,7 +125,7 @@ mod tests { } fn do_roundtrip_test(input: PrimitiveArray) -> VortexResult<()> { - let delta = DeltaArray::try_from_primitive_array(&input)?; + let delta = DeltaVTable::try_from_primitive_array(&input)?; assert_eq!(delta.len(), input.len()); let decompressed = delta_decompress(&delta, &mut SESSION.create_execution_ctx())?; assert_arrays_eq!(decompressed, input); diff --git a/encodings/fastlanes/src/delta/array/delta_decompress.rs b/encodings/fastlanes/src/delta/array/delta_decompress.rs index f3b3868fde7..63bc27c89ad 100644 --- a/encodings/fastlanes/src/delta/array/delta_decompress.rs +++ b/encodings/fastlanes/src/delta/array/delta_decompress.rs @@ -17,6 +17,7 @@ use vortex_buffer::BufferMut; use vortex_error::VortexResult; use crate::DeltaArray; +use crate::DeltaArrayExt; pub fn delta_decompress( array: &DeltaArray, diff --git a/encodings/fastlanes/src/delta/array/mod.rs b/encodings/fastlanes/src/delta/array/mod.rs index e83f8307cc4..1c52249a4c1 100644 --- a/encodings/fastlanes/src/delta/array/mod.rs +++ b/encodings/fastlanes/src/delta/array/mod.rs @@ -4,18 +4,11 @@ use fastlanes::FastLanes; use vortex_array::ArrayCommon; use vortex_array::ArrayRef; -use vortex_array::IntoArray; -use vortex_array::arrays::PrimitiveArray; use vortex_array::dtype::DType; -use vortex_array::dtype::NativePType; use vortex_array::dtype::PType; use vortex_array::match_each_unsigned_integer_ptype; use vortex_array::stats::ArrayStats; -use vortex_array::validity::Validity; -use vortex_buffer::Buffer; use vortex_error::VortexExpect as _; -use vortex_error::VortexResult; -use vortex_error::vortex_bail; pub mod delta_compress; pub mod delta_decompress; @@ -28,8 +21,8 @@ pub mod delta_decompress; /// # Examples /// /// ``` -/// use vortex_fastlanes::DeltaArray; -/// let array = DeltaArray::try_from_vec(vec![1_u32, 2, 3, 5, 10, 11]).unwrap(); +/// use vortex_fastlanes::DeltaVTable; +/// let array = DeltaVTable::try_from_vec(vec![1_u32, 2, 3, 5, 10, 11]).unwrap(); /// ``` /// /// # Details @@ -62,78 +55,56 @@ pub struct DeltaArray { pub(super) deltas: ArrayRef, } -impl DeltaArray { - // TODO(ngates): remove constructing from vec - pub fn try_from_vec(vec: Vec) -> VortexResult { - Self::try_from_primitive_array(&PrimitiveArray::new( - Buffer::copy_from(vec), - Validity::NonNullable, - )) - } +/// Extension trait for [`DeltaArray`] methods. +pub trait DeltaArrayExt { + fn bases(&self) -> &ArrayRef; + + fn deltas(&self) -> &ArrayRef; + + fn len(&self) -> usize; - pub fn try_from_primitive_array(array: &PrimitiveArray) -> VortexResult { - let (bases, deltas) = delta_compress::delta_compress(array)?; + fn is_empty(&self) -> bool; + + fn dtype(&self) -> &DType; + + /// The logical offset into the first chunk of [`Self::deltas`]. + fn offset(&self) -> usize; +} - Self::try_from_delta_compress_parts(bases.into_array(), deltas.into_array()) +impl DeltaArrayExt for DeltaArray { + #[inline] + fn bases(&self) -> &ArrayRef { + &self.bases } - /// Create a [`DeltaArray`] from the given `bases` and `deltas` arrays. - /// Note the `deltas` might be nullable - pub fn try_from_delta_compress_parts(bases: ArrayRef, deltas: ArrayRef) -> VortexResult { - let logical_len = deltas.len(); - Self::try_new(bases, deltas, 0, logical_len) + #[inline] + fn deltas(&self) -> &ArrayRef { + &self.deltas } - pub fn try_new( - bases: ArrayRef, - deltas: ArrayRef, - offset: usize, - logical_len: usize, - ) -> VortexResult { - if offset >= 1024 { - vortex_bail!("offset must be less than 1024: {}", offset); - } - if offset + logical_len > deltas.len() { - vortex_bail!( - "offset + logical_len, {} + {}, must be less than or equal to the size of deltas: {}", - offset, - logical_len, - deltas.len() - ) - } - if !bases.dtype().eq_ignore_nullability(deltas.dtype()) { - vortex_bail!( - "DeltaArray: bases and deltas must have the same dtype, got {:?} and {:?}", - bases.dtype(), - deltas.dtype() - ); - } - let DType::Primitive(ptype, _) = bases.dtype().clone() else { - vortex_bail!( - "DeltaArray: dtype must be an integer, got {}", - bases.dtype() - ); - }; - - if !ptype.is_int() { - vortex_bail!("DeltaArray: ptype must be an integer, got {}", ptype); - } + #[inline] + fn len(&self) -> usize { + self.common.len() + } - let lanes = lane_count(ptype); + #[inline] + fn is_empty(&self) -> bool { + self.common.len() == 0 + } - if deltas.len().is_multiple_of(1024) != bases.len().is_multiple_of(lanes) { - vortex_bail!( - "deltas length ({}) is a multiple of 1024 iff bases length ({}) is a multiple of LANES ({})", - deltas.len(), - bases.len(), - lanes, - ); - } + #[inline] + fn dtype(&self) -> &DType { + self.common.dtype() + } - // SAFETY: validation done above - Ok(unsafe { Self::new_unchecked(bases, deltas, offset, logical_len) }) + #[inline] + /// The logical offset into the first chunk of [`Self::deltas`]. + fn offset(&self) -> usize { + self.offset } +} +impl DeltaArray { pub(crate) unsafe fn new_unchecked( bases: ArrayRef, deltas: ArrayRef, @@ -149,16 +120,6 @@ impl DeltaArray { } } - #[inline] - pub fn bases(&self) -> &ArrayRef { - &self.bases - } - - #[inline] - pub fn deltas(&self) -> &ArrayRef { - &self.deltas - } - #[inline] pub(crate) fn lanes(&self) -> usize { let ptype = @@ -166,27 +127,6 @@ impl DeltaArray { lane_count(ptype) } - #[inline] - pub fn len(&self) -> usize { - self.common.len() - } - - #[inline] - pub fn is_empty(&self) -> bool { - self.common.len() == 0 - } - - #[inline] - pub fn dtype(&self) -> &DType { - self.common.dtype() - } - - #[inline] - /// The logical offset into the first chunk of [`Self::deltas`]. - pub fn offset(&self) -> usize { - self.offset - } - #[inline] pub(crate) fn bases_len(&self) -> usize { self.bases.len() diff --git a/encodings/fastlanes/src/delta/compute/cast.rs b/encodings/fastlanes/src/delta/compute/cast.rs index 70640c6c79b..70c8d868e93 100644 --- a/encodings/fastlanes/src/delta/compute/cast.rs +++ b/encodings/fastlanes/src/delta/compute/cast.rs @@ -11,6 +11,7 @@ use vortex_error::VortexResult; use vortex_error::vortex_panic; use crate::delta::DeltaArray; +use crate::delta::DeltaArrayExt; use crate::delta::DeltaVTable; impl CastReduce for DeltaVTable { @@ -38,7 +39,7 @@ impl CastReduce for DeltaVTable { // Create a new DeltaArray with the casted components Ok(Some( - DeltaArray::try_from_delta_compress_parts(casted_bases, casted_deltas)?.into_array(), + DeltaVTable::try_from_delta_compress_parts(casted_bases, casted_deltas)?.into_array(), )) } } @@ -56,12 +57,12 @@ mod tests { use vortex_array::dtype::PType; use vortex_buffer::buffer; - use crate::delta::DeltaArray; + use crate::delta::DeltaVTable; #[test] fn test_cast_delta_u8_to_u32() { let primitive = PrimitiveArray::from_iter([10u8, 20, 30, 40, 50]); - let array = DeltaArray::try_from_primitive_array(&primitive).unwrap(); + let array = DeltaVTable::try_from_primitive_array(&primitive).unwrap(); let casted = array .into_array() @@ -84,7 +85,7 @@ mod tests { buffer![100u16, 0, 200, 300, 0], vortex_array::validity::Validity::NonNullable, ); - let array = DeltaArray::try_from_primitive_array(&values).unwrap(); + let array = DeltaVTable::try_from_primitive_array(&values).unwrap(); let casted = array .into_array() @@ -122,7 +123,7 @@ mod tests { ) )] fn test_cast_delta_conformance(#[case] primitive: PrimitiveArray) { - let delta_array = DeltaArray::try_from_primitive_array(&primitive).unwrap(); + let delta_array = DeltaVTable::try_from_primitive_array(&primitive).unwrap(); test_cast_conformance(&delta_array.into_array()); } } diff --git a/encodings/fastlanes/src/delta/mod.rs b/encodings/fastlanes/src/delta/mod.rs index 3a8e08bbde5..7fd9cdb9aa0 100644 --- a/encodings/fastlanes/src/delta/mod.rs +++ b/encodings/fastlanes/src/delta/mod.rs @@ -3,6 +3,7 @@ mod array; pub use array::DeltaArray; +pub use array::DeltaArrayExt; pub use array::delta_compress::delta_compress; mod compute; diff --git a/encodings/fastlanes/src/delta/vtable/mod.rs b/encodings/fastlanes/src/delta/vtable/mod.rs index 96c0e0dd248..15cdfe0d519 100644 --- a/encodings/fastlanes/src/delta/vtable/mod.rs +++ b/encodings/fastlanes/src/delta/vtable/mod.rs @@ -29,6 +29,7 @@ use vortex_error::vortex_panic; use vortex_session::VortexSession; use crate::DeltaArray; +use crate::DeltaArrayExt; use crate::delta::array::delta_decompress::delta_decompress; mod operations; @@ -186,7 +187,7 @@ impl VTable for DeltaVTable { let bases = children.get(0, dtype, bases_len)?; let deltas = children.get(1, dtype, deltas_len)?; - DeltaArray::try_new(bases, deltas, metadata.0.offset as usize, len) + Self::try_new(bases, deltas, metadata.0.offset as usize, len) } fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { @@ -199,6 +200,97 @@ pub struct DeltaVTable; impl DeltaVTable { pub const ID: ArrayId = ArrayId::new_ref("fastlanes.delta"); + + // TODO(ngates): remove constructing from vec + pub fn try_from_vec( + vec: Vec, + ) -> VortexResult { + use vortex_array::arrays::PrimitiveArray; + use vortex_array::validity::Validity; + use vortex_buffer::Buffer; + + Self::try_from_primitive_array(&PrimitiveArray::new( + Buffer::copy_from(vec), + Validity::NonNullable, + )) + } + + pub fn try_from_primitive_array( + array: &vortex_array::arrays::PrimitiveArray, + ) -> VortexResult { + use vortex_array::IntoArray; + + use crate::delta::array::delta_compress; + + let (bases, deltas) = delta_compress::delta_compress(array)?; + + Self::try_from_delta_compress_parts(bases.into_array(), deltas.into_array()) + } + + /// Create a [`DeltaArray`] from the given `bases` and `deltas` arrays. + /// Note the `deltas` might be nullable + pub fn try_from_delta_compress_parts( + bases: ArrayRef, + deltas: ArrayRef, + ) -> VortexResult { + let logical_len = deltas.len(); + Self::try_new(bases, deltas, 0, logical_len) + } + + pub fn try_new( + bases: ArrayRef, + deltas: ArrayRef, + offset: usize, + logical_len: usize, + ) -> VortexResult { + use vortex_array::dtype::DType; + use vortex_error::vortex_bail; + + use crate::delta::array::lane_count; + + if offset >= 1024 { + vortex_bail!("offset must be less than 1024: {}", offset); + } + if offset + logical_len > deltas.len() { + vortex_bail!( + "offset + logical_len, {} + {}, must be less than or equal to the size of deltas: {}", + offset, + logical_len, + deltas.len() + ) + } + if !bases.dtype().eq_ignore_nullability(deltas.dtype()) { + vortex_bail!( + "DeltaArray: bases and deltas must have the same dtype, got {:?} and {:?}", + bases.dtype(), + deltas.dtype() + ); + } + let DType::Primitive(ptype, _) = bases.dtype().clone() else { + vortex_bail!( + "DeltaArray: dtype must be an integer, got {}", + bases.dtype() + ); + }; + + if !ptype.is_int() { + vortex_bail!("DeltaArray: ptype must be an integer, got {}", ptype); + } + + let lanes = lane_count(ptype); + + if deltas.len().is_multiple_of(1024) != bases.len().is_multiple_of(lanes) { + vortex_bail!( + "deltas length ({}) is a multiple of 1024 iff bases length ({}) is a multiple of LANES ({})", + deltas.len(), + bases.len(), + lanes, + ); + } + + // SAFETY: validation done above + Ok(unsafe { DeltaArray::new_unchecked(bases, deltas, offset, logical_len) }) + } } #[cfg(test)] diff --git a/encodings/fastlanes/src/delta/vtable/operations.rs b/encodings/fastlanes/src/delta/vtable/operations.rs index ea9107ed534..6ad7b8a10a2 100644 --- a/encodings/fastlanes/src/delta/vtable/operations.rs +++ b/encodings/fastlanes/src/delta/vtable/operations.rs @@ -26,10 +26,11 @@ mod tests { use vortex_array::compute::conformance::consistency::test_array_consistency; use crate::DeltaArray; + use crate::DeltaVTable; #[test] fn test_slice_non_jagged_array_first_chunk_of_two() { - let delta = DeltaArray::try_from_vec((0u32..2048).collect()).unwrap(); + let delta = DeltaVTable::try_from_vec((0u32..2048).collect()).unwrap(); let actual = delta.slice(10..250).unwrap(); let expected = PrimitiveArray::from_iter(10u32..250).into_array(); @@ -38,7 +39,7 @@ mod tests { #[test] fn test_slice_non_jagged_array_second_chunk_of_two() { - let delta = DeltaArray::try_from_vec((0u32..2048).collect()).unwrap(); + let delta = DeltaVTable::try_from_vec((0u32..2048).collect()).unwrap(); let actual = delta.slice(1024 + 10..1024 + 250).unwrap(); let expected = PrimitiveArray::from_iter((1024 + 10u32)..(1024 + 250)).into_array(); @@ -47,7 +48,7 @@ mod tests { #[test] fn test_slice_non_jagged_array_span_two_chunks_chunk_of_two() { - let delta = DeltaArray::try_from_vec((0u32..2048).collect()).unwrap(); + let delta = DeltaVTable::try_from_vec((0u32..2048).collect()).unwrap(); let actual = delta.slice(1000..1048).unwrap(); let expected = PrimitiveArray::from_iter(1000u32..1048).into_array(); @@ -56,7 +57,7 @@ mod tests { #[test] fn test_slice_non_jagged_array_span_two_chunks_chunk_of_four() { - let delta = DeltaArray::try_from_vec((0u32..4096).collect()).unwrap(); + let delta = DeltaVTable::try_from_vec((0u32..4096).collect()).unwrap(); let actual = delta.slice(2040..2050).unwrap(); let expected = PrimitiveArray::from_iter(2040u32..2050).into_array(); @@ -65,7 +66,7 @@ mod tests { #[test] fn test_slice_non_jagged_array_whole() { - let delta = DeltaArray::try_from_vec((0u32..4096).collect()).unwrap(); + let delta = DeltaVTable::try_from_vec((0u32..4096).collect()).unwrap(); let actual = delta.slice(0..4096).unwrap(); let expected = PrimitiveArray::from_iter(0u32..4096).into_array(); @@ -74,7 +75,7 @@ mod tests { #[test] fn test_slice_non_jagged_array_empty() { - let delta = DeltaArray::try_from_vec((0u32..4096).collect()).unwrap(); + let delta = DeltaVTable::try_from_vec((0u32..4096).collect()).unwrap(); let actual = delta.slice(0..0).unwrap(); let expected = PrimitiveArray::from_iter(Vec::::new()).into_array(); @@ -91,7 +92,7 @@ mod tests { #[test] fn test_slice_jagged_array_second_chunk_of_two() { - let delta = DeltaArray::try_from_vec((0u32..2000).collect()).unwrap(); + let delta = DeltaVTable::try_from_vec((0u32..2000).collect()).unwrap(); let actual = delta.slice(1024 + 10..1024 + 250).unwrap(); let expected = PrimitiveArray::from_iter((1024 + 10u32)..(1024 + 250)).into_array(); @@ -100,7 +101,7 @@ mod tests { #[test] fn test_slice_jagged_array_empty() { - let delta = DeltaArray::try_from_vec((0u32..4000).collect()).unwrap(); + let delta = DeltaVTable::try_from_vec((0u32..4000).collect()).unwrap(); let actual = delta.slice(0..0).unwrap(); let expected = PrimitiveArray::from_iter(Vec::::new()).into_array(); @@ -117,7 +118,7 @@ mod tests { #[test] fn test_slice_of_slice_of_non_jagged() { - let delta = DeltaArray::try_from_vec((0u32..2048).collect()).unwrap(); + let delta = DeltaVTable::try_from_vec((0u32..2048).collect()).unwrap(); let sliced = delta.slice(10..1013).unwrap(); let sliced_again = sliced.slice(0..2).unwrap(); @@ -128,7 +129,7 @@ mod tests { #[test] fn test_slice_of_slice_of_jagged() { - let delta = DeltaArray::try_from_vec((0u32..2000).collect()).unwrap(); + let delta = DeltaVTable::try_from_vec((0u32..2000).collect()).unwrap(); let sliced = delta.slice(10..1013).unwrap(); let sliced_again = sliced.slice(0..2).unwrap(); @@ -139,7 +140,7 @@ mod tests { #[test] fn test_slice_of_slice_second_chunk_of_non_jagged() { - let delta = DeltaArray::try_from_vec((0u32..2048).collect()).unwrap(); + let delta = DeltaVTable::try_from_vec((0u32..2048).collect()).unwrap(); let sliced = delta.slice(1034..1050).unwrap(); let sliced_again = sliced.slice(0..2).unwrap(); @@ -150,7 +151,7 @@ mod tests { #[test] fn test_slice_of_slice_second_chunk_of_jagged() { - let delta = DeltaArray::try_from_vec((0u32..2000).collect()).unwrap(); + let delta = DeltaVTable::try_from_vec((0u32..2000).collect()).unwrap(); let sliced = delta.slice(1034..1050).unwrap(); let sliced_again = sliced.slice(0..2).unwrap(); @@ -161,7 +162,7 @@ mod tests { #[test] fn test_slice_of_slice_spanning_two_chunks_of_non_jagged() { - let delta = DeltaArray::try_from_vec((0u32..2048).collect()).unwrap(); + let delta = DeltaVTable::try_from_vec((0u32..2048).collect()).unwrap(); let sliced = delta.slice(1010..1050).unwrap(); let sliced_again = sliced.slice(5..20).unwrap(); @@ -172,7 +173,7 @@ mod tests { #[test] fn test_slice_of_slice_spanning_two_chunks_of_jagged() { - let delta = DeltaArray::try_from_vec((0u32..2000).collect()).unwrap(); + let delta = DeltaVTable::try_from_vec((0u32..2000).collect()).unwrap(); let sliced = delta.slice(1010..1050).unwrap(); let sliced_again = sliced.slice(5..20).unwrap(); @@ -183,7 +184,7 @@ mod tests { #[test] fn test_scalar_at_non_jagged_array() { - let delta = DeltaArray::try_from_vec((0u32..2048).collect()) + let delta = DeltaVTable::try_from_vec((0u32..2048).collect()) .unwrap() .into_array(); @@ -194,14 +195,14 @@ mod tests { #[test] #[should_panic] fn test_scalar_at_non_jagged_array_oob() { - let delta = DeltaArray::try_from_vec((0u32..2048).collect()) + let delta = DeltaVTable::try_from_vec((0u32..2048).collect()) .unwrap() .into_array(); delta.scalar_at(2048).unwrap(); } #[test] fn test_scalar_at_jagged_array() { - let delta = DeltaArray::try_from_vec((0u32..2000).collect()) + let delta = DeltaVTable::try_from_vec((0u32..2000).collect()) .unwrap() .into_array(); @@ -212,7 +213,7 @@ mod tests { #[test] #[should_panic] fn test_scalar_at_jagged_array_oob() { - let delta = DeltaArray::try_from_vec((0u32..2000).collect()) + let delta = DeltaVTable::try_from_vec((0u32..2000).collect()) .unwrap() .into_array(); delta.scalar_at(2000).unwrap(); @@ -220,23 +221,23 @@ mod tests { #[rstest] // Basic delta arrays - #[case::delta_u32(DeltaArray::try_from_vec((0u32..100).collect()).unwrap())] - #[case::delta_u64(DeltaArray::try_from_vec((0..100).map(|i| i as u64 * 10).collect()).unwrap())] + #[case::delta_u32(DeltaVTable::try_from_vec((0u32..100).collect()).unwrap())] + #[case::delta_u64(DeltaVTable::try_from_vec((0..100).map(|i| i as u64 * 10).collect()).unwrap())] // Large arrays (multiple chunks) - #[case::delta_large_u32(DeltaArray::try_from_vec((0u32..2048).collect()).unwrap())] - #[case::delta_large_u64(DeltaArray::try_from_vec((0u64..2048).collect()).unwrap())] + #[case::delta_large_u32(DeltaVTable::try_from_vec((0u32..2048).collect()).unwrap())] + #[case::delta_large_u64(DeltaVTable::try_from_vec((0u64..2048).collect()).unwrap())] // Single element - #[case::delta_single(DeltaArray::try_from_vec(vec![42u32]).unwrap())] + #[case::delta_single(DeltaVTable::try_from_vec(vec![42u32]).unwrap())] fn test_delta_consistency(#[case] array: DeltaArray) { test_array_consistency(&array.into_array()); } #[rstest] - #[case::delta_u8_basic(DeltaArray::try_from_vec(vec![1u8, 1, 1, 1, 1]).unwrap())] - #[case::delta_u16_basic(DeltaArray::try_from_vec(vec![1u16, 1, 1, 1, 1]).unwrap())] - #[case::delta_u32_basic(DeltaArray::try_from_vec(vec![1u32, 1, 1, 1, 1]).unwrap())] - #[case::delta_u64_basic(DeltaArray::try_from_vec(vec![1u64, 1, 1, 1, 1]).unwrap())] - #[case::delta_u32_large(DeltaArray::try_from_vec(vec![1u32; 100]).unwrap())] + #[case::delta_u8_basic(DeltaVTable::try_from_vec(vec![1u8, 1, 1, 1, 1]).unwrap())] + #[case::delta_u16_basic(DeltaVTable::try_from_vec(vec![1u16, 1, 1, 1, 1]).unwrap())] + #[case::delta_u32_basic(DeltaVTable::try_from_vec(vec![1u32, 1, 1, 1, 1]).unwrap())] + #[case::delta_u64_basic(DeltaVTable::try_from_vec(vec![1u64, 1, 1, 1, 1]).unwrap())] + #[case::delta_u32_large(DeltaVTable::try_from_vec(vec![1u32; 100]).unwrap())] fn test_delta_binary_numeric(#[case] array: DeltaArray) { test_binary_numeric_array(array.into_array()); } diff --git a/encodings/fastlanes/src/delta/vtable/slice.rs b/encodings/fastlanes/src/delta/vtable/slice.rs index 1878f41f1e0..eed033063fe 100644 --- a/encodings/fastlanes/src/delta/vtable/slice.rs +++ b/encodings/fastlanes/src/delta/vtable/slice.rs @@ -10,6 +10,7 @@ use vortex_array::arrays::slice::SliceReduce; use vortex_error::VortexResult; use crate::DeltaArray; +use crate::DeltaArrayExt; use crate::delta::vtable::DeltaVTable; impl SliceReduce for DeltaVTable { diff --git a/encodings/fastlanes/src/delta/vtable/validity.rs b/encodings/fastlanes/src/delta/vtable/validity.rs index 71b930025c6..54e4dedd8ce 100644 --- a/encodings/fastlanes/src/delta/vtable/validity.rs +++ b/encodings/fastlanes/src/delta/vtable/validity.rs @@ -5,6 +5,7 @@ use vortex_array::ArrayRef; use vortex_array::vtable::ValidityChildSliceHelper; use crate::DeltaArray; +use crate::DeltaArrayExt; impl ValidityChildSliceHelper for DeltaArray { fn unsliced_child_and_slice(&self) -> (&ArrayRef, usize, usize) { diff --git a/encodings/fastlanes/src/for/array/for_compress.rs b/encodings/fastlanes/src/for/array/for_compress.rs index 95277505360..999c3c80b07 100644 --- a/encodings/fastlanes/src/for/array/for_compress.rs +++ b/encodings/fastlanes/src/for/array/for_compress.rs @@ -13,25 +13,24 @@ use vortex_error::VortexResult; use vortex_error::vortex_err; use crate::FoRArray; - -impl FoRArray { - pub fn encode(array: PrimitiveArray) -> VortexResult { - let stats = ArrayStats::from(array.statistics().to_owned()); - let min = array - .statistics() - .compute_stat(Stat::Min)? - .ok_or_else(|| vortex_err!("Min stat not found"))?; - - let encoded = match_each_integer_ptype!(array.ptype(), |T| { - compress_primitive::(array, T::try_from(&min)?)?.into_array() - }); - let for_array = FoRArray::try_new(encoded, min)?; - for_array - .stats_set() - .to_ref(for_array.as_ref()) - .inherit_from(stats.to_ref(for_array.as_ref())); - Ok(for_array) - } +use crate::FoRVTable; + +pub(crate) fn encode_for(array: PrimitiveArray) -> VortexResult { + let stats = ArrayStats::from(array.statistics().to_owned()); + let min = array + .statistics() + .compute_stat(Stat::Min)? + .ok_or_else(|| vortex_err!("Min stat not found"))?; + + let encoded = match_each_integer_ptype!(array.ptype(), |T| { + compress_primitive::(array, T::try_from(&min)?)?.into_array() + }); + let for_array = FoRVTable::try_new(encoded, min)?; + for_array + .stats_set() + .to_ref(for_array.as_ref()) + .inherit_from(stats.to_ref(for_array.as_ref())); + Ok(for_array) } fn compress_primitive( @@ -67,7 +66,8 @@ mod test { use vortex_session::VortexSession; use super::*; - use crate::BitPackedArray; + use crate::BitPackedVTable; + use crate::FoRArrayExt; use crate::r#for::array::for_decompress::decompress; use crate::r#for::array::for_decompress::fused_decompress; @@ -80,7 +80,7 @@ mod test { (1i32..10).collect::>(), Validity::NonNullable, ); - let compressed = FoRArray::encode(array.clone()).unwrap(); + let compressed = FoRVTable::encode(array.clone()).unwrap(); assert_eq!(i32::try_from(compressed.reference_scalar()).unwrap(), 1); assert_arrays_eq!(compressed, array); @@ -95,7 +95,7 @@ mod test { .collect::>(), Validity::NonNullable, ); - let compressed = FoRArray::encode(array).unwrap(); + let compressed = FoRVTable::encode(array).unwrap(); assert_eq!( u32::try_from(compressed.reference_scalar()).unwrap(), 1_000_000u32 @@ -108,7 +108,7 @@ mod test { assert_eq!(array.statistics().len(), 0); let dtype = array.dtype().clone(); - let compressed = FoRArray::encode(array).unwrap(); + let compressed = FoRVTable::encode(array).unwrap(); assert_eq!(compressed.reference_scalar().dtype(), &dtype); assert!(compressed.reference_scalar().dtype().is_signed_int()); assert!(compressed.encoded().dtype().is_signed_int()); @@ -121,7 +121,7 @@ mod test { fn test_decompress() { // Create a range offset by a million. let array = PrimitiveArray::from_iter((0u32..100_000).step_by(1024).map(|v| v + 1_000_000)); - let compressed = FoRArray::encode(array.clone()).unwrap(); + let compressed = FoRVTable::encode(array.clone()).unwrap(); assert_arrays_eq!(compressed, array); } @@ -130,8 +130,8 @@ mod test { // Create a range offset by a million. let expect = PrimitiveArray::from_iter((0u32..1024).map(|x| x % 7 + 10)); let array = PrimitiveArray::from_iter((0u32..1024).map(|x| x % 7)); - let bp = BitPackedArray::encode(&array.into_array(), 3).unwrap(); - let compressed = FoRArray::try_new(bp.into_array(), 10u32.into()).unwrap(); + let bp = BitPackedVTable::encode(&array.into_array(), 3).unwrap(); + let compressed = FoRVTable::try_new(bp.into_array(), 10u32.into()).unwrap(); assert_arrays_eq!(compressed, expect); } @@ -140,8 +140,8 @@ mod test { // Create a range offset by a million. let expect = PrimitiveArray::from_iter((0u32..1024).map(|x| x % 7 + 10)); let array = PrimitiveArray::from_iter((0u32..1024).map(|x| x % 7)); - let bp = BitPackedArray::encode(&array.into_array(), 2).unwrap(); - let compressed = FoRArray::try_new(bp.clone().into_array(), 10u32.into()).unwrap(); + let bp = BitPackedVTable::encode(&array.into_array(), 2).unwrap(); + let compressed = FoRVTable::try_new(bp.clone().into_array(), 10u32.into()).unwrap(); let decompressed = fused_decompress::(&compressed, &bp, &mut SESSION.create_execution_ctx())?; assert_arrays_eq!(decompressed, expect); @@ -151,7 +151,7 @@ mod test { #[test] fn test_overflow() -> VortexResult<()> { let array = PrimitiveArray::from_iter(i8::MIN..=i8::MAX); - let compressed = FoRArray::encode(array.clone()).unwrap(); + let compressed = FoRVTable::encode(array.clone()).unwrap(); assert_eq!( i8::MIN, compressed diff --git a/encodings/fastlanes/src/for/array/for_decompress.rs b/encodings/fastlanes/src/for/array/for_decompress.rs index f7292a5a06e..24b6f298796 100644 --- a/encodings/fastlanes/src/for/array/for_decompress.rs +++ b/encodings/fastlanes/src/for/array/for_decompress.rs @@ -19,8 +19,10 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use crate::BitPackedArray; +use crate::BitPackedArrayExt; use crate::BitPackedVTable; use crate::FoRArray; +use crate::FoRArrayExt; use crate::bitpack_decompress; use crate::unpack_iter::UnpackStrategy; use crate::unpack_iter::UnpackedChunks; diff --git a/encodings/fastlanes/src/for/array/mod.rs b/encodings/fastlanes/src/for/array/mod.rs index 5542c336657..df866778ca2 100644 --- a/encodings/fastlanes/src/for/array/mod.rs +++ b/encodings/fastlanes/src/for/array/mod.rs @@ -6,8 +6,6 @@ use vortex_array::ArrayRef; use vortex_array::dtype::PType; use vortex_array::scalar::Scalar; use vortex_array::stats::ArrayStats; -use vortex_error::VortexResult; -use vortex_error::vortex_bail; pub mod for_compress; pub mod for_decompress; @@ -24,25 +22,6 @@ pub struct FoRArray { } impl FoRArray { - pub fn try_new(encoded: ArrayRef, reference: Scalar) -> VortexResult { - if reference.is_null() { - vortex_bail!("Reference value cannot be null"); - } - let reference = reference.cast( - &reference - .dtype() - .with_nullability(encoded.dtype().nullability()), - )?; - - let len = encoded.len(); - let dtype = reference.dtype().clone(); - Ok(Self { - encoded, - reference, - common: ArrayCommon::new(len, dtype), - }) - } - pub(crate) unsafe fn new_unchecked(encoded: ArrayRef, reference: Scalar) -> Self { let len = encoded.len(); let dtype = reference.dtype().clone(); @@ -54,22 +33,36 @@ impl FoRArray { } #[inline] - pub fn ptype(&self) -> PType { - self.dtype().as_ptype() + pub(crate) fn stats_set(&self) -> &ArrayStats { + self.common.stats() } +} + +/// Extension trait for [`FoRArray`] methods. +pub trait FoRArrayExt { + /// Returns the primitive type of this array. + fn ptype(&self) -> PType; + + /// Returns a reference to the encoded child array. + fn encoded(&self) -> &ArrayRef; + /// Returns the reference scalar value. + fn reference_scalar(&self) -> &Scalar; +} + +impl FoRArrayExt for FoRArray { #[inline] - pub fn encoded(&self) -> &ArrayRef { - &self.encoded + fn ptype(&self) -> PType { + self.dtype().as_ptype() } #[inline] - pub fn reference_scalar(&self) -> &Scalar { - &self.reference + fn encoded(&self) -> &ArrayRef { + &self.encoded } #[inline] - pub(crate) fn stats_set(&self) -> &ArrayStats { - self.common.stats() + fn reference_scalar(&self) -> &Scalar { + &self.reference } } diff --git a/encodings/fastlanes/src/for/compute/cast.rs b/encodings/fastlanes/src/for/compute/cast.rs index b6a4e308338..dbd3b99c01e 100644 --- a/encodings/fastlanes/src/for/compute/cast.rs +++ b/encodings/fastlanes/src/for/compute/cast.rs @@ -9,6 +9,7 @@ use vortex_array::scalar_fn::fns::cast::CastReduce; use vortex_error::VortexResult; use crate::r#for::FoRArray; +use crate::r#for::FoRArrayExt; use crate::r#for::FoRVTable; impl CastReduce for FoRVTable { @@ -23,7 +24,7 @@ impl CastReduce for FoRVTable { let casted_reference = array.reference_scalar().cast(dtype)?; Ok(Some( - FoRArray::try_new(casted_child, casted_reference)?.into_array(), + FoRVTable::try_new(casted_child, casted_reference)?.into_array(), )) } } @@ -43,10 +44,11 @@ mod tests { use vortex_buffer::buffer; use crate::FoRArray; + use crate::FoRVTable; #[test] fn test_cast_for_i32_to_i64() { - let for_array = FoRArray::try_new( + let for_array = FoRVTable::try_new( buffer![0i32, 10, 20, 30, 40].into_array(), Scalar::from(100i32), ) @@ -71,7 +73,7 @@ mod tests { #[test] fn test_cast_for_nullable() { let values = PrimitiveArray::from_option_iter([Some(0i32), None, Some(20), Some(30), None]); - let for_array = FoRArray::try_new(values.into_array(), Scalar::from(50i32)).unwrap(); + let for_array = FoRVTable::try_new(values.into_array(), Scalar::from(50i32)).unwrap(); let casted = for_array .into_array() @@ -84,19 +86,19 @@ mod tests { } #[rstest] - #[case(FoRArray::try_new( + #[case(FoRVTable::try_new( buffer![0i32, 1, 2, 3, 4].into_array(), Scalar::from(100i32) ).unwrap())] - #[case(FoRArray::try_new( + #[case(FoRVTable::try_new( buffer![0u64, 10, 20, 30].into_array(), Scalar::from(1000u64) ).unwrap())] - #[case(FoRArray::try_new( + #[case(FoRVTable::try_new( PrimitiveArray::from_option_iter([Some(0i16), None, Some(5), Some(10), None]).into_array(), Scalar::from(50i16) ).unwrap())] - #[case(FoRArray::try_new( + #[case(FoRVTable::try_new( buffer![-10i32, -5, 0, 5, 10].into_array(), Scalar::from(-100i32) ).unwrap())] diff --git a/encodings/fastlanes/src/for/compute/compare.rs b/encodings/fastlanes/src/for/compute/compare.rs index 4de2093d937..b8ea047245a 100644 --- a/encodings/fastlanes/src/for/compute/compare.rs +++ b/encodings/fastlanes/src/for/compute/compare.rs @@ -23,6 +23,7 @@ use vortex_error::VortexExpect as _; use vortex_error::VortexResult; use crate::FoRArray; +use crate::FoRArrayExt; use crate::FoRVTable; impl CompareKernel for FoRVTable { @@ -104,7 +105,7 @@ mod tests { fn test_compare_constant() { let reference = Scalar::from(10); // 10, 30, 12 - let lhs = FoRArray::try_new( + let lhs = FoRVTable::try_new( PrimitiveArray::new(buffer!(0i32, 20, 2), Validity::AllValid).into_array(), reference, ) @@ -143,7 +144,7 @@ mod tests { fn test_compare_nullable_constant() { let reference = Scalar::from(0); // 10, 30, 12 - let lhs = FoRArray::try_new( + let lhs = FoRVTable::try_new( PrimitiveArray::new(buffer!(0i32, 20, 2), Validity::NonNullable).into_array(), reference, ) @@ -169,7 +170,7 @@ mod tests { fn compare_non_encodable_constant() { let reference = Scalar::from(10); // 10, 30, 12 - let lhs = FoRArray::try_new( + let lhs = FoRVTable::try_new( PrimitiveArray::new(buffer!(0i32, 10, 1), Validity::AllValid).into_array(), reference, ) @@ -198,7 +199,7 @@ mod tests { fn compare_large_constant() { let reference = Scalar::from(-9219218377546224477i64); #[allow(clippy::cast_possible_truncation)] - let lhs = FoRArray::try_new( + let lhs = FoRVTable::try_new( PrimitiveArray::new( buffer![0i64, 9654309310445864926u64 as i64], Validity::AllValid, diff --git a/encodings/fastlanes/src/for/compute/is_constant.rs b/encodings/fastlanes/src/for/compute/is_constant.rs index 534daaede05..ab7056720ff 100644 --- a/encodings/fastlanes/src/for/compute/is_constant.rs +++ b/encodings/fastlanes/src/for/compute/is_constant.rs @@ -9,6 +9,7 @@ use vortex_array::register_kernel; use vortex_error::VortexResult; use crate::FoRArray; +use crate::FoRArrayExt; use crate::FoRVTable; impl IsConstantKernel for FoRVTable { diff --git a/encodings/fastlanes/src/for/compute/is_sorted.rs b/encodings/fastlanes/src/for/compute/is_sorted.rs index a1984d8489b..31c8e93b2aa 100644 --- a/encodings/fastlanes/src/for/compute/is_sorted.rs +++ b/encodings/fastlanes/src/for/compute/is_sorted.rs @@ -11,6 +11,7 @@ use vortex_array::register_kernel; use vortex_error::VortexResult; use crate::FoRArray; +use crate::FoRArrayExt; use crate::FoRVTable; /// FoR can express sortedness directly on its encoded form. @@ -101,12 +102,13 @@ mod test { use vortex_array::validity::Validity; use vortex_buffer::buffer; - use crate::FoRArray; + use crate::FoRArrayExt; + use crate::FoRVTable; #[test] fn test_sorted() { let a = PrimitiveArray::new(buffer![-1, 0, i8::MAX], Validity::NonNullable); - let b = FoRArray::encode(a).unwrap(); + let b = FoRVTable::encode(a).unwrap(); assert!( is_sorted(&b.clone().into_array()).unwrap().unwrap(), "{}", @@ -114,7 +116,7 @@ mod test { ); let a = PrimitiveArray::new(buffer![i8::MIN, 0, i8::MAX], Validity::NonNullable); - let b = FoRArray::encode(a).unwrap(); + let b = FoRVTable::encode(a).unwrap(); assert!( is_sorted(&b.clone().into_array()).unwrap().unwrap(), "{}", @@ -122,7 +124,7 @@ mod test { ); let a = PrimitiveArray::new(buffer![i8::MIN, 0, 30, 127], Validity::NonNullable); - let b = FoRArray::encode(a).unwrap(); + let b = FoRVTable::encode(a).unwrap(); assert!( is_sorted(&b.clone().into_array()).unwrap().unwrap(), "{}", @@ -130,7 +132,7 @@ mod test { ); let a = PrimitiveArray::new(buffer![i8::MIN, -3, -1], Validity::NonNullable); - let b = FoRArray::encode(a).unwrap(); + let b = FoRVTable::encode(a).unwrap(); assert!( is_sorted(&b.clone().into_array()).unwrap().unwrap(), "{}", @@ -138,7 +140,7 @@ mod test { ); let a = PrimitiveArray::new(buffer![-10, -3, -1], Validity::NonNullable); - let b = FoRArray::encode(a).unwrap(); + let b = FoRVTable::encode(a).unwrap(); assert!( is_sorted(&b.clone().into_array()).unwrap().unwrap(), "{}", @@ -146,7 +148,7 @@ mod test { ); let a = PrimitiveArray::new(buffer![-10, -11, -1], Validity::NonNullable); - let b = FoRArray::encode(a).unwrap(); + let b = FoRVTable::encode(a).unwrap(); assert!( !is_sorted(&b.clone().into_array()).unwrap().unwrap(), "{}", @@ -154,7 +156,7 @@ mod test { ); let a = PrimitiveArray::new(buffer![-10, i8::MIN, -1], Validity::NonNullable); - let b = FoRArray::encode(a).unwrap(); + let b = FoRVTable::encode(a).unwrap(); assert!( !is_sorted(&b.clone().into_array()).unwrap().unwrap(), "{}", diff --git a/encodings/fastlanes/src/for/compute/mod.rs b/encodings/fastlanes/src/for/compute/mod.rs index f261ab052fe..f7711fcb6f8 100644 --- a/encodings/fastlanes/src/for/compute/mod.rs +++ b/encodings/fastlanes/src/for/compute/mod.rs @@ -16,6 +16,7 @@ use vortex_error::VortexResult; use vortex_mask::Mask; use crate::FoRArray; +use crate::FoRArrayExt; use crate::FoRVTable; impl TakeExecute for FoRVTable { @@ -25,7 +26,7 @@ impl TakeExecute for FoRVTable { _ctx: &mut ExecutionCtx, ) -> VortexResult> { Ok(Some( - FoRArray::try_new( + FoRVTable::try_new( array.encoded().take(indices.to_array())?, array.reference_scalar().clone(), )? @@ -36,7 +37,7 @@ impl TakeExecute for FoRVTable { impl FilterReduce for FoRVTable { fn filter(array: &FoRArray, mask: &Mask) -> VortexResult> { - FoRArray::try_new( + FoRVTable::try_new( array.encoded().filter(mask.clone())?, array.reference_scalar().clone(), ) @@ -54,38 +55,39 @@ mod test { use vortex_buffer::buffer; use crate::FoRArray; + use crate::FoRVTable; #[test] fn test_filter_for_array() { // Test with i32 values let values = buffer![100i32, 101, 102, 103, 104].into_array(); let reference = Scalar::from(100i32); - let for_array = FoRArray::try_new(values, reference).unwrap(); + let for_array = FoRVTable::try_new(values, reference).unwrap(); test_filter_conformance(&for_array.into_array()); // Test with u64 values let values = buffer![1000u64, 1001, 1002, 1003, 1004].into_array(); let reference = Scalar::from(1000u64); - let for_array = FoRArray::try_new(values, reference).unwrap(); + let for_array = FoRVTable::try_new(values, reference).unwrap(); test_filter_conformance(&for_array.into_array()); // Test with nullable values let values = PrimitiveArray::from_option_iter([Some(50i16), None, Some(52), Some(53), None]); let reference = Scalar::from(50i16); - let for_array = FoRArray::try_new(values.into_array(), reference).unwrap(); + let for_array = FoRVTable::try_new(values.into_array(), reference).unwrap(); test_filter_conformance(&for_array.into_array()); } #[rstest] - #[case(FoRArray::try_new(buffer![100i32, 101, 102, 103, 104].into_array(), Scalar::from(100i32)).unwrap())] - #[case(FoRArray::try_new(buffer![1000u64, 1001, 1002, 1003, 1004].into_array(), Scalar::from(1000u64)).unwrap())] - #[case(FoRArray::try_new( + #[case(FoRVTable::try_new(buffer![100i32, 101, 102, 103, 104].into_array(), Scalar::from(100i32)).unwrap())] + #[case(FoRVTable::try_new(buffer![1000u64, 1001, 1002, 1003, 1004].into_array(), Scalar::from(1000u64)).unwrap())] + #[case(FoRVTable::try_new( PrimitiveArray::from_option_iter([Some(50i16), None, Some(52), Some(53), None]).into_array(), Scalar::from(50i16) ).unwrap())] - #[case(FoRArray::try_new(buffer![-100i32, -99, -98, -97, -96].into_array(), Scalar::from(-100i32)).unwrap())] - #[case(FoRArray::try_new(buffer![42i64].into_array(), Scalar::from(40i64)).unwrap())] + #[case(FoRVTable::try_new(buffer![-100i32, -99, -98, -97, -96].into_array(), Scalar::from(-100i32)).unwrap())] + #[case(FoRVTable::try_new(buffer![42i64].into_array(), Scalar::from(40i64)).unwrap())] fn test_take_for_conformance(#[case] for_array: FoRArray) { use vortex_array::compute::conformance::take::test_take_conformance; test_take_conformance(&for_array.into_array()); @@ -103,55 +105,56 @@ mod tests { use vortex_buffer::buffer; use crate::FoRArray; + use crate::FoRVTable; #[rstest] // Basic FoR arrays - #[case::for_i32(FoRArray::try_new( + #[case::for_i32(FoRVTable::try_new( buffer![100i32, 101, 102, 103, 104].into_array(), Scalar::from(100i32) ).unwrap())] - #[case::for_u64(FoRArray::try_new( + #[case::for_u64(FoRVTable::try_new( buffer![1000u64, 1001, 1002, 1003, 1004].into_array(), Scalar::from(1000u64) ).unwrap())] // Nullable arrays - #[case::for_nullable_i16(FoRArray::try_new( + #[case::for_nullable_i16(FoRVTable::try_new( PrimitiveArray::from_option_iter([Some(50i16), None, Some(52), Some(53), None]).into_array(), Scalar::from(50i16) ).unwrap())] - #[case::for_nullable_i32(FoRArray::try_new( + #[case::for_nullable_i32(FoRVTable::try_new( PrimitiveArray::from_option_iter([Some(200i32), None, Some(202), Some(203), None]).into_array(), Scalar::from(200i32) ).unwrap())] // Negative values - #[case::for_negative(FoRArray::try_new( + #[case::for_negative(FoRVTable::try_new( buffer![-100i32, -99, -98, -97, -96].into_array(), Scalar::from(-100i32) ).unwrap())] // Edge cases - #[case::for_single(FoRArray::try_new( + #[case::for_single(FoRVTable::try_new( buffer![42i64].into_array(), Scalar::from(40i64) ).unwrap())] - #[case::for_zero_ref(FoRArray::try_new( + #[case::for_zero_ref(FoRVTable::try_new( buffer![0u32, 1, 2, 3, 4].into_array(), Scalar::from(0u32) ).unwrap())] // Large arrays (> 1024 elements for fastlanes chunking) - #[case::for_large(FoRArray::try_new( + #[case::for_large(FoRVTable::try_new( PrimitiveArray::from_iter((0..1500).map(|i| 5000 + i)).into_array(), Scalar::from(5000i32) ).unwrap())] - #[case::for_very_large(FoRArray::try_new( + #[case::for_very_large(FoRVTable::try_new( PrimitiveArray::from_iter((0..3072).map(|i| 10000 + i as i64)).into_array(), Scalar::from(10000i64) ).unwrap())] - #[case::for_large_nullable(FoRArray::try_new( + #[case::for_large_nullable(FoRVTable::try_new( PrimitiveArray::from_option_iter((0..2048).map(|i| (i % 15 == 0).then_some(1000 + i))).into_array(), Scalar::from(1000i32) ).unwrap())] // Arrays with large deltas from reference - #[case::for_large_deltas(FoRArray::try_new( + #[case::for_large_deltas(FoRVTable::try_new( buffer![100i64, 200, 300, 400, 500].into_array(), Scalar::from(100i64) ).unwrap())] @@ -161,23 +164,23 @@ mod tests { } #[rstest] - #[case::for_i32_basic(FoRArray::try_new( + #[case::for_i32_basic(FoRVTable::try_new( buffer![100i32, 101, 102, 103, 104].into_array(), Scalar::from(100i32) ).unwrap())] - #[case::for_u32_basic(FoRArray::try_new( + #[case::for_u32_basic(FoRVTable::try_new( buffer![1000u32, 1001, 1002, 1003, 1004].into_array(), Scalar::from(1000u32) ).unwrap())] - #[case::for_i64_basic(FoRArray::try_new( + #[case::for_i64_basic(FoRVTable::try_new( buffer![5000i64, 5001, 5002, 5003, 5004].into_array(), Scalar::from(5000i64) ).unwrap())] - #[case::for_u64_basic(FoRArray::try_new( + #[case::for_u64_basic(FoRVTable::try_new( buffer![10000u64, 10001, 10002, 10003, 10004].into_array(), Scalar::from(10000u64) ).unwrap())] - #[case::for_i32_large(FoRArray::try_new( + #[case::for_i32_large(FoRVTable::try_new( PrimitiveArray::from_iter((0..100).map(|i| 2000 + i)).into_array(), Scalar::from(2000i32) ).unwrap())] diff --git a/encodings/fastlanes/src/for/mod.rs b/encodings/fastlanes/src/for/mod.rs index d192690a3e3..470fb3c6a72 100644 --- a/encodings/fastlanes/src/for/mod.rs +++ b/encodings/fastlanes/src/for/mod.rs @@ -3,6 +3,7 @@ mod array; pub use array::FoRArray; +pub use array::FoRArrayExt; mod compute; diff --git a/encodings/fastlanes/src/for/vtable/mod.rs b/encodings/fastlanes/src/for/vtable/mod.rs index df173e39203..6d3b32d237d 100644 --- a/encodings/fastlanes/src/for/vtable/mod.rs +++ b/encodings/fastlanes/src/for/vtable/mod.rs @@ -27,6 +27,7 @@ use vortex_error::vortex_panic; use vortex_session::VortexSession; use crate::FoRArray; +use crate::FoRArrayExt; use crate::r#for::array::for_decompress::decompress; use crate::r#for::vtable::kernels::PARENT_KERNELS; use crate::r#for::vtable::rules::PARENT_RULES; @@ -154,7 +155,7 @@ impl VTable for FoRVTable { let encoded = children.get(0, dtype, len)?; - FoRArray::try_new(encoded, metadata.clone()) + Self::try_new(encoded, metadata.clone()) } fn reduce_parent( @@ -184,4 +185,29 @@ pub struct FoRVTable; impl FoRVTable { pub const ID: ArrayId = ArrayId::new_ref("fastlanes.for"); + + /// Create a new FoR-encoded array from an encoded child array and a reference scalar. + pub fn try_new(encoded: ArrayRef, reference: Scalar) -> VortexResult { + if reference.is_null() { + vortex_bail!("Reference value cannot be null"); + } + let reference = reference.cast( + &reference + .dtype() + .with_nullability(encoded.dtype().nullability()), + )?; + + let len = encoded.len(); + let dtype = reference.dtype().clone(); + Ok(FoRArray { + encoded, + reference, + common: vortex_array::ArrayCommon::new(len, dtype), + }) + } + + /// Encode a primitive array using Frame of Reference encoding. + pub fn encode(array: vortex_array::arrays::PrimitiveArray) -> VortexResult { + crate::r#for::array::for_compress::encode_for(array) + } } diff --git a/encodings/fastlanes/src/for/vtable/operations.rs b/encodings/fastlanes/src/for/vtable/operations.rs index 1277a2f7be1..94318ea3c75 100644 --- a/encodings/fastlanes/src/for/vtable/operations.rs +++ b/encodings/fastlanes/src/for/vtable/operations.rs @@ -9,6 +9,7 @@ use vortex_error::VortexResult; use super::FoRVTable; use crate::FoRArray; +use crate::FoRArrayExt; impl OperationsVTable for FoRVTable { fn scalar_at(array: &FoRArray, index: usize) -> VortexResult { @@ -38,12 +39,12 @@ mod test { use vortex_array::arrays::PrimitiveArray; use vortex_array::assert_arrays_eq; - use crate::FoRArray; + use crate::FoRVTable; #[test] fn for_scalar_at() { let for_arr = - FoRArray::encode(PrimitiveArray::from_iter([-100, 1100, 1500, 1900])).unwrap(); + FoRVTable::encode(PrimitiveArray::from_iter([-100, 1100, 1500, 1900])).unwrap(); let expected = PrimitiveArray::from_iter([-100, 1100, 1500, 1900]); assert_arrays_eq!(for_arr, expected); } diff --git a/encodings/fastlanes/src/for/vtable/slice.rs b/encodings/fastlanes/src/for/vtable/slice.rs index e27017ac657..5cbf6f9ec58 100644 --- a/encodings/fastlanes/src/for/vtable/slice.rs +++ b/encodings/fastlanes/src/for/vtable/slice.rs @@ -9,6 +9,7 @@ use vortex_array::arrays::slice::SliceReduce; use vortex_error::VortexResult; use crate::FoRArray; +use crate::FoRArrayExt; use crate::FoRVTable; impl SliceReduce for FoRVTable { diff --git a/encodings/fastlanes/src/for/vtable/validity.rs b/encodings/fastlanes/src/for/vtable/validity.rs index 210a7f4b095..94c4a9df7ee 100644 --- a/encodings/fastlanes/src/for/vtable/validity.rs +++ b/encodings/fastlanes/src/for/vtable/validity.rs @@ -6,6 +6,7 @@ use vortex_array::vtable::ValidityChild; use super::FoRVTable; use crate::FoRArray; +use crate::FoRArrayExt; impl ValidityChild for FoRVTable { fn validity_child(array: &FoRArray) -> &ArrayRef { diff --git a/encodings/fastlanes/src/rle/array/mod.rs b/encodings/fastlanes/src/rle/array/mod.rs index 95772416889..14f79385bf7 100644 --- a/encodings/fastlanes/src/rle/array/mod.rs +++ b/encodings/fastlanes/src/rle/array/mod.rs @@ -37,84 +37,53 @@ pub struct RLEArray { pub(super) offset: usize, } -impl RLEArray { - fn validate( - values: &ArrayRef, - indices: &ArrayRef, - value_idx_offsets: &ArrayRef, - offset: usize, - ) -> VortexResult<()> { - vortex_ensure!( - offset < 1024, - "Offset must be smaller than 1024, got {}", - offset - ); - - vortex_ensure!( - values.dtype().is_primitive(), - "RLE values must be a primitive type, got {}", - values.dtype() - ); - - vortex_ensure!( - matches!(indices.dtype().as_ptype(), PType::U8 | PType::U16), - "RLE indices must be u8 or u16, got {}", - indices.dtype() - ); - - vortex_ensure!( - value_idx_offsets.dtype().is_unsigned_int() && !value_idx_offsets.dtype().is_nullable(), - "RLE value idx offsets must be non-nullable unsigned integer, got {}", - value_idx_offsets.dtype() - ); - - vortex_ensure!( - indices.len().div_ceil(FL_CHUNK_SIZE) == value_idx_offsets.len(), - "RLE must have one value idx offset per chunk, got {}", - value_idx_offsets.len() - ); - - vortex_ensure!( - indices.len() >= values.len(), - "RLE must have at least as many indices as values, got {} indices and {} values", - indices.len(), - values.len() - ); - - Ok(()) - } - - /// Create a new chunk-based RLE array from its components. - /// - /// # Arguments - /// - /// * `values` - Unique values from all chunks - /// * `indices` - Chunk-local indices from all chunks - /// * `values_idx_offsets` - Start indices for each value chunk. - /// * `offset` - Offset into the first chunk - /// * `length` - Array length - pub fn try_new( - values: ArrayRef, - indices: ArrayRef, - values_idx_offsets: ArrayRef, - offset: usize, - length: usize, - ) -> VortexResult { - assert_eq!(indices.len() % FL_CHUNK_SIZE, 0); - Self::validate(&values, &indices, &values_idx_offsets, offset)?; - - // Ensure that the DType has the same nullability as the indices array. - let dtype = DType::Primitive(values.dtype().as_ptype(), indices.dtype().nullability()); - - Ok(Self { - common: ArrayCommon::new(length, dtype), - values, - indices, - values_idx_offsets, - offset, - }) - } +pub(crate) fn validate( + values: &ArrayRef, + indices: &ArrayRef, + value_idx_offsets: &ArrayRef, + offset: usize, +) -> VortexResult<()> { + vortex_ensure!( + offset < 1024, + "Offset must be smaller than 1024, got {}", + offset + ); + + vortex_ensure!( + values.dtype().is_primitive(), + "RLE values must be a primitive type, got {}", + values.dtype() + ); + + vortex_ensure!( + matches!(indices.dtype().as_ptype(), PType::U8 | PType::U16), + "RLE indices must be u8 or u16, got {}", + indices.dtype() + ); + + vortex_ensure!( + value_idx_offsets.dtype().is_unsigned_int() && !value_idx_offsets.dtype().is_nullable(), + "RLE value idx offsets must be non-nullable unsigned integer, got {}", + value_idx_offsets.dtype() + ); + + vortex_ensure!( + indices.len().div_ceil(FL_CHUNK_SIZE) == value_idx_offsets.len(), + "RLE must have one value idx offset per chunk, got {}", + value_idx_offsets.len() + ); + + vortex_ensure!( + indices.len() >= values.len(), + "RLE must have at least as many indices as values, got {} indices and {} values", + indices.len(), + values.len() + ); + + Ok(()) +} +impl RLEArray { /// Create a new RLEArray without validation. /// /// # Safety @@ -141,36 +110,6 @@ impl RLEArray { } } - #[inline] - pub fn len(&self) -> usize { - self.common.len() - } - - #[inline] - pub fn is_empty(&self) -> bool { - self.common.len() == 0 - } - - #[inline] - pub fn dtype(&self) -> &DType { - self.common.dtype() - } - - #[inline] - pub fn values(&self) -> &ArrayRef { - &self.values - } - - #[inline] - pub fn indices(&self) -> &ArrayRef { - &self.indices - } - - #[inline] - pub fn values_idx_offsets(&self) -> &ArrayRef { - &self.values_idx_offsets - } - /// Values index offset relative to the first chunk. /// /// Offsets in `values_idx_offsets` are absolute and need to be shifted @@ -196,15 +135,70 @@ impl RLEArray { .expect("index must be of type usize") } - /// Index offset into the array #[inline] - pub fn offset(&self) -> usize { - self.offset + pub(crate) fn stats_set(&self) -> &ArrayStats { + self.common.stats() + } +} + +/// Extension trait for [`RLEArray`] methods. +pub trait RLEArrayExt { + /// Returns the length of the array. + fn len(&self) -> usize; + + /// Returns whether the array is empty. + fn is_empty(&self) -> bool; + + /// Returns the data type of the array. + fn dtype(&self) -> &DType; + + /// Returns a reference to the values array. + fn values(&self) -> &ArrayRef; + + /// Returns a reference to the indices array. + fn indices(&self) -> &ArrayRef; + + /// Returns a reference to the values index offsets array. + fn values_idx_offsets(&self) -> &ArrayRef; + + /// Index offset into the array. + fn offset(&self) -> usize; +} + +impl RLEArrayExt for RLEArray { + #[inline] + fn len(&self) -> usize { + self.common.len() } #[inline] - pub(crate) fn stats_set(&self) -> &ArrayStats { - self.common.stats() + fn is_empty(&self) -> bool { + self.common.len() == 0 + } + + #[inline] + fn dtype(&self) -> &DType { + self.common.dtype() + } + + #[inline] + fn values(&self) -> &ArrayRef { + &self.values + } + + #[inline] + fn indices(&self) -> &ArrayRef { + &self.indices + } + + #[inline] + fn values_idx_offsets(&self) -> &ArrayRef { + &self.values_idx_offsets + } + + #[inline] + fn offset(&self) -> usize { + self.offset } } @@ -226,7 +220,8 @@ mod tests { use vortex_buffer::ByteBufferMut; use vortex_session::registry::ReadContext; - use crate::RLEArray; + use crate::RLEArrayExt; + use crate::RLEVTable; use crate::test::SESSION; #[test] @@ -238,7 +233,7 @@ mod tests { PrimitiveArray::from_iter([0u16, 0, 1, 1, 2].iter().cycle().take(1024).copied()) .into_array(); let values_idx_offsets = PrimitiveArray::from_iter([0u64]).into_array(); - let rle_array = RLEArray::try_new(values, indices, values_idx_offsets, 0, 5).unwrap(); + let rle_array = RLEVTable::try_new(values, indices, values_idx_offsets, 0, 5).unwrap(); assert_eq!(rle_array.len(), 5); assert_eq!(rle_array.values().len(), 3); @@ -265,7 +260,7 @@ mod tests { ) .into_array(); - let rle_array = RLEArray::try_new( + let rle_array = RLEVTable::try_new( values.clone(), indices_with_validity, values_idx_offsets, @@ -301,7 +296,7 @@ mod tests { ) .into_array(); - let rle_array = RLEArray::try_new( + let rle_array = RLEVTable::try_new( values.clone(), indices_with_validity, values_idx_offsets, @@ -338,7 +333,7 @@ mod tests { ) .into_array(); - let rle_array = RLEArray::try_new( + let rle_array = RLEVTable::try_new( values.clone(), indices_with_validity, values_idx_offsets, @@ -380,7 +375,7 @@ mod tests { ) .into_array(); - let rle_array = RLEArray::try_new( + let rle_array = RLEVTable::try_new( values.clone(), indices_with_validity, values_idx_offsets, @@ -402,7 +397,7 @@ mod tests { let values = PrimitiveArray::from_iter(Vec::::new()).into_array(); let indices = PrimitiveArray::from_iter(Vec::::new()).into_array(); let values_idx_offsets = PrimitiveArray::from_iter(Vec::::new()).into_array(); - let rle_array = RLEArray::try_new( + let rle_array = RLEVTable::try_new( values, indices.clone(), values_idx_offsets, @@ -420,7 +415,7 @@ mod tests { let values = PrimitiveArray::from_iter([10u32, 20, 30, 40]).into_array(); let indices = PrimitiveArray::from_iter([0u16, 1].repeat(1024)).into_array(); let values_idx_offsets = PrimitiveArray::from_iter([0u64, 2]).into_array(); - let rle_array = RLEArray::try_new(values, indices, values_idx_offsets, 0, 2048).unwrap(); + let rle_array = RLEVTable::try_new(values, indices, values_idx_offsets, 0, 2048).unwrap(); assert_eq!(rle_array.len(), 2048); assert_eq!(rle_array.values().len(), 4); @@ -432,7 +427,7 @@ mod tests { #[test] fn test_rle_serialization() { let primitive = PrimitiveArray::from_iter((0..2048).map(|i| (i / 100) as u32)); - let rle_array = RLEArray::encode(&primitive).unwrap(); + let rle_array = RLEVTable::encode(&primitive).unwrap(); assert_eq!(rle_array.len(), 2048); let original_data = rle_array.to_primitive(); @@ -467,9 +462,9 @@ mod tests { #[test] fn test_rle_serialization_slice() { let primitive = PrimitiveArray::from_iter((0..2048).map(|i| (i / 100) as u32)); - let rle_array = RLEArray::encode(&primitive).unwrap(); + let rle_array = RLEVTable::encode(&primitive).unwrap(); - let sliced = RLEArray::try_new( + let sliced = RLEVTable::try_new( rle_array.values().clone(), rle_array.indices().clone(), rle_array.values_idx_offsets().clone(), diff --git a/encodings/fastlanes/src/rle/array/rle_compress.rs b/encodings/fastlanes/src/rle/array/rle_compress.rs index 8eec282e844..d82d40dcc9e 100644 --- a/encodings/fastlanes/src/rle/array/rle_compress.rs +++ b/encodings/fastlanes/src/rle/array/rle_compress.rs @@ -18,12 +18,11 @@ use vortex_error::VortexResult; use crate::FL_CHUNK_SIZE; use crate::RLEArray; +use crate::RLEVTable; -impl RLEArray { - /// Encodes a primitive array of unsigned integers using FastLanes RLE. - pub fn encode(array: &PrimitiveArray) -> VortexResult { - match_each_native_ptype!(array.ptype(), |T| { rle_encode_typed::(array) }) - } +/// Encodes a primitive array of unsigned integers using FastLanes RLE. +pub(crate) fn encode_rle(array: &PrimitiveArray) -> VortexResult { + match_each_native_ptype!(array.ptype(), |T| { rle_encode_typed::(array) }) } /// Encodes a primitive array of unsigned integers using FastLanes RLE. @@ -100,7 +99,7 @@ where // SAFETY: NativeValue is repr(transparent) to T. let values_buf = unsafe { values_buf.transmute::().freeze() }; - RLEArray::try_new( + RLEVTable::try_new( values_buf.into_array(), PrimitiveArray::new(indices_buf.freeze(), padded_validity(array)).into_array(), values_idx_offsets.into_array(), @@ -144,13 +143,15 @@ mod tests { use vortex_buffer::Buffer; use super::*; + use crate::RLEArrayExt; + use crate::RLEVTable; #[test] fn test_encode_decode() { // u8 let values_u8: Buffer = [1, 1, 2, 2, 3, 3].iter().copied().collect(); let array_u8 = values_u8.into_array(); - let encoded_u8 = RLEArray::encode(&array_u8.to_primitive()).unwrap(); + let encoded_u8 = RLEVTable::encode(&array_u8.to_primitive()).unwrap(); let decoded_u8 = encoded_u8.to_primitive(); let expected_u8 = PrimitiveArray::from_iter(vec![1u8, 1, 2, 2, 3, 3]); assert_arrays_eq!(decoded_u8, expected_u8); @@ -158,7 +159,7 @@ mod tests { // u16 let values_u16: Buffer = [100, 100, 200, 200].iter().copied().collect(); let array_u16 = values_u16.into_array(); - let encoded_u16 = RLEArray::encode(&array_u16.to_primitive()).unwrap(); + let encoded_u16 = RLEVTable::encode(&array_u16.to_primitive()).unwrap(); let decoded_u16 = encoded_u16.to_primitive(); let expected_u16 = PrimitiveArray::from_iter(vec![100u16, 100, 200, 200]); assert_arrays_eq!(decoded_u16, expected_u16); @@ -166,7 +167,7 @@ mod tests { // u64 let values_u64: Buffer = [1000, 1000, 2000].iter().copied().collect(); let array_u64 = values_u64.into_array(); - let encoded_u64 = RLEArray::encode(&array_u64.to_primitive()).unwrap(); + let encoded_u64 = RLEVTable::encode(&array_u64.to_primitive()).unwrap(); let decoded_u64 = encoded_u64.to_primitive(); let expected_u64 = PrimitiveArray::from_iter(vec![1000u64, 1000, 2000]); assert_arrays_eq!(decoded_u64, expected_u64); @@ -176,7 +177,7 @@ mod tests { fn test_length() { let values: Buffer = [1, 1, 2, 2, 2, 3].iter().copied().collect(); let array = values.into_array(); - let encoded = RLEArray::encode(&array.to_primitive()).unwrap(); + let encoded = RLEVTable::encode(&array.to_primitive()).unwrap(); assert_eq!(encoded.len(), 6); } @@ -184,7 +185,7 @@ mod tests { fn test_empty_length() { let values: Buffer = Buffer::empty(); let array = values.into_array(); - let encoded = RLEArray::encode(&array.to_primitive()).unwrap(); + let encoded = RLEVTable::encode(&array.to_primitive()).unwrap(); assert_eq!(encoded.len(), 0); assert_eq!(encoded.values().len(), 0); @@ -195,7 +196,7 @@ mod tests { let values: Buffer = vec![42; 2000].into_iter().collect(); let array = values.into_array(); - let encoded = RLEArray::encode(&array.to_primitive()).unwrap(); + let encoded = RLEVTable::encode(&array.to_primitive()).unwrap(); assert_eq!(encoded.values().len(), 2); // 2 chunks, each storing value 42 let decoded = encoded.to_primitive(); // Verify round-trip @@ -208,7 +209,7 @@ mod tests { let values: Buffer = (0u8..=255).collect(); let array = values.into_array(); - let encoded = RLEArray::encode(&array.to_primitive()).unwrap(); + let encoded = RLEVTable::encode(&array.to_primitive()).unwrap(); assert_eq!(encoded.values().len(), 256); let decoded = encoded.to_primitive(); // Verify round-trip @@ -222,7 +223,7 @@ mod tests { let values: Buffer = (0..1500).map(|i| (i / 100) as u32).collect(); let array = values.into_array(); - let encoded = RLEArray::encode(&array.to_primitive()).unwrap(); + let encoded = RLEVTable::encode(&array.to_primitive()).unwrap(); assert_eq!(encoded.len(), 1500); assert_arrays_eq!(encoded, array); @@ -236,7 +237,7 @@ mod tests { let values: Buffer = (0..2048).map(|i| (i / 100) as u32).collect(); let array = values.into_array(); - let encoded = RLEArray::encode(&array.to_primitive()).unwrap(); + let encoded = RLEVTable::encode(&array.to_primitive()).unwrap(); assert_eq!(encoded.len(), 2048); assert_arrays_eq!(encoded, array); @@ -257,7 +258,7 @@ mod tests { #[case::f64((-2000..2000).map(|i| i as f64).collect::>())] fn test_roundtrip_primitive_types(#[case] values: Buffer) { let primitive = values.clone().into_array().to_primitive(); - let result = RLEArray::encode(&primitive).unwrap(); + let result = RLEVTable::encode(&primitive).unwrap(); let decoded = result.to_primitive(); let expected = PrimitiveArray::new(values, primitive.validity().clone()); assert_arrays_eq!(decoded, expected); @@ -271,7 +272,7 @@ mod tests { #[case(vec![0f64, -0f64])] fn test_float_zeros(#[case] values: Vec) { let primitive = PrimitiveArray::from_iter(values); - let rle = RLEArray::encode(&primitive).unwrap(); + let rle = RLEVTable::encode(&primitive).unwrap(); let decoded = rle.to_primitive(); assert_arrays_eq!(primitive, decoded); } diff --git a/encodings/fastlanes/src/rle/array/rle_decompress.rs b/encodings/fastlanes/src/rle/array/rle_decompress.rs index 07489d2c5f9..dc579047857 100644 --- a/encodings/fastlanes/src/rle/array/rle_decompress.rs +++ b/encodings/fastlanes/src/rle/array/rle_decompress.rs @@ -19,6 +19,7 @@ use vortex_error::vortex_panic; use crate::FL_CHUNK_SIZE; use crate::RLEArray; +use crate::RLEArrayExt; /// Decompresses an RLE array back into a primitive array. #[expect( diff --git a/encodings/fastlanes/src/rle/compute/cast.rs b/encodings/fastlanes/src/rle/compute/cast.rs index f401e9ad098..96dea72286e 100644 --- a/encodings/fastlanes/src/rle/compute/cast.rs +++ b/encodings/fastlanes/src/rle/compute/cast.rs @@ -9,6 +9,7 @@ use vortex_array::scalar_fn::fns::cast::CastReduce; use vortex_error::VortexResult; use crate::rle::RLEArray; +use crate::rle::RLEArrayExt; use crate::rle::RLEVTable; impl CastReduce for RLEVTable { @@ -55,7 +56,7 @@ mod tests { use vortex_array::validity::Validity; use vortex_buffer::Buffer; - use crate::rle::RLEArray; + use crate::rle::RLEVTable; #[test] fn try_cast_rle_success() { @@ -63,7 +64,7 @@ mod tests { Buffer::from_iter([10u8, 20, 30, 40, 50]), Validity::from_iter([true, true, true, true, true]), ); - let rle = RLEArray::encode(&primitive).unwrap(); + let rle = RLEVTable::encode(&primitive).unwrap(); let casted = rle .into_array() @@ -79,7 +80,7 @@ mod tests { Buffer::from_iter([10u8, 20, 30, 40, 50]), Validity::from_iter([true, false, true, true, false]), ); - let rle = RLEArray::encode(&primitive).unwrap(); + let rle = RLEVTable::encode(&primitive).unwrap(); rle.into_array() .cast(DType::Primitive(PType::U8, Nullability::NonNullable)) .and_then(|a| a.to_canonical().map(|c| c.into_array())) @@ -136,7 +137,7 @@ mod tests { ) )] fn test_cast_rle_conformance(#[case] primitive: PrimitiveArray) { - let rle_array = RLEArray::encode(&primitive).unwrap(); + let rle_array = RLEVTable::encode(&primitive).unwrap(); test_cast_conformance(&rle_array.into_array()); } } diff --git a/encodings/fastlanes/src/rle/kernel.rs b/encodings/fastlanes/src/rle/kernel.rs index 8c22c8bfc24..02f1af3ed04 100644 --- a/encodings/fastlanes/src/rle/kernel.rs +++ b/encodings/fastlanes/src/rle/kernel.rs @@ -13,6 +13,7 @@ use vortex_error::VortexResult; use crate::FL_CHUNK_SIZE; use crate::RLEArray; +use crate::RLEArrayExt; use crate::RLEVTable; pub(crate) static PARENT_KERNELS: ParentKernelSet = diff --git a/encodings/fastlanes/src/rle/mod.rs b/encodings/fastlanes/src/rle/mod.rs index f3400fcfe21..e61647b6911 100644 --- a/encodings/fastlanes/src/rle/mod.rs +++ b/encodings/fastlanes/src/rle/mod.rs @@ -3,6 +3,7 @@ mod array; pub use array::RLEArray; +pub use array::RLEArrayExt; mod compute; mod kernel; diff --git a/encodings/fastlanes/src/rle/vtable/mod.rs b/encodings/fastlanes/src/rle/vtable/mod.rs index 0a682d3abf7..331bf5e2bcc 100644 --- a/encodings/fastlanes/src/rle/vtable/mod.rs +++ b/encodings/fastlanes/src/rle/vtable/mod.rs @@ -27,6 +27,7 @@ use vortex_error::vortex_panic; use vortex_session::VortexSession; use crate::RLEArray; +use crate::RLEArrayExt; use crate::rle::array::rle_decompress::rle_decompress; use crate::rle::kernel::PARENT_KERNELS; use crate::rle::vtable::rules::RULES; @@ -212,7 +213,7 @@ impl VTable for RLEVTable { usize::try_from(metadata.values_idx_offsets_len)?, )?; - RLEArray::try_new( + Self::try_new( values, indices, values_idx_offsets, @@ -240,6 +241,47 @@ pub struct RLEVTable; impl RLEVTable { pub const ID: ArrayId = ArrayId::new_ref("fastlanes.rle"); + + /// Create a new chunk-based RLE array from its components. + /// + /// # Arguments + /// + /// * `values` - Unique values from all chunks + /// * `indices` - Chunk-local indices from all chunks + /// * `values_idx_offsets` - Start indices for each value chunk. + /// * `offset` - Offset into the first chunk + /// * `length` - Array length + pub fn try_new( + values: ArrayRef, + indices: ArrayRef, + values_idx_offsets: ArrayRef, + offset: usize, + length: usize, + ) -> VortexResult { + use vortex_array::ArrayCommon; + + use crate::FL_CHUNK_SIZE; + use crate::rle::array::validate; + + assert_eq!(indices.len() % FL_CHUNK_SIZE, 0); + validate(&values, &indices, &values_idx_offsets, offset)?; + + // Ensure that the DType has the same nullability as the indices array. + let dtype = DType::Primitive(values.dtype().as_ptype(), indices.dtype().nullability()); + + Ok(RLEArray { + common: ArrayCommon::new(length, dtype), + values, + indices, + values_idx_offsets, + offset, + }) + } + + /// Encodes a primitive array of unsigned integers using FastLanes RLE. + pub fn encode(array: &vortex_array::arrays::PrimitiveArray) -> VortexResult { + crate::rle::array::rle_compress::encode_rle(array) + } } #[cfg(test)] diff --git a/encodings/fastlanes/src/rle/vtable/operations.rs b/encodings/fastlanes/src/rle/vtable/operations.rs index 31afaaee60d..936d8b7c873 100644 --- a/encodings/fastlanes/src/rle/vtable/operations.rs +++ b/encodings/fastlanes/src/rle/vtable/operations.rs @@ -9,6 +9,7 @@ use vortex_error::VortexResult; use super::RLEVTable; use crate::FL_CHUNK_SIZE; use crate::RLEArray; +use crate::RLEArrayExt; impl OperationsVTable for RLEVTable { fn scalar_at(array: &RLEArray, index: usize) -> VortexResult { @@ -59,7 +60,7 @@ mod tests { .into_array(); let values_idx_offsets = PrimitiveArray::from_iter([0u64]).into_array(); - RLEArray::try_new( + RLEVTable::try_new( values, indices.clone(), values_idx_offsets, @@ -94,7 +95,7 @@ mod tests { ) .into_array(); - RLEArray::try_new( + RLEVTable::try_new( values, indices.clone(), values_idx_offsets, @@ -162,7 +163,7 @@ mod tests { let expected: Vec = (0..3000).map(|i| (i / 50) as u16).collect(); let array = values.into_array(); - let encoded = RLEArray::encode(&array.to_primitive()).unwrap(); + let encoded = RLEVTable::encode(&array.to_primitive()).unwrap(); // Access scalars from multiple chunks. for &idx in &[1023, 1024, 1025, 2047, 2048, 2049] { @@ -268,7 +269,7 @@ mod tests { let expected: Vec = (0..2100).map(|i| (i / 100) as u32).collect(); let array = values.into_array(); - let encoded = RLEArray::encode(&array.to_primitive()).unwrap(); + let encoded = RLEVTable::encode(&array.to_primitive()).unwrap(); // Slice across first and second chunk. let slice = encoded.slice(500..1500).unwrap(); diff --git a/encodings/fastlanes/src/rle/vtable/validity.rs b/encodings/fastlanes/src/rle/vtable/validity.rs index 893a0677766..feabc07aefd 100644 --- a/encodings/fastlanes/src/rle/vtable/validity.rs +++ b/encodings/fastlanes/src/rle/vtable/validity.rs @@ -7,6 +7,7 @@ use vortex_array::vtable::ValidityChildSliceHelper; use super::RLEVTable; use crate::RLEArray; +use crate::RLEArrayExt; impl ValidityChild for RLEVTable { fn validity_child(array: &RLEArray) -> &ArrayRef { diff --git a/encodings/fsst/public-api.lock b/encodings/fsst/public-api.lock index 8451466972c..6a44f4ca5bd 100644 --- a/encodings/fsst/public-api.lock +++ b/encodings/fsst/public-api.lock @@ -16,8 +16,6 @@ pub fn vortex_fsst::FSSTArray::symbol_lengths(&self) -> &vortex_buffer::buffer:: pub fn vortex_fsst::FSSTArray::symbols(&self) -> &vortex_buffer::buffer::Buffer -pub fn vortex_fsst::FSSTArray::try_new(dtype: vortex_array::dtype::DType, symbols: vortex_buffer::buffer::Buffer, symbol_lengths: vortex_buffer::buffer::Buffer, codes: vortex_array::arrays::varbin::array::VarBinArray, uncompressed_lengths: vortex_array::array::ArrayRef) -> vortex_error::VortexResult - pub fn vortex_fsst::FSSTArray::uncompressed_lengths(&self) -> &vortex_array::array::ArrayRef pub fn vortex_fsst::FSSTArray::uncompressed_lengths_dtype(&self) -> &vortex_array::dtype::DType @@ -92,6 +90,8 @@ impl vortex_fsst::FSSTVTable pub const vortex_fsst::FSSTVTable::ID: vortex_array::vtable::dyn_::ArrayId +pub fn vortex_fsst::FSSTVTable::try_new(dtype: vortex_array::dtype::DType, symbols: vortex_buffer::buffer::Buffer, symbol_lengths: vortex_buffer::buffer::Buffer, codes: vortex_array::arrays::varbin::array::VarBinArray, uncompressed_lengths: vortex_array::array::ArrayRef) -> vortex_error::VortexResult + impl core::fmt::Debug for vortex_fsst::FSSTVTable pub fn vortex_fsst::FSSTVTable::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result diff --git a/encodings/fsst/src/array.rs b/encodings/fsst/src/array.rs index 984738aa072..9821de4e06b 100644 --- a/encodings/fsst/src/array.rs +++ b/encodings/fsst/src/array.rs @@ -245,7 +245,7 @@ impl VTable for FSSTVTable { len, )?; - return FSSTArray::try_new( + return Self::try_new( dtype.clone(), symbols, symbol_lengths, @@ -292,7 +292,7 @@ impl VTable for FSSTVTable { codes_validity, )?; - return FSSTArray::try_new( + return Self::try_new( dtype.clone(), symbols, symbol_lengths, @@ -392,24 +392,15 @@ pub struct FSSTVTable; impl FSSTVTable { pub const ID: ArrayId = ArrayId::new_ref("vortex.fsst"); -} -impl FSSTArray { /// Build an FSST array from a set of `symbols` and `codes`. - /// - /// Symbols are 8-bytes and can represent short strings, each of which is assigned - /// a code. - /// - /// The `codes` array is a Binary array where each binary datum is a sequence of 8-bit codes. - /// Each code corresponds either to a symbol, or to the "escape code", - /// which tells the decoder to emit the following byte without doing a table lookup. pub fn try_new( dtype: DType, symbols: Buffer, symbol_lengths: Buffer, codes: VarBinArray, uncompressed_lengths: ArrayRef, - ) -> VortexResult { + ) -> VortexResult { // Check: symbols must not have length > MAX_CODE if symbols.len() > 255 { vortex_bail!(InvalidArgument: "symbols array must have length <= 255"); @@ -433,7 +424,7 @@ impl FSSTArray { // SAFETY: all components validated above unsafe { - Ok(Self::new_unchecked( + Ok(FSSTArray::new_unchecked( dtype, symbols, symbol_lengths, @@ -442,78 +433,101 @@ impl FSSTArray { )) } } +} - pub(crate) unsafe fn new_unchecked( - dtype: DType, - symbols: Buffer, - symbol_lengths: Buffer, - codes: VarBinArray, - uncompressed_lengths: ArrayRef, - ) -> Self { - let symbols2 = symbols.clone(); - let symbol_lengths2 = symbol_lengths.clone(); - let compressor = Arc::new(LazyLock::new(Box::new(move || { - Compressor::rebuild_from(symbols2.as_slice(), symbol_lengths2.as_slice()) - }) - as Box Compressor + Send>)); - let len = codes.len(); - let codes_array = codes.clone().into_array(); +/// Extension trait for [`FSSTArray`] methods. +pub trait FSSTArrayExt { + /// Access the symbol table array. + fn symbols(&self) -> &Buffer; - Self { - common: ArrayCommon::new(len, dtype), - symbols, - symbol_lengths, - codes, - codes_array, - uncompressed_lengths, - compressor, - } - } + /// Access the symbol lengths array. + fn symbol_lengths(&self) -> &Buffer; + + /// Access the codes array. + fn codes(&self) -> &VarBinArray; - /// Access the symbol table array - pub fn symbols(&self) -> &Buffer { + /// Get the DType of the codes array. + fn codes_dtype(&self) -> &DType; + + /// Get the uncompressed length for each element in the array. + fn uncompressed_lengths(&self) -> &ArrayRef; + + /// Get the DType of the uncompressed lengths array. + fn uncompressed_lengths_dtype(&self) -> &DType; + + /// Build a [`Decompressor`][fsst::Decompressor] that can be used to decompress values from + /// this array. + fn decompressor(&self) -> Decompressor<'_>; + + /// Retrieves the FSST compressor. + fn compressor(&self) -> &Compressor; +} + +impl FSSTArrayExt for FSSTArray { + fn symbols(&self) -> &Buffer { &self.symbols } - /// Access the symbol table array - pub fn symbol_lengths(&self) -> &Buffer { + fn symbol_lengths(&self) -> &Buffer { &self.symbol_lengths } - /// Access the codes array - pub fn codes(&self) -> &VarBinArray { + fn codes(&self) -> &VarBinArray { &self.codes } - /// Get the DType of the codes array #[inline] - pub fn codes_dtype(&self) -> &DType { + fn codes_dtype(&self) -> &DType { self.codes.dtype() } - /// Get the uncompressed length for each element in the array. - pub fn uncompressed_lengths(&self) -> &ArrayRef { + fn uncompressed_lengths(&self) -> &ArrayRef { &self.uncompressed_lengths } - /// Get the DType of the uncompressed lengths array #[inline] - pub fn uncompressed_lengths_dtype(&self) -> &DType { + fn uncompressed_lengths_dtype(&self) -> &DType { self.uncompressed_lengths.dtype() } - /// Build a [`Decompressor`][fsst::Decompressor] that can be used to decompress values from - /// this array. - pub fn decompressor(&self) -> Decompressor<'_> { + fn decompressor(&self) -> Decompressor<'_> { Decompressor::new(self.symbols().as_slice(), self.symbol_lengths().as_slice()) } - /// Retrieves the FSST compressor. - pub fn compressor(&self) -> &Compressor { + fn compressor(&self) -> &Compressor { self.compressor.as_ref() } } +impl FSSTArray { + pub(crate) unsafe fn new_unchecked( + dtype: DType, + symbols: Buffer, + symbol_lengths: Buffer, + codes: VarBinArray, + uncompressed_lengths: ArrayRef, + ) -> Self { + let symbols2 = symbols.clone(); + let symbol_lengths2 = symbol_lengths.clone(); + let compressor = Arc::new(LazyLock::new(Box::new(move || { + Compressor::rebuild_from(symbols2.as_slice(), symbol_lengths2.as_slice()) + }) + as Box Compressor + Send>)); + let len = codes.len(); + let codes_array = codes.clone().into_array(); + + Self { + common: ArrayCommon::new(len, dtype), + symbols, + symbol_lengths, + codes, + codes_array, + uncompressed_lengths, + compressor, + } + } +} + impl ValidityChild for FSSTVTable { fn validity_child(array: &FSSTArray) -> &ArrayRef { &array.codes_array @@ -540,6 +554,7 @@ mod test { use vortex_buffer::Buffer; use vortex_error::VortexError; + use crate::FSSTArrayExt; use crate::FSSTVTable; use crate::array::FSSTMetadata; use crate::fsst_compress_iter; diff --git a/encodings/fsst/src/canonical.rs b/encodings/fsst/src/canonical.rs index b5474d1d923..745736d52a4 100644 --- a/encodings/fsst/src/canonical.rs +++ b/encodings/fsst/src/canonical.rs @@ -19,6 +19,7 @@ use vortex_buffer::ByteBufferMut; use vortex_error::VortexResult; use crate::FSSTArray; +use crate::FSSTArrayExt; pub(super) fn canonicalize_fsst( array: &FSSTArray, diff --git a/encodings/fsst/src/compress.rs b/encodings/fsst/src/compress.rs index 669e488a16c..444633371e8 100644 --- a/encodings/fsst/src/compress.rs +++ b/encodings/fsst/src/compress.rs @@ -15,6 +15,7 @@ use vortex_buffer::BufferMut; use vortex_error::VortexExpect; use crate::FSSTArray; +use crate::FSSTVTable; /// Compress a string array using FSST. pub fn fsst_compress + AsRef>( @@ -103,7 +104,7 @@ where let uncompressed_lengths = uncompressed_lengths.into_array(); - FSSTArray::try_new(dtype, symbols, symbol_lengths, codes, uncompressed_lengths) + FSSTVTable::try_new(dtype, symbols, symbol_lengths, codes, uncompressed_lengths) .vortex_expect("building FSSTArray from parts") } diff --git a/encodings/fsst/src/compute/cast.rs b/encodings/fsst/src/compute/cast.rs index d6ecd8bb614..f9e7ecbb99d 100644 --- a/encodings/fsst/src/compute/cast.rs +++ b/encodings/fsst/src/compute/cast.rs @@ -10,6 +10,7 @@ use vortex_array::scalar_fn::fns::cast::CastReduce; use vortex_error::VortexResult; use crate::FSSTArray; +use crate::FSSTArrayExt; use crate::FSSTVTable; impl CastReduce for FSSTVTable { @@ -25,7 +26,7 @@ impl CastReduce for FSSTVTable { .cast(array.codes().dtype().with_nullability(dtype.nullability()))?; Ok(Some( - FSSTArray::try_new( + FSSTVTable::try_new( dtype.clone(), array.symbols().clone(), array.symbol_lengths().clone(), diff --git a/encodings/fsst/src/compute/compare.rs b/encodings/fsst/src/compute/compare.rs index af0cf557eb1..7f01721c833 100644 --- a/encodings/fsst/src/compute/compare.rs +++ b/encodings/fsst/src/compute/compare.rs @@ -20,6 +20,7 @@ use vortex_error::VortexResult; use vortex_error::vortex_bail; use crate::FSSTArray; +use crate::FSSTArrayExt; use crate::FSSTVTable; impl CompareKernel for FSSTVTable { diff --git a/encodings/fsst/src/compute/filter.rs b/encodings/fsst/src/compute/filter.rs index 2aa0cc207a9..0d46ac99f24 100644 --- a/encodings/fsst/src/compute/filter.rs +++ b/encodings/fsst/src/compute/filter.rs @@ -11,6 +11,7 @@ use vortex_error::VortexResult; use vortex_mask::Mask; use crate::FSSTArray; +use crate::FSSTArrayExt; use crate::FSSTVTable; impl FilterKernel for FSSTVTable { @@ -27,7 +28,7 @@ impl FilterKernel for FSSTVTable { .vortex_expect("must be VarBinVTable"); Ok(Some( - FSSTArray::try_new( + FSSTVTable::try_new( array.dtype().clone(), array.symbols().clone(), array.symbol_lengths().clone(), diff --git a/encodings/fsst/src/compute/mod.rs b/encodings/fsst/src/compute/mod.rs index 0c98126e098..438ca27c5dd 100644 --- a/encodings/fsst/src/compute/mod.rs +++ b/encodings/fsst/src/compute/mod.rs @@ -18,6 +18,7 @@ use vortex_error::VortexResult; use vortex_error::vortex_err; use crate::FSSTArray; +use crate::FSSTArrayExt; use crate::FSSTVTable; impl TakeExecute for FSSTVTable { @@ -27,7 +28,7 @@ impl TakeExecute for FSSTVTable { _ctx: &mut ExecutionCtx, ) -> VortexResult> { Ok(Some( - FSSTArray::try_new( + FSSTVTable::try_new( array .dtype() .clone() diff --git a/encodings/fsst/src/ops.rs b/encodings/fsst/src/ops.rs index f75fa067ac1..d940ede421f 100644 --- a/encodings/fsst/src/ops.rs +++ b/encodings/fsst/src/ops.rs @@ -9,6 +9,7 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use crate::FSSTArray; +use crate::FSSTArrayExt; use crate::FSSTVTable; impl OperationsVTable for FSSTVTable { diff --git a/encodings/fsst/src/slice.rs b/encodings/fsst/src/slice.rs index d914bb2c85a..5cb64f5644d 100644 --- a/encodings/fsst/src/slice.rs +++ b/encodings/fsst/src/slice.rs @@ -11,6 +11,7 @@ use vortex_error::VortexResult; use vortex_error::vortex_err; use crate::FSSTArray; +use crate::FSSTArrayExt; use crate::FSSTVTable; impl SliceReduce for FSSTVTable { diff --git a/encodings/pco/public-api.lock b/encodings/pco/public-api.lock index 599f0da1072..b49697cf7d5 100644 --- a/encodings/pco/public-api.lock +++ b/encodings/pco/public-api.lock @@ -6,12 +6,6 @@ impl vortex_pco::PcoArray pub fn vortex_pco::PcoArray::decompress(&self) -> vortex_error::VortexResult -pub fn vortex_pco::PcoArray::from_array(array: vortex_array::array::ArrayRef, level: usize, nums_per_page: usize) -> vortex_error::VortexResult - -pub fn vortex_pco::PcoArray::from_primitive(parray: &vortex_array::arrays::primitive::array::PrimitiveArray, level: usize, values_per_page: usize) -> vortex_error::VortexResult - -pub fn vortex_pco::PcoArray::new(chunk_metas: alloc::vec::Vec, pages: alloc::vec::Vec, dtype: vortex_array::dtype::DType, metadata: vortex_pco::PcoMetadata, len: usize, validity: vortex_array::validity::Validity) -> Self - impl vortex_pco::PcoArray pub fn vortex_pco::PcoArray::to_array(&self) -> vortex_array::array::ArrayRef @@ -120,6 +114,12 @@ impl vortex_pco::PcoVTable pub const vortex_pco::PcoVTable::ID: vortex_array::vtable::dyn_::ArrayId +pub fn vortex_pco::PcoVTable::from_array(array: vortex_array::array::ArrayRef, level: usize, nums_per_page: usize) -> vortex_error::VortexResult + +pub fn vortex_pco::PcoVTable::from_primitive(parray: &vortex_array::arrays::primitive::array::PrimitiveArray, level: usize, values_per_page: usize) -> vortex_error::VortexResult + +pub fn vortex_pco::PcoVTable::new(chunk_metas: alloc::vec::Vec, pages: alloc::vec::Vec, dtype: vortex_array::dtype::DType, metadata: vortex_pco::PcoMetadata, len: usize, validity: vortex_array::validity::Validity) -> vortex_pco::PcoArray + impl core::fmt::Debug for vortex_pco::PcoVTable pub fn vortex_pco::PcoVTable::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result diff --git a/encodings/pco/src/array.rs b/encodings/pco/src/array.rs index bd9627cbad3..11c47492ed5 100644 --- a/encodings/pco/src/array.rs +++ b/encodings/pco/src/array.rs @@ -235,7 +235,7 @@ impl VTable for PcoVTable { .sum::(); vortex_ensure!(pages.len() == expected_n_pages); - Ok(PcoArray::new( + Ok(PcoVTable::new( chunk_metas, pages, dtype.clone(), @@ -324,7 +324,38 @@ pub struct PcoArray { slice_stop: usize, } -impl PcoArray { +/// Extension trait for [`PcoArray`] instance methods. +pub trait PcoArrayExt { + /// Decompress this [`PcoArray`] back into a [`PrimitiveArray`]. + fn decompress(&self) -> VortexResult; +} + +impl PcoArrayExt for PcoArray { + fn decompress(&self) -> VortexResult { + // To start, we figure out which chunks and pages we need to decompress, and with + // what value offset into the first such page. + let number_type = number_type_from_dtype(self.common.dtype()); + let values_byte_buffer = match_number_enum!( + number_type, + NumberType => { + decompress_values_typed::(self)? + } + ); + + Ok(PrimitiveArray::from_values_byte_buffer( + values_byte_buffer, + self.common.dtype().as_ptype(), + self.unsliced_validity + .slice(self.slice_start..self.slice_stop)?, + self.common.len(), + )) + } +} + +impl PcoVTable { + /// Create a new [`PcoArray`] from pre-compressed chunk metadata buffers, page buffers, + /// dtype, metadata, length, and validity. + #[allow(clippy::new_ret_no_self)] pub fn new( chunk_metas: Vec, pages: Vec, @@ -332,8 +363,8 @@ impl PcoArray { metadata: PcoMetadata, len: usize, validity: Validity, - ) -> Self { - Self { + ) -> PcoArray { + PcoArray { chunk_metas, pages, metadata, @@ -345,14 +376,37 @@ impl PcoArray { } } + /// Compress a [`PrimitiveArray`] into a [`PcoArray`] with the given compression level + /// and values per page. pub fn from_primitive( parray: &PrimitiveArray, level: usize, values_per_page: usize, - ) -> VortexResult { - Self::from_primitive_with_values_per_chunk(parray, level, VALUES_PER_CHUNK, values_per_page) + ) -> VortexResult { + PcoArray::from_primitive_with_values_per_chunk( + parray, + level, + VALUES_PER_CHUNK, + values_per_page, + ) + } + + /// Compress an [`ArrayRef`] into a [`PcoArray`] with the given compression level + /// and values per page. The array must be a primitive array. + pub fn from_array( + array: ArrayRef, + level: usize, + nums_per_page: usize, + ) -> VortexResult { + if let Some(parray) = array.as_opt::() { + Self::from_primitive(parray, level, nums_per_page) + } else { + Err(vortex_err!("Pco can only encode primitive arrays")) + } } +} +impl PcoArray { pub(crate) fn from_primitive_with_values_per_chunk( parray: &PrimitiveArray, level: usize, @@ -416,7 +470,7 @@ impl PcoArray { header, chunks: chunk_infos, }; - Ok(PcoArray::new( + Ok(PcoVTable::new( chunk_meta_buffers, page_buffers, parray.dtype().clone(), @@ -426,107 +480,6 @@ impl PcoArray { )) } - pub fn from_array(array: ArrayRef, level: usize, nums_per_page: usize) -> VortexResult { - if let Some(parray) = array.as_opt::() { - Self::from_primitive(parray, level, nums_per_page) - } else { - Err(vortex_err!("Pco can only encode primitive arrays")) - } - } - - pub fn decompress(&self) -> VortexResult { - // To start, we figure out which chunks and pages we need to decompress, and with - // what value offset into the first such page. - let number_type = number_type_from_dtype(self.common.dtype()); - let values_byte_buffer = match_number_enum!( - number_type, - NumberType => { - self.decompress_values_typed::()? - } - ); - - Ok(PrimitiveArray::from_values_byte_buffer( - values_byte_buffer, - self.common.dtype().as_ptype(), - self.unsliced_validity - .slice(self.slice_start..self.slice_stop)?, - self.common.len(), - )) - } - - fn decompress_values_typed(&self) -> VortexResult { - // To start, we figure out what range of values we need to decompress. - let slice_value_indices = self - .unsliced_validity - .to_mask(self.unsliced_n_rows) - .valid_counts_for_indices(&[self.slice_start, self.slice_stop]); - let slice_value_start = slice_value_indices[0]; - let slice_value_stop = slice_value_indices[1]; - let slice_n_values = slice_value_stop - slice_value_start; - - // Then we decompress those pages into a buffer. Note that these values - // may exceed the bounds of the slice, so we need to slice later. - let (fd, _) = - FileDecompressor::new(self.metadata.header.as_slice()).map_err(vortex_err_from_pco)?; - let mut decompressed_values = BufferMut::::with_capacity(slice_n_values); - let mut page_idx = 0; - let mut page_value_start = 0; - let mut n_skipped_values = 0; - for (chunk_info, chunk_meta) in self.metadata.chunks.iter().zip(&self.chunk_metas) { - // lazily initialize chunk decompressor - let mut chunk_decompressor: Option> = None; - for page_info in &chunk_info.pages { - let page_n_values = page_info.n_values as usize; - let page_value_stop = page_value_start + page_n_values; - - if page_value_start >= slice_value_stop { - break; - } - - if page_value_stop > slice_value_start { - // we need this page - let old_len = decompressed_values.len(); - let new_len = old_len + page_n_values; - decompressed_values.reserve(page_n_values); - unsafe { - decompressed_values.set_len(new_len); - } - let page: &[u8] = self.pages[page_idx].as_ref(); - - let mut cd = match chunk_decompressor.take() { - Some(d) => d, - None => { - let (new_cd, _) = fd - .chunk_decompressor(chunk_meta.as_ref()) - .map_err(vortex_err_from_pco)?; - new_cd - } - }; - - let mut pd = cd - .page_decompressor(page, page_n_values) - .map_err(vortex_err_from_pco)?; - pd.read(&mut decompressed_values[old_len..new_len]) - .map_err(vortex_err_from_pco)?; - - chunk_decompressor = Some(cd); - } else { - n_skipped_values += page_n_values; - } - - page_value_start = page_value_stop; - page_idx += 1; - } - } - - // Slice only the values requested. - let value_offset = slice_value_start - n_skipped_values; - Ok(decompressed_values - .freeze() - .slice(value_offset..value_offset + slice_n_values) - .into_byte_buffer()) - } - pub(crate) fn _slice(&self, start: usize, stop: usize) -> Self { let new_start = self.slice_start + start; let new_stop = self.slice_start + stop; @@ -555,6 +508,79 @@ impl PcoArray { } } +fn decompress_values_typed(array: &PcoArray) -> VortexResult { + // To start, we figure out what range of values we need to decompress. + let slice_value_indices = array + .unsliced_validity + .to_mask(array.unsliced_n_rows) + .valid_counts_for_indices(&[array.slice_start, array.slice_stop]); + let slice_value_start = slice_value_indices[0]; + let slice_value_stop = slice_value_indices[1]; + let slice_n_values = slice_value_stop - slice_value_start; + + // Then we decompress those pages into a buffer. Note that these values + // may exceed the bounds of the slice, so we need to slice later. + let (fd, _) = + FileDecompressor::new(array.metadata.header.as_slice()).map_err(vortex_err_from_pco)?; + let mut decompressed_values = BufferMut::::with_capacity(slice_n_values); + let mut page_idx = 0; + let mut page_value_start = 0; + let mut n_skipped_values = 0; + for (chunk_info, chunk_meta) in array.metadata.chunks.iter().zip(&array.chunk_metas) { + // lazily initialize chunk decompressor + let mut chunk_decompressor: Option> = None; + for page_info in &chunk_info.pages { + let page_n_values = page_info.n_values as usize; + let page_value_stop = page_value_start + page_n_values; + + if page_value_start >= slice_value_stop { + break; + } + + if page_value_stop > slice_value_start { + // we need this page + let old_len = decompressed_values.len(); + let new_len = old_len + page_n_values; + decompressed_values.reserve(page_n_values); + unsafe { + decompressed_values.set_len(new_len); + } + let page: &[u8] = array.pages[page_idx].as_ref(); + + let mut cd = match chunk_decompressor.take() { + Some(d) => d, + None => { + let (new_cd, _) = fd + .chunk_decompressor(chunk_meta.as_ref()) + .map_err(vortex_err_from_pco)?; + new_cd + } + }; + + let mut pd = cd + .page_decompressor(page, page_n_values) + .map_err(vortex_err_from_pco)?; + pd.read(&mut decompressed_values[old_len..new_len]) + .map_err(vortex_err_from_pco)?; + + chunk_decompressor = Some(cd); + } else { + n_skipped_values += page_n_values; + } + + page_value_start = page_value_stop; + page_idx += 1; + } + } + + // Slice only the values requested. + let value_offset = slice_value_start - n_skipped_values; + Ok(decompressed_values + .freeze() + .slice(value_offset..value_offset + slice_n_values) + .into_byte_buffer()) +} + impl ValiditySliceHelper for PcoArray { fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) { (&self.unsliced_validity, self.slice_start, self.slice_stop) @@ -575,7 +601,7 @@ mod tests { use vortex_array::validity::Validity; use vortex_buffer::buffer; - use crate::PcoArray; + use crate::PcoVTable; #[test] fn test_slice_nullable() { @@ -584,7 +610,7 @@ mod tests { buffer![10u32, 20, 30, 40, 50, 60], Validity::from_iter([false, true, true, true, true, false]), ); - let pco = PcoArray::from_primitive(&values, 0, 128).unwrap(); + let pco = PcoVTable::from_primitive(&values, 0, 128).unwrap(); assert_arrays_eq!( pco, PrimitiveArray::from_option_iter([ diff --git a/encodings/pco/src/compute/cast.rs b/encodings/pco/src/compute/cast.rs index 2bcd46aeb79..88e3b84e352 100644 --- a/encodings/pco/src/compute/cast.rs +++ b/encodings/pco/src/compute/cast.rs @@ -30,7 +30,7 @@ impl CastReduce for PcoVTable { .cast_nullability(dtype.nullability(), array.len())?; return Ok(Some( - PcoArray::new( + PcoVTable::new( array.chunk_metas.clone(), array.pages.clone(), dtype.clone(), @@ -62,12 +62,12 @@ mod tests { use vortex_array::validity::Validity; use vortex_buffer::buffer; - use crate::PcoArray; + use crate::PcoVTable; #[test] fn test_cast_pco_f32_to_f64() { let values = PrimitiveArray::from_iter([1.0f32, 2.0, 3.0, 4.0, 5.0]); - let pco = PcoArray::from_primitive(&values, 0, 128).unwrap(); + let pco = PcoVTable::from_primitive(&values, 0, 128).unwrap(); let casted = pco .into_array() @@ -88,7 +88,7 @@ mod tests { fn test_cast_pco_nullability_change() { // Test casting from NonNullable to Nullable let values = PrimitiveArray::from_iter([10u32, 20, 30, 40]); - let pco = PcoArray::from_primitive(&values, 0, 128).unwrap(); + let pco = PcoVTable::from_primitive(&values, 0, 128).unwrap(); let casted = pco .into_array() @@ -106,7 +106,7 @@ mod tests { buffer![10u32, 20, 30, 40, 50, 60], Validity::from_iter([true, true, true, true, true, true]), ); - let pco = PcoArray::from_primitive(&values, 0, 128).unwrap(); + let pco = PcoVTable::from_primitive(&values, 0, 128).unwrap(); let sliced = pco.slice(1..5).unwrap(); let casted = sliced .cast(DType::Primitive(PType::U32, Nullability::NonNullable)) @@ -129,7 +129,7 @@ mod tests { Some(50), Some(60), ]); - let pco = PcoArray::from_primitive(&values, 0, 128).unwrap(); + let pco = PcoVTable::from_primitive(&values, 0, 128).unwrap(); let sliced = pco.slice(1..5).unwrap(); let casted = sliced .cast(DType::Primitive(PType::U32, Nullability::NonNullable)) @@ -163,7 +163,7 @@ mod tests { Validity::NonNullable, ))] fn test_cast_pco_conformance(#[case] values: PrimitiveArray) { - let pco = PcoArray::from_primitive(&values, 0, 128).unwrap(); + let pco = PcoVTable::from_primitive(&values, 0, 128).unwrap(); test_cast_conformance(&pco.into_array()); } } diff --git a/encodings/pco/src/compute/mod.rs b/encodings/pco/src/compute/mod.rs index a5db93b4833..a78110d54f8 100644 --- a/encodings/pco/src/compute/mod.rs +++ b/encodings/pco/src/compute/mod.rs @@ -11,45 +11,46 @@ mod tests { use vortex_array::compute::conformance::consistency::test_array_consistency; use crate::PcoArray; + use crate::PcoVTable; fn pco_f32() -> PcoArray { let values = PrimitiveArray::from_iter([1.23f32, 4.56, 7.89, 10.11, 12.13]); - PcoArray::from_primitive(&values, 0, 128).unwrap() + PcoVTable::from_primitive(&values, 0, 128).unwrap() } fn pco_f64() -> PcoArray { let values = PrimitiveArray::from_iter([100.1f64, 200.2, 300.3, 400.4, 500.5]); - PcoArray::from_primitive(&values, 0, 128).unwrap() + PcoVTable::from_primitive(&values, 0, 128).unwrap() } fn pco_i32() -> PcoArray { let values = PrimitiveArray::from_iter([100i32, 200, 300, 400, 500]); - PcoArray::from_primitive(&values, 0, 128).unwrap() + PcoVTable::from_primitive(&values, 0, 128).unwrap() } fn pco_u64() -> PcoArray { let values = PrimitiveArray::from_iter([1000u64, 2000, 3000, 4000]); - PcoArray::from_primitive(&values, 0, 128).unwrap() + PcoVTable::from_primitive(&values, 0, 128).unwrap() } fn pco_i16() -> PcoArray { let values = PrimitiveArray::from_iter([10i16, 20, 30, 40, 50]); - PcoArray::from_primitive(&values, 0, 128).unwrap() + PcoVTable::from_primitive(&values, 0, 128).unwrap() } fn pco_i32_alt() -> PcoArray { let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]); - PcoArray::from_primitive(&values, 0, 128).unwrap() + PcoVTable::from_primitive(&values, 0, 128).unwrap() } fn pco_single() -> PcoArray { let values = PrimitiveArray::from_iter([42.42f64]); - PcoArray::from_primitive(&values, 0, 128).unwrap() + PcoVTable::from_primitive(&values, 0, 128).unwrap() } fn pco_large() -> PcoArray { let values = PrimitiveArray::from_iter(0u32..1000); - PcoArray::from_primitive(&values, 3, 128).unwrap() + PcoVTable::from_primitive(&values, 3, 128).unwrap() } #[rstest] diff --git a/encodings/pco/src/test.rs b/encodings/pco/src/test.rs index 85d7dd64220..33267191d2b 100644 --- a/encodings/pco/src/test.rs +++ b/encodings/pco/src/test.rs @@ -36,13 +36,14 @@ static SESSION: LazyLock = LazyLock::new(|| { }); use crate::PcoArray; +use crate::PcoArrayExt; use crate::PcoVTable; #[test] fn test_compress_decompress() { let data: Vec = (0..200).collect(); let array = PrimitiveArray::from_iter(data.clone()); - let compressed = PcoArray::from_primitive(&array, 3, 0).unwrap(); + let compressed = PcoVTable::from_primitive(&array, 3, 0).unwrap(); // this data should be compressible assert!(compressed.pages.len() < array.nbytes() as usize); @@ -64,7 +65,7 @@ fn test_compress_decompress() { #[test] fn test_compress_decompress_small() { let array = PrimitiveArray::from_option_iter([None, Some(1)]); - let compressed = PcoArray::from_primitive(&array, 3, 0).unwrap(); + let compressed = PcoVTable::from_primitive(&array, 3, 0).unwrap(); let expected = array.into_array(); assert_arrays_eq!(compressed, expected); @@ -77,7 +78,7 @@ fn test_compress_decompress_small() { fn test_empty() { let data: Vec = vec![]; let array = PrimitiveArray::from_iter(data.clone()); - let compressed = PcoArray::from_primitive(&array, 3, 100).unwrap(); + let compressed = PcoVTable::from_primitive(&array, 3, 100).unwrap(); let primitive = compressed.decompress().unwrap(); assert_arrays_eq!(primitive, PrimitiveArray::from_iter(data)); } @@ -132,7 +133,7 @@ fn test_validity_vtable() { Buffer::from(data), Validity::Array(BoolArray::from_iter(mask_bools.clone()).into_array()), ); - let compressed = PcoArray::from_primitive(&array, 3, 0).unwrap(); + let compressed = PcoVTable::from_primitive(&array, 3, 0).unwrap(); assert_eq!( compressed.validity_mask().unwrap(), Mask::from_iter(mask_bools) @@ -146,7 +147,7 @@ fn test_validity_vtable() { #[test] fn test_serde() -> VortexResult<()> { let data: PrimitiveArray = (0i32..1_000_000).collect(); - let pco = PcoArray::from_primitive(&data, 3, 100)?.into_array(); + let pco = PcoVTable::from_primitive(&data, 3, 100)?.into_array(); let context = ArrayContext::empty(); diff --git a/encodings/runend/benches/run_end_compress.rs b/encodings/runend/benches/run_end_compress.rs index 533045c7314..8af06130a1a 100644 --- a/encodings/runend/benches/run_end_compress.rs +++ b/encodings/runend/benches/run_end_compress.rs @@ -15,7 +15,7 @@ use vortex_array::compute::warm_up_vtables; use vortex_array::dtype::IntegerPType; use vortex_array::validity::Validity; use vortex_buffer::Buffer; -use vortex_runend::RunEndArray; +use vortex_runend::RunEndVTable; use vortex_runend::compress::runend_encode; fn main() { @@ -72,7 +72,7 @@ fn decompress(bencher: Bencher, (length, run_step): (usize, usi .collect::>() .into_array(); - let run_end_array = RunEndArray::new(ends, values); + let run_end_array = RunEndVTable::new(ends, values); let array = run_end_array.into_array(); bencher @@ -97,7 +97,7 @@ fn take_indices(bencher: Bencher, (length, run_step): (usize, usize)) { let source_array = PrimitiveArray::from_iter(0..(length as i32)).into_array(); let (ends, values) = runend_encode(&values); - let runend_array = RunEndArray::try_new(ends.into_array(), values) + let runend_array = RunEndVTable::try_new(ends.into_array(), values) .unwrap() .into_array(); @@ -129,7 +129,7 @@ fn decompress_utf8(bencher: Bencher, (length, run_step): (usize, usize)) { let values = VarBinViewArray::from_iter_str((0..num_runs).map(|i| format!("run_value_{i}"))) .into_array(); - let run_end_array = RunEndArray::new(ends, values); + let run_end_array = RunEndVTable::new(ends, values); let array = run_end_array.into_array(); bencher diff --git a/encodings/runend/benches/run_end_null_count.rs b/encodings/runend/benches/run_end_null_count.rs index 4a3b8823e2c..1c1095f5cfc 100644 --- a/encodings/runend/benches/run_end_null_count.rs +++ b/encodings/runend/benches/run_end_null_count.rs @@ -13,6 +13,7 @@ use vortex_array::arrays::PrimitiveArray; use vortex_array::compute::warm_up_vtables; use vortex_buffer::Buffer; use vortex_runend::RunEndArray; +use vortex_runend::RunEndVTable; fn main() { warm_up_vtables(); @@ -70,5 +71,5 @@ fn fixture(n: usize, run_step: usize, valid_density: f64) -> RunEndArray { ) .into_array(); - RunEndArray::new(ends, values) + RunEndVTable::new(ends, values) } diff --git a/encodings/runend/src/arbitrary.rs b/encodings/runend/src/arbitrary.rs index f4a3a83f150..114e849ed32 100644 --- a/encodings/runend/src/arbitrary.rs +++ b/encodings/runend/src/arbitrary.rs @@ -15,6 +15,7 @@ use vortex_buffer::Buffer; use vortex_error::VortexExpect; use crate::RunEndArray; +use crate::RunEndVTable; /// A wrapper type to implement `Arbitrary` for `RunEndArray`. #[derive(Clone, Debug)] @@ -42,7 +43,7 @@ impl ArbitraryRunEndArray { // Empty RunEndArray let ends = PrimitiveArray::from_iter(Vec::::new()).into_array(); let values = ArbitraryArray::arbitrary_with(u, Some(0), dtype)?.0; - let runend_array = RunEndArray::try_new(ends, values) + let runend_array = RunEndVTable::try_new(ends, values) .vortex_expect("Empty RunEndArray creation should succeed"); return Ok(ArbitraryRunEndArray(runend_array)); } @@ -54,7 +55,7 @@ impl ArbitraryRunEndArray { // Each end must be > previous end, and first end must be >= 1 let ends = random_strictly_sorted_ends(u, num_runs, len)?; - let runend_array = RunEndArray::try_new(ends, values) + let runend_array = RunEndVTable::try_new(ends, values) .vortex_expect("RunEndArray creation should succeed in arbitrary impl"); Ok(ArbitraryRunEndArray(runend_array)) diff --git a/encodings/runend/src/array.rs b/encodings/runend/src/array.rs index 0bfb439c010..bc4e43c2953 100644 --- a/encodings/runend/src/array.rs +++ b/encodings/runend/src/array.rs @@ -163,7 +163,7 @@ impl VTable for RunEndVTable { let values = children.get(1, dtype, runs)?; - RunEndArray::try_new_offset_length( + Self::try_new_offset_length( ends, values, usize::try_from(metadata.offset).vortex_expect("Offset must be a valid usize"), @@ -225,89 +225,11 @@ pub struct RunEndVTable; impl RunEndVTable { pub const ID: ArrayId = ArrayId::new_ref("vortex.runend"); -} - -impl RunEndArray { - fn validate( - ends: &ArrayRef, - values: &ArrayRef, - offset: usize, - length: usize, - ) -> VortexResult<()> { - // DType validation - vortex_ensure!( - ends.dtype().is_unsigned_int(), - "run ends must be unsigned integers, was {}", - ends.dtype(), - ); - vortex_ensure!( - ends.len() == values.len(), - "run ends len != run values len, {} != {}", - ends.len(), - values.len() - ); - - // Handle empty run-ends - if ends.is_empty() { - vortex_ensure!( - offset == 0, - "non-zero offset provided for empty RunEndArray" - ); - return Ok(()); - } - - // Avoid building a non-empty array with zero logical length. - if length == 0 { - vortex_ensure!( - ends.is_empty(), - "run ends must be empty when length is zero" - ); - return Ok(()); - } - - debug_assert!({ - // Run ends must be strictly sorted for binary search to work correctly. - let pre_validation = ends.statistics().to_owned(); - - let is_sorted = ends - .statistics() - .compute_is_strict_sorted() - .unwrap_or(false); - - // Preserve the original statistics since compute_is_strict_sorted may have mutated them. - // We don't want to run with different stats in debug mode and outside. - ends.statistics().inherit(pre_validation.iter()); - is_sorted - }); - // Skip host-only validation when ends are not host-resident. - if !ends.is_host() { - return Ok(()); - } - - // Validate the offset and length are valid for the given ends and values - if offset != 0 && length != 0 { - let first_run_end = usize::try_from(&ends.scalar_at(0)?)?; - if first_run_end <= offset { - vortex_bail!("First run end {first_run_end} must be bigger than offset {offset}"); - } - } - - let last_run_end = usize::try_from(&ends.scalar_at(ends.len() - 1)?)?; - let min_required_end = offset + length; - if last_run_end < min_required_end { - vortex_bail!("Last run end {last_run_end} must be >= offset+length {min_required_end}"); - } - - Ok(()) - } -} - -impl RunEndArray { /// Build a new `RunEndArray` from an array of run `ends` and an array of `values`. /// - /// Panics if any of the validation conditions described in [`RunEndArray::try_new`] is - /// not satisfied. + /// Panics if any of the validation conditions described in + /// [`RunEndVTable::try_new`] is not satisfied. /// /// # Examples /// @@ -316,11 +238,11 @@ impl RunEndArray { /// # use vortex_array::IntoArray; /// # use vortex_buffer::buffer; /// # use vortex_error::VortexResult; - /// # use vortex_runend::RunEndArray; + /// # use vortex_runend::RunEndVTable; /// # fn main() -> VortexResult<()> { /// let ends = buffer![2u8, 3u8].into_array(); /// let values = BoolArray::from_iter([false, true]).into_array(); - /// let run_end = RunEndArray::new(ends, values); + /// let run_end = RunEndVTable::new(ends, values); /// /// // Array encodes /// assert_eq!(run_end.scalar_at(0)?, false.into()); @@ -329,16 +251,13 @@ impl RunEndArray { /// # Ok(()) /// # } /// ``` - pub fn new(ends: ArrayRef, values: ArrayRef) -> Self { + #[allow(clippy::new_ret_no_self)] + pub fn new(ends: ArrayRef, values: ArrayRef) -> RunEndArray { Self::try_new(ends, values).vortex_expect("RunEndArray new") } /// Build a new `RunEndArray` from components. - /// - /// # Validation - /// - /// The `ends` must be non-nullable unsigned integers. - pub fn try_new(ends: ArrayRef, values: ArrayRef) -> VortexResult { + pub fn try_new(ends: ArrayRef, values: ArrayRef) -> VortexResult { let length: usize = if ends.is_empty() { 0 } else { @@ -349,18 +268,16 @@ impl RunEndArray { } /// Construct a new sliced `RunEndArray` with the provided offset and length. - /// - /// This performs all the same validation as [`RunEndArray::try_new`]. pub fn try_new_offset_length( ends: ArrayRef, values: ArrayRef, offset: usize, length: usize, - ) -> VortexResult { - Self::validate(&ends, &values, offset, length)?; + ) -> VortexResult { + validate_runend(&ends, &values, offset, length)?; let dtype = values.dtype().clone(); - Ok(Self { + Ok(RunEndArray { common: ArrayCommon::new(length, dtype), ends, values, @@ -368,48 +285,13 @@ impl RunEndArray { }) } - /// Build a new `RunEndArray` without validation. - /// - /// # Safety - /// - /// The caller must ensure that all the validation performed in [`RunEndArray::try_new`] is - /// satisfied before calling this function. - /// - /// See [`RunEndArray::try_new`] for the preconditions needed to build a new array. - pub unsafe fn new_unchecked( - ends: ArrayRef, - values: ArrayRef, - offset: usize, - length: usize, - ) -> Self { - let dtype = values.dtype().clone(); - Self { - common: ArrayCommon::new(length, dtype), - ends, - values, - offset, - } - } - - /// Convert the given logical index to an index into the `values` array - pub fn find_physical_index(&self, index: usize) -> VortexResult { - Ok(self - .ends() - .as_primitive_typed() - .search_sorted( - &PValue::from(index + self.offset()), - SearchSortedSide::Right, - )? - .to_ends_index(self.ends().len())) - } - /// Run the array through run-end encoding. - pub fn encode(array: ArrayRef) -> VortexResult { + pub fn encode(array: ArrayRef) -> VortexResult { if let Some(parray) = array.as_opt::() { let (ends, values) = runend_encode(parray); // SAFETY: runend_encode handles this unsafe { - Ok(Self::new_unchecked( + Ok(RunEndArray::new_unchecked( ends.into_array(), values, 0, @@ -420,36 +302,129 @@ impl RunEndArray { vortex_bail!("REE can only encode primitive arrays") } } +} + +fn validate_runend( + ends: &ArrayRef, + values: &ArrayRef, + offset: usize, + length: usize, +) -> VortexResult<()> { + // DType validation + vortex_ensure!( + ends.dtype().is_unsigned_int(), + "run ends must be unsigned integers, was {}", + ends.dtype(), + ); + vortex_ensure!( + ends.len() == values.len(), + "run ends len != run values len, {} != {}", + ends.len(), + values.len() + ); + + // Handle empty run-ends + if ends.is_empty() { + vortex_ensure!( + offset == 0, + "non-zero offset provided for empty RunEndArray" + ); + return Ok(()); + } + + // Avoid building a non-empty array with zero logical length. + if length == 0 { + vortex_ensure!( + ends.is_empty(), + "run ends must be empty when length is zero" + ); + return Ok(()); + } + + debug_assert!({ + // Run ends must be strictly sorted for binary search to work correctly. + let pre_validation = ends.statistics().to_owned(); + + let is_sorted = ends + .statistics() + .compute_is_strict_sorted() + .unwrap_or(false); + + // Preserve the original statistics since compute_is_strict_sorted may have mutated them. + // We don't want to run with different stats in debug mode and outside. + ends.statistics().inherit(pre_validation.iter()); + is_sorted + }); + + // Skip host-only validation when ends are not host-resident. + if !ends.is_host() { + return Ok(()); + } + + // Validate the offset and length are valid for the given ends and values + if offset != 0 && length != 0 { + let first_run_end = usize::try_from(&ends.scalar_at(0)?)?; + if first_run_end <= offset { + vortex_bail!("First run end {first_run_end} must be bigger than offset {offset}"); + } + } + + let last_run_end = usize::try_from(&ends.scalar_at(ends.len() - 1)?)?; + let min_required_end = offset + length; + if last_run_end < min_required_end { + vortex_bail!("Last run end {last_run_end} must be >= offset+length {min_required_end}"); + } + + Ok(()) +} + +/// Extension trait for [`RunEndArray`] methods. +pub trait RunEndArrayExt: Sized { + /// Convert the given logical index to an index into the `values` array. + fn find_physical_index(&self, index: usize) -> VortexResult; /// The offset that the `ends` is relative to. - /// - /// This is generally zero for a "new" array, and non-zero after a slicing operation. + fn offset(&self) -> usize; + + /// The encoded "ends" of value runs. + fn ends(&self) -> &ArrayRef; + + /// The scalar values. + fn values(&self) -> &ArrayRef; + + /// Split a `RunEndArray` into parts. + fn into_parts(self) -> RunEndArrayParts; +} + +impl RunEndArrayExt for RunEndArray { + fn find_physical_index(&self, index: usize) -> VortexResult { + Ok(self + .ends() + .as_primitive_typed() + .search_sorted( + &PValue::from(index + self.offset()), + SearchSortedSide::Right, + )? + .to_ends_index(self.ends().len())) + } + #[inline] - pub fn offset(&self) -> usize { + fn offset(&self) -> usize { self.offset } - /// The encoded "ends" of value runs. - /// - /// The `i`-th element indicates that there is a run of the same value, beginning - /// at `ends[i]` (inclusive) and terminating at `ends[i+1]` (exclusive). #[inline] - pub fn ends(&self) -> &ArrayRef { + fn ends(&self) -> &ArrayRef { &self.ends } - /// The scalar values. - /// - /// The `i`-th element is the scalar value for the `i`-th repeated run. The run begins - /// at `ends[i]` (inclusive) and terminates at `ends[i+1]` (exclusive). #[inline] - pub fn values(&self) -> &ArrayRef { + fn values(&self) -> &ArrayRef { &self.values } - /// Split an `RunEndArray` into parts. #[inline] - pub fn into_parts(self) -> RunEndArrayParts { + fn into_parts(self) -> RunEndArrayParts { RunEndArrayParts { ends: self.ends, values: self.values, @@ -457,6 +432,29 @@ impl RunEndArray { } } +impl RunEndArray { + /// Build a new `RunEndArray` without validation. + /// + /// # Safety + /// + /// The caller must ensure that all the validation performed in + /// [`RunEndVTable::try_new`] is satisfied. + pub unsafe fn new_unchecked( + ends: ArrayRef, + values: ArrayRef, + offset: usize, + length: usize, + ) -> Self { + let dtype = values.dtype().clone(); + Self { + common: ArrayCommon::new(length, dtype), + ends, + values, + offset, + } + } +} + impl ValidityVTable for RunEndVTable { fn validity(array: &RunEndArray) -> VortexResult { Ok(match array.values().validity()? { @@ -512,11 +510,11 @@ mod tests { use vortex_array::dtype::PType; use vortex_buffer::buffer; - use crate::RunEndArray; + use crate::RunEndVTable; #[test] fn test_runend_constructor() { - let arr = RunEndArray::new( + let arr = RunEndVTable::new( buffer![2u32, 5, 10].into_array(), buffer![1i32, 2, 3].into_array(), ); @@ -536,7 +534,7 @@ mod tests { #[test] fn test_runend_utf8() { let values = VarBinViewArray::from_iter_str(["a", "b", "c"]).into_array(); - let arr = RunEndArray::new(buffer![2u32, 5, 10].into_array(), values); + let arr = RunEndVTable::new(buffer![2u32, 5, 10].into_array(), values); assert_eq!(arr.len(), 10); assert_eq!(arr.dtype(), &DType::Utf8(Nullability::NonNullable)); @@ -553,7 +551,7 @@ mod tests { let dict = DictArray::try_new(dict_codes, dict_values).unwrap(); let arr = - RunEndArray::try_new(buffer![2u32, 5, 10].into_array(), dict.into_array()).unwrap(); + RunEndVTable::try_new(buffer![2u32, 5, 10].into_array(), dict.into_array()).unwrap(); assert_eq!(arr.len(), 10); let expected = diff --git a/encodings/runend/src/arrow.rs b/encodings/runend/src/arrow.rs index 5ecd04a20c2..812a7ea02b9 100644 --- a/encodings/runend/src/arrow.rs +++ b/encodings/runend/src/arrow.rs @@ -79,6 +79,7 @@ mod tests { use vortex_session::VortexSession; use crate::RunEndArray; + use crate::RunEndVTable; static SESSION: LazyLock = LazyLock::new(|| VortexSession::empty().with::()); @@ -253,7 +254,7 @@ mod tests { #[test] fn test_sliced_runend_to_arrow_ree() -> VortexResult<()> { - let array = RunEndArray::encode( + let array = RunEndVTable::encode( PrimitiveArray::from_iter(vec![10i32, 10, 20, 20, 20, 30, 30]).into_array(), )?; // Slicing from index 1 produces a non-zero offset in the RunEndArray. diff --git a/encodings/runend/src/compute/cast.rs b/encodings/runend/src/compute/cast.rs index 831f2fcf6f0..d5b6c8a54c1 100644 --- a/encodings/runend/src/compute/cast.rs +++ b/encodings/runend/src/compute/cast.rs @@ -9,6 +9,7 @@ use vortex_array::scalar_fn::fns::cast::CastReduce; use vortex_error::VortexResult; use crate::RunEndArray; +use crate::RunEndArrayExt; use crate::RunEndVTable; impl CastReduce for RunEndVTable { @@ -48,10 +49,11 @@ mod tests { use vortex_buffer::buffer; use crate::RunEndArray; + use crate::RunEndVTable; #[test] fn test_cast_runend_i32_to_i64() { - let runend = RunEndArray::try_new( + let runend = RunEndVTable::try_new( buffer![3u64, 5, 8, 10].into_array(), buffer![100i32, 200, 100, 300].into_array(), ) @@ -90,7 +92,7 @@ mod tests { #[test] fn test_cast_runend_nullable() { - let runend = RunEndArray::try_new( + let runend = RunEndVTable::try_new( buffer![2u64, 4, 7].into_array(), PrimitiveArray::from_option_iter([Some(10i32), None, Some(20)]).into_array(), ) @@ -109,7 +111,7 @@ mod tests { #[test] fn test_cast_runend_with_offset() { // Create a RunEndArray: [100, 100, 100, 200, 200, 300, 300, 300, 300, 300] - let runend = RunEndArray::try_new( + let runend = RunEndVTable::try_new( buffer![3u64, 5, 10].into_array(), buffer![100i32, 200, 300].into_array(), ) @@ -134,23 +136,23 @@ mod tests { } #[rstest] - #[case(RunEndArray::try_new( + #[case(RunEndVTable::try_new( buffer![3u64, 5, 8].into_array(), buffer![100i32, 200, 300].into_array() ).unwrap())] - #[case(RunEndArray::try_new( + #[case(RunEndVTable::try_new( buffer![1u64, 4, 10].into_array(), buffer![1.5f32, 2.5, 3.5].into_array() ).unwrap())] - #[case(RunEndArray::try_new( + #[case(RunEndVTable::try_new( buffer![2u64, 3, 5].into_array(), PrimitiveArray::from_option_iter([Some(42i32), None, Some(84)]).into_array() ).unwrap())] - #[case(RunEndArray::try_new( + #[case(RunEndVTable::try_new( buffer![10u64].into_array(), buffer![255u8].into_array() ).unwrap())] - #[case(RunEndArray::try_new( + #[case(RunEndVTable::try_new( buffer![2u64, 4, 6, 8, 10].into_array(), BoolArray::from_iter(vec![true, false, true, false, true]).into_array() ).unwrap())] diff --git a/encodings/runend/src/compute/compare.rs b/encodings/runend/src/compute/compare.rs index d3b66f41a2a..57faeb9f50f 100644 --- a/encodings/runend/src/compute/compare.rs +++ b/encodings/runend/src/compute/compare.rs @@ -14,6 +14,7 @@ use vortex_array::scalar_fn::fns::operators::Operator; use vortex_error::VortexResult; use crate::RunEndArray; +use crate::RunEndArrayExt; use crate::RunEndVTable; use crate::compress::runend_decode_bools; @@ -55,9 +56,10 @@ mod test { use vortex_array::scalar_fn::fns::operators::Operator; use crate::RunEndArray; + use crate::RunEndVTable; fn ree_array() -> RunEndArray { - RunEndArray::encode( + RunEndVTable::encode( PrimitiveArray::from_iter([1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5]).into_array(), ) .unwrap() diff --git a/encodings/runend/src/compute/fill_null.rs b/encodings/runend/src/compute/fill_null.rs index 3cda14b945f..bd823a61557 100644 --- a/encodings/runend/src/compute/fill_null.rs +++ b/encodings/runend/src/compute/fill_null.rs @@ -9,6 +9,7 @@ use vortex_array::scalar_fn::fns::fill_null::FillNullReduce; use vortex_error::VortexResult; use crate::RunEndArray; +use crate::RunEndArrayExt; use crate::RunEndArrayParts; use crate::RunEndVTable; diff --git a/encodings/runend/src/compute/filter.rs b/encodings/runend/src/compute/filter.rs index c8e87326095..ae480d72cb6 100644 --- a/encodings/runend/src/compute/filter.rs +++ b/encodings/runend/src/compute/filter.rs @@ -21,6 +21,7 @@ use vortex_error::VortexResult; use vortex_mask::Mask; use crate::RunEndArray; +use crate::RunEndArrayExt; use crate::RunEndVTable; use crate::compute::take::take_indices_unchecked; @@ -123,9 +124,10 @@ mod tests { use vortex_mask::Mask; use crate::RunEndArray; + use crate::RunEndVTable; fn ree_array() -> RunEndArray { - RunEndArray::encode( + RunEndVTable::encode( PrimitiveArray::from_iter([1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5]).into_array(), ) .unwrap() @@ -138,7 +140,7 @@ mod tests { assert_arrays_eq!( filtered, - RunEndArray::new( + RunEndVTable::new( PrimitiveArray::from_iter([1u8, 2, 3]).into_array(), PrimitiveArray::from_iter([1i32, 4, 2]).into_array() ) diff --git a/encodings/runend/src/compute/is_constant.rs b/encodings/runend/src/compute/is_constant.rs index cd4df321672..1246c56688f 100644 --- a/encodings/runend/src/compute/is_constant.rs +++ b/encodings/runend/src/compute/is_constant.rs @@ -10,6 +10,7 @@ use vortex_array::expr::stats::Stat; use vortex_array::register_kernel; use vortex_error::VortexResult; +use crate::RunEndArrayExt; use crate::RunEndVTable; impl IsConstantKernel for RunEndVTable { diff --git a/encodings/runend/src/compute/is_sorted.rs b/encodings/runend/src/compute/is_sorted.rs index 4c57297b775..cb343cea92b 100644 --- a/encodings/runend/src/compute/is_sorted.rs +++ b/encodings/runend/src/compute/is_sorted.rs @@ -10,6 +10,7 @@ use vortex_array::register_kernel; use vortex_error::VortexResult; use crate::RunEndArray; +use crate::RunEndArrayExt; use crate::RunEndVTable; impl IsSortedKernel for RunEndVTable { diff --git a/encodings/runend/src/compute/min_max.rs b/encodings/runend/src/compute/min_max.rs index c58616ae714..57bc4b52ab2 100644 --- a/encodings/runend/src/compute/min_max.rs +++ b/encodings/runend/src/compute/min_max.rs @@ -9,6 +9,7 @@ use vortex_array::register_kernel; use vortex_error::VortexResult; use crate::RunEndArray; +use crate::RunEndArrayExt; use crate::RunEndVTable; impl MinMaxKernel for RunEndVTable { diff --git a/encodings/runend/src/compute/mod.rs b/encodings/runend/src/compute/mod.rs index d0d14b956aa..a76baff2c76 100644 --- a/encodings/runend/src/compute/mod.rs +++ b/encodings/runend/src/compute/mod.rs @@ -20,27 +20,28 @@ mod tests { use vortex_buffer::buffer; use crate::RunEndArray; + use crate::RunEndVTable; #[rstest] // Simple run-end arrays - #[case::runend_i32(RunEndArray::encode( + #[case::runend_i32(RunEndVTable::encode( buffer![1i32, 1, 1, 2, 2, 3, 3, 3, 3].into_array() ).unwrap())] - #[case::runend_single_run(RunEndArray::encode( + #[case::runend_single_run(RunEndVTable::encode( buffer![5i32, 5, 5, 5, 5].into_array() ).unwrap())] - #[case::runend_alternating(RunEndArray::encode( + #[case::runend_alternating(RunEndVTable::encode( buffer![1i32, 2, 1, 2, 1, 2].into_array() ).unwrap())] // Different types - #[case::runend_u64(RunEndArray::encode( + #[case::runend_u64(RunEndVTable::encode( buffer![100u64, 100, 200, 200, 200].into_array() ).unwrap())] // Edge cases - #[case::runend_single(RunEndArray::encode( + #[case::runend_single(RunEndVTable::encode( buffer![42i32].into_array() ).unwrap())] - #[case::runend_large(RunEndArray::encode( + #[case::runend_large(RunEndVTable::encode( PrimitiveArray::from_iter((0..1000).map(|i| i / 10)).into_array() ).unwrap())] diff --git a/encodings/runend/src/compute/take.rs b/encodings/runend/src/compute/take.rs index b9125b881ab..52d3f0d878e 100644 --- a/encodings/runend/src/compute/take.rs +++ b/encodings/runend/src/compute/take.rs @@ -21,6 +21,7 @@ use vortex_error::VortexResult; use vortex_error::vortex_bail; use crate::RunEndArray; +use crate::RunEndArrayExt; use crate::RunEndVTable; impl TakeExecute for RunEndVTable { @@ -103,9 +104,10 @@ mod test { use vortex_buffer::buffer; use crate::RunEndArray; + use crate::RunEndVTable; fn ree_array() -> RunEndArray { - RunEndArray::encode(buffer![1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5].into_array()).unwrap() + RunEndVTable::encode(buffer![1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5].into_array()).unwrap() } #[test] @@ -153,10 +155,10 @@ mod test { #[rstest] #[case(ree_array())] - #[case(RunEndArray::encode( + #[case(RunEndVTable::encode( buffer![1u8, 1, 2, 2, 2, 3, 3, 3, 3, 4].into_array(), ).unwrap())] - #[case(RunEndArray::encode( + #[case(RunEndVTable::encode( PrimitiveArray::from_option_iter([ Some(10), Some(10), @@ -168,9 +170,9 @@ mod test { ]) .into_array(), ).unwrap())] - #[case(RunEndArray::encode(buffer![42i32, 42, 42, 42, 42].into_array()) + #[case(RunEndVTable::encode(buffer![42i32, 42, 42, 42, 42].into_array()) .unwrap())] - #[case(RunEndArray::encode( + #[case(RunEndVTable::encode( buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10].into_array(), ).unwrap())] #[case({ @@ -180,7 +182,7 @@ mod test { values.push(i); } } - RunEndArray::encode(PrimitiveArray::from_iter(values).into_array()).unwrap() + RunEndVTable::encode(PrimitiveArray::from_iter(values).into_array()).unwrap() })] fn test_take_runend_conformance(#[case] array: RunEndArray) { test_take_conformance(&array.into_array()); @@ -189,7 +191,7 @@ mod test { #[rstest] #[case(ree_array().slice(3..6).unwrap())] #[case({ - let array = RunEndArray::encode( + let array = RunEndVTable::encode( buffer![1i32, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3].into_array(), ) .unwrap(); diff --git a/encodings/runend/src/compute/take_from.rs b/encodings/runend/src/compute/take_from.rs index cf6b3117be2..5c03f3ba312 100644 --- a/encodings/runend/src/compute/take_from.rs +++ b/encodings/runend/src/compute/take_from.rs @@ -11,6 +11,7 @@ use vortex_array::kernel::ExecuteParentKernel; use vortex_error::VortexResult; use crate::RunEndArray; +use crate::RunEndArrayExt; use crate::RunEndVTable; #[derive(Debug)] @@ -64,6 +65,8 @@ mod tests { use vortex_session::VortexSession; use crate::RunEndArray; + use crate::RunEndArrayExt; + use crate::RunEndVTable; use crate::compute::take_from::RunEndVTableTakeFrom; /// Build a DictArray whose codes are run-end encoded. @@ -73,7 +76,7 @@ mod tests { /// Codes: `[0, 0, 0, 1, 1, 0, 0]` /// RunEnd encoded codes: ends=`[3, 5, 7]`, values=`[0, 1, 0]` fn make_dict_with_runend_codes() -> (RunEndArray, DictArray) { - let codes = RunEndArray::encode(buffer![0u32, 0, 0, 1, 1, 0, 0].into_array()).unwrap(); + let codes = RunEndVTable::encode(buffer![0u32, 0, 0, 1, 1, 0, 0].into_array()).unwrap(); let values = buffer![2i32, 3].into_array(); let dict = DictArray::try_new(codes.clone().into_array(), values).unwrap(); (codes, dict) diff --git a/encodings/runend/src/kernel.rs b/encodings/runend/src/kernel.rs index 5155b3236d4..d626df4f725 100644 --- a/encodings/runend/src/kernel.rs +++ b/encodings/runend/src/kernel.rs @@ -17,6 +17,7 @@ use vortex_array::scalar_fn::fns::binary::CompareExecuteAdaptor; use vortex_error::VortexResult; use crate::RunEndArray; +use crate::RunEndArrayExt; use crate::RunEndVTable; use crate::compute::take_from::RunEndVTableTakeFrom; diff --git a/encodings/runend/src/ops.rs b/encodings/runend/src/ops.rs index 5c7b60e0727..eca72750a1c 100644 --- a/encodings/runend/src/ops.rs +++ b/encodings/runend/src/ops.rs @@ -11,6 +11,7 @@ use vortex_array::vtable::OperationsVTable; use vortex_error::VortexResult; use crate::RunEndArray; +use crate::RunEndArrayExt; use crate::RunEndVTable; impl OperationsVTable for RunEndVTable { @@ -54,11 +55,11 @@ mod tests { use vortex_array::dtype::PType; use vortex_buffer::buffer; - use crate::RunEndArray; + use crate::RunEndVTable; #[test] fn slice_array() { - let arr = RunEndArray::try_new( + let arr = RunEndVTable::try_new( buffer![2u32, 5, 10].into_array(), buffer![1i32, 2, 3].into_array(), ) @@ -77,7 +78,7 @@ mod tests { #[test] fn double_slice() { - let arr = RunEndArray::try_new( + let arr = RunEndVTable::try_new( buffer![2u32, 5, 10].into_array(), buffer![1i32, 2, 3].into_array(), ) @@ -94,7 +95,7 @@ mod tests { #[test] fn slice_end_inclusive() { - let arr = RunEndArray::try_new( + let arr = RunEndVTable::try_new( buffer![2u32, 5, 10].into_array(), buffer![1i32, 2, 3].into_array(), ) @@ -113,7 +114,7 @@ mod tests { #[test] fn slice_at_end() { - let re_array = RunEndArray::try_new( + let re_array = RunEndVTable::try_new( buffer![7_u64, 10].into_array(), buffer![2_u64, 3].into_array(), ) @@ -127,7 +128,7 @@ mod tests { #[test] fn slice_single_end() { - let re_array = RunEndArray::try_new( + let re_array = RunEndVTable::try_new( buffer![7_u64, 10].into_array(), buffer![2_u64, 3].into_array(), ) @@ -151,7 +152,7 @@ mod tests { #[test] fn ree_scalar_at_end() { - let scalar = RunEndArray::encode(buffer![1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5].into_array()) + let scalar = RunEndVTable::encode(buffer![1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5].into_array()) .unwrap() .scalar_at(11) .unwrap(); @@ -163,7 +164,7 @@ mod tests { fn slice_along_run_boundaries() { // Create a runend array with runs: [1, 1, 1] [4, 4, 4] [2, 2] [5, 5, 5, 5] // Run ends at indices: 3, 6, 8, 12 - let arr = RunEndArray::try_new( + let arr = RunEndVTable::try_new( buffer![3u32, 6, 8, 12].into_array(), buffer![1i32, 4, 2, 5].into_array(), ) diff --git a/encodings/runend/src/rules.rs b/encodings/runend/src/rules.rs index df756772f44..e449b4279dd 100644 --- a/encodings/runend/src/rules.rs +++ b/encodings/runend/src/rules.rs @@ -15,6 +15,7 @@ use vortex_array::scalar_fn::fns::fill_null::FillNullReduceAdaptor; use vortex_error::VortexResult; use crate::RunEndArray; +use crate::RunEndArrayExt; use crate::RunEndVTable; pub(super) const RULES: ParentRuleSet = ParentRuleSet::new(&[ diff --git a/encodings/sequence/public-api.lock b/encodings/sequence/public-api.lock index a48416a322a..5bdadc65d1c 100644 --- a/encodings/sequence/public-api.lock +++ b/encodings/sequence/public-api.lock @@ -14,10 +14,6 @@ pub fn vortex_sequence::SequenceArray::multiplier(&self) -> vortex_array::scalar pub fn vortex_sequence::SequenceArray::ptype(&self) -> vortex_array::dtype::ptype::PType -pub fn vortex_sequence::SequenceArray::try_new(base: vortex_array::scalar::typed_view::primitive::pvalue::PValue, multiplier: vortex_array::scalar::typed_view::primitive::pvalue::PValue, ptype: vortex_array::dtype::ptype::PType, nullability: vortex_array::dtype::nullability::Nullability, length: usize) -> vortex_error::VortexResult - -pub fn vortex_sequence::SequenceArray::try_new_typed>(base: T, multiplier: T, nullability: vortex_array::dtype::nullability::Nullability, length: usize) -> vortex_error::VortexResult - impl vortex_sequence::SequenceArray pub fn vortex_sequence::SequenceArray::to_array(&self) -> vortex_array::array::ArrayRef @@ -66,6 +62,10 @@ impl vortex_sequence::SequenceVTable pub const vortex_sequence::SequenceVTable::ID: vortex_array::vtable::dyn_::ArrayId +pub fn vortex_sequence::SequenceVTable::try_new(base: vortex_array::scalar::typed_view::primitive::pvalue::PValue, multiplier: vortex_array::scalar::typed_view::primitive::pvalue::PValue, ptype: vortex_array::dtype::ptype::PType, nullability: vortex_array::dtype::nullability::Nullability, length: usize) -> vortex_error::VortexResult + +pub fn vortex_sequence::SequenceVTable::try_new_typed>(base: T, multiplier: T, nullability: vortex_array::dtype::nullability::Nullability, length: usize) -> vortex_error::VortexResult + impl core::fmt::Debug for vortex_sequence::SequenceVTable pub fn vortex_sequence::SequenceVTable::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result diff --git a/encodings/sequence/src/array.rs b/encodings/sequence/src/array.rs index ca4d9e44d2c..36f6fdb8609 100644 --- a/encodings/sequence/src/array.rs +++ b/encodings/sequence/src/array.rs @@ -80,45 +80,54 @@ pub struct SequenceArray { multiplier: PValue, } -impl SequenceArray { - pub fn try_new_typed>( - base: T, - multiplier: T, - nullability: Nullability, - length: usize, - ) -> VortexResult { - Self::try_new( - base.into(), - multiplier.into(), - T::PTYPE, - nullability, - length, - ) +/// Extension trait for [`SequenceArray`] methods. +pub trait SequenceArrayExt: Sized { + /// Returns the primitive type of the array. + fn ptype(&self) -> PType; + + /// Returns the base value. + fn base(&self) -> PValue; + + /// Returns the multiplier value. + fn multiplier(&self) -> PValue; + + /// Returns the validated final value of a sequence array. + fn last(&self) -> PValue; + + /// Consumes the array and returns its parts. + fn into_parts(self) -> SequenceArrayParts; +} + +impl SequenceArrayExt for SequenceArray { + fn ptype(&self) -> PType { + self.common.dtype().as_ptype() } - /// Constructs a sequence array using two integer values (with the same ptype). - pub fn try_new( - base: PValue, - multiplier: PValue, - ptype: PType, - nullability: Nullability, - length: usize, - ) -> VortexResult { - if !ptype.is_int() { - vortex_bail!("only integer ptype are supported in SequenceArray currently") - } + fn base(&self) -> PValue { + self.base + } - Self::try_last(base, multiplier, ptype, length).map_err(|e| { - e.with_context(format!( - "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ", - )) - })?; + fn multiplier(&self) -> PValue { + self.multiplier + } - // SAFETY: we just validated that `ptype` is an integer and that the final - // element is representable via `try_last`. - Ok(unsafe { Self::new_unchecked(base, multiplier, ptype, nullability, length) }) + fn last(&self) -> PValue { + Self::try_last(self.base, self.multiplier, self.ptype(), self.common.len()) + .vortex_expect("validated array") + } + + fn into_parts(self) -> SequenceArrayParts { + SequenceArrayParts { + base: self.base, + multiplier: self.multiplier, + len: self.common.len(), + ptype: self.common.dtype().as_ptype(), + nullability: self.common.dtype().nullability(), + } } +} +impl SequenceArray { /// Constructs a [`SequenceArray`] without validating that the `ptype` is an integer /// type or that the final element is representable. /// @@ -127,9 +136,6 @@ impl SequenceArray { /// The caller must ensure that: /// - `ptype` is an integer type (i.e., `ptype.is_int()` returns `true`). /// - `base + (length - 1) * multiplier` does not overflow the range of `ptype`. - /// - /// Violating the first invariant will cause a panic. Violating the second will - /// cause silent wraparound when materializing elements, producing incorrect values. pub(crate) unsafe fn new_unchecked( base: PValue, multiplier: PValue, @@ -167,18 +173,6 @@ impl SequenceArray { } } - pub fn ptype(&self) -> PType { - self.common.dtype().as_ptype() - } - - pub fn base(&self) -> PValue { - self.base - } - - pub fn multiplier(&self) -> PValue { - self.multiplier - } - pub(crate) fn try_last( base: PValue, multiplier: PValue, @@ -216,22 +210,6 @@ impl SequenceArray { PValue::from(value) }) } - - /// Returns the validated final value of a sequence array - pub fn last(&self) -> PValue { - Self::try_last(self.base, self.multiplier, self.ptype(), self.common.len()) - .vortex_expect("validated array") - } - - pub fn into_parts(self) -> SequenceArrayParts { - SequenceArrayParts { - base: self.base, - multiplier: self.multiplier, - len: self.common.len(), - ptype: self.common.dtype().as_ptype(), - nullability: self.common.dtype().nullability(), - } - } } impl VTable for SequenceVTable { @@ -362,7 +340,7 @@ impl VTable for SequenceVTable { _buffers: &[BufferHandle], _children: &dyn ArrayChildren, ) -> VortexResult { - SequenceArray::try_new( + Self::try_new( metadata.base, metadata.multiplier, dtype.as_ptype(), @@ -422,6 +400,45 @@ pub struct SequenceVTable; impl SequenceVTable { pub const ID: ArrayId = ArrayId::new_ref("vortex.sequence"); + + /// Constructs a sequence array using a typed base and multiplier. + pub fn try_new_typed>( + base: T, + multiplier: T, + nullability: Nullability, + length: usize, + ) -> VortexResult { + Self::try_new( + base.into(), + multiplier.into(), + T::PTYPE, + nullability, + length, + ) + } + + /// Constructs a sequence array using two integer values (with the same ptype). + pub fn try_new( + base: PValue, + multiplier: PValue, + ptype: PType, + nullability: Nullability, + length: usize, + ) -> VortexResult { + if !ptype.is_int() { + vortex_bail!("only integer ptype are supported in SequenceArray currently") + } + + SequenceArray::try_last(base, multiplier, ptype, length).map_err(|e| { + e.with_context(format!( + "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ", + )) + })?; + + // SAFETY: we just validated that `ptype` is an integer and that the final + // element is representable via `try_last`. + Ok(unsafe { SequenceArray::new_unchecked(base, multiplier, ptype, nullability, length) }) + } } #[cfg(test)] @@ -436,11 +453,11 @@ mod tests { use vortex_array::scalar::ScalarValue; use vortex_error::VortexResult; - use crate::array::SequenceArray; + use crate::array::SequenceVTable; #[test] fn test_sequence_canonical() { - let arr = SequenceArray::try_new_typed(2i64, 3, Nullability::NonNullable, 4).unwrap(); + let arr = SequenceVTable::try_new_typed(2i64, 3, Nullability::NonNullable, 4).unwrap(); let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3)); @@ -449,7 +466,7 @@ mod tests { #[test] fn test_sequence_slice_canonical() { - let arr = SequenceArray::try_new_typed(2i64, 3, Nullability::NonNullable, 4) + let arr = SequenceVTable::try_new_typed(2i64, 3, Nullability::NonNullable, 4) .unwrap() .slice(2..3) .unwrap(); @@ -461,7 +478,7 @@ mod tests { #[test] fn test_sequence_scalar_at() { - let scalar = SequenceArray::try_new_typed(2i64, 3, Nullability::NonNullable, 4) + let scalar = SequenceVTable::try_new_typed(2i64, 3, Nullability::NonNullable, 4) .unwrap() .scalar_at(2) .unwrap(); @@ -474,19 +491,19 @@ mod tests { #[test] fn test_sequence_min_max() { - assert!(SequenceArray::try_new_typed(-127i8, -1i8, Nullability::NonNullable, 2).is_ok()); - assert!(SequenceArray::try_new_typed(126i8, -1i8, Nullability::NonNullable, 2).is_ok()); + assert!(SequenceVTable::try_new_typed(-127i8, -1i8, Nullability::NonNullable, 2).is_ok()); + assert!(SequenceVTable::try_new_typed(126i8, -1i8, Nullability::NonNullable, 2).is_ok()); } #[test] fn test_sequence_too_big() { - assert!(SequenceArray::try_new_typed(127i8, 1i8, Nullability::NonNullable, 2).is_err()); - assert!(SequenceArray::try_new_typed(-128i8, -1i8, Nullability::NonNullable, 2).is_err()); + assert!(SequenceVTable::try_new_typed(127i8, 1i8, Nullability::NonNullable, 2).is_err()); + assert!(SequenceVTable::try_new_typed(-128i8, -1i8, Nullability::NonNullable, 2).is_err()); } #[test] fn positive_multiplier_is_strict_sorted() -> VortexResult<()> { - let arr = SequenceArray::try_new_typed(0i64, 3, Nullability::NonNullable, 4)?; + let arr = SequenceVTable::try_new_typed(0i64, 3, Nullability::NonNullable, 4)?; let is_sorted = arr .statistics() @@ -502,7 +519,7 @@ mod tests { #[test] fn zero_multiplier_is_sorted_not_strict() -> VortexResult<()> { - let arr = SequenceArray::try_new_typed(5i64, 0, Nullability::NonNullable, 4)?; + let arr = SequenceVTable::try_new_typed(5i64, 0, Nullability::NonNullable, 4)?; let is_sorted = arr .statistics() @@ -518,7 +535,7 @@ mod tests { #[test] fn negative_multiplier_not_sorted() -> VortexResult<()> { - let arr = SequenceArray::try_new_typed(10i64, -1, Nullability::NonNullable, 4)?; + let arr = SequenceVTable::try_new_typed(10i64, -1, Nullability::NonNullable, 4)?; let is_sorted = arr .statistics() @@ -537,7 +554,7 @@ mod tests { #[test] fn test_large_multiplier_sorted() -> VortexResult<()> { let large_multiplier = (i64::MAX as u64) + 1; - let arr = SequenceArray::try_new_typed(0, large_multiplier, Nullability::NonNullable, 2)?; + let arr = SequenceVTable::try_new_typed(0, large_multiplier, Nullability::NonNullable, 2)?; let is_sorted = arr .statistics() diff --git a/encodings/sequence/src/compress.rs b/encodings/sequence/src/compress.rs index 09ffa47bc55..1e5ea23e0f8 100644 --- a/encodings/sequence/src/compress.rs +++ b/encodings/sequence/src/compress.rs @@ -19,6 +19,8 @@ use vortex_buffer::trusted_len::TrustedLen; use vortex_error::VortexResult; use crate::SequenceArray; +use crate::SequenceArrayExt; +use crate::SequenceVTable; /// An iterator that yields `base, base + step, base + 2*step, ...` via repeated addition. struct SequenceIter { @@ -112,7 +114,7 @@ fn encode_primitive_array + CheckedAdd + CheckedSu ) -> VortexResult> { if slice.len() == 1 { // The multiplier here can be any value, zero is chosen - return SequenceArray::try_new_typed(slice[0], P::zero(), nullability, 1) + return SequenceVTable::try_new_typed(slice[0], P::zero(), nullability, 1) .map(|a| Some(a.into_array())); } let base = slice[0]; @@ -133,7 +135,7 @@ fn encode_primitive_array + CheckedAdd + CheckedSu .windows(2) .all(|w| Some(w[1]) == w[0].checked_add(&multiplier)) .then_some( - SequenceArray::try_new_typed(base, multiplier, nullability, slice.len()) + SequenceVTable::try_new_typed(base, multiplier, nullability, slice.len()) .map(|a| a.into_array()), ) .transpose() diff --git a/encodings/sequence/src/compute/cast.rs b/encodings/sequence/src/compute/cast.rs index 12f24e4ded9..3854f6b5ced 100644 --- a/encodings/sequence/src/compute/cast.rs +++ b/encodings/sequence/src/compute/cast.rs @@ -12,6 +12,7 @@ use vortex_error::VortexResult; use vortex_error::vortex_err; use crate::SequenceArray; +use crate::SequenceArrayExt; use crate::SequenceVTable; impl CastReduce for SequenceVTable { @@ -32,7 +33,7 @@ impl CastReduce for SequenceVTable { // For SequenceArray, we can just create a new one with the same parameters // but different nullability return Ok(Some( - SequenceArray::try_new( + SequenceVTable::try_new( array.base(), array.multiplier(), *target_ptype, @@ -71,7 +72,7 @@ impl CastReduce for SequenceVTable { .ok_or_else(|| vortex_err!("Cast resulted in null multiplier value"))?; return Ok(Some( - SequenceArray::try_new( + SequenceVTable::try_new( new_base, new_multiplier, *target_ptype, @@ -100,11 +101,12 @@ mod tests { use vortex_array::dtype::PType; use crate::SequenceArray; + use crate::SequenceVTable; #[test] fn test_cast_sequence_nullability() { let sequence = - SequenceArray::try_new_typed(0u32, 1u32, Nullability::NonNullable, 4).unwrap(); + SequenceVTable::try_new_typed(0u32, 1u32, Nullability::NonNullable, 4).unwrap(); // Cast to nullable let casted = sequence @@ -120,7 +122,7 @@ mod tests { #[test] fn test_cast_sequence_u32_to_i64() { let sequence = - SequenceArray::try_new_typed(100u32, 10u32, Nullability::NonNullable, 4).unwrap(); + SequenceVTable::try_new_typed(100u32, 10u32, Nullability::NonNullable, 4).unwrap(); let casted = sequence .into_array() @@ -140,7 +142,7 @@ mod tests { fn test_cast_sequence_i16_to_i32_nullable() { // Test ptype change AND nullability change in one cast let sequence = - SequenceArray::try_new_typed(5i16, 3i16, Nullability::NonNullable, 3).unwrap(); + SequenceVTable::try_new_typed(5i16, 3i16, Nullability::NonNullable, 3).unwrap(); let casted = sequence .into_array() @@ -162,7 +164,7 @@ mod tests { #[test] fn test_cast_sequence_to_float_delegates_to_canonical() { let sequence = - SequenceArray::try_new_typed(0i32, 1i32, Nullability::NonNullable, 5).unwrap(); + SequenceVTable::try_new_typed(0i32, 1i32, Nullability::NonNullable, 5).unwrap(); // Cast to float should delegate to canonical (SequenceArray doesn't support float) let casted = sequence @@ -184,15 +186,15 @@ mod tests { } #[rstest] - #[case::i32(SequenceArray::try_new_typed(0i32, 1i32, Nullability::NonNullable, 5).unwrap())] - #[case::u64(SequenceArray::try_new_typed(1000u64, 100u64, Nullability::NonNullable, 4).unwrap())] + #[case::i32(SequenceVTable::try_new_typed(0i32, 1i32, Nullability::NonNullable, 5).unwrap())] + #[case::u64(SequenceVTable::try_new_typed(1000u64, 100u64, Nullability::NonNullable, 4).unwrap())] // TODO(DK): SequenceArray does not actually conform. You cannot cast this array to u8 even // though all its values are representable therein. // - // #[case::negative_step(SequenceArray::try_new_typed(100i32, -10i32, Nullability::NonNullable, + // #[case::negative_step(SequenceVTable::try_new_typed(100i32, -10i32, Nullability::NonNullable, // 5).unwrap())] - #[case::single(SequenceArray::try_new_typed(42i64, 0i64, Nullability::NonNullable, 1).unwrap())] - #[case::constant(SequenceArray::try_new_typed( + #[case::single(SequenceVTable::try_new_typed(42i64, 0i64, Nullability::NonNullable, 1).unwrap())] + #[case::constant(SequenceVTable::try_new_typed( 100i32, 0i32, // multiplier of 0 means constant array Nullability::NonNullable, diff --git a/encodings/sequence/src/compute/compare.rs b/encodings/sequence/src/compute/compare.rs index 6e21e916729..87fc97955bf 100644 --- a/encodings/sequence/src/compute/compare.rs +++ b/encodings/sequence/src/compute/compare.rs @@ -21,6 +21,7 @@ use vortex_error::vortex_bail; use vortex_error::vortex_err; use crate::SequenceArray; +use crate::SequenceArrayExt; use crate::array::SequenceVTable; impl CompareKernel for SequenceVTable { @@ -145,11 +146,11 @@ mod tests { use vortex_array::dtype::Nullability::Nullable; use vortex_array::scalar_fn::fns::operators::Operator; - use crate::SequenceArray; + use crate::SequenceVTable; #[test] fn test_compare_match() { - let lhs = SequenceArray::try_new_typed(2i64, 1, NonNullable, 4).unwrap(); + let lhs = SequenceVTable::try_new_typed(2i64, 1, NonNullable, 4).unwrap(); let rhs = ConstantArray::new(4i64, lhs.len()); let result = lhs .into_array() @@ -161,7 +162,7 @@ mod tests { #[test] fn test_compare_match_scale() { - let lhs = SequenceArray::try_new_typed(2i64, 3, Nullable, 4).unwrap(); + let lhs = SequenceVTable::try_new_typed(2i64, 3, Nullable, 4).unwrap(); let rhs = ConstantArray::new(8i64, lhs.len()); let result = lhs .into_array() @@ -173,7 +174,7 @@ mod tests { #[test] fn test_compare_no_match() { - let lhs = SequenceArray::try_new_typed(2i64, 1, NonNullable, 4).unwrap(); + let lhs = SequenceVTable::try_new_typed(2i64, 1, NonNullable, 4).unwrap(); let rhs = ConstantArray::new(1i64, lhs.len()); let result = lhs .into_array() diff --git a/encodings/sequence/src/compute/filter.rs b/encodings/sequence/src/compute/filter.rs index 3f804cc3bc3..c596bb89c93 100644 --- a/encodings/sequence/src/compute/filter.rs +++ b/encodings/sequence/src/compute/filter.rs @@ -15,6 +15,7 @@ use vortex_error::VortexResult; use vortex_mask::Mask; use crate::SequenceArray; +use crate::SequenceArrayExt; use crate::SequenceVTable; impl FilterKernel for SequenceVTable { @@ -54,20 +55,21 @@ mod tests { use vortex_array::dtype::Nullability; use crate::SequenceArray; + use crate::SequenceVTable; #[rstest] - #[case(SequenceArray::try_new_typed(0i32, 1, Nullability::NonNullable, 5).unwrap())] - #[case(SequenceArray::try_new_typed(10i32, 2, Nullability::NonNullable, 5).unwrap())] - #[case(SequenceArray::try_new_typed(100i32, -3, Nullability::NonNullable, 5).unwrap())] - #[case(SequenceArray::try_new_typed(0i32, 1, Nullability::NonNullable, 1).unwrap())] - #[case(SequenceArray::try_new_typed(0i32, 1, Nullability::NonNullable, MEDIUM_SIZE).unwrap())] - #[case(SequenceArray::try_new_typed(0i32, 1, Nullability::NonNullable, LARGE_SIZE).unwrap())] - #[case(SequenceArray::try_new_typed(0i64, 1, Nullability::NonNullable, 5).unwrap())] - #[case(SequenceArray::try_new_typed(1000i64, 50, Nullability::NonNullable, 5).unwrap())] - #[case(SequenceArray::try_new_typed(-100i64, 10, Nullability::NonNullable, MEDIUM_SIZE).unwrap())] - #[case(SequenceArray::try_new_typed(0u32, 1, Nullability::NonNullable, 5).unwrap())] - #[case(SequenceArray::try_new_typed(0u32, 5, Nullability::NonNullable, MEDIUM_SIZE).unwrap())] - #[case(SequenceArray::try_new_typed(0u64, 1, Nullability::NonNullable, LARGE_SIZE).unwrap())] + #[case(SequenceVTable::try_new_typed(0i32, 1, Nullability::NonNullable, 5).unwrap())] + #[case(SequenceVTable::try_new_typed(10i32, 2, Nullability::NonNullable, 5).unwrap())] + #[case(SequenceVTable::try_new_typed(100i32, -3, Nullability::NonNullable, 5).unwrap())] + #[case(SequenceVTable::try_new_typed(0i32, 1, Nullability::NonNullable, 1).unwrap())] + #[case(SequenceVTable::try_new_typed(0i32, 1, Nullability::NonNullable, MEDIUM_SIZE).unwrap())] + #[case(SequenceVTable::try_new_typed(0i32, 1, Nullability::NonNullable, LARGE_SIZE).unwrap())] + #[case(SequenceVTable::try_new_typed(0i64, 1, Nullability::NonNullable, 5).unwrap())] + #[case(SequenceVTable::try_new_typed(1000i64, 50, Nullability::NonNullable, 5).unwrap())] + #[case(SequenceVTable::try_new_typed(-100i64, 10, Nullability::NonNullable, MEDIUM_SIZE).unwrap())] + #[case(SequenceVTable::try_new_typed(0u32, 1, Nullability::NonNullable, 5).unwrap())] + #[case(SequenceVTable::try_new_typed(0u32, 5, Nullability::NonNullable, MEDIUM_SIZE).unwrap())] + #[case(SequenceVTable::try_new_typed(0u64, 1, Nullability::NonNullable, LARGE_SIZE).unwrap())] fn test_filter_sequence_conformance(#[case] array: SequenceArray) { test_filter_conformance(&array.into_array()); } diff --git a/encodings/sequence/src/compute/is_sorted.rs b/encodings/sequence/src/compute/is_sorted.rs index cec77aa837a..9e88175e54e 100644 --- a/encodings/sequence/src/compute/is_sorted.rs +++ b/encodings/sequence/src/compute/is_sorted.rs @@ -9,6 +9,7 @@ use vortex_array::register_kernel; use vortex_error::VortexResult; use crate::SequenceArray; +use crate::SequenceArrayExt; use crate::SequenceVTable; impl IsSortedKernel for SequenceVTable { diff --git a/encodings/sequence/src/compute/list_contains.rs b/encodings/sequence/src/compute/list_contains.rs index c7c4c8565c9..80afb4535d0 100644 --- a/encodings/sequence/src/compute/list_contains.rs +++ b/encodings/sequence/src/compute/list_contains.rs @@ -9,6 +9,7 @@ use vortex_array::scalar_fn::fns::list_contains::ListContainsElementReduce; use vortex_error::VortexExpect; use vortex_error::VortexResult; +use crate::SequenceArrayExt; use crate::array::SequenceVTable; use crate::compute::compare::find_intersection_scalar; @@ -60,7 +61,7 @@ mod tests { use vortex_array::expr::root; use vortex_array::scalar::Scalar; - use crate::SequenceArray; + use crate::SequenceVTable; #[test] fn test_list_contains_seq() { @@ -74,7 +75,7 @@ mod tests { // [1, 3] in 1 // 2 // 3 - let array = SequenceArray::try_new_typed(1, 1, Nullability::NonNullable, 3).unwrap(); + let array = SequenceVTable::try_new_typed(1, 1, Nullability::NonNullable, 3).unwrap(); let expr = list_contains(lit(list_scalar.clone()), root()); let result = array.apply(&expr).unwrap(); @@ -86,7 +87,7 @@ mod tests { // [1, 3] in 1 // 3 // 5 - let array = SequenceArray::try_new_typed(1, 2, Nullability::NonNullable, 3).unwrap(); + let array = SequenceVTable::try_new_typed(1, 2, Nullability::NonNullable, 3).unwrap(); let expr = list_contains(lit(list_scalar), root()); let result = array.apply(&expr).unwrap(); diff --git a/encodings/sequence/src/compute/min_max.rs b/encodings/sequence/src/compute/min_max.rs index 8753d83a239..13fd4609609 100644 --- a/encodings/sequence/src/compute/min_max.rs +++ b/encodings/sequence/src/compute/min_max.rs @@ -10,6 +10,7 @@ use vortex_array::scalar::Scalar; use vortex_error::VortexResult; use crate::SequenceArray; +use crate::SequenceArrayExt; use crate::array::SequenceVTable; impl MinMaxKernel for SequenceVTable { diff --git a/encodings/sequence/src/compute/mod.rs b/encodings/sequence/src/compute/mod.rs index a08302c8302..3a20dcaa110 100644 --- a/encodings/sequence/src/compute/mod.rs +++ b/encodings/sequence/src/compute/mod.rs @@ -18,16 +18,17 @@ mod tests { use vortex_array::dtype::Nullability; use crate::SequenceArray; + use crate::SequenceVTable; #[rstest] // Basic sequence arrays - A[i] = base + i * multiplier - #[case::sequence_i32(SequenceArray::try_new_typed( + #[case::sequence_i32(SequenceVTable::try_new_typed( 0i32, // base 1i32, // multiplier Nullability::NonNullable, 5 // length ).unwrap())] // Results in [0, 1, 2, 3, 4] - #[case::sequence_i64_step2(SequenceArray::try_new_typed( + #[case::sequence_i64_step2(SequenceVTable::try_new_typed( 10i64, // base 2i64, // multiplier Nullability::NonNullable, @@ -35,13 +36,13 @@ mod tests { ).unwrap())] // Results in [10, 12, 14, 16, 18] // Different types - #[case::sequence_u32(SequenceArray::try_new_typed( + #[case::sequence_u32(SequenceVTable::try_new_typed( 100u32, // base 10u32, // multiplier Nullability::NonNullable, 5 // length ).unwrap())] // Results in [100, 110, 120, 130, 140] - #[case::sequence_i16(SequenceArray::try_new_typed( + #[case::sequence_i16(SequenceVTable::try_new_typed( -10i16, // base 3i16, // multiplier Nullability::NonNullable, @@ -49,19 +50,19 @@ mod tests { ).unwrap())] // Results in [-10, -7, -4, -1, 2] // Edge cases - #[case::sequence_single(SequenceArray::try_new_typed( + #[case::sequence_single(SequenceVTable::try_new_typed( 42i32, 0i32, // multiplier of 0 means constant array Nullability::NonNullable, 1 ).unwrap())] - #[case::sequence_zero_multiplier(SequenceArray::try_new_typed( + #[case::sequence_zero_multiplier(SequenceVTable::try_new_typed( 100i32, 0i32, // All values will be 100 Nullability::NonNullable, 5 ).unwrap())] - #[case::sequence_negative_step(SequenceArray::try_new_typed( + #[case::sequence_negative_step(SequenceVTable::try_new_typed( 100i32, -10i32, // Decreasing sequence Nullability::NonNullable, @@ -69,13 +70,13 @@ mod tests { ).unwrap())] // Results in [100, 90, 80, 70, 60] // Large arrays - #[case::sequence_large(SequenceArray::try_new_typed( + #[case::sequence_large(SequenceVTable::try_new_typed( 0i64, 1i64, Nullability::NonNullable, 2000 ).unwrap())] // Results in [0, 1, 2, ..., 1999] - #[case::sequence_large_step(SequenceArray::try_new_typed( + #[case::sequence_large_step(SequenceVTable::try_new_typed( 1000i32, 100i32, Nullability::NonNullable, diff --git a/encodings/sequence/src/compute/slice.rs b/encodings/sequence/src/compute/slice.rs index 14954c14ed7..4625462de9d 100644 --- a/encodings/sequence/src/compute/slice.rs +++ b/encodings/sequence/src/compute/slice.rs @@ -9,6 +9,7 @@ use vortex_array::arrays::slice::SliceReduce; use vortex_error::VortexResult; use crate::SequenceArray; +use crate::SequenceArrayExt; use crate::SequenceVTable; impl SliceReduce for SequenceVTable { diff --git a/encodings/sequence/src/compute/take.rs b/encodings/sequence/src/compute/take.rs index 4295e50ff9d..1a0c615fe86 100644 --- a/encodings/sequence/src/compute/take.rs +++ b/encodings/sequence/src/compute/take.rs @@ -26,6 +26,7 @@ use vortex_mask::AllOr; use vortex_mask::Mask; use crate::SequenceArray; +use crate::SequenceArrayExt; use crate::SequenceVTable; fn take_inner( @@ -112,51 +113,52 @@ mod test { use vortex_array::dtype::Nullability; use crate::SequenceArray; + use crate::SequenceVTable; #[rstest] - #[case::basic_sequence(SequenceArray::try_new_typed( + #[case::basic_sequence(SequenceVTable::try_new_typed( 0i32, 1i32, Nullability::NonNullable, 10 ).unwrap())] - #[case::sequence_with_multiplier(SequenceArray::try_new_typed( + #[case::sequence_with_multiplier(SequenceVTable::try_new_typed( 10i32, 5i32, Nullability::Nullable, 20 ).unwrap())] - #[case::sequence_i64(SequenceArray::try_new_typed( + #[case::sequence_i64(SequenceVTable::try_new_typed( 100i64, 10i64, Nullability::NonNullable, 50 ).unwrap())] - #[case::sequence_u32(SequenceArray::try_new_typed( + #[case::sequence_u32(SequenceVTable::try_new_typed( 0u32, 2u32, Nullability::NonNullable, 100 ).unwrap())] - #[case::sequence_negative_step(SequenceArray::try_new_typed( + #[case::sequence_negative_step(SequenceVTable::try_new_typed( 1000i32, -10i32, Nullability::Nullable, 30 ).unwrap())] - #[case::sequence_constant(SequenceArray::try_new_typed( + #[case::sequence_constant(SequenceVTable::try_new_typed( 42i32, 0i32, // multiplier of 0 means all values are the same Nullability::Nullable, 15 ).unwrap())] - #[case::sequence_i16(SequenceArray::try_new_typed( + #[case::sequence_i16(SequenceVTable::try_new_typed( -100i16, 3i16, Nullability::NonNullable, 25 ).unwrap())] - #[case::sequence_large(SequenceArray::try_new_typed( + #[case::sequence_large(SequenceVTable::try_new_typed( 0i64, 1i64, Nullability::Nullable, @@ -170,7 +172,8 @@ mod test { #[test] #[should_panic(expected = "out of bounds")] fn test_bounds_check() { - let array = SequenceArray::try_new_typed(0i32, 1i32, Nullability::NonNullable, 10).unwrap(); + let array = + SequenceVTable::try_new_typed(0i32, 1i32, Nullability::NonNullable, 10).unwrap(); let indices = PrimitiveArray::from_iter([0i32, 20]); let _array = array .take(indices.into_array()) diff --git a/encodings/sequence/src/lib.rs b/encodings/sequence/src/lib.rs index 974a88eb0aa..b28e667a987 100644 --- a/encodings/sequence/src/lib.rs +++ b/encodings/sequence/src/lib.rs @@ -10,6 +10,7 @@ mod rules; /// Represents the equation A\[i\] = a * i + b. /// This can be used for compression, fast comparisons and also for row ids. pub use array::SequenceArray; +pub use array::SequenceArrayExt; pub use array::SequenceArrayParts; /// Represents the equation A\[i\] = a * i + b. /// This can be used for compression, fast comparisons and also for row ids. diff --git a/encodings/sparse/public-api.lock b/encodings/sparse/public-api.lock index 5116e90a2ac..64412b32da1 100644 --- a/encodings/sparse/public-api.lock +++ b/encodings/sparse/public-api.lock @@ -24,18 +24,12 @@ pub struct vortex_sparse::SparseArray impl vortex_sparse::SparseArray -pub fn vortex_sparse::SparseArray::encode(array: &vortex_array::array::ArrayRef, fill_value: core::option::Option) -> vortex_error::VortexResult - pub fn vortex_sparse::SparseArray::fill_scalar(&self) -> &vortex_array::scalar::Scalar pub fn vortex_sparse::SparseArray::patches(&self) -> &vortex_array::patches::Patches pub fn vortex_sparse::SparseArray::resolved_patches(&self) -> vortex_error::VortexResult -pub fn vortex_sparse::SparseArray::try_new(indices: vortex_array::array::ArrayRef, values: vortex_array::array::ArrayRef, len: usize, fill_value: vortex_array::scalar::Scalar) -> vortex_error::VortexResult - -pub fn vortex_sparse::SparseArray::try_new_from_patches(patches: vortex_array::patches::Patches, fill_value: vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_sparse::SparseArray pub fn vortex_sparse::SparseArray::to_array(&self) -> vortex_array::array::ArrayRef @@ -78,6 +72,12 @@ impl vortex_sparse::SparseVTable pub const vortex_sparse::SparseVTable::ID: vortex_array::vtable::dyn_::ArrayId +pub fn vortex_sparse::SparseVTable::encode(array: &vortex_array::array::ArrayRef, fill_value: core::option::Option) -> vortex_error::VortexResult + +pub fn vortex_sparse::SparseVTable::try_new(indices: vortex_array::array::ArrayRef, values: vortex_array::array::ArrayRef, len: usize, fill_value: vortex_array::scalar::Scalar) -> vortex_error::VortexResult + +pub fn vortex_sparse::SparseVTable::try_new_from_patches(patches: vortex_array::patches::Patches, fill_value: vortex_array::scalar::Scalar) -> vortex_error::VortexResult + impl core::fmt::Debug for vortex_sparse::SparseVTable pub fn vortex_sparse::SparseVTable::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result diff --git a/encodings/sparse/src/canonical.rs b/encodings/sparse/src/canonical.rs index 441ee671898..905f25cfd0b 100644 --- a/encodings/sparse/src/canonical.rs +++ b/encodings/sparse/src/canonical.rs @@ -54,6 +54,7 @@ use vortex_error::vortex_panic; use crate::ConstantArray; use crate::SparseArray; +use crate::SparseArrayExt; pub(super) fn execute_sparse( array: &SparseArray, @@ -569,7 +570,7 @@ mod test { use vortex_error::VortexResult; use vortex_mask::Mask; - use crate::SparseArray; + use crate::SparseVTable; #[rstest] #[case(Some(true))] @@ -579,7 +580,7 @@ mod test { let indices = buffer![0u64, 1, 7].into_array(); let values = BoolArray::from_iter([Some(true), None, Some(false)]).into_array(); let sparse_bools = - SparseArray::try_new(indices, values, 10, Scalar::from(fill_value)).unwrap(); + SparseVTable::try_new(indices, values, 10, Scalar::from(fill_value)).unwrap(); let actual = sparse_bools.to_bool(); let expected = BoolArray::from_iter([ @@ -606,7 +607,7 @@ mod test { let indices = buffer![0u64, 1, 7].into_array(); let values = PrimitiveArray::from_option_iter([Some(0i32), None, Some(1)]).into_array(); let sparse_ints = - SparseArray::try_new(indices, values, 10, Scalar::from(fill_value)).unwrap(); + SparseVTable::try_new(indices, values, 10, Scalar::from(fill_value)).unwrap(); assert_eq!(*sparse_ints.dtype(), DType::Primitive(PType::I32, Nullable)); let flat_ints = sparse_ints.to_primitive(); @@ -657,7 +658,7 @@ mod test { vec![Scalar::from(Some(-10i32)), Scalar::from(Some(-1i32))], ); let len = 10; - let sparse_struct = SparseArray::try_new(indices, patch_values, len, fill_scalar).unwrap(); + let sparse_struct = SparseVTable::try_new(indices, patch_values, len, fill_scalar).unwrap(); let expected_a = PrimitiveArray::from_option_iter((0..len).map(|i| { if i == 0 { @@ -724,7 +725,7 @@ mod test { let fill_scalar = Scalar::null(struct_dtype); let len = 10; - let sparse_struct = SparseArray::try_new(indices, patch_values, len, fill_scalar).unwrap(); + let sparse_struct = SparseVTable::try_new(indices, patch_values, len, fill_scalar).unwrap(); let expected_a = PrimitiveArray::from_option_iter((0..len).map(|i| { if i == 0 { @@ -775,7 +776,7 @@ mod test { .into_array(); let len = 10; let fill_scalar = Scalar::decimal(DecimalValue::I32(123), decimal_dtype, Nullable); - let sparse_struct = SparseArray::try_new(indices, patch_values, len, fill_scalar).unwrap(); + let sparse_struct = SparseVTable::try_new(indices, patch_values, len, fill_scalar).unwrap(); let expected = DecimalArray::new( buffer![100i128, 200, 123, 123, 123, 123, 123, 300, 4000, 123], @@ -810,7 +811,7 @@ mod test { ]) .into_array(); - let array = SparseArray::try_new( + let array = SparseVTable::try_new( buffer![0u16, 3, 4, 5, 7, 9, 10].into_array(), strings, 12, @@ -851,7 +852,7 @@ mod test { ]) .into_array(); - let array = SparseArray::try_new( + let array = SparseVTable::try_new( buffer![0u16, 3, 4, 5, 7, 9, 10].into_array(), strings, 12, @@ -885,7 +886,7 @@ mod test { VarBinViewArray::from_iter_str(["hello", "goodbye", "hello", "bonjour", "你好"]) .into_array(); - let array = SparseArray::try_new( + let array = SparseVTable::try_new( buffer![0u16, 3, 4, 5, 8].into_array(), strings, 9, @@ -915,7 +916,7 @@ mod test { ]) .into_array(); - let array = SparseArray::try_new( + let array = SparseVTable::try_new( buffer![0u16, 3, 4, 5, 7, 9, 10].into_array(), strings, 12, @@ -956,7 +957,7 @@ mod test { ]) .into_array(); - let array = SparseArray::try_new( + let array = SparseVTable::try_new( buffer![0u16, 3, 4, 5, 7, 9, 10].into_array(), binaries, 12, @@ -1003,7 +1004,7 @@ mod test { let indices = buffer![0u8, 3u8, 4u8, 5u8].into_array(); let fill_value = Scalar::null(lists.dtype().clone()); - let sparse = SparseArray::try_new(indices, lists, 6, fill_value) + let sparse = SparseVTable::try_new(indices, lists, 6, fill_value) .unwrap() .into_array(); @@ -1057,7 +1058,7 @@ mod test { let indices = buffer![0u8, 3u8, 4u8, 5u8].into_array(); let fill_value = Scalar::null(lists.dtype().clone()); - let sparse = SparseArray::try_new(indices, lists, 6, fill_value) + let sparse = SparseVTable::try_new(indices, lists, 6, fill_value) .unwrap() .into_array(); @@ -1100,7 +1101,7 @@ mod test { let indices = buffer![0u8, 3u8, 4u8, 5u8].into_array(); let fill_value = Scalar::from(Some(vec![5i32, 6, 7, 8])); - let sparse = SparseArray::try_new(indices, lists, 6, fill_value) + let sparse = SparseVTable::try_new(indices, lists, 6, fill_value) .unwrap() .into_array(); @@ -1169,7 +1170,7 @@ mod test { ]) .into_array(); - let array = SparseArray::try_new( + let array = SparseVTable::try_new( buffer![0u16, 3, 4, 5, 7, 9, 10].into_array(), strings, 12, @@ -1211,7 +1212,7 @@ mod test { 3, Nullable, )); - let sparse = SparseArray::try_new(indices, fsl, 5, fill_value) + let sparse = SparseVTable::try_new(indices, fsl, 5, fill_value) .unwrap() .into_array(); @@ -1249,7 +1250,7 @@ mod test { ], NonNullable, ); - let sparse = SparseArray::try_new(indices, fsl, 6, fill_value) + let sparse = SparseVTable::try_new(indices, fsl, 6, fill_value) .unwrap() .into_array(); @@ -1287,7 +1288,7 @@ mod test { ], Nullable, ); - let sparse = SparseArray::try_new(indices, fsl, 6, fill_value) + let sparse = SparseVTable::try_new(indices, fsl, 6, fill_value) .unwrap() .into_array(); @@ -1335,7 +1336,7 @@ mod test { NonNullable, ); - let sparse = SparseArray::try_new(indices, fsl, 100, fill_value) + let sparse = SparseVTable::try_new(indices, fsl, 100, fill_value) .unwrap() .into_array(); @@ -1392,7 +1393,7 @@ mod test { ], NonNullable, ); - let sparse = SparseArray::try_new(indices, fsl, 1, fill_value) + let sparse = SparseVTable::try_new(indices, fsl, 1, fill_value) .unwrap() .into_array(); @@ -1418,7 +1419,7 @@ mod test { let indices = buffer![0u8, 1u8, 2u8, 3u8].into_array(); let fill_value = Scalar::from(Some(vec![42i32; 252])); // 252 + 4 elements = 256 > u8::MAX - let sparse = SparseArray::try_new(indices, lists, 5, fill_value) + let sparse = SparseVTable::try_new(indices, lists, 5, fill_value) .unwrap() .into_array(); @@ -1480,7 +1481,7 @@ mod test { // - Index 7: List 2 [30, 31, 32, 33] // - Index 8-9: null let indices = buffer![1u8, 4, 7].into_array(); - let sparse = SparseArray::try_new( + let sparse = SparseVTable::try_new( indices, list_view.into_array(), 10, @@ -1566,7 +1567,8 @@ mod test { // Extract only the values we need from the sliced array let values = sliced.slice(0..2).unwrap(); let sparse = - SparseArray::try_new(indices, values, 5, Scalar::null(sliced.dtype().clone())).unwrap(); + SparseVTable::try_new(indices, values, 5, Scalar::null(sliced.dtype().clone())) + .unwrap(); let canonical = sparse.to_canonical()?.into_array(); let result_listview = canonical.to_listview(); diff --git a/encodings/sparse/src/compute/cast.rs b/encodings/sparse/src/compute/cast.rs index 8d7c660a8cc..25dde2802bd 100644 --- a/encodings/sparse/src/compute/cast.rs +++ b/encodings/sparse/src/compute/cast.rs @@ -9,6 +9,7 @@ use vortex_array::scalar_fn::fns::cast::CastReduce; use vortex_error::VortexResult; use crate::SparseArray; +use crate::SparseArrayExt; use crate::SparseVTable; impl CastReduce for SparseVTable { @@ -21,7 +22,7 @@ impl CastReduce for SparseVTable { .map_values(|values| values.cast(dtype.clone()))?; Ok(Some( - SparseArray::try_new_from_patches(casted_patches, casted_fill)?.into_array(), + SparseVTable::try_new_from_patches(casted_patches, casted_fill)?.into_array(), )) } } @@ -42,10 +43,11 @@ mod tests { use vortex_buffer::buffer; use crate::SparseArray; + use crate::SparseVTable; #[test] fn test_cast_sparse_i32_to_i64() { - let sparse = SparseArray::try_new( + let sparse = SparseVTable::try_new( buffer![2u64, 5, 8].into_array(), buffer![100i32, 200, 300].into_array(), 10, @@ -68,7 +70,7 @@ mod tests { #[test] fn test_cast_sparse_with_null_fill() { - let sparse = SparseArray::try_new( + let sparse = SparseVTable::try_new( buffer![1u64, 3, 5].into_array(), PrimitiveArray::from_option_iter([Some(42i32), Some(84), Some(126)]).into_array(), 8, @@ -87,25 +89,25 @@ mod tests { } #[rstest] - #[case(SparseArray::try_new( + #[case(SparseVTable::try_new( buffer![2u64, 5, 8].into_array(), buffer![100i32, 200, 300].into_array(), 10, Scalar::from(0i32) ).unwrap())] - #[case(SparseArray::try_new( + #[case(SparseVTable::try_new( buffer![0u64, 4, 9].into_array(), buffer![1.5f32, 2.5, 3.5].into_array(), 10, Scalar::from(0.0f32) ).unwrap())] - #[case(SparseArray::try_new( + #[case(SparseVTable::try_new( buffer![1u64, 3, 7].into_array(), PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(), 10, Scalar::null_native::() ).unwrap())] - #[case(SparseArray::try_new( + #[case(SparseVTable::try_new( buffer![5u64].into_array(), buffer![42u8].into_array(), 10, diff --git a/encodings/sparse/src/compute/filter.rs b/encodings/sparse/src/compute/filter.rs index 2f313925901..7b71bba92fc 100644 --- a/encodings/sparse/src/compute/filter.rs +++ b/encodings/sparse/src/compute/filter.rs @@ -10,6 +10,7 @@ use vortex_mask::Mask; use crate::ConstantArray; use crate::SparseArray; +use crate::SparseArrayExt; use crate::SparseVTable; impl FilterKernel for SparseVTable { @@ -27,7 +28,7 @@ impl FilterKernel for SparseVTable { }; Ok(Some( - SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())? + SparseVTable::try_new_from_patches(new_patches, array.fill_scalar().clone())? .into_array(), )) } @@ -52,11 +53,11 @@ mod tests { use vortex_buffer::buffer; use vortex_mask::Mask; - use crate::SparseArray; + use crate::SparseVTable; #[fixture] fn array() -> ArrayRef { - SparseArray::try_new( + SparseVTable::try_new( buffer![2u64, 9, 15].into_array(), PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(), 20, @@ -76,7 +77,7 @@ mod tests { // Construct expected SparseArray: index 2 was kept, which had value 33. // The new index is 0 (since it's the only element). - let expected = SparseArray::try_new( + let expected = SparseVTable::try_new( buffer![0u64].into_array(), PrimitiveArray::new(buffer![33_i32], Validity::AllValid).into_array(), 1, @@ -90,7 +91,7 @@ mod tests { #[test] fn true_fill_value() { let mask = Mask::from_iter([false, true, false, true, false, true, true]); - let array = SparseArray::try_new( + let array = SparseVTable::try_new( buffer![0_u64, 3, 6].into_array(), PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(), 7, @@ -105,7 +106,7 @@ mod tests { // Mask keeps indices 1, 3, 5, 6 -> new indices 0, 1, 2, 3. // Index 3 (value 44) maps to new index 1. // Index 6 (value 55) maps to new index 3. - let expected = SparseArray::try_new( + let expected = SparseVTable::try_new( buffer![1u64, 3].into_array(), PrimitiveArray::new(buffer![44_i32, 55], Validity::AllValid).into_array(), 4, @@ -120,7 +121,7 @@ mod tests { fn test_filter_sparse_array() { let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)); test_filter_conformance( - &SparseArray::try_new( + &SparseVTable::try_new( buffer![1u64, 2, 4].into_array(), buffer![100i32, 200, 300] .into_array() @@ -135,7 +136,7 @@ mod tests { let ten_fill_value = Scalar::from(10i32); test_filter_conformance( - &SparseArray::try_new( + &SparseVTable::try_new( buffer![1u64, 2, 4].into_array(), buffer![100i32, 200, 300].into_array(), 5, diff --git a/encodings/sparse/src/compute/mod.rs b/encodings/sparse/src/compute/mod.rs index 0f0ad4b98c4..e88eec94fd5 100644 --- a/encodings/sparse/src/compute/mod.rs +++ b/encodings/sparse/src/compute/mod.rs @@ -24,11 +24,11 @@ mod test { use vortex_buffer::buffer; use vortex_mask::Mask; - use crate::SparseArray; + use crate::SparseVTable; #[fixture] fn array() -> ArrayRef { - SparseArray::try_new( + SparseVTable::try_new( buffer![2u64, 9, 15].into_array(), PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(), 20, @@ -48,7 +48,7 @@ mod test { // Construct expected SparseArray: index 2 was kept, which had value 33. // The new index is 0 (since it's the only element). - let expected = SparseArray::try_new( + let expected = SparseVTable::try_new( buffer![0u64].into_array(), PrimitiveArray::new(buffer![33_i32], Validity::AllValid).into_array(), 1, @@ -62,7 +62,7 @@ mod test { #[test] fn true_fill_value() { let mask = Mask::from_iter([false, true, false, true, false, true, true]); - let array = SparseArray::try_new( + let array = SparseVTable::try_new( buffer![0_u64, 3, 6].into_array(), PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(), 7, @@ -77,7 +77,7 @@ mod test { // Mask keeps indices 1, 3, 5, 6 -> new indices 0, 1, 2, 3. // Index 3 (value 44) maps to new index 1. // Index 6 (value 55) maps to new index 3. - let expected = SparseArray::try_new( + let expected = SparseVTable::try_new( buffer![1u64, 3].into_array(), PrimitiveArray::new(buffer![44_i32, 55], Validity::AllValid).into_array(), 4, @@ -97,7 +97,7 @@ mod test { fn test_mask_sparse_array() { let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)); test_mask_conformance( - &SparseArray::try_new( + &SparseVTable::try_new( buffer![1u64, 2, 4].into_array(), buffer![100i32, 200, 300] .into_array() @@ -112,7 +112,7 @@ mod test { let ten_fill_value = Scalar::from(10i32); test_mask_conformance( - &SparseArray::try_new( + &SparseVTable::try_new( buffer![1u64, 2, 4].into_array(), buffer![100i32, 200, 300].into_array(), 5, @@ -139,49 +139,50 @@ mod tests { use vortex_buffer::buffer; use crate::SparseArray; + use crate::SparseVTable; #[rstest] // Basic sparse arrays - #[case::sparse_i32_null_fill(SparseArray::try_new( + #[case::sparse_i32_null_fill(SparseVTable::try_new( buffer![2u64, 5, 8].into_array(), PrimitiveArray::from_option_iter([Some(100i32), Some(200), Some(300)]).into_array(), 10, Scalar::null_native::() ).unwrap())] - #[case::sparse_i32_value_fill(SparseArray::try_new( + #[case::sparse_i32_value_fill(SparseVTable::try_new( buffer![1u64, 3, 7].into_array(), buffer![42i32, 84, 126].into_array(), 10, Scalar::from(0i32) ).unwrap())] // Different types - #[case::sparse_u64(SparseArray::try_new( + #[case::sparse_u64(SparseVTable::try_new( buffer![0u64, 4, 9].into_array(), buffer![1000u64, 2000, 3000].into_array(), 10, Scalar::from(999u64) ).unwrap())] - #[case::sparse_f32(SparseArray::try_new( + #[case::sparse_f32(SparseVTable::try_new( buffer![2u64, 6].into_array(), buffer![std::f32::consts::PI, std::f32::consts::E].into_array(), 8, Scalar::from(0.0f32) ).unwrap())] // Edge cases - #[case::sparse_single_patch(SparseArray::try_new( + #[case::sparse_single_patch(SparseVTable::try_new( buffer![5u64].into_array(), buffer![42i32].into_array(), 10, Scalar::from(-1i32) ).unwrap())] - #[case::sparse_dense_patches(SparseArray::try_new( + #[case::sparse_dense_patches(SparseVTable::try_new( buffer![0u64, 1, 2, 3, 4].into_array(), PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40), Some(50)]).into_array(), 5, Scalar::null_native::() ).unwrap())] // Large sparse arrays - #[case::sparse_large(SparseArray::try_new( + #[case::sparse_large(SparseVTable::try_new( buffer![100u64, 500, 900, 1500, 1999].into_array(), buffer![111i32, 222, 333, 444, 555].into_array(), 2000, @@ -190,7 +191,7 @@ mod tests { // Nullable patches #[case::sparse_nullable_patches({ let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)); - SparseArray::try_new( + SparseVTable::try_new( buffer![1u64, 4, 7].into_array(), PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]) .into_array() @@ -206,37 +207,37 @@ mod tests { } #[rstest] - #[case::sparse_i32_basic(SparseArray::try_new( + #[case::sparse_i32_basic(SparseVTable::try_new( buffer![2u64, 5, 8].into_array(), buffer![100i32, 200, 300].into_array(), 10, Scalar::from(0i32) ).unwrap())] - #[case::sparse_u32_basic(SparseArray::try_new( + #[case::sparse_u32_basic(SparseVTable::try_new( buffer![1u64, 3, 7].into_array(), buffer![1000u32, 2000, 3000].into_array(), 10, Scalar::from(100u32) ).unwrap())] - #[case::sparse_i64_basic(SparseArray::try_new( + #[case::sparse_i64_basic(SparseVTable::try_new( buffer![0u64, 4, 9].into_array(), buffer![5000i64, 6000, 7000].into_array(), 10, Scalar::from(1000i64) ).unwrap())] - #[case::sparse_f32_basic(SparseArray::try_new( + #[case::sparse_f32_basic(SparseVTable::try_new( buffer![2u64, 6].into_array(), buffer![1.5f32, 2.5].into_array(), 8, Scalar::from(0.5f32) ).unwrap())] - #[case::sparse_f64_basic(SparseArray::try_new( + #[case::sparse_f64_basic(SparseVTable::try_new( buffer![1u64, 5, 9].into_array(), buffer![10.1f64, 20.2, 30.3].into_array(), 10, Scalar::from(5.0f64) ).unwrap())] - #[case::sparse_i32_large(SparseArray::try_new( + #[case::sparse_i32_large(SparseVTable::try_new( buffer![10u64, 50, 90, 150, 199].into_array(), buffer![111i32, 222, 333, 444, 555].into_array(), 200, diff --git a/encodings/sparse/src/compute/take.rs b/encodings/sparse/src/compute/take.rs index 88813dc89fb..4fff1668315 100644 --- a/encodings/sparse/src/compute/take.rs +++ b/encodings/sparse/src/compute/take.rs @@ -10,6 +10,7 @@ use vortex_error::VortexResult; use crate::ConstantArray; use crate::SparseArray; +use crate::SparseArrayExt; use crate::SparseVTable; impl TakeExecute for SparseVTable { @@ -41,7 +42,7 @@ impl TakeExecute for SparseVTable { } Ok(Some( - SparseArray::try_new_from_patches( + SparseVTable::try_new_from_patches( new_patches, array.fill_scalar().cast( &array @@ -69,6 +70,7 @@ mod test { use vortex_buffer::buffer; use crate::SparseArray; + use crate::SparseVTable; fn test_array_fill_value() -> Scalar { // making this const is annoying @@ -76,7 +78,7 @@ mod test { } fn sparse_array() -> ArrayRef { - SparseArray::try_new( + SparseVTable::try_new( buffer![0u64, 37, 47, 99].into_array(), PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(), 100, @@ -129,7 +131,7 @@ mod test { #[test] fn nullable_take() { - let arr = SparseArray::try_new( + let arr = SparseVTable::try_new( buffer![1u32].into_array(), buffer![10].into_array(), 10, @@ -150,7 +152,7 @@ mod test { #[test] fn nullable_take_with_many_patches() { - let arr = SparseArray::try_new( + let arr = SparseVTable::try_new( buffer![1u32, 3, 7, 8, 9].into_array(), buffer![10, 8, 3, 2, 1].into_array(), 10, @@ -170,13 +172,13 @@ mod test { } #[rstest] - #[case(SparseArray::try_new( + #[case(SparseVTable::try_new( buffer![0u64, 37, 47, 99].into_array(), PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(), 100, Scalar::null_native::(), ).unwrap())] - #[case(SparseArray::try_new( + #[case(SparseVTable::try_new( buffer![1u32, 3, 7, 8, 9].into_array(), buffer![10, 8, 3, 2, 1].into_array(), 10, @@ -184,14 +186,14 @@ mod test { ).unwrap())] #[case({ let nullable_values = PrimitiveArray::from_option_iter([Some(100i64), None, Some(300)]); - SparseArray::try_new( + SparseVTable::try_new( buffer![2u64, 4, 6].into_array(), nullable_values.into_array(), 10, Scalar::null_native::(), ).unwrap() })] - #[case(SparseArray::try_new( + #[case(SparseVTable::try_new( buffer![5u64].into_array(), buffer![999i32].into_array(), 20, diff --git a/encodings/sparse/src/lib.rs b/encodings/sparse/src/lib.rs index f78ee9dad49..c8454ccbce7 100644 --- a/encodings/sparse/src/lib.rs +++ b/encodings/sparse/src/lib.rs @@ -207,7 +207,7 @@ impl VTable for SparseVTable { )?; let patch_values = children.get(1, dtype, metadata.patches.len()?)?; - SparseArray::try_new( + Self::try_new( patch_indices, patch_values, len, @@ -272,15 +272,14 @@ pub struct SparseVTable; impl SparseVTable { pub const ID: ArrayId = ArrayId::new_ref("vortex.sparse"); -} -impl SparseArray { + /// Build a new `SparseArray` from indices, values, length, and a fill value. pub fn try_new( indices: ArrayRef, values: ArrayRef, len: usize, fill_value: Scalar, - ) -> VortexResult { + ) -> VortexResult { vortex_ensure!( indices.len() == values.len(), "Mismatched indices {} and values {} length", @@ -307,7 +306,7 @@ impl SparseArray { } let dtype = fill_value.dtype().clone(); - Ok(Self { + Ok(SparseArray { common: ArrayCommon::new(len, dtype), // TODO(0ax1): handle chunk offsets patches: Patches::new(len, 0, indices, values, None)?, @@ -315,8 +314,8 @@ impl SparseArray { }) } - /// Build a new SparseArray from an existing set of patches. - pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult { + /// Build a new `SparseArray` from an existing set of patches. + pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult { vortex_ensure!( fill_value.dtype() == patches.values().dtype(), "fill value, {:?}, should be instance of values dtype, {} but was {}.", @@ -327,55 +326,14 @@ impl SparseArray { let len = patches.array_len(); let dtype = fill_value.dtype().clone(); - Ok(Self { + Ok(SparseArray { common: ArrayCommon::new(len, dtype), patches, fill_value, }) } - pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self { - let len = patches.array_len(); - let dtype = fill_value.dtype().clone(); - Self { - common: ArrayCommon::new(len, dtype), - patches, - fill_value, - } - } - - #[inline] - pub fn patches(&self) -> &Patches { - &self.patches - } - - #[inline] - pub fn resolved_patches(&self) -> VortexResult { - let patches = self.patches(); - let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?; - let indices = patches.indices().to_array().binary( - ConstantArray::new(indices_offset, patches.indices().len()).into_array(), - Operator::Sub, - )?; - - Patches::new( - patches.array_len(), - 0, - indices, - patches.values().clone(), - // TODO(0ax1): handle chunk offsets - None, - ) - } - - #[inline] - pub fn fill_scalar(&self) -> &Scalar { - &self.fill_value - } - /// Encode given array as a SparseArray. - /// - /// Optionally provided fill value will be respected if the array is less than 90% null. pub fn encode(array: &ArrayRef, fill_value: Option) -> VortexResult { if let Some(fill_value) = fill_value.as_ref() && array.dtype() != fill_value.dtype() @@ -416,7 +374,7 @@ impl SparseArray { } }; - return Ok(SparseArray::try_new( + return Ok(Self::try_new( non_null_indices, non_null_values, array.len(), @@ -464,11 +422,66 @@ impl SparseArray { Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(), }; - SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill) + Self::try_new(indices.into_array(), non_top_values, array.len(), fill) .map(|a| a.into_array()) } } +/// Extension trait for [`SparseArray`] methods. +pub trait SparseArrayExt { + /// Returns a reference to the patches. + fn patches(&self) -> &Patches; + + /// Returns patches with resolved (offset-subtracted) indices. + fn resolved_patches(&self) -> VortexResult; + + /// Returns a reference to the fill scalar. + fn fill_scalar(&self) -> &Scalar; +} + +impl SparseArrayExt for SparseArray { + #[inline] + fn patches(&self) -> &Patches { + &self.patches + } + + #[inline] + fn resolved_patches(&self) -> VortexResult { + let patches = self.patches(); + let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?; + let indices = patches.indices().to_array().binary( + ConstantArray::new(indices_offset, patches.indices().len()).into_array(), + Operator::Sub, + )?; + + Patches::new( + patches.array_len(), + 0, + indices, + patches.values().clone(), + // TODO(0ax1): handle chunk offsets + None, + ) + } + + #[inline] + fn fill_scalar(&self) -> &Scalar { + &self.fill_value + } +} + +impl SparseArray { + pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self { + let len = patches.array_len(); + let dtype = fill_value.dtype().clone(); + Self { + common: ArrayCommon::new(len, dtype), + patches, + fill_value, + } + } +} + impl ValidityVTable for SparseVTable { fn validity(array: &SparseArray) -> VortexResult { let patches = unsafe { @@ -525,7 +538,7 @@ mod test { let mut values = buffer![100i32, 200, 300].into_array(); values = values.cast(fill_value.dtype().clone()).unwrap(); - SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value) + SparseVTable::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value) .unwrap() .into_array() } @@ -548,7 +561,7 @@ mod test { #[test] pub fn test_scalar_at_again() { - let arr = SparseArray::try_new( + let arr = SparseVTable::try_new( ConstantArray::new(10u32, 1).into_array(), ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(), 100, @@ -584,7 +597,7 @@ mod test { #[test] pub fn validity_mask_sliced_nonnull_fill() { - let sliced = SparseArray::try_new( + let sliced = SparseVTable::try_new( buffer![2u64, 5, 8].into_array(), ConstantArray::new( Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)), @@ -647,7 +660,7 @@ mod test { let values = buffer![15_u32, 135, 13531, 42].into_array(); let indices = buffer![10_u64, 11, 50, 100].into_array(); - SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap(); + SparseVTable::try_new(indices, values, 100, 0_u32.into()).unwrap(); } #[test] @@ -655,7 +668,7 @@ mod test { let values = buffer![15_u32, 135, 13531, 42].into_array(); let indices = buffer![10_u64, 11, 50, 100].into_array(); - SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap(); + SparseVTable::try_new(indices, values, 101, 0_u32.into()).unwrap(); } #[test] @@ -666,7 +679,7 @@ mod test { true, true, false, true, false, true, false, true, true, false, true, false, ]), ); - let sparse = SparseArray::encode(&original.clone().into_array(), None) + let sparse = SparseVTable::encode(&original.clone().into_array(), None) .vortex_expect("SparseArray::encode should succeed for test data"); assert_eq!( sparse.validity_mask().unwrap(), @@ -683,7 +696,7 @@ mod test { let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)]) .into_array(); let array = - SparseArray::try_new(indices, values, 10, Scalar::null_native::()).unwrap(); + SparseVTable::try_new(indices, values, 10, Scalar::null_native::()).unwrap(); let actual = array.validity_mask().unwrap(); let expected = Mask::from_iter([ true, false, true, false, false, false, false, false, true, false, diff --git a/encodings/sparse/src/ops.rs b/encodings/sparse/src/ops.rs index 7fbe2b4af4c..93dc47de0bf 100644 --- a/encodings/sparse/src/ops.rs +++ b/encodings/sparse/src/ops.rs @@ -6,6 +6,7 @@ use vortex_array::vtable::OperationsVTable; use vortex_error::VortexResult; use crate::SparseArray; +use crate::SparseArrayExt; use crate::SparseVTable; impl OperationsVTable for SparseVTable { @@ -32,7 +33,7 @@ mod tests { let values = buffer![0u64].into_array(); let indices = buffer![0u8].into_array(); - let sparse = SparseArray::try_new(indices, values, 1000, 999u64.into()).unwrap(); + let sparse = SparseVTable::try_new(indices, values, 1000, 999u64.into()).unwrap(); let sliced = sparse.slice(0..1000).unwrap(); let mut expected = vec![999u64; 1000]; expected[0] = 0; diff --git a/encodings/sparse/src/rules.rs b/encodings/sparse/src/rules.rs index 221af9060ae..0e1f14f6a7b 100644 --- a/encodings/sparse/src/rules.rs +++ b/encodings/sparse/src/rules.rs @@ -11,6 +11,7 @@ use vortex_array::scalar_fn::fns::not::NotReduceAdaptor; use vortex_error::VortexResult; use crate::SparseArray; +use crate::SparseArrayExt; use crate::SparseVTable; pub(crate) static RULES: ParentRuleSet = ParentRuleSet::new(&[ @@ -23,7 +24,7 @@ impl NotReduce for SparseVTable { let inverted_fill = array.fill_scalar().as_bool().invert().into_scalar(); let inverted_patches = array.patches().clone().map_values(|values| values.not())?; Ok(Some( - SparseArray::try_new_from_patches(inverted_patches, inverted_fill)?.into_array(), + SparseVTable::try_new_from_patches(inverted_patches, inverted_fill)?.into_array(), )) } } diff --git a/encodings/sparse/src/slice.rs b/encodings/sparse/src/slice.rs index adf448a2891..14c1a9a8c14 100644 --- a/encodings/sparse/src/slice.rs +++ b/encodings/sparse/src/slice.rs @@ -11,6 +11,7 @@ use vortex_error::VortexResult; use crate::ConstantArray; use crate::SparseArray; +use crate::SparseArrayExt; use crate::SparseVTable; impl SliceKernel for SparseVTable { diff --git a/encodings/zigzag/src/array.rs b/encodings/zigzag/src/array.rs index 5386cc5c454..89c5af09f14 100644 --- a/encodings/zigzag/src/array.rs +++ b/encodings/zigzag/src/array.rs @@ -136,7 +136,7 @@ impl VTable for ZigZagVTable { let encoded_type = DType::Primitive(ptype.to_unsigned(), dtype.nullability()); let encoded = children.get(0, &encoded_type, len)?; - ZigZagArray::try_new(encoded) + Self::try_new(encoded) } fn with_children(array: &mut Self::Array, children: Vec) -> VortexResult<()> { @@ -180,16 +180,17 @@ pub struct ZigZagArray { #[derive(Debug)] pub struct ZigZagVTable; +#[allow(clippy::new_ret_no_self)] impl ZigZagVTable { pub const ID: ArrayId = ArrayId::new_ref("vortex.zigzag"); -} -impl ZigZagArray { - pub fn new(encoded: ArrayRef) -> Self { + /// Build a new `ZigZagArray`, panicking on validation failure. + pub fn new(encoded: ArrayRef) -> ZigZagArray { Self::try_new(encoded).vortex_expect("ZigZagArray new") } - pub fn try_new(encoded: ArrayRef) -> VortexResult { + /// Build a new `ZigZagArray` from an unsigned integer encoded array. + pub fn try_new(encoded: ArrayRef) -> VortexResult { let encoded_dtype = encoded.dtype().clone(); if !encoded_dtype.is_unsigned_int() { vortex_bail!(MismatchedTypes: "unsigned int", encoded_dtype); @@ -199,17 +200,28 @@ impl ZigZagArray { .with_nullability(encoded_dtype.nullability()); let len = encoded.len(); - Ok(Self { + Ok(ZigZagArray { common: ArrayCommon::new(len, dtype), encoded, }) } +} + +/// Extension trait for [`ZigZagArray`] methods. +pub trait ZigZagArrayExt { + /// Returns the primitive type of the array. + fn ptype(&self) -> PType; + + /// Returns a reference to the encoded child array. + fn encoded(&self) -> &ArrayRef; +} - pub fn ptype(&self) -> PType { +impl ZigZagArrayExt for ZigZagArray { + fn ptype(&self) -> PType { self.dtype().as_ptype() } - pub fn encoded(&self) -> &ArrayRef { + fn encoded(&self) -> &ArrayRef { &self.encoded } } diff --git a/encodings/zigzag/src/compress.rs b/encodings/zigzag/src/compress.rs index f53a04ff40c..2b5bf61a6fb 100644 --- a/encodings/zigzag/src/compress.rs +++ b/encodings/zigzag/src/compress.rs @@ -14,6 +14,7 @@ use vortex_error::vortex_panic; use zigzag::ZigZag as ExternalZigZag; use crate::ZigZagArray; +use crate::ZigZagVTable; pub fn zigzag_encode(parray: PrimitiveArray) -> VortexResult { let validity = parray.validity().clone(); @@ -27,7 +28,7 @@ pub fn zigzag_encode(parray: PrimitiveArray) -> VortexResult { parray.ptype() ), }; - ZigZagArray::try_new(encoded.into_array()) + ZigZagVTable::try_new(encoded.into_array()) } fn zigzag_encode_primitive( diff --git a/encodings/zigzag/src/compute/cast.rs b/encodings/zigzag/src/compute/cast.rs index c92aa6b60bf..e316e097332 100644 --- a/encodings/zigzag/src/compute/cast.rs +++ b/encodings/zigzag/src/compute/cast.rs @@ -9,6 +9,7 @@ use vortex_array::scalar_fn::fns::cast::CastReduce; use vortex_error::VortexResult; use crate::ZigZagArray; +use crate::ZigZagArrayExt; use crate::ZigZagVTable; impl CastReduce for ZigZagVTable { @@ -20,7 +21,7 @@ impl CastReduce for ZigZagVTable { let new_encoded_dtype = DType::Primitive(dtype.as_ptype().to_unsigned(), dtype.nullability()); let new_encoded = array.encoded().cast(new_encoded_dtype)?; - Ok(Some(ZigZagArray::try_new(new_encoded)?.into_array())) + Ok(Some(ZigZagVTable::try_new(new_encoded)?.into_array())) } } diff --git a/encodings/zigzag/src/compute/mod.rs b/encodings/zigzag/src/compute/mod.rs index e88451924c0..b985ce127d1 100644 --- a/encodings/zigzag/src/compute/mod.rs +++ b/encodings/zigzag/src/compute/mod.rs @@ -17,12 +17,13 @@ use vortex_error::VortexResult; use vortex_mask::Mask; use crate::ZigZagArray; +use crate::ZigZagArrayExt; use crate::ZigZagVTable; impl FilterReduce for ZigZagVTable { fn filter(array: &ZigZagArray, mask: &Mask) -> VortexResult> { let encoded = array.encoded().filter(mask.clone())?; - Ok(Some(ZigZagArray::try_new(encoded)?.into_array())) + Ok(Some(ZigZagVTable::try_new(encoded)?.into_array())) } } @@ -33,7 +34,7 @@ impl TakeExecute for ZigZagVTable { _ctx: &mut ExecutionCtx, ) -> VortexResult> { let encoded = array.encoded().take(indices.to_array())?; - Ok(Some(ZigZagArray::try_new(encoded)?.into_array())) + Ok(Some(ZigZagVTable::try_new(encoded)?.into_array())) } } @@ -44,7 +45,7 @@ impl MaskReduce for ZigZagVTable { EmptyOptions, [array.encoded().clone(), mask.clone()], )?; - Ok(Some(ZigZagArray::try_new(masked_encoded)?.into_array())) + Ok(Some(ZigZagVTable::try_new(masked_encoded)?.into_array())) } } diff --git a/encodings/zigzag/src/slice.rs b/encodings/zigzag/src/slice.rs index 74d31cb63b3..00117868af3 100644 --- a/encodings/zigzag/src/slice.rs +++ b/encodings/zigzag/src/slice.rs @@ -8,13 +8,13 @@ use vortex_array::IntoArray; use vortex_array::arrays::slice::SliceReduce; use vortex_error::VortexResult; -use crate::ZigZagArray; +use crate::ZigZagArrayExt; use crate::ZigZagVTable; impl SliceReduce for ZigZagVTable { fn slice(array: &Self::Array, range: Range) -> VortexResult> { Ok(Some( - ZigZagArray::new(array.encoded().slice(range)?).into_array(), + ZigZagVTable::new(array.encoded().slice(range)?).into_array(), )) } } diff --git a/encodings/zstd/benches/listview_rebuild.rs b/encodings/zstd/benches/listview_rebuild.rs index 68147912f17..93ebda0b2cf 100644 --- a/encodings/zstd/benches/listview_rebuild.rs +++ b/encodings/zstd/benches/listview_rebuild.rs @@ -10,13 +10,13 @@ use vortex_array::arrays::VarBinViewArray; use vortex_array::arrays::listview::ListViewRebuildMode; use vortex_array::validity::Validity; use vortex_buffer::Buffer; -use vortex_zstd::ZstdArray; +use vortex_zstd::ZstdVTable; #[divan::bench(sample_size = 1000)] fn rebuild_naive(bencher: Bencher) { let dudes = VarBinViewArray::from_iter_str(["Washington", "Adams", "Jefferson", "Madison"]) .into_array(); - let dudes = ZstdArray::from_array(dudes, 9, 1024).unwrap().into_array(); + let dudes = ZstdVTable::from_array(dudes, 9, 1024).unwrap().into_array(); let offsets = std::iter::repeat_n(0u32, 1024) .collect::>() diff --git a/encodings/zstd/public-api.lock b/encodings/zstd/public-api.lock index eeba254f97a..2a6d3bdc66c 100644 --- a/encodings/zstd/public-api.lock +++ b/encodings/zstd/public-api.lock @@ -2,25 +2,13 @@ pub mod vortex_zstd pub struct vortex_zstd::ZstdArray -impl vortex_zstd::ZstdArray - -pub fn vortex_zstd::ZstdArray::decompress(&self) -> vortex_error::VortexResult - -pub fn vortex_zstd::ZstdArray::from_array(array: vortex_array::array::ArrayRef, level: i32, values_per_frame: usize) -> vortex_error::VortexResult - -pub fn vortex_zstd::ZstdArray::from_canonical(canonical: &vortex_array::canonical::Canonical, level: i32, values_per_frame: usize) -> vortex_error::VortexResult> - -pub fn vortex_zstd::ZstdArray::from_primitive(parray: &vortex_array::arrays::primitive::array::PrimitiveArray, level: i32, values_per_frame: usize) -> vortex_error::VortexResult - -pub fn vortex_zstd::ZstdArray::from_primitive_without_dict(parray: &vortex_array::arrays::primitive::array::PrimitiveArray, level: i32, values_per_frame: usize) -> vortex_error::VortexResult - -pub fn vortex_zstd::ZstdArray::from_var_bin_view(vbv: &vortex_array::arrays::varbinview::array::VarBinViewArray, level: i32, values_per_frame: usize) -> vortex_error::VortexResult +pub trait vortex_zstd::ZstdArrayExt: core::marker::Sized -pub fn vortex_zstd::ZstdArray::from_var_bin_view_without_dict(vbv: &vortex_array::arrays::varbinview::array::VarBinViewArray, level: i32, values_per_frame: usize) -> vortex_error::VortexResult +pub fn vortex_zstd::ZstdArrayExt::decompress(&self) -> vortex_error::VortexResult -pub fn vortex_zstd::ZstdArray::into_parts(self) -> vortex_zstd::ZstdArrayParts +pub fn vortex_zstd::ZstdArrayExt::into_parts(self) -> vortex_zstd::ZstdArrayParts -pub fn vortex_zstd::ZstdArray::new(dictionary: core::option::Option, frames: alloc::vec::Vec, dtype: vortex_array::dtype::DType, metadata: vortex_zstd::ZstdMetadata, n_rows: usize, validity: vortex_array::validity::Validity) -> Self +impl vortex_zstd::ZstdArrayExt for vortex_zstd::ZstdArray impl vortex_zstd::ZstdArray @@ -160,6 +148,20 @@ impl vortex_zstd::ZstdVTable pub const vortex_zstd::ZstdVTable::ID: vortex_array::vtable::dyn_::ArrayId +pub fn vortex_zstd::ZstdVTable::from_array(array: vortex_array::array::ArrayRef, level: i32, values_per_frame: usize) -> vortex_error::VortexResult + +pub fn vortex_zstd::ZstdVTable::from_canonical(canonical: &vortex_array::canonical::Canonical, level: i32, values_per_frame: usize) -> vortex_error::VortexResult> + +pub fn vortex_zstd::ZstdVTable::from_primitive(parray: &vortex_array::arrays::primitive::array::PrimitiveArray, level: i32, values_per_frame: usize) -> vortex_error::VortexResult + +pub fn vortex_zstd::ZstdVTable::from_primitive_without_dict(parray: &vortex_array::arrays::primitive::array::PrimitiveArray, level: i32, values_per_frame: usize) -> vortex_error::VortexResult + +pub fn vortex_zstd::ZstdVTable::from_var_bin_view(vbv: &vortex_array::arrays::varbinview::array::VarBinViewArray, level: i32, values_per_frame: usize) -> vortex_error::VortexResult + +pub fn vortex_zstd::ZstdVTable::from_var_bin_view_without_dict(vbv: &vortex_array::arrays::varbinview::array::VarBinViewArray, level: i32, values_per_frame: usize) -> vortex_error::VortexResult + +pub fn vortex_zstd::ZstdVTable::new(dictionary: core::option::Option, frames: alloc::vec::Vec, dtype: vortex_array::dtype::DType, metadata: vortex_zstd::ZstdMetadata, n_rows: usize, validity: vortex_array::validity::Validity) -> vortex_zstd::ZstdArray + impl core::fmt::Debug for vortex_zstd::ZstdVTable pub fn vortex_zstd::ZstdVTable::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result diff --git a/encodings/zstd/src/array.rs b/encodings/zstd/src/array.rs index db16ad9a56a..a8d3df3ff3e 100644 --- a/encodings/zstd/src/array.rs +++ b/encodings/zstd/src/array.rs @@ -245,7 +245,7 @@ impl VTable for ZstdVTable { ) }; - Ok(ZstdArray::new( + Ok(ZstdVTable::new( dictionary_buffer, compressed_buffers, dtype.clone(), @@ -289,6 +289,129 @@ pub struct ZstdVTable; impl ZstdVTable { pub const ID: ArrayId = ArrayId::new_ref("vortex.zstd"); + + /// Creates a new ZstdArray from its constituent parts. + #[allow(clippy::new_ret_no_self)] + pub fn new( + dictionary: Option, + frames: Vec, + dtype: DType, + metadata: ZstdMetadata, + n_rows: usize, + validity: Validity, + ) -> ZstdArray { + ZstdArray { + dictionary, + frames, + metadata, + common: ArrayCommon::new(n_rows, dtype), + unsliced_validity: validity, + unsliced_n_rows: n_rows, + slice_start: 0, + slice_stop: n_rows, + } + } + + /// Creates a ZstdArray from a primitive array. + /// + /// # Arguments + /// * `parray` - The primitive array to compress + /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression) + /// * `values_per_frame` - Number of values per frame (0 = single frame) + pub fn from_primitive( + parray: &PrimitiveArray, + level: i32, + values_per_frame: usize, + ) -> VortexResult { + from_primitive_impl(parray, level, values_per_frame, true) + } + + /// Creates a ZstdArray from a primitive array without using a dictionary. + /// + /// This is useful when the compressed data will be decompressed by systems + /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU). + /// + /// Note: Without a dictionary, each frame is compressed independently. + /// Dictionaries are trained from sample data from previously seen frames, + /// to improve compression ratio. + /// + /// # Arguments + /// * `parray` - The primitive array to compress + /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression) + /// * `values_per_frame` - Number of values per frame (0 = single frame) + pub fn from_primitive_without_dict( + parray: &PrimitiveArray, + level: i32, + values_per_frame: usize, + ) -> VortexResult { + from_primitive_impl(parray, level, values_per_frame, false) + } + + /// Creates a ZstdArray from a VarBinView array. + /// + /// # Arguments + /// * `vbv` - The VarBinView array to compress + /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression) + /// * `values_per_frame` - Number of values per frame (0 = single frame) + pub fn from_var_bin_view( + vbv: &VarBinViewArray, + level: i32, + values_per_frame: usize, + ) -> VortexResult { + from_var_bin_view_impl(vbv, level, values_per_frame, true) + } + + /// Creates a ZstdArray from a VarBinView array without using a dictionary. + /// + /// This is useful when the compressed data will be decompressed by systems + /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU). + /// + /// Note: Without a dictionary, each frame is compressed independently. + /// Dictionaries are trained from sample data from previously seen frames, + /// to improve compression ratio. + /// + /// # Arguments + /// * `vbv` - The VarBinView array to compress + /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression) + /// * `values_per_frame` - Number of values per frame (0 = single frame) + pub fn from_var_bin_view_without_dict( + vbv: &VarBinViewArray, + level: i32, + values_per_frame: usize, + ) -> VortexResult { + from_var_bin_view_impl(vbv, level, values_per_frame, false) + } + + /// Creates a ZstdArray from a canonical array representation. + pub fn from_canonical( + canonical: &Canonical, + level: i32, + values_per_frame: usize, + ) -> VortexResult> { + match canonical { + Canonical::Primitive(parray) => Ok(Some(ZstdVTable::from_primitive( + parray, + level, + values_per_frame, + )?)), + Canonical::VarBinView(vbv) => Ok(Some(ZstdVTable::from_var_bin_view( + vbv, + level, + values_per_frame, + )?)), + _ => Ok(None), + } + } + + /// Creates a ZstdArray from any array by first converting it to canonical form. + pub fn from_array( + array: ArrayRef, + level: i32, + values_per_frame: usize, + ) -> VortexResult { + Self::from_canonical(&array.to_canonical()?, level, values_per_frame)? + .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays")) + } } #[derive(Clone, Debug)] @@ -397,51 +520,29 @@ pub fn reconstruct_views(buffer: &ByteBuffer) -> Buffer { res.freeze() } -impl ZstdArray { - pub fn new( - dictionary: Option, - frames: Vec, - dtype: DType, - metadata: ZstdMetadata, - n_rows: usize, - validity: Validity, - ) -> Self { - Self { - dictionary, - frames, - metadata, - common: ArrayCommon::new(n_rows, dtype), - unsliced_validity: validity, - unsliced_n_rows: n_rows, - slice_start: 0, - slice_stop: n_rows, - } - } - - fn compress_values( - value_bytes: &ByteBuffer, - frame_byte_starts: &[usize], - level: i32, - values_per_frame: usize, - n_values: usize, - use_dictionary: bool, - ) -> VortexResult { - let n_frames = frame_byte_starts.len(); - - // Would-be sample sizes if we end up applying zstd dictionary - let mut sample_sizes = Vec::with_capacity(n_frames); - for i in 0..n_frames { - let frame_byte_end = frame_byte_starts - .get(i + 1) - .copied() - .unwrap_or(value_bytes.len()); - sample_sizes.push(frame_byte_end - frame_byte_starts[i]); - } - debug_assert_eq!(sample_sizes.iter().sum::(), value_bytes.len()); - - let (dictionary, mut compressor) = if !use_dictionary - || sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY - { +fn compress_values( + value_bytes: &ByteBuffer, + frame_byte_starts: &[usize], + level: i32, + values_per_frame: usize, + n_values: usize, + use_dictionary: bool, +) -> VortexResult { + let n_frames = frame_byte_starts.len(); + + // Would-be sample sizes if we end up applying zstd dictionary + let mut sample_sizes = Vec::with_capacity(n_frames); + for i in 0..n_frames { + let frame_byte_end = frame_byte_starts + .get(i + 1) + .copied() + .unwrap_or(value_bytes.len()); + sample_sizes.push(frame_byte_end - frame_byte_starts[i]); + } + debug_assert_eq!(sample_sizes.iter().sum::(), value_bytes.len()); + + let (dictionary, mut compressor) = + if !use_dictionary || sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY { // no dictionary (None, zstd::bulk::Compressor::new(level)?) } else { @@ -454,252 +555,167 @@ impl ZstdArray { (Some(ByteBuffer::from(dict)), compressor) }; - let mut frame_metas = vec![]; - let mut frames = vec![]; - for i in 0..n_frames { - let frame_byte_end = frame_byte_starts - .get(i + 1) - .copied() - .unwrap_or(value_bytes.len()); - - let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end); - let compressed = compressor - .compress(uncompressed) - .map_err(|err| VortexError::from(err).with_context("while compressing"))?; - frame_metas.push(ZstdFrameMetadata { - uncompressed_size: uncompressed.len() as u64, - n_values: values_per_frame.min(n_values - i * values_per_frame) as u64, - }); - frames.push(ByteBuffer::from(compressed)); - } - - Ok(Frames { - dictionary, - frames, - frame_metas, - }) - } - - /// Creates a ZstdArray from a primitive array. - /// - /// # Arguments - /// * `parray` - The primitive array to compress - /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression) - /// * `values_per_frame` - Number of values per frame (0 = single frame) - pub fn from_primitive( - parray: &PrimitiveArray, - level: i32, - values_per_frame: usize, - ) -> VortexResult { - Self::from_primitive_impl(parray, level, values_per_frame, true) - } - - /// Creates a ZstdArray from a primitive array without using a dictionary. - /// - /// This is useful when the compressed data will be decompressed by systems - /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU). - /// - /// Note: Without a dictionary, each frame is compressed independently. - /// Dictionaries are trained from sample data from previously seen frames, - /// to improve compression ratio. - /// - /// # Arguments - /// * `parray` - The primitive array to compress - /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression) - /// * `values_per_frame` - Number of values per frame (0 = single frame) - pub fn from_primitive_without_dict( - parray: &PrimitiveArray, - level: i32, - values_per_frame: usize, - ) -> VortexResult { - Self::from_primitive_impl(parray, level, values_per_frame, false) - } - - fn from_primitive_impl( - parray: &PrimitiveArray, - level: i32, - values_per_frame: usize, - use_dictionary: bool, - ) -> VortexResult { - let dtype = parray.dtype().clone(); - let byte_width = parray.ptype().byte_width(); - - // We compress only the valid elements. - let values = collect_valid_primitive(parray)?; - let n_values = values.len(); - let values_per_frame = if values_per_frame > 0 { - values_per_frame - } else { - n_values - }; - - let value_bytes = values.buffer_handle().try_to_host_sync()?; - // Align frames to buffer alignment. This is necessary for overaligned buffers. - let alignment = *value_bytes.alignment(); - let step_width = (values_per_frame * byte_width).div_ceil(alignment) * alignment; - - let frame_byte_starts = (0..n_values * byte_width) - .step_by(step_width) - .collect::>(); - let Frames { - dictionary, - frames, - frame_metas, - } = Self::compress_values( - &value_bytes, - &frame_byte_starts, - level, - values_per_frame, - n_values, - use_dictionary, - )?; - - let metadata = ZstdMetadata { - dictionary_size: dictionary - .as_ref() - .map_or(0, |dict| dict.len()) - .try_into()?, - frames: frame_metas, - }; + let mut frame_metas = vec![]; + let mut frames = vec![]; + for i in 0..n_frames { + let frame_byte_end = frame_byte_starts + .get(i + 1) + .copied() + .unwrap_or(value_bytes.len()); + + let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end); + let compressed = compressor + .compress(uncompressed) + .map_err(|err| VortexError::from(err).with_context("while compressing"))?; + frame_metas.push(ZstdFrameMetadata { + uncompressed_size: uncompressed.len() as u64, + n_values: values_per_frame.min(n_values - i * values_per_frame) as u64, + }); + frames.push(ByteBuffer::from(compressed)); + } + + Ok(Frames { + dictionary, + frames, + frame_metas, + }) +} - Ok(ZstdArray::new( - dictionary, - frames, - dtype, - metadata, - parray.len(), - parray.validity().clone(), - )) - } +fn from_primitive_impl( + parray: &PrimitiveArray, + level: i32, + values_per_frame: usize, + use_dictionary: bool, +) -> VortexResult { + let dtype = parray.dtype().clone(); + let byte_width = parray.ptype().byte_width(); + + // We compress only the valid elements. + let values = collect_valid_primitive(parray)?; + let n_values = values.len(); + let values_per_frame = if values_per_frame > 0 { + values_per_frame + } else { + n_values + }; - /// Creates a ZstdArray from a VarBinView array. - /// - /// # Arguments - /// * `vbv` - The VarBinView array to compress - /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression) - /// * `values_per_frame` - Number of values per frame (0 = single frame) - pub fn from_var_bin_view( - vbv: &VarBinViewArray, - level: i32, - values_per_frame: usize, - ) -> VortexResult { - Self::from_var_bin_view_impl(vbv, level, values_per_frame, true) - } + let value_bytes = values.buffer_handle().try_to_host_sync()?; + // Align frames to buffer alignment. This is necessary for overaligned buffers. + let alignment = *value_bytes.alignment(); + let step_width = (values_per_frame * byte_width).div_ceil(alignment) * alignment; + + let frame_byte_starts = (0..n_values * byte_width) + .step_by(step_width) + .collect::>(); + let Frames { + dictionary, + frames, + frame_metas, + } = compress_values( + &value_bytes, + &frame_byte_starts, + level, + values_per_frame, + n_values, + use_dictionary, + )?; + + let metadata = ZstdMetadata { + dictionary_size: dictionary + .as_ref() + .map_or(0, |dict| dict.len()) + .try_into()?, + frames: frame_metas, + }; - /// Creates a ZstdArray from a VarBinView array without using a dictionary. - /// - /// This is useful when the compressed data will be decompressed by systems - /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU). - /// - /// Note: Without a dictionary, each frame is compressed independently. - /// Dictionaries are trained from sample data from previously seen frames, - /// to improve compression ratio. - /// - /// # Arguments - /// * `vbv` - The VarBinView array to compress - /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression) - /// * `values_per_frame` - Number of values per frame (0 = single frame) - pub fn from_var_bin_view_without_dict( - vbv: &VarBinViewArray, - level: i32, - values_per_frame: usize, - ) -> VortexResult { - Self::from_var_bin_view_impl(vbv, level, values_per_frame, false) - } + Ok(ZstdVTable::new( + dictionary, + frames, + dtype, + metadata, + parray.len(), + parray.validity().clone(), + )) +} - fn from_var_bin_view_impl( - vbv: &VarBinViewArray, - level: i32, - values_per_frame: usize, - use_dictionary: bool, - ) -> VortexResult { - // Approach for strings: we prefix each string with its length as a u32. - // This is the same as what Parquet does. In some cases it may be better - // to separate the binary data and lengths as two separate streams, but - // this approach is simpler and can be best in cases when there is - // mutual information between strings and their lengths. - let dtype = vbv.dtype().clone(); - - // We compress only the valid elements. - let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?; - let n_values = value_byte_indices.len(); - let values_per_frame = if values_per_frame > 0 { - values_per_frame - } else { - n_values - }; +fn from_var_bin_view_impl( + vbv: &VarBinViewArray, + level: i32, + values_per_frame: usize, + use_dictionary: bool, +) -> VortexResult { + // Approach for strings: we prefix each string with its length as a u32. + // This is the same as what Parquet does. In some cases it may be better + // to separate the binary data and lengths as two separate streams, but + // this approach is simpler and can be best in cases when there is + // mutual information between strings and their lengths. + let dtype = vbv.dtype().clone(); + + // We compress only the valid elements. + let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?; + let n_values = value_byte_indices.len(); + let values_per_frame = if values_per_frame > 0 { + values_per_frame + } else { + n_values + }; - let frame_byte_starts = (0..n_values) - .step_by(values_per_frame) - .map(|i| value_byte_indices[i]) - .collect::>(); - let Frames { - dictionary, - frames, - frame_metas, - } = Self::compress_values( - &value_bytes, - &frame_byte_starts, - level, - values_per_frame, - n_values, - use_dictionary, - )?; - - let metadata = ZstdMetadata { - dictionary_size: dictionary - .as_ref() - .map_or(0, |dict| dict.len()) - .try_into()?, - frames: frame_metas, - }; - Ok(ZstdArray::new( - dictionary, - frames, - dtype, - metadata, - vbv.len(), - vbv.validity().clone(), - )) - } + let frame_byte_starts = (0..n_values) + .step_by(values_per_frame) + .map(|i| value_byte_indices[i]) + .collect::>(); + let Frames { + dictionary, + frames, + frame_metas, + } = compress_values( + &value_bytes, + &frame_byte_starts, + level, + values_per_frame, + n_values, + use_dictionary, + )?; + + let metadata = ZstdMetadata { + dictionary_size: dictionary + .as_ref() + .map_or(0, |dict| dict.len()) + .try_into()?, + frames: frame_metas, + }; + Ok(ZstdVTable::new( + dictionary, + frames, + dtype, + metadata, + vbv.len(), + vbv.validity().clone(), + )) +} - pub fn from_canonical( - canonical: &Canonical, - level: i32, - values_per_frame: usize, - ) -> VortexResult> { - match canonical { - Canonical::Primitive(parray) => Ok(Some(ZstdArray::from_primitive( - parray, - level, - values_per_frame, - )?)), - Canonical::VarBinView(vbv) => Ok(Some(ZstdArray::from_var_bin_view( - vbv, - level, - values_per_frame, - )?)), - _ => Ok(None), - } +fn byte_width(array: &ZstdArray) -> usize { + if array.common.dtype().is_primitive() { + array.common.dtype().as_ptype().byte_width() + } else { + 1 } +} - pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult { - Self::from_canonical(&array.to_canonical()?, level, values_per_frame)? - .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays")) - } +/// Extension trait for [`ZstdArray`] instance methods. +pub trait ZstdArrayExt: Sized { + /// Decompresses the array, returning the uncompressed array. + fn decompress(&self) -> VortexResult; - fn byte_width(&self) -> usize { - if self.common.dtype().is_primitive() { - self.common.dtype().as_ptype().byte_width() - } else { - 1 - } - } + /// Consumes the array and returns its parts. + fn into_parts(self) -> ZstdArrayParts; +} - pub fn decompress(&self) -> VortexResult { +impl ZstdArrayExt for ZstdArray { + fn decompress(&self) -> VortexResult { // To start, we figure out which frames we need to decompress, and with // what row offset into the first such frame. - let byte_width = self.byte_width(); + let bw = byte_width(self); let slice_n_rows = self.slice_stop - self.slice_start; let slice_value_indices = self .unsliced_validity @@ -722,7 +738,7 @@ impl ZstdArray { .vortex_expect("Uncompressed size must fit in usize"); let frame_n_values = if frame_meta.n_values == 0 { // possibly older primitive-only metadata that just didn't store this - frame_uncompressed_size / byte_width + frame_uncompressed_size / bw } else { usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize") }; @@ -746,7 +762,7 @@ impl ZstdArray { }; let mut decompressed = ByteBufferMut::with_capacity_aligned( uncompressed_size_to_decompress, - Alignment::new(byte_width), + Alignment::new(bw), ); unsafe { // safety: we immediately fill all bytes in the following loop, @@ -799,8 +815,8 @@ impl ZstdArray { match dtype { DType::Primitive(..) => { let slice_values_buffer = decompressed.slice( - (slice_value_idx_start - n_skipped_values) * byte_width - ..(slice_value_idx_stop - n_skipped_values) * byte_width, + (slice_value_idx_start - n_skipped_values) * bw + ..(slice_value_idx_stop - n_skipped_values) * bw, ); let primitive = PrimitiveArray::from_values_byte_buffer( slice_values_buffer, @@ -869,6 +885,21 @@ impl ZstdArray { } } + fn into_parts(self) -> ZstdArrayParts { + ZstdArrayParts { + dictionary: self.dictionary, + frames: self.frames, + metadata: self.metadata, + dtype: self.common.into_dtype(), + validity: self.unsliced_validity, + n_rows: self.unsliced_n_rows, + slice_start: self.slice_start, + slice_stop: self.slice_stop, + } + } +} + +impl ZstdArray { pub(crate) fn _slice(&self, start: usize, stop: usize) -> ZstdArray { let new_start = self.slice_start + start; let new_stop = self.slice_start + stop; @@ -895,20 +926,6 @@ impl ZstdArray { } } - /// Consumes the array and returns its parts. - pub fn into_parts(self) -> ZstdArrayParts { - ZstdArrayParts { - dictionary: self.dictionary, - frames: self.frames, - metadata: self.metadata, - dtype: self.common.into_dtype(), - validity: self.unsliced_validity, - n_rows: self.unsliced_n_rows, - slice_start: self.slice_start, - slice_stop: self.slice_stop, - } - } - pub(crate) fn dtype(&self) -> &DType { self.common.dtype() } diff --git a/encodings/zstd/src/compute/cast.rs b/encodings/zstd/src/compute/cast.rs index 932ffebed5d..7ce0d581e3e 100644 --- a/encodings/zstd/src/compute/cast.rs +++ b/encodings/zstd/src/compute/cast.rs @@ -32,7 +32,7 @@ impl CastReduce for ZstdVTable { (Nullability::NonNullable, Nullability::Nullable) => { // nonnull => null, trivial cast by altering the validity Ok(Some( - ZstdArray::new( + ZstdVTable::new( array.dictionary.clone(), array.frames.clone(), dtype.clone(), @@ -58,7 +58,7 @@ impl CastReduce for ZstdVTable { // If there are no nulls, the cast is trivial Ok(Some( - ZstdArray::new( + ZstdVTable::new( array.dictionary.clone(), array.frames.clone(), dtype.clone(), @@ -88,12 +88,12 @@ mod tests { use vortex_array::validity::Validity; use vortex_buffer::buffer; - use crate::ZstdArray; + use crate::ZstdVTable; #[test] fn test_cast_zstd_i32_to_i64() { let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]); - let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap(); + let zstd = ZstdVTable::from_primitive(&values, 0, 0).unwrap(); let casted = zstd .into_array() @@ -111,7 +111,7 @@ mod tests { #[test] fn test_cast_zstd_nullability_change() { let values = PrimitiveArray::from_iter([10u32, 20, 30, 40]); - let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap(); + let zstd = ZstdVTable::from_primitive(&values, 0, 0).unwrap(); let casted = zstd .into_array() @@ -129,7 +129,7 @@ mod tests { buffer![10u32, 20, 30, 40, 50, 60], Validity::from_iter([true, true, true, true, true, true]), ); - let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap(); + let zstd = ZstdVTable::from_primitive(&values, 0, 128).unwrap(); let sliced = zstd.slice(1..5).unwrap(); let casted = sliced .cast(DType::Primitive(PType::U32, Nullability::NonNullable)) @@ -153,7 +153,7 @@ mod tests { Some(50), Some(60), ]); - let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap(); + let zstd = ZstdVTable::from_primitive(&values, 0, 128).unwrap(); let sliced = zstd.slice(1..5).unwrap(); let casted = sliced .cast(DType::Primitive(PType::U32, Nullability::NonNullable)) @@ -185,7 +185,7 @@ mod tests { Validity::NonNullable, ))] fn test_cast_zstd_conformance(#[case] values: PrimitiveArray) { - let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap(); + let zstd = ZstdVTable::from_primitive(&values, 0, 0).unwrap(); test_cast_conformance(&zstd.into_array()); } } diff --git a/encodings/zstd/src/compute/mod.rs b/encodings/zstd/src/compute/mod.rs index 73e941b82c3..6d51a9e6fb7 100644 --- a/encodings/zstd/src/compute/mod.rs +++ b/encodings/zstd/src/compute/mod.rs @@ -12,26 +12,27 @@ mod tests { use vortex_buffer::buffer; use crate::ZstdArray; + use crate::ZstdVTable; fn zstd_i32() -> ZstdArray { let values = PrimitiveArray::from_iter([100i32, 200, 300, 400, 500]); - ZstdArray::from_primitive(&values, 0, 0).unwrap() + ZstdVTable::from_primitive(&values, 0, 0).unwrap() } fn zstd_f64() -> ZstdArray { let values = PrimitiveArray::from_iter([1.1f64, 2.2, 3.3, 4.4, 5.5]); - ZstdArray::from_primitive(&values, 0, 0).unwrap() + ZstdVTable::from_primitive(&values, 0, 0).unwrap() } fn zstd_u32() -> ZstdArray { let values = PrimitiveArray::from_iter([10u32, 20, 30, 40, 50]); - ZstdArray::from_primitive(&values, 0, 0).unwrap() + ZstdVTable::from_primitive(&values, 0, 0).unwrap() } fn zstd_nullable_i64() -> ZstdArray { let values = PrimitiveArray::from_option_iter([Some(1000i64), None, Some(3000), Some(4000), None]); - ZstdArray::from_primitive(&values, 0, 0).unwrap() + ZstdVTable::from_primitive(&values, 0, 0).unwrap() } fn zstd_single() -> ZstdArray { @@ -39,7 +40,7 @@ mod tests { buffer![42i64], vortex_array::validity::Validity::NonNullable, ); - ZstdArray::from_primitive(&values, 0, 0).unwrap() + ZstdVTable::from_primitive(&values, 0, 0).unwrap() } fn zstd_large() -> ZstdArray { @@ -47,7 +48,7 @@ mod tests { buffer![0u32..1000], vortex_array::validity::Validity::NonNullable, ); - ZstdArray::from_primitive(&values, 3, 0).unwrap() + ZstdVTable::from_primitive(&values, 3, 0).unwrap() } fn zstd_all_same() -> ZstdArray { @@ -55,12 +56,12 @@ mod tests { buffer![42i32; 100], vortex_array::validity::Validity::NonNullable, ); - ZstdArray::from_primitive(&values, 0, 0).unwrap() + ZstdVTable::from_primitive(&values, 0, 0).unwrap() } fn zstd_negative() -> ZstdArray { let values = PrimitiveArray::from_iter([-100i32, -50, 0, 50, 100]); - ZstdArray::from_primitive(&values, 0, 0).unwrap() + ZstdVTable::from_primitive(&values, 0, 0).unwrap() } #[rstest] diff --git a/encodings/zstd/src/test.rs b/encodings/zstd/src/test.rs index 7d3b4fe1d1d..53df9076177 100644 --- a/encodings/zstd/src/test.rs +++ b/encodings/zstd/src/test.rs @@ -17,14 +17,15 @@ use vortex_buffer::Alignment; use vortex_buffer::Buffer; use vortex_mask::Mask; -use crate::ZstdArray; +use crate::ZstdArrayExt; +use crate::ZstdVTable; #[test] fn test_zstd_compress_decompress() { let data: Vec = (0..200).collect(); let array = PrimitiveArray::from_iter(data.clone()); - let compressed = ZstdArray::from_primitive(&array, 3, 0).unwrap(); + let compressed = ZstdVTable::from_primitive(&array, 3, 0).unwrap(); // this data should be compressible assert!(compressed.frames.len() < array.nbytes() as usize); assert!(compressed.dictionary.is_none()); @@ -52,7 +53,7 @@ fn test_zstd_empty() { Validity::NonNullable, ); - let compressed = ZstdArray::from_primitive(&array, 3, 100).unwrap(); + let compressed = ZstdVTable::from_primitive(&array, 3, 100).unwrap(); assert_arrays_eq!(compressed, PrimitiveArray::from_iter(data)); } @@ -68,7 +69,7 @@ fn test_zstd_with_validity_and_multi_frame() { Validity::Array(BoolArray::from_iter(validity).into_array()), ); - let compressed = ZstdArray::from_primitive(&array, 0, 30).unwrap(); + let compressed = ZstdVTable::from_primitive(&array, 0, 30).unwrap(); assert!(compressed.dictionary.is_none()); assert_nth_scalar!(compressed, 0, None::); assert_nth_scalar!(compressed, 3, 3); @@ -102,7 +103,7 @@ fn test_zstd_with_dict() { Validity::NonNullable, ); - let compressed = ZstdArray::from_primitive(&array, 0, 16).unwrap(); + let compressed = ZstdVTable::from_primitive(&array, 0, 16).unwrap(); assert!(compressed.dictionary.is_some()); assert_nth_scalar!(compressed, 0, 0); assert_nth_scalar!(compressed, 199, 199); @@ -123,7 +124,7 @@ fn test_validity_vtable() { (0..5).collect::>(), Validity::Array(BoolArray::from_iter(mask_bools.clone()).into_array()), ); - let compressed = ZstdArray::from_primitive(&array, 3, 0).unwrap(); + let compressed = ZstdVTable::from_primitive(&array, 3, 0).unwrap(); assert_eq!( compressed.validity_mask().unwrap(), Mask::from_iter(mask_bools) @@ -145,7 +146,7 @@ fn test_zstd_var_bin_view() { ]; let array = VarBinViewArray::from_iter(data, DType::Utf8(Nullability::Nullable)); - let compressed = ZstdArray::from_var_bin_view(&array, 0, 3).unwrap(); + let compressed = ZstdVTable::from_var_bin_view(&array, 0, 3).unwrap(); assert!(compressed.dictionary.is_none()); assert_nth_scalar!(compressed, 0, "foo"); assert_nth_scalar!(compressed, 1, "bar"); @@ -170,7 +171,7 @@ fn test_zstd_decompress_var_bin_view() { ]; let array = VarBinViewArray::from_iter(data, DType::Utf8(Nullability::Nullable)); - let compressed = ZstdArray::from_var_bin_view(&array, 0, 3).unwrap(); + let compressed = ZstdVTable::from_var_bin_view(&array, 0, 3).unwrap(); assert!(compressed.dictionary.is_none()); assert_nth_scalar!(compressed, 0, "foo"); assert_nth_scalar!(compressed, 1, "bar"); @@ -189,7 +190,7 @@ fn test_zstd_decompress_var_bin_view() { fn test_sliced_array_children() { let data: Vec> = (0..10).map(|v| (v != 5).then_some(v)).collect(); let compressed = - ZstdArray::from_primitive(&PrimitiveArray::from_option_iter(data), 0, 100).unwrap(); + ZstdVTable::from_primitive(&PrimitiveArray::from_option_iter(data), 0, 100).unwrap(); let sliced = compressed.slice(0..4).unwrap(); sliced.children(); } @@ -202,7 +203,7 @@ fn test_zstd_frame_start_buffer_alignment() { let aligned_buffer = Buffer::copy_from_aligned(&data, Alignment::new(8)); // u8 array now has a 8-byte alignment. let array = PrimitiveArray::new(aligned_buffer, Validity::NonNullable); - let compressed = ZstdArray::from_primitive(&array, 0, 1); + let compressed = ZstdVTable::from_primitive(&array, 0, 1); assert!(compressed.is_ok()); } diff --git a/encodings/zstd/src/zstd_buffers.rs b/encodings/zstd/src/zstd_buffers.rs index f9baaf4da48..a6fccaf7937 100644 --- a/encodings/zstd/src/zstd_buffers.rs +++ b/encodings/zstd/src/zstd_buffers.rs @@ -41,6 +41,49 @@ pub struct ZstdBuffersVTable; impl ZstdBuffersVTable { pub const ID: ArrayId = ArrayId::new_ref("vortex.zstd_buffers"); + + /// Compresses the buffers of the given array using ZSTD. + /// + /// Each buffer of the input array is independently ZSTD-compressed. The children + /// and metadata of the input array are preserved as-is. + pub fn compress(array: &ArrayRef, level: i32) -> VortexResult { + let encoding_id = array.encoding_id(); + let metadata = array + .metadata()? + .ok_or_else(|| vortex_err!("Array does not support serialization"))?; + let buffer_handles = array.buffer_handles(); + let children = array.children(); + + let mut compressed_buffers = Vec::with_capacity(buffer_handles.len()); + let mut uncompressed_sizes = Vec::with_capacity(buffer_handles.len()); + let mut buffer_alignments = Vec::with_capacity(buffer_handles.len()); + + let mut compressor = zstd::bulk::Compressor::new(level)?; + // Compression is currently CPU-only, so we gather all buffers on the host. + for handle in &buffer_handles { + buffer_alignments.push(u32::from(handle.alignment())); + let host_buf = handle.clone().try_to_host_sync()?; + uncompressed_sizes.push(host_buf.len() as u64); + let compressed = compressor.compress(&host_buf)?; + compressed_buffers.push(BufferHandle::new_host(ByteBuffer::from(compressed))); + } + + let compressed = ZstdBuffersArray { + inner_encoding_id: encoding_id, + inner_metadata: metadata, + compressed_buffers, + uncompressed_sizes, + buffer_alignments, + children, + common: ArrayCommon::new(array.len(), array.dtype().clone()), + }; + compressed + .common + .stats() + .to_ref(compressed.as_ref()) + .inherit_from(array.statistics()); + Ok(compressed) + } } /// An encoding that ZSTD-compresses the buffers of any wrapped array. @@ -118,123 +161,24 @@ impl ZstdBuffersDecodePlan { } } -impl ZstdBuffersArray { - fn validate(&self) -> VortexResult<()> { - vortex_ensure_eq!( - self.compressed_buffers.len(), - self.uncompressed_sizes.len(), - "zstd_buffers metadata mismatch: {} compressed buffers vs {} sizes", - self.compressed_buffers.len(), - self.uncompressed_sizes.len() - ); - vortex_ensure_eq!( - self.compressed_buffers.len(), - self.buffer_alignments.len(), - "zstd_buffers metadata mismatch: {} compressed buffers vs {} alignments", - self.compressed_buffers.len(), - self.buffer_alignments.len() - ); - Ok(()) - } - - /// Compresses the buffers of the given array using ZSTD. +/// Extension trait providing instance methods for [`ZstdBuffersArray`]. +pub trait ZstdBuffersArrayExt { + /// Build the inner array from the given (decompressed) buffer handles. /// - /// Each buffer of the input array is independently ZSTD-compressed. The children - /// and metadata of the input array are preserved as-is. - pub fn compress(array: &ArrayRef, level: i32) -> VortexResult { - let encoding_id = array.encoding_id(); - let metadata = array - .metadata()? - .ok_or_else(|| vortex_err!("Array does not support serialization"))?; - let buffer_handles = array.buffer_handles(); - let children = array.children(); - - let mut compressed_buffers = Vec::with_capacity(buffer_handles.len()); - let mut uncompressed_sizes = Vec::with_capacity(buffer_handles.len()); - let mut buffer_alignments = Vec::with_capacity(buffer_handles.len()); - - let mut compressor = zstd::bulk::Compressor::new(level)?; - // Compression is currently CPU-only, so we gather all buffers on the host. - for handle in &buffer_handles { - buffer_alignments.push(u32::from(handle.alignment())); - let host_buf = handle.clone().try_to_host_sync()?; - uncompressed_sizes.push(host_buf.len() as u64); - let compressed = compressor.compress(&host_buf)?; - compressed_buffers.push(BufferHandle::new_host(ByteBuffer::from(compressed))); - } - - let compressed = Self { - inner_encoding_id: encoding_id, - inner_metadata: metadata, - compressed_buffers, - uncompressed_sizes, - buffer_alignments, - children, - common: ArrayCommon::new(array.len(), array.dtype().clone()), - }; - compressed - .common - .stats() - .to_ref(compressed.as_ref()) - .inherit_from(array.statistics()); - Ok(compressed) - } - - fn decompress_buffers(&self) -> VortexResult> { - // CPU decode path: zstd::bulk works on host bytes, so compressed buffers are - // materialized on the host via `try_to_host_sync`. - let mut decompressor = zstd::bulk::Decompressor::new()?; - let mut result = Vec::with_capacity(self.compressed_buffers.len()); - for (i, (buf, &uncompressed_size)) in self - .compressed_buffers - .iter() - .zip(&self.uncompressed_sizes) - .enumerate() - { - let size = usize::try_from(uncompressed_size)?; - let alignment = self.buffer_alignments.get(i).copied().unwrap_or(1); - - let aligned = Alignment::try_from(alignment)?; - let mut output = ByteBufferMut::with_capacity_aligned(size, aligned); - let spare = output.spare_capacity_mut(); - - // This is currently guaranteed, but still good to check because - // of the unsafe calls below. - if spare.len() < size { - return Err(vortex_err!( - "Insufficient output capacity: expected at least {}, got {}", - size, - spare.len() - )); - } - // SAFETY: we only expose the first `size` bytes and mark them initialized via - // `set_len(size)` after zstd reports how many bytes were written. - let dst = - unsafe { std::slice::from_raw_parts_mut(spare.as_mut_ptr().cast::(), size) }; - let compressed = buf.clone().try_to_host_sync()?; - let written = decompressor.decompress_to_buffer(compressed.as_slice(), dst)?; - if written != size { - return Err(vortex_err!( - "Decompressed size mismatch: expected {}, got {}", - size, - written - )); - } - // SAFETY: zstd wrote exactly `size` initialized bytes into `dst`. - unsafe { output.set_len(size) }; - result.push(BufferHandle::new_host(output.freeze())); - } - Ok(result) - } + /// This is exposed to help non-CPU executors pass uncompressed buffer handles + /// to build the inner array. + fn build_inner( + &self, + buffer_handles: &[BufferHandle], + session: &VortexSession, + ) -> VortexResult; - fn decompress_and_build_inner(&self, session: &VortexSession) -> VortexResult { - let decompressed_buffers = self.decompress_buffers()?; - self.build_inner(&decompressed_buffers, session) - } + /// Create a [`ZstdBuffersDecodePlan`] describing how to decompress all buffers. + fn decode_plan(&self) -> VortexResult; +} - // This is exposed to help non-CPU executors pass uncompressed buffer handles - // to build the inner array. - pub fn build_inner( +impl ZstdBuffersArrayExt for ZstdBuffersArray { + fn build_inner( &self, buffer_handles: &[BufferHandle], session: &VortexSession, @@ -256,10 +200,10 @@ impl ZstdBuffersArray { ) } - pub fn decode_plan(&self) -> VortexResult { + fn decode_plan(&self) -> VortexResult { // If invariants are somehow broken, device decompression could have UB, so ensure // they still hold. - self.validate()?; + validate(self)?; let output_sizes = self .uncompressed_sizes @@ -297,6 +241,78 @@ impl ZstdBuffersArray { } } +fn validate(array: &ZstdBuffersArray) -> VortexResult<()> { + vortex_ensure_eq!( + array.compressed_buffers.len(), + array.uncompressed_sizes.len(), + "zstd_buffers metadata mismatch: {} compressed buffers vs {} sizes", + array.compressed_buffers.len(), + array.uncompressed_sizes.len() + ); + vortex_ensure_eq!( + array.compressed_buffers.len(), + array.buffer_alignments.len(), + "zstd_buffers metadata mismatch: {} compressed buffers vs {} alignments", + array.compressed_buffers.len(), + array.buffer_alignments.len() + ); + Ok(()) +} + +fn decompress_buffers(array: &ZstdBuffersArray) -> VortexResult> { + // CPU decode path: zstd::bulk works on host bytes, so compressed buffers are + // materialized on the host via `try_to_host_sync`. + let mut decompressor = zstd::bulk::Decompressor::new()?; + let mut result = Vec::with_capacity(array.compressed_buffers.len()); + for (i, (buf, &uncompressed_size)) in array + .compressed_buffers + .iter() + .zip(&array.uncompressed_sizes) + .enumerate() + { + let size = usize::try_from(uncompressed_size)?; + let alignment = array.buffer_alignments.get(i).copied().unwrap_or(1); + + let aligned = Alignment::try_from(alignment)?; + let mut output = ByteBufferMut::with_capacity_aligned(size, aligned); + let spare = output.spare_capacity_mut(); + + // This is currently guaranteed, but still good to check because + // of the unsafe calls below. + if spare.len() < size { + return Err(vortex_err!( + "Insufficient output capacity: expected at least {}, got {}", + size, + spare.len() + )); + } + // SAFETY: we only expose the first `size` bytes and mark them initialized via + // `set_len(size)` after zstd reports how many bytes were written. + let dst = unsafe { std::slice::from_raw_parts_mut(spare.as_mut_ptr().cast::(), size) }; + let compressed = buf.clone().try_to_host_sync()?; + let written = decompressor.decompress_to_buffer(compressed.as_slice(), dst)?; + if written != size { + return Err(vortex_err!( + "Decompressed size mismatch: expected {}, got {}", + size, + written + )); + } + // SAFETY: zstd wrote exactly `size` initialized bytes into `dst`. + unsafe { output.set_len(size) }; + result.push(BufferHandle::new_host(output.freeze())); + } + Ok(result) +} + +fn decompress_and_build_inner( + array: &ZstdBuffersArray, + session: &VortexSession, +) -> VortexResult { + let decompressed_buffers = decompress_buffers(array)?; + array.build_inner(&decompressed_buffers, session) +} + fn compute_output_layout( output_sizes: &[usize], output_alignments: &[Alignment], @@ -452,7 +468,7 @@ impl VTable for ZstdBuffersVTable { common: ArrayCommon::new(len, dtype.clone()), }; - array.validate()?; + validate(&array)?; Ok(array) } @@ -463,7 +479,7 @@ impl VTable for ZstdBuffersVTable { fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { let session = ctx.session(); - let inner_array = array.decompress_and_build_inner(session)?; + let inner_array = decompress_and_build_inner(array, session)?; inner_array.execute::(ctx) } } @@ -473,7 +489,7 @@ impl OperationsVTable for ZstdBuffersVTable { // TODO(os): maybe we should not support scalar_at, it is really slow, and adding a cache // layer here is weird. Valid use of zstd buffers array would be by executing it first into // canonical - let inner_array = array.decompress_and_build_inner(&vortex_array::LEGACY_SESSION)?; + let inner_array = decompress_and_build_inner(array, &vortex_array::LEGACY_SESSION)?; inner_array.scalar_at(index) } } @@ -484,7 +500,7 @@ impl ValidityVTable for ZstdBuffersVTable { return Ok(vortex_array::validity::Validity::NonNullable); } - let inner_array = array.decompress_and_build_inner(&vortex_array::LEGACY_SESSION)?; + let inner_array = decompress_and_build_inner(array, &vortex_array::LEGACY_SESSION)?; inner_array.validity() } } @@ -546,7 +562,7 @@ mod tests { #[case::empty_primitive(make_empty_primitive_array())] #[case::inlined_varbinview(make_inlined_varbinview_array())] fn test_roundtrip(#[case] input: ArrayRef) -> VortexResult<()> { - let compressed = ZstdBuffersArray::compress(&input, 3)?; + let compressed = ZstdBuffersVTable::compress(&input, 3)?; assert_eq!(compressed.common.len(), input.len()); assert_eq!(compressed.common.dtype(), input.dtype()); @@ -563,7 +579,7 @@ mod tests { let input = make_primitive_array(); input.statistics().set(Stat::Min, Precision::exact(0i32)); - let compressed = ZstdBuffersArray::compress(&input, 3)?; + let compressed = ZstdBuffersVTable::compress(&input, 3)?; assert!(compressed.statistics().get(Stat::Min).is_some()); Ok(()) @@ -572,7 +588,7 @@ mod tests { #[test] fn test_validity_delegates_for_nullable_input() -> VortexResult<()> { let input = make_nullable_primitive_array(); - let compressed = ZstdBuffersArray::compress(&input, 3)?.into_array(); + let compressed = ZstdBuffersVTable::compress(&input, 3)?.into_array(); assert_eq!(compressed.all_valid()?, input.all_valid()?); assert_eq!(compressed.all_invalid()?, input.all_invalid()?); diff --git a/vortex-btrblocks/src/compressor/decimal.rs b/vortex-btrblocks/src/compressor/decimal.rs index bf738a72839..87a9e711338 100644 --- a/vortex-btrblocks/src/compressor/decimal.rs +++ b/vortex-btrblocks/src/compressor/decimal.rs @@ -9,7 +9,7 @@ use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::decimal::narrowed_decimal; use vortex_array::dtype::DecimalType; use vortex_array::vtable::ValidityHelper; -use vortex_decimal_byte_parts::DecimalBytePartsArray; +use vortex_decimal_byte_parts::DecimalBytePartsVTable; use vortex_error::VortexResult; use crate::BtrBlocksCompressor; @@ -39,5 +39,5 @@ pub fn compress_decimal( Excludes::none(), )?; - DecimalBytePartsArray::try_new(compressed, decimal.decimal_dtype()).map(|d| d.into_array()) + DecimalBytePartsVTable::try_new(compressed, decimal.decimal_dtype()).map(|d| d.into_array()) } diff --git a/vortex-btrblocks/src/compressor/float/mod.rs b/vortex-btrblocks/src/compressor/float/mod.rs index fab17d37532..67b38fc2924 100644 --- a/vortex-btrblocks/src/compressor/float/mod.rs +++ b/vortex-btrblocks/src/compressor/float/mod.rs @@ -8,7 +8,8 @@ use std::hash::Hash; use std::hash::Hasher; use enum_iterator::Sequence; -use vortex_alp::ALPArray; +use vortex_alp::ALPArrayExt; +use vortex_alp::ALPRDArrayExt; use vortex_alp::ALPVTable; use vortex_alp::RDEncoder; use vortex_alp::alp_encode; @@ -27,7 +28,7 @@ use vortex_array::vtable::VTable; use vortex_array::vtable::ValidityHelper; use vortex_error::VortexResult; use vortex_error::vortex_panic; -use vortex_sparse::SparseArray; +use vortex_sparse::SparseArrayExt; use vortex_sparse::SparseVTable; use self::dictionary::dictionary_encode; @@ -343,7 +344,7 @@ impl Scheme for ALPScheme { let patches = alp.patches().map(compress_patches).transpose()?; - Ok(ALPArray::new(compressed_alp_ints, alp.exponents(), patches).into_array()) + Ok(ALPVTable::new(compressed_alp_ints, alp.exponents(), patches).into_array()) } } @@ -501,7 +502,7 @@ impl Scheme for NullDominated { assert!(ctx.allowed_cascading > 0); // We pass None as we only run this pathway for NULL-dominated float arrays - let sparse_encoded = SparseArray::encode(&stats.src.clone().into_array(), None)?; + let sparse_encoded = SparseVTable::encode(&stats.src.clone().into_array(), None)?; if let Some(sparse) = sparse_encoded.as_opt::() { // Compress the values @@ -516,7 +517,7 @@ impl Scheme for NullDominated { Excludes::int_only(&new_excludes), )?; - SparseArray::try_new( + SparseVTable::try_new( compressed_indices, sparse.patches().values().clone(), sparse.len(), @@ -545,7 +546,7 @@ impl Scheme for PcoScheme { _ctx: CompressorContext, _excludes: &[FloatCode], ) -> VortexResult { - Ok(vortex_pco::PcoArray::from_primitive( + Ok(vortex_pco::PcoVTable::from_primitive( stats.source(), pco::DEFAULT_COMPRESSION_LEVEL, 8192, diff --git a/vortex-btrblocks/src/compressor/integer/mod.rs b/vortex-btrblocks/src/compressor/integer/mod.rs index 18c1535668f..a978260f796 100644 --- a/vortex-btrblocks/src/compressor/integer/mod.rs +++ b/vortex-btrblocks/src/compressor/integer/mod.rs @@ -25,16 +25,19 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_err; -use vortex_fastlanes::FoRArray; +use vortex_fastlanes::BitPackedArrayExt; +use vortex_fastlanes::FoRArrayExt; +use vortex_fastlanes::FoRVTable; use vortex_fastlanes::bitpack_compress::bit_width_histogram; use vortex_fastlanes::bitpack_compress::bitpack_encode; use vortex_fastlanes::bitpack_compress::find_best_bit_width; use vortex_runend::RunEndArray; use vortex_runend::compress::runend_encode; use vortex_sequence::sequence_encode; -use vortex_sparse::SparseArray; +use vortex_sparse::SparseArrayExt; use vortex_sparse::SparseVTable; -use vortex_zigzag::ZigZagArray; +use vortex_zigzag::ZigZagArrayExt; +use vortex_zigzag::ZigZagVTable; use vortex_zigzag::zigzag_encode; use self::dictionary::dictionary_encode; @@ -384,7 +387,7 @@ impl Scheme for FORScheme { ctx: CompressorContext, excludes: &[IntCode], ) -> VortexResult { - let for_array = FoRArray::encode(stats.src.clone())?; + let for_array = FoRVTable::encode(stats.src.clone())?; let biased = for_array.encoded().to_primitive(); let biased_stats = IntegerStats::generate_opts( &biased, @@ -404,7 +407,7 @@ impl Scheme for FORScheme { let compressed = BitPackingScheme.compress(compressor, &biased_stats, leaf_ctx, excludes)?; - let for_compressed = FoRArray::try_new(compressed, for_array.reference_scalar().clone())?; + let for_compressed = FoRVTable::try_new(compressed, for_array.reference_scalar().clone())?; for_compressed .as_ref() .statistics() @@ -476,7 +479,7 @@ impl Scheme for ZigZagScheme { tracing::debug!("zigzag output: {}", compressed.encoding_id()); - Ok(ZigZagArray::try_new(compressed)?.into_array()) + Ok(ZigZagVTable::try_new(compressed)?.into_array()) } } @@ -600,7 +603,7 @@ impl Scheme for SparseScheme { .into_array()); } - let sparse_encoded = SparseArray::encode( + let sparse_encoded = SparseVTable::encode( &stats.src.clone().into_array(), Some(Scalar::primitive_value( top_pvalue, @@ -628,7 +631,7 @@ impl Scheme for SparseScheme { Excludes::int_only(&new_excludes), )?; - SparseArray::try_new( + SparseVTable::try_new( compressed_indices, compressed_values, sparse.len(), @@ -872,7 +875,7 @@ impl Scheme for PcoScheme { _ctx: CompressorContext, _excludes: &[IntCode], ) -> VortexResult { - Ok(vortex_pco::PcoArray::from_primitive( + Ok(vortex_pco::PcoVTable::from_primitive( stats.source(), pco::DEFAULT_COMPRESSION_LEVEL, 8192, diff --git a/vortex-btrblocks/src/compressor/rle.rs b/vortex-btrblocks/src/compressor/rle.rs index c0cb20780bb..87e8a161483 100644 --- a/vortex-btrblocks/src/compressor/rle.rs +++ b/vortex-btrblocks/src/compressor/rle.rs @@ -12,6 +12,8 @@ use vortex_array::ToCanonical; use vortex_array::arrays::PrimitiveArray; use vortex_error::VortexResult; use vortex_fastlanes::RLEArray; +use vortex_fastlanes::RLEArrayExt; +use vortex_fastlanes::RLEVTable; use crate::BtrBlocksCompressor; use crate::CanonicalCompressor; @@ -114,7 +116,7 @@ impl Scheme for RLEScheme { ctx: CompressorContext, excludes: &[C::Code], ) -> VortexResult { - let rle_array = RLEArray::encode(RLEStats::source(stats))?; + let rle_array = RLEVTable::encode(RLEStats::source(stats))?; if ctx.allowed_cascading == 0 { return Ok(rle_array.into_array()); diff --git a/vortex-btrblocks/src/compressor/string.rs b/vortex-btrblocks/src/compressor/string.rs index 382dc8c4ae9..4a90f17347e 100644 --- a/vortex-btrblocks/src/compressor/string.rs +++ b/vortex-btrblocks/src/compressor/string.rs @@ -23,12 +23,15 @@ use vortex_array::vtable::ValidityHelper; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_err; -use vortex_fsst::FSSTArray; +use vortex_fsst::FSSTArrayExt; +use vortex_fsst::FSSTVTable; use vortex_fsst::fsst_compress; use vortex_fsst::fsst_train_compressor; -use vortex_sparse::SparseArray; +use vortex_sparse::SparseArrayExt; use vortex_sparse::SparseVTable; use vortex_utils::aliases::hash_set::HashSet; +#[cfg(all(feature = "zstd", feature = "unstable_encodings"))] +use vortex_zstd::ZstdBuffersVTable; use super::integer::DictScheme as IntDictScheme; use super::integer::SequenceScheme as IntSequenceScheme; @@ -378,7 +381,7 @@ impl Scheme for FSSTScheme { fsst.codes().validity().clone(), )?; - let fsst = FSSTArray::try_new( + let fsst = FSSTVTable::try_new( fsst.dtype().clone(), fsst.symbols().clone(), fsst.symbol_lengths().clone(), @@ -496,7 +499,7 @@ impl Scheme for NullDominated { assert!(ctx.allowed_cascading > 0); // We pass None as we only run this pathway for NULL-dominated string arrays - let sparse_encoded = SparseArray::encode(&stats.src.clone().into_array(), None)?; + let sparse_encoded = SparseVTable::encode(&stats.src.clone().into_array(), None)?; if let Some(sparse) = sparse_encoded.as_opt::() { // Compress the indices only (not the values for strings) @@ -509,7 +512,7 @@ impl Scheme for NullDominated { Excludes::int_only(&new_excludes), )?; - SparseArray::try_new( + SparseVTable::try_new( compressed_indices, sparse.patches().values().clone(), sparse.len(), @@ -540,7 +543,7 @@ impl Scheme for ZstdScheme { ) -> VortexResult { let compacted = stats.source().compact_buffers()?; Ok( - vortex_zstd::ZstdArray::from_var_bin_view_without_dict(&compacted, 3, 8192)? + vortex_zstd::ZstdVTable::from_var_bin_view_without_dict(&compacted, 3, 8192)? .into_array(), ) } @@ -562,10 +565,7 @@ impl Scheme for ZstdBuffersScheme { _ctx: CompressorContext, _excludes: &[StringCode], ) -> VortexResult { - Ok( - vortex_zstd::ZstdBuffersArray::compress(&stats.source().clone().into_array(), 3)? - .into_array(), - ) + Ok(ZstdBuffersVTable::compress(&stats.source().clone().into_array(), 3)?.into_array()) } } diff --git a/vortex-btrblocks/src/compressor/temporal.rs b/vortex-btrblocks/src/compressor/temporal.rs index 6fb917be58d..1e039111e44 100644 --- a/vortex-btrblocks/src/compressor/temporal.rs +++ b/vortex-btrblocks/src/compressor/temporal.rs @@ -8,7 +8,7 @@ use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::ToCanonical; use vortex_array::arrays::TemporalArray; -use vortex_datetime_parts::DateTimePartsArray; +use vortex_datetime_parts::DateTimePartsVTable; use vortex_datetime_parts::TemporalParts; use vortex_datetime_parts::split_temporal; use vortex_error::VortexResult; @@ -48,5 +48,5 @@ pub fn compress_temporal( Excludes::none(), )?; - Ok(DateTimePartsArray::try_new(dtype, days, seconds, subseconds)?.into_array()) + Ok(DateTimePartsVTable::try_new(dtype, days, seconds, subseconds)?.into_array()) } diff --git a/vortex-cuda/benches/bitpacked_cuda.rs b/vortex-cuda/benches/bitpacked_cuda.rs index 44c911f545c..064142feec0 100644 --- a/vortex-cuda/benches/bitpacked_cuda.rs +++ b/vortex-cuda/benches/bitpacked_cuda.rs @@ -24,6 +24,7 @@ use vortex::array::validity::Validity::NonNullable; use vortex::buffer::Buffer; use vortex::dtype::NativePType; use vortex::encodings::fastlanes::BitPackedArray; +use vortex::encodings::fastlanes::BitPackedVTable; use vortex::encodings::fastlanes::unpack_iter::BitPacked; use vortex::error::VortexExpect; use vortex::session::VortexSession; @@ -56,7 +57,7 @@ where .collect(); let primitive_array = PrimitiveArray::new(Buffer::from(values), NonNullable); - BitPackedArray::encode(&primitive_array.into_array(), bit_width) + BitPackedVTable::encode(&primitive_array.into_array(), bit_width) .vortex_expect("failed to create BitPacked array") } @@ -96,7 +97,7 @@ where .collect(); let primitive_array = PrimitiveArray::new(Buffer::from(values), NonNullable).into_array(); - BitPackedArray::encode(&primitive_array, bit_width) + BitPackedVTable::encode(&primitive_array, bit_width) .vortex_expect("failed to create BitPacked array with patches") } diff --git a/vortex-cuda/benches/date_time_parts_cuda.rs b/vortex-cuda/benches/date_time_parts_cuda.rs index b4c15496865..f0967ebc85d 100644 --- a/vortex-cuda/benches/date_time_parts_cuda.rs +++ b/vortex-cuda/benches/date_time_parts_cuda.rs @@ -25,6 +25,7 @@ use vortex::buffer::Buffer; use vortex::dtype::DType; use vortex::dtype::Nullability; use vortex::encodings::datetime_parts::DateTimePartsArray; +use vortex::encodings::datetime_parts::DateTimePartsVTable; use vortex::error::VortexExpect; use vortex::extension::datetime::TimeUnit; use vortex::extension::datetime::Timestamp; @@ -44,7 +45,7 @@ fn make_datetimeparts_array(len: usize, time_unit: TimeUnit) -> DateTimePartsArr let dtype = DType::Extension(Timestamp::new(time_unit, Nullability::NonNullable).erased()); - DateTimePartsArray::try_new(dtype, days_arr, seconds_arr, subseconds_arr) + DateTimePartsVTable::try_new(dtype, days_arr, seconds_arr, subseconds_arr) .vortex_expect("Failed to create DateTimePartsArray") } diff --git a/vortex-cuda/benches/dynamic_dispatch_cuda.rs b/vortex-cuda/benches/dynamic_dispatch_cuda.rs index dda49222b40..4f402759dbc 100644 --- a/vortex-cuda/benches/dynamic_dispatch_cuda.rs +++ b/vortex-cuda/benches/dynamic_dispatch_cuda.rs @@ -24,13 +24,15 @@ use vortex::array::scalar::Scalar; use vortex::array::validity::Validity::NonNullable; use vortex::buffer::Buffer; use vortex::dtype::PType; -use vortex::encodings::alp::ALPArray; +use vortex::encodings::alp::ALPArrayExt; use vortex::encodings::alp::ALPFloat; +use vortex::encodings::alp::ALPVTable; use vortex::encodings::alp::Exponents; use vortex::encodings::alp::alp_encode; -use vortex::encodings::fastlanes::BitPackedArray; -use vortex::encodings::fastlanes::FoRArray; -use vortex::encodings::runend::RunEndArray; +use vortex::encodings::fastlanes::BitPackedVTable; +use vortex::encodings::fastlanes::FoRArrayExt; +use vortex::encodings::fastlanes::FoRVTable; +use vortex::encodings::runend::RunEndVTable; use vortex::error::VortexExpect; use vortex::error::VortexResult; use vortex::error::vortex_err; @@ -176,9 +178,9 @@ fn bench_for_bitpacked(c: &mut Criterion) { .map(|i| (i as u64 % (max_val + 1)) as u32) .collect(); let prim = PrimitiveArray::new(Buffer::from(residuals), NonNullable); - let bp = BitPackedArray::encode(&prim.into_array(), bit_width).vortex_expect("bitpack"); + let bp = BitPackedVTable::encode(&prim.into_array(), bit_width).vortex_expect("bitpack"); let for_arr = - FoRArray::try_new(bp.into_array(), Scalar::from(reference)).vortex_expect("for"); + FoRVTable::try_new(bp.into_array(), Scalar::from(reference)).vortex_expect("for"); let array = for_arr.into_array(); group.bench_with_input( @@ -220,7 +222,7 @@ fn bench_dict_bp_codes(c: &mut Criterion) { let codes: Vec = (0..*len).map(|i| (i % dict_size) as u32).collect(); let codes_prim = PrimitiveArray::new(Buffer::from(codes), NonNullable); - let codes_bp = BitPackedArray::encode(&codes_prim.into_array(), dict_bit_width) + let codes_bp = BitPackedVTable::encode(&codes_prim.into_array(), dict_bit_width) .vortex_expect("bitpack codes"); let values_prim = PrimitiveArray::new(Buffer::from(dict_values.clone()), NonNullable); let dict = DictArray::new(codes_bp.into_array(), values_prim.into_array()); @@ -267,7 +269,7 @@ fn bench_runend(c: &mut Criterion) { let ends_arr = PrimitiveArray::new(Buffer::from(ends), NonNullable).into_array(); let values_arr = PrimitiveArray::new(Buffer::from(values), NonNullable).into_array(); - let re = RunEndArray::new(ends_arr, values_arr); + let re = RunEndVTable::new(ends_arr, values_arr); let array = re.into_array(); group.bench_with_input( @@ -308,9 +310,9 @@ fn bench_dict_bp_codes_bp_for_values(c: &mut Criterion) { // Dict values: residuals 0..63 bitpacked, FoR adds 1_000_000 let dict_residuals: Vec = (0..dict_size as u32).collect(); let dict_prim = PrimitiveArray::new(Buffer::from(dict_residuals), NonNullable); - let dict_bp = BitPackedArray::encode(&dict_prim.into_array(), dict_bit_width) + let dict_bp = BitPackedVTable::encode(&dict_prim.into_array(), dict_bit_width) .vortex_expect("bitpack dict"); - let dict_for = FoRArray::try_new(dict_bp.into_array(), Scalar::from(dict_reference)) + let dict_for = FoRVTable::try_new(dict_bp.into_array(), Scalar::from(dict_reference)) .vortex_expect("for dict"); for (len, len_str) in BENCH_ARGS { @@ -318,7 +320,7 @@ fn bench_dict_bp_codes_bp_for_values(c: &mut Criterion) { let codes: Vec = (0..*len).map(|i| (i % dict_size) as u32).collect(); let codes_prim = PrimitiveArray::new(Buffer::from(codes), NonNullable); - let codes_bp = BitPackedArray::encode(&codes_prim.into_array(), codes_bit_width) + let codes_bp = BitPackedVTable::encode(&codes_prim.into_array(), codes_bit_width) .vortex_expect("bitpack codes"); let dict = DictArray::new(codes_bp.into_array(), dict_for.clone().into_array()); @@ -369,12 +371,12 @@ fn bench_alp_for_bitpacked(c: &mut Criterion) { // Encode: ALP → FoR → BitPacked let alp = alp_encode(&float_prim, Some(exponents)).vortex_expect("alp_encode"); assert!(alp.patches().is_none()); - let for_arr = FoRArray::encode(alp.encoded().to_primitive()).vortex_expect("for encode"); + let for_arr = FoRVTable::encode(alp.encoded().to_primitive()).vortex_expect("for encode"); let bp = - BitPackedArray::encode(for_arr.encoded(), bit_width).vortex_expect("bitpack encode"); + BitPackedVTable::encode(for_arr.encoded(), bit_width).vortex_expect("bitpack encode"); - let tree = ALPArray::new( - FoRArray::try_new(bp.into_array(), for_arr.reference_scalar().clone()) + let tree = ALPVTable::new( + FoRVTable::try_new(bp.into_array(), for_arr.reference_scalar().clone()) .vortex_expect("for_new") .into_array(), exponents, diff --git a/vortex-cuda/benches/for_cuda.rs b/vortex-cuda/benches/for_cuda.rs index 31f7b270e92..b6a92776666 100644 --- a/vortex-cuda/benches/for_cuda.rs +++ b/vortex-cuda/benches/for_cuda.rs @@ -25,8 +25,9 @@ use vortex::array::validity::Validity; use vortex::buffer::Buffer; use vortex::dtype::NativePType; use vortex::dtype::PType; -use vortex::encodings::fastlanes::BitPackedArray; +use vortex::encodings::fastlanes::BitPackedVTable; use vortex::encodings::fastlanes::FoRArray; +use vortex::encodings::fastlanes::FoRVTable; use vortex::error::VortexExpect; use vortex::scalar::Scalar; use vortex::session::VortexSession; @@ -55,11 +56,11 @@ where PrimitiveArray::new(Buffer::from(data), Validity::NonNullable).into_array(); if bp && T::PTYPE != PType::U8 { - let child = BitPackedArray::encode(&primitive_array, 8).vortex_expect("failed to bitpack"); - FoRArray::try_new(child.into_array(), reference.into()) + let child = BitPackedVTable::encode(&primitive_array, 8).vortex_expect("failed to bitpack"); + FoRVTable::try_new(child.into_array(), reference.into()) .vortex_expect("failed to create FoR array") } else { - FoRArray::try_new(primitive_array, reference.into()) + FoRVTable::try_new(primitive_array, reference.into()) .vortex_expect("failed to create FoR array") } } diff --git a/vortex-cuda/benches/runend_cuda.rs b/vortex-cuda/benches/runend_cuda.rs index ced52c19c40..e183a5825a4 100644 --- a/vortex-cuda/benches/runend_cuda.rs +++ b/vortex-cuda/benches/runend_cuda.rs @@ -24,6 +24,7 @@ use vortex::array::validity::Validity; use vortex::buffer::Buffer; use vortex::dtype::NativePType; use vortex::encodings::runend::RunEndArray; +use vortex::encodings::runend::RunEndVTable; use vortex::session::VortexSession; use vortex_cuda::CudaSession; use vortex_cuda::executor::CudaArrayExt; @@ -54,7 +55,7 @@ where let ends_array = PrimitiveArray::new(Buffer::from(ends), Validity::NonNullable).into_array(); let values_array = PrimitiveArray::new(Buffer::from(values), Validity::NonNullable).into_array(); - RunEndArray::new(ends_array, values_array) + RunEndVTable::new(ends_array, values_array) } /// Benchmark run-end decoding for a specific type with varying run lengths diff --git a/vortex-cuda/benches/zstd_cuda.rs b/vortex-cuda/benches/zstd_cuda.rs index 41be9fa8cf9..9efb0b427dd 100644 --- a/vortex-cuda/benches/zstd_cuda.rs +++ b/vortex-cuda/benches/zstd_cuda.rs @@ -14,7 +14,9 @@ use cudarc::driver::sys::CUevent_flags; use futures::executor::block_on; use vortex::array::arrays::VarBinViewArray; use vortex::encodings::zstd::ZstdArray; +use vortex::encodings::zstd::ZstdArrayExt; use vortex::encodings::zstd::ZstdArrayParts; +use vortex::encodings::zstd::ZstdVTable; use vortex::error::VortexExpect; use vortex::error::VortexResult; use vortex::error::vortex_err; @@ -58,7 +60,7 @@ fn make_zstd_array(num_strings: usize) -> VortexResult<(ZstdArray, usize)> { let zstd_compression_level = -10; // Less compression but faster. let zstd_array = // Disable dictionary as nvCOMP doesn't support ZSTD dictionaries. - ZstdArray::from_var_bin_view_without_dict(&var_bin_view, zstd_compression_level, 2048)?; + ZstdVTable::from_var_bin_view_without_dict(&var_bin_view, zstd_compression_level, 2048)?; Ok((zstd_array, uncompressed_size)) } diff --git a/vortex-cuda/src/dynamic_dispatch/mod.rs b/vortex-cuda/src/dynamic_dispatch/mod.rs index a2a431f1490..b7ce8586772 100644 --- a/vortex-cuda/src/dynamic_dispatch/mod.rs +++ b/vortex-cuda/src/dynamic_dispatch/mod.rs @@ -200,14 +200,18 @@ mod tests { use vortex::array::validity::Validity::NonNullable; use vortex::buffer::Buffer; use vortex::dtype::PType; - use vortex::encodings::alp::ALPArray; + use vortex::encodings::alp::ALPArrayExt; use vortex::encodings::alp::ALPFloat; + use vortex::encodings::alp::ALPVTable; use vortex::encodings::alp::Exponents; use vortex::encodings::alp::alp_encode; use vortex::encodings::fastlanes::BitPackedArray; - use vortex::encodings::fastlanes::FoRArray; - use vortex::encodings::runend::RunEndArray; - use vortex::encodings::zigzag::ZigZagArray; + use vortex::encodings::fastlanes::BitPackedArrayExt; + use vortex::encodings::fastlanes::BitPackedVTable; + use vortex::encodings::fastlanes::FoRArrayExt; + use vortex::encodings::fastlanes::FoRVTable; + use vortex::encodings::runend::RunEndVTable; + use vortex::encodings::zigzag::ZigZagVTable; use vortex::error::VortexExpect; use vortex::error::VortexResult; use vortex::session::VortexSession; @@ -229,7 +233,7 @@ mod tests { .map(|i| ((i as u64) % (max_val + 1)) as u32) .collect(); let primitive = PrimitiveArray::new(Buffer::from(values), NonNullable); - BitPackedArray::encode(&primitive.into_array(), bit_width) + BitPackedVTable::encode(&primitive.into_array(), bit_width) .vortex_expect("failed to create BitPacked array") } @@ -451,7 +455,7 @@ mod tests { let expected: Vec = raw.iter().map(|&v| v + reference).collect(); let bp = make_bitpacked_array_u32(bit_width, len); - let for_arr = FoRArray::try_new(bp.into_array(), Scalar::from(reference))?; + let for_arr = FoRVTable::try_new(bp.into_array(), Scalar::from(reference))?; let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; let (plan, _bufs) = build_plan(&for_arr.into_array(), &cuda_ctx)?; @@ -476,7 +480,7 @@ mod tests { let ends_arr = PrimitiveArray::new(Buffer::from(ends), NonNullable).into_array(); let values_arr = PrimitiveArray::new(Buffer::from(values), NonNullable).into_array(); - let re = RunEndArray::new(ends_arr, values_arr); + let re = RunEndVTable::new(ends_arr, values_arr); let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; let (plan, _bufs) = build_plan(&re.into_array(), &cuda_ctx)?; @@ -501,12 +505,12 @@ mod tests { // BitPack+FoR the dict values let dict_prim = PrimitiveArray::new(Buffer::from(dict_residuals), NonNullable); - let dict_bp = BitPackedArray::encode(&dict_prim.into_array(), 6)?; - let dict_for = FoRArray::try_new(dict_bp.into_array(), Scalar::from(dict_reference))?; + let dict_bp = BitPackedVTable::encode(&dict_prim.into_array(), 6)?; + let dict_for = FoRVTable::try_new(dict_bp.into_array(), Scalar::from(dict_reference))?; // BitPack the codes let codes_prim = PrimitiveArray::new(Buffer::from(codes), NonNullable); - let codes_bp = BitPackedArray::encode(&codes_prim.into_array(), 6)?; + let codes_bp = BitPackedVTable::encode(&codes_prim.into_array(), 6)?; let dict = DictArray::try_new(codes_bp.into_array(), dict_for.into_array())?; @@ -532,11 +536,11 @@ mod tests { let alp = alp_encode(&float_prim, Some(exponents))?; assert!(alp.patches().is_none()); - let for_arr = FoRArray::encode(alp.encoded().to_primitive())?; - let bp = BitPackedArray::encode(for_arr.encoded(), 6)?; + let for_arr = FoRVTable::encode(alp.encoded().to_primitive())?; + let bp = BitPackedVTable::encode(for_arr.encoded(), 6)?; - let tree = ALPArray::new( - FoRArray::try_new(bp.into_array(), for_arr.reference_scalar().clone())?.into_array(), + let tree = ALPVTable::new( + FoRVTable::try_new(bp.into_array(), for_arr.reference_scalar().clone())?.into_array(), exponents, None, ); @@ -566,8 +570,8 @@ mod tests { .collect(); let prim = PrimitiveArray::new(Buffer::from(raw), NonNullable); - let bp = BitPackedArray::encode(&prim.into_array(), bit_width)?; - let zz = ZigZagArray::try_new(bp.into_array())?; + let bp = BitPackedVTable::encode(&prim.into_array(), bit_width)?; + let zz = ZigZagVTable::try_new(bp.into_array())?; let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; let (plan, _bufs) = build_plan(&zz.into_array(), &cuda_ctx)?; @@ -594,8 +598,8 @@ mod tests { let ends_arr = PrimitiveArray::new(Buffer::from(ends), NonNullable).into_array(); let values_arr = PrimitiveArray::new(Buffer::from(values), NonNullable).into_array(); - let re = RunEndArray::new(ends_arr, values_arr); - let for_arr = FoRArray::try_new(re.into_array(), Scalar::from(reference))?; + let re = RunEndVTable::new(ends_arr, values_arr); + let for_arr = FoRVTable::try_new(re.into_array(), Scalar::from(reference))?; let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; let (plan, _bufs) = build_plan(&for_arr.into_array(), &cuda_ctx)?; @@ -623,7 +627,7 @@ mod tests { let codes_prim = PrimitiveArray::new(Buffer::from(codes), NonNullable); let values_prim = PrimitiveArray::new(Buffer::from(dict_values), NonNullable); let dict = DictArray::try_new(codes_prim.into_array(), values_prim.into_array())?; - let for_arr = FoRArray::try_new(dict.into_array(), Scalar::from(reference))?; + let for_arr = FoRVTable::try_new(dict.into_array(), Scalar::from(reference))?; let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; let (plan, _bufs) = build_plan(&for_arr.into_array(), &cuda_ctx)?; @@ -646,8 +650,8 @@ mod tests { // BitPack codes, then wrap in FoR (reference=0 so values unchanged) let bit_width: u8 = 3; let codes_prim = PrimitiveArray::new(Buffer::from(codes), NonNullable); - let codes_bp = BitPackedArray::encode(&codes_prim.into_array(), bit_width)?; - let codes_for = FoRArray::try_new(codes_bp.into_array(), Scalar::from(0u32))?; + let codes_bp = BitPackedVTable::encode(&codes_prim.into_array(), bit_width)?; + let codes_for = FoRVTable::try_new(codes_bp.into_array(), Scalar::from(0u32))?; let values_prim = PrimitiveArray::new(Buffer::from(dict_values), NonNullable); let dict = DictArray::try_new(codes_for.into_array(), values_prim.into_array())?; @@ -671,7 +675,7 @@ mod tests { let bit_width: u8 = 2; let codes_prim = PrimitiveArray::new(Buffer::from(codes), NonNullable); - let codes_bp = BitPackedArray::encode(&codes_prim.into_array(), bit_width)?; + let codes_bp = BitPackedVTable::encode(&codes_prim.into_array(), bit_width)?; let values_prim = PrimitiveArray::new(Buffer::from(dict_values), NonNullable); let dict = DictArray::try_new(codes_bp.into_array(), values_prim.into_array())?; diff --git a/vortex-cuda/src/dynamic_dispatch/plan_builder.rs b/vortex-cuda/src/dynamic_dispatch/plan_builder.rs index 734e46b35fd..f0405cf02e6 100644 --- a/vortex-cuda/src/dynamic_dispatch/plan_builder.rs +++ b/vortex-cuda/src/dynamic_dispatch/plan_builder.rs @@ -15,13 +15,18 @@ use vortex::array::arrays::PrimitiveVTable; use vortex::array::arrays::primitive::PrimitiveArrayParts; use vortex::array::buffer::BufferHandle; use vortex::dtype::PType; +use vortex::encodings::alp::ALPArrayExt; use vortex::encodings::alp::ALPFloat; use vortex::encodings::alp::ALPVTable; +use vortex::encodings::fastlanes::BitPackedArrayExt; use vortex::encodings::fastlanes::BitPackedArrayParts; use vortex::encodings::fastlanes::BitPackedVTable; +use vortex::encodings::fastlanes::FoRArrayExt; use vortex::encodings::fastlanes::FoRVTable; +use vortex::encodings::runend::RunEndArrayExt; use vortex::encodings::runend::RunEndArrayParts; use vortex::encodings::runend::RunEndVTable; +use vortex::encodings::zigzag::ZigZagArrayExt; use vortex::encodings::zigzag::ZigZagVTable; use vortex::error::VortexResult; use vortex::error::vortex_bail; diff --git a/vortex-cuda/src/kernel/encodings/alp.rs b/vortex-cuda/src/kernel/encodings/alp.rs index 99d8cc67913..94beff82aff 100644 --- a/vortex-cuda/src/kernel/encodings/alp.rs +++ b/vortex-cuda/src/kernel/encodings/alp.rs @@ -17,6 +17,7 @@ use vortex::array::buffer::BufferHandle; use vortex::array::match_each_unsigned_integer_ptype; use vortex::dtype::NativePType; use vortex::encodings::alp::ALPArray; +use vortex::encodings::alp::ALPArrayExt; use vortex::encodings::alp::ALPFloat; use vortex::encodings::alp::ALPVTable; use vortex::encodings::alp::match_each_alp_float_ptype; @@ -124,7 +125,6 @@ mod tests { use vortex::array::validity::Validity; use vortex::buffer::Buffer; use vortex::buffer::buffer; - use vortex::encodings::alp::ALPArray; use vortex::encodings::alp::Exponents; use vortex::error::VortexExpect; use vortex::session::VortexSession; @@ -155,7 +155,7 @@ mod tests { ) .unwrap(); - let alp_array = ALPArray::try_new( + let alp_array = ALPVTable::try_new( PrimitiveArray::new(Buffer::from(encoded_data.clone()), Validity::NonNullable) .into_array(), exponents, diff --git a/vortex-cuda/src/kernel/encodings/bitpacked.rs b/vortex-cuda/src/kernel/encodings/bitpacked.rs index ff078411ea1..e19e92965e4 100644 --- a/vortex-cuda/src/kernel/encodings/bitpacked.rs +++ b/vortex-cuda/src/kernel/encodings/bitpacked.rs @@ -17,6 +17,7 @@ use vortex::array::buffer::DeviceBufferExt; use vortex::array::match_each_integer_ptype; use vortex::dtype::NativePType; use vortex::encodings::fastlanes::BitPackedArray; +use vortex::encodings::fastlanes::BitPackedArrayExt; use vortex::encodings::fastlanes::BitPackedArrayParts; use vortex::encodings::fastlanes::BitPackedVTable; use vortex::encodings::fastlanes::unpack_iter::BitPacked; @@ -201,7 +202,7 @@ mod tests { let array = PrimitiveArray::new(iter.collect::>(), NonNullable); // Last two items should be patched - let bp_with_patches = BitPackedArray::encode(&array.into_array(), bw)?; + let bp_with_patches = BitPackedVTable::encode(&array.into_array(), bw)?; assert!(bp_with_patches.patches().is_some()); let cpu_result = bp_with_patches.to_canonical()?.into_array(); @@ -232,7 +233,7 @@ mod tests { ); // Last two items should be patched - let bp_with_patches = BitPackedArray::encode(&array.into_array(), 9)?; + let bp_with_patches = BitPackedVTable::encode(&array.into_array(), 9)?; assert!(bp_with_patches.patches().is_some()); let cpu_result = bp_with_patches.to_canonical()?.into_array(); @@ -274,7 +275,7 @@ mod tests { NonNullable, ); - let bitpacked_array = BitPackedArray::encode(&primitive_array.into_array(), bit_width) + let bitpacked_array = BitPackedVTable::encode(&primitive_array.into_array(), bit_width) .vortex_expect("operation should succeed in test"); let cpu_result = bitpacked_array.to_canonical()?; @@ -323,7 +324,7 @@ mod tests { NonNullable, ); - let bitpacked_array = BitPackedArray::encode(&primitive_array.into_array(), bit_width) + let bitpacked_array = BitPackedVTable::encode(&primitive_array.into_array(), bit_width) .vortex_expect("operation should succeed in test"); let cpu_result = bitpacked_array.to_canonical()?; @@ -388,7 +389,7 @@ mod tests { NonNullable, ); - let bitpacked_array = BitPackedArray::encode(&primitive_array.into_array(), bit_width) + let bitpacked_array = BitPackedVTable::encode(&primitive_array.into_array(), bit_width) .vortex_expect("operation should succeed in test"); let cpu_result = bitpacked_array.to_canonical()?; @@ -485,7 +486,7 @@ mod tests { NonNullable, ); - let bitpacked_array = BitPackedArray::encode(&primitive_array.into_array(), bit_width) + let bitpacked_array = BitPackedVTable::encode(&primitive_array.into_array(), bit_width) .vortex_expect("operation should succeed in test"); let cpu_result = bitpacked_array.to_canonical()?; let gpu_result = block_on(async { @@ -518,7 +519,7 @@ mod tests { NonNullable, ); - let bitpacked_array = BitPackedArray::encode(&primitive_array.into_array(), bit_width) + let bitpacked_array = BitPackedVTable::encode(&primitive_array.into_array(), bit_width) .vortex_expect("operation should succeed in test"); let slice_ref = bitpacked_array.clone().into_array().slice(67..3969)?; let mut exec_ctx = ExecutionCtx::new(VortexSession::empty().with::()); diff --git a/vortex-cuda/src/kernel/encodings/date_time_parts.rs b/vortex-cuda/src/kernel/encodings/date_time_parts.rs index 0dfad060890..fd1296a8046 100644 --- a/vortex-cuda/src/kernel/encodings/date_time_parts.rs +++ b/vortex-cuda/src/kernel/encodings/date_time_parts.rs @@ -21,6 +21,7 @@ use vortex::dtype::DType; use vortex::dtype::NativePType; use vortex::dtype::Nullability; use vortex::dtype::PType; +use vortex::encodings::datetime_parts::DateTimePartsArrayExt; use vortex::encodings::datetime_parts::DateTimePartsVTable; use vortex::error::VortexResult; use vortex::error::vortex_bail; @@ -237,7 +238,7 @@ mod tests { None, ); - DateTimePartsArray::try_new( + DateTimePartsVTable::try_new( temporal.dtype().clone(), days_arr, seconds_arr, @@ -348,7 +349,7 @@ mod tests { None, ); - let dtp_array = DateTimePartsArray::try_new( + let dtp_array = DateTimePartsVTable::try_new( temporal.dtype().clone(), days_arr, seconds_arr, diff --git a/vortex-cuda/src/kernel/encodings/decimal_byte_parts.rs b/vortex-cuda/src/kernel/encodings/decimal_byte_parts.rs index 872ba45957a..fa0b219af01 100644 --- a/vortex-cuda/src/kernel/encodings/decimal_byte_parts.rs +++ b/vortex-cuda/src/kernel/encodings/decimal_byte_parts.rs @@ -9,6 +9,7 @@ use vortex::array::ArrayRef; use vortex::array::Canonical; use vortex::array::arrays::DecimalArray; use vortex::array::arrays::primitive::PrimitiveArrayParts; +use vortex::encodings::decimal_byte_parts::DecimalBytePartsArrayExt; use vortex::encodings::decimal_byte_parts::DecimalBytePartsArrayParts; use vortex::encodings::decimal_byte_parts::DecimalBytePartsVTable; use vortex::error::VortexResult; @@ -60,7 +61,6 @@ mod tests { use vortex::array::validity::Validity; use vortex::buffer::Buffer; use vortex::dtype::DecimalDType; - use vortex::encodings::decimal_byte_parts::DecimalBytePartsArray; use vortex::error::VortexExpect; use vortex::session::VortexSession; @@ -82,7 +82,7 @@ mod tests { .vortex_expect("create execution context"); let decimal_dtype = DecimalDType::new(precision, scale); - let dbp_array = DecimalBytePartsArray::try_new( + let dbp_array = DecimalBytePartsVTable::try_new( PrimitiveArray::new(encoded, Validity::NonNullable).into_array(), decimal_dtype, ) diff --git a/vortex-cuda/src/kernel/encodings/for_.rs b/vortex-cuda/src/kernel/encodings/for_.rs index 47280834c3c..c5c76b3b6d9 100644 --- a/vortex-cuda/src/kernel/encodings/for_.rs +++ b/vortex-cuda/src/kernel/encodings/for_.rs @@ -16,8 +16,10 @@ use vortex::array::arrays::primitive::PrimitiveArrayParts; use vortex::array::match_each_integer_ptype; use vortex::array::match_each_native_simd_ptype; use vortex::dtype::NativePType; +use vortex::encodings::fastlanes::BitPackedArrayExt; use vortex::encodings::fastlanes::BitPackedVTable; use vortex::encodings::fastlanes::FoRArray; +use vortex::encodings::fastlanes::FoRArrayExt; use vortex::encodings::fastlanes::FoRVTable; use vortex::error::VortexExpect; use vortex::error::VortexResult; @@ -127,7 +129,6 @@ mod tests { use vortex::array::validity::Validity::NonNullable; use vortex::buffer::Buffer; use vortex::dtype::NativePType; - use vortex::encodings::fastlanes::BitPackedArray; use vortex::encodings::fastlanes::FoRArray; use vortex::error::VortexExpect; use vortex::scalar::Scalar; @@ -138,7 +139,7 @@ mod tests { use crate::session::CudaSession; fn make_for_array>(input_data: Vec, reference: T) -> FoRArray { - FoRArray::try_new( + FoRVTable::try_new( PrimitiveArray::new(Buffer::from(input_data), NonNullable).into_array(), reference.into(), ) @@ -180,8 +181,8 @@ mod tests { .take(1024) .collect::>() .into_array(); - let packed = BitPackedArray::encode(&values, 3).unwrap().into_array(); - let for_array = FoRArray::try_new(packed, (-8i8).into()).unwrap(); + let packed = BitPackedVTable::encode(&values, 3).unwrap().into_array(); + let for_array = FoRVTable::try_new(packed, (-8i8).into()).unwrap(); let cpu_result = for_array.to_canonical().unwrap(); diff --git a/vortex-cuda/src/kernel/encodings/runend.rs b/vortex-cuda/src/kernel/encodings/runend.rs index 2411d6f0142..4272e03b90a 100644 --- a/vortex-cuda/src/kernel/encodings/runend.rs +++ b/vortex-cuda/src/kernel/encodings/runend.rs @@ -19,6 +19,7 @@ use vortex::array::validity::Validity; use vortex::dtype::NativePType; use vortex::dtype::PType; use vortex::encodings::runend::RunEndArray; +use vortex::encodings::runend::RunEndArrayExt; use vortex::encodings::runend::RunEndArrayParts; use vortex::encodings::runend::RunEndVTable; use vortex::error::VortexResult; @@ -182,7 +183,7 @@ mod tests { PrimitiveArray::new(Buffer::from(ends), Validity::NonNullable).into_array(); let values_array = PrimitiveArray::new(Buffer::from(values), Validity::NonNullable).into_array(); - RunEndArray::new(ends_array, values_array) + RunEndVTable::new(ends_array, values_array) } #[rstest] diff --git a/vortex-cuda/src/kernel/encodings/sequence.rs b/vortex-cuda/src/kernel/encodings/sequence.rs index abe53cdd18f..84078d6e5a7 100644 --- a/vortex-cuda/src/kernel/encodings/sequence.rs +++ b/vortex-cuda/src/kernel/encodings/sequence.rs @@ -14,6 +14,7 @@ use vortex::array::buffer::BufferHandle; use vortex::array::match_each_native_ptype; use vortex::dtype::NativePType; use vortex::dtype::Nullability; +use vortex::encodings::sequence::SequenceArrayExt; use vortex::encodings::sequence::SequenceArrayParts; use vortex::encodings::sequence::SequenceVTable; use vortex::error::VortexResult; @@ -89,7 +90,7 @@ mod tests { use vortex::array::assert_arrays_eq; use vortex::dtype::NativePType; use vortex::dtype::Nullability; - use vortex::encodings::sequence::SequenceArray; + use vortex::encodings::sequence::SequenceVTable; use vortex::scalar::PValue; use vortex::session::VortexSession; @@ -126,7 +127,7 @@ mod tests { ) { let mut cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty()).unwrap(); - let array = SequenceArray::try_new_typed(base, multiplier, nullability, len).unwrap(); + let array = SequenceVTable::try_new_typed(base, multiplier, nullability, len).unwrap(); let cpu_result = array.to_canonical().unwrap().into_array(); diff --git a/vortex-cuda/src/kernel/encodings/zigzag.rs b/vortex-cuda/src/kernel/encodings/zigzag.rs index 4dd26fc267a..606af5e384c 100644 --- a/vortex-cuda/src/kernel/encodings/zigzag.rs +++ b/vortex-cuda/src/kernel/encodings/zigzag.rs @@ -15,6 +15,7 @@ use vortex::array::match_each_unsigned_integer_ptype; use vortex::dtype::NativePType; use vortex::dtype::PType; use vortex::encodings::zigzag::ZigZagArray; +use vortex::encodings::zigzag::ZigZagArrayExt; use vortex::encodings::zigzag::ZigZagVTable; use vortex::error::VortexResult; use vortex::error::vortex_ensure; @@ -102,7 +103,6 @@ mod tests { use vortex::array::assert_arrays_eq; use vortex::array::validity::Validity::NonNullable; use vortex::buffer::Buffer; - use vortex::encodings::zigzag::ZigZagArray; use vortex::error::VortexExpect; use vortex::session::VortexSession; @@ -119,7 +119,7 @@ mod tests { // So encoded [0, 2, 4, 1, 3] should decode to [0, 1, 2, -1, -2] let encoded_data: Vec = vec![0, 2, 4, 1, 3]; - let zigzag_array = ZigZagArray::try_new( + let zigzag_array = ZigZagVTable::try_new( PrimitiveArray::new(Buffer::from(encoded_data), NonNullable).into_array(), )?; diff --git a/vortex-cuda/src/kernel/encodings/zstd.rs b/vortex-cuda/src/kernel/encodings/zstd.rs index 2a57e87e8d7..cd26533a021 100644 --- a/vortex-cuda/src/kernel/encodings/zstd.rs +++ b/vortex-cuda/src/kernel/encodings/zstd.rs @@ -24,6 +24,7 @@ use vortex::buffer::Buffer; use vortex::buffer::ByteBuffer; use vortex::dtype::DType; use vortex::encodings::zstd::ZstdArray; +use vortex::encodings::zstd::ZstdArrayExt; use vortex::encodings::zstd::ZstdArrayParts; use vortex::encodings::zstd::ZstdMetadata; use vortex::encodings::zstd::ZstdVTable; @@ -345,7 +346,6 @@ mod tests { use vortex::array::IntoArray; use vortex::array::arrays::VarBinViewArray; use vortex::array::assert_arrays_eq; - use vortex::encodings::zstd::ZstdArray; use vortex::error::VortexResult; use vortex::session::VortexSession; @@ -366,7 +366,7 @@ mod tests { "baz", ]); - let zstd_array = ZstdArray::from_var_bin_view(&strings, 3, 0)?; + let zstd_array = ZstdVTable::from_var_bin_view(&strings, 3, 0)?; let cpu_result = zstd_array.decompress()?.to_canonical()?; let gpu_result = ZstdExecutor @@ -401,7 +401,7 @@ mod tests { // Compress with ZSTD using values_per_frame=3 to create multiple frames. // 14 strings and 3 values per frame = ceil(14/3) = 5 frames. - let zstd_array = ZstdArray::from_var_bin_view(&strings, 3, 3)?; + let zstd_array = ZstdVTable::from_var_bin_view(&strings, 3, 3)?; let cpu_result = zstd_array.decompress()?.to_canonical()?; let gpu_result = ZstdExecutor @@ -430,7 +430,7 @@ mod tests { "final test string", ]); - let zstd_array = ZstdArray::from_var_bin_view(&strings, 3, 0)?; + let zstd_array = ZstdVTable::from_var_bin_view(&strings, 3, 0)?; // Slice the array to get a subset (indices 2..7) let sliced_zstd = zstd_array.slice(2..7)?; diff --git a/vortex-cuda/src/kernel/encodings/zstd_buffers.rs b/vortex-cuda/src/kernel/encodings/zstd_buffers.rs index bc218c841c0..a1f47ac4421 100644 --- a/vortex-cuda/src/kernel/encodings/zstd_buffers.rs +++ b/vortex-cuda/src/kernel/encodings/zstd_buffers.rs @@ -17,6 +17,7 @@ use vortex::array::buffer::DeviceBuffer; use vortex::buffer::Alignment; use vortex::buffer::Buffer; use vortex::encodings::zstd::ZstdBuffersArray; +use vortex::encodings::zstd::ZstdBuffersArrayExt; use vortex::encodings::zstd::ZstdBuffersVTable; use vortex::error::VortexResult; use vortex::error::vortex_err; @@ -223,7 +224,6 @@ mod tests { use vortex::array::arrays::PrimitiveArray; use vortex::array::arrays::VarBinViewArray; use vortex::array::assert_arrays_eq; - use vortex::encodings::zstd::ZstdBuffersArray; use vortex::error::VortexExpect; use vortex::error::VortexResult; use vortex::session::VortexSession; @@ -238,7 +238,7 @@ mod tests { .vortex_expect("failed to create execution context"); let input = PrimitiveArray::from_iter(0i64..1024).into_array(); - let compressed = ZstdBuffersArray::compress(&input, 3)?; + let compressed = ZstdBuffersVTable::compress(&input, 3)?; let cpu_result = compressed.clone().into_array().to_canonical()?; let gpu_result = ZstdBuffersExecutor @@ -265,7 +265,7 @@ mod tests { "baz", ]) .into_array(); - let compressed = ZstdBuffersArray::compress(&input, 3)?; + let compressed = ZstdBuffersVTable::compress(&input, 3)?; let cpu_result = compressed.clone().into_array().to_canonical()?; let gpu_result = ZstdBuffersExecutor diff --git a/vortex-duckdb/src/e2e_test/vortex_scan_test.rs b/vortex-duckdb/src/e2e_test/vortex_scan_test.rs index 75db1ac3cd6..ab58ccc7a17 100644 --- a/vortex-duckdb/src/e2e_test/vortex_scan_test.rs +++ b/vortex-duckdb/src/e2e_test/vortex_scan_test.rs @@ -36,8 +36,8 @@ use vortex::file::WriteOptionsSessionExt; use vortex::io::runtime::BlockingRuntime; use vortex::scalar::PValue; use vortex::scalar::Scalar; -use vortex_runend::RunEndArray; -use vortex_sequence::SequenceArray; +use vortex_runend::RunEndVTable; +use vortex_sequence::SequenceVTable; use crate::RUNTIME; use crate::SESSION; @@ -789,10 +789,10 @@ async fn write_vortex_file_with_encodings() -> NamedTempFile { // 4. Run-End let run_ends = buffer![3u32, 5]; let run_values = buffer![100i32, 200]; - let rle_array = RunEndArray::try_new(run_ends.into_array(), run_values.into_array()).unwrap(); + let rle_array = RunEndVTable::try_new(run_ends.into_array(), run_values.into_array()).unwrap(); // 5. Sequence array - let sequence_array = SequenceArray::try_new( + let sequence_array = SequenceVTable::try_new( PValue::I64(0), PValue::I64(10), PType::I64, diff --git a/vortex-duckdb/src/exporter/run_end.rs b/vortex-duckdb/src/exporter/run_end.rs index 2a27af64940..b1aa3b788a5 100644 --- a/vortex-duckdb/src/exporter/run_end.rs +++ b/vortex-duckdb/src/exporter/run_end.rs @@ -11,6 +11,7 @@ use vortex::array::search_sorted::SearchSorted; use vortex::array::search_sorted::SearchSortedSide; use vortex::dtype::IntegerPType; use vortex::encodings::runend::RunEndArray; +use vortex::encodings::runend::RunEndArrayExt; use vortex::encodings::runend::RunEndArrayParts; use vortex::error::VortexExpect; use vortex::error::VortexResult; diff --git a/vortex-duckdb/src/exporter/sequence.rs b/vortex-duckdb/src/exporter/sequence.rs index 380f57c95bf..7c9f8a6bad8 100644 --- a/vortex-duckdb/src/exporter/sequence.rs +++ b/vortex-duckdb/src/exporter/sequence.rs @@ -4,6 +4,7 @@ use bitvec::macros::internal::funty::Fundamental; use vortex::array::ExecutionCtx; use vortex::encodings::sequence::SequenceArray; +use vortex::encodings::sequence::SequenceArrayExt; use vortex::error::VortexExpect; use vortex::error::VortexResult; @@ -44,6 +45,7 @@ impl ColumnExporter for SequenceExporter { #[cfg(test)] mod tests { use vortex::dtype::Nullability; + use vortex::encodings::sequence::SequenceVTable; use vortex_array::VortexSessionExecute; use super::*; @@ -54,7 +56,7 @@ mod tests { #[test] fn test_sequence() { - let arr = SequenceArray::try_new_typed(2, 5, Nullability::NonNullable, 100).unwrap(); + let arr = SequenceVTable::try_new_typed(2, 5, Nullability::NonNullable, 100).unwrap(); let mut chunk = DataChunk::new([LogicalType::new(cpp::duckdb_type::DUCKDB_TYPE_INTEGER)]); new_exporter(&arr) diff --git a/vortex-layout/src/layouts/row_idx/mod.rs b/vortex-layout/src/layouts/row_idx/mod.rs index fd7513aad96..bfaab67d695 100644 --- a/vortex-layout/src/layouts/row_idx/mod.rs +++ b/vortex-layout/src/layouts/row_idx/mod.rs @@ -36,6 +36,7 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_mask::Mask; use vortex_sequence::SequenceArray; +use vortex_sequence::SequenceVTable; use vortex_session::VortexSession; use vortex_utils::aliases::dash_map::DashMap; @@ -246,7 +247,7 @@ impl LayoutReader for RowIdxLayoutReader { // Returns a SequenceArray representing the row indices for the given row range, fn idx_array(row_offset: u64, row_range: &Range) -> SequenceArray { - SequenceArray::try_new( + SequenceVTable::try_new( PValue::U64(row_offset + row_range.start), PValue::U64(1), PType::U64, diff --git a/vortex-python/src/arrays/fastlanes.rs b/vortex-python/src/arrays/fastlanes.rs index 120622c2594..ccdd8ae592a 100644 --- a/vortex-python/src/arrays/fastlanes.rs +++ b/vortex-python/src/arrays/fastlanes.rs @@ -2,6 +2,7 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use pyo3::prelude::*; +use vortex::encodings::fastlanes::BitPackedArrayExt; use vortex::encodings::fastlanes::BitPackedVTable; use vortex::encodings::fastlanes::DeltaVTable; use vortex::encodings::fastlanes::FoRVTable; diff --git a/vortex-python/src/arrays/range_to_sequence.rs b/vortex-python/src/arrays/range_to_sequence.rs index 450d2fa65a3..e45278fdee7 100644 --- a/vortex-python/src/arrays/range_to_sequence.rs +++ b/vortex-python/src/arrays/range_to_sequence.rs @@ -9,7 +9,7 @@ use vortex::buffer::Buffer; use vortex::dtype::DType; use vortex::dtype::NativePType; use vortex::dtype::Nullability; -use vortex::encodings::sequence::SequenceArray; +use vortex::encodings::sequence::SequenceVTable; use vortex::error::VortexExpect; use vortex::error::VortexResult; use vortex::error::vortex_bail; @@ -43,7 +43,7 @@ pub fn sequence_array_from_range + Into> vortex_bail!("Step, {}, does not fit in requested dtype: {}", step, dtype); }; - Ok(SequenceArray::try_new_typed::(start, step, dtype.nullability(), len)?.into_array()) + Ok(SequenceVTable::try_new_typed::(start, step, dtype.nullability(), len)?.into_array()) } fn range_len(start: isize, stop: isize, step: isize) -> Option { diff --git a/vortex/benches/common_encoding_tree_throughput.rs b/vortex/benches/common_encoding_tree_throughput.rs index f4243a7d332..a75d8cd0631 100644 --- a/vortex/benches/common_encoding_tree_throughput.rs +++ b/vortex/benches/common_encoding_tree_throughput.rs @@ -26,16 +26,20 @@ use vortex::array::builtins::ArrayBuiltins; use vortex::array::vtable::ValidityHelper; use vortex::dtype::DType; use vortex::dtype::PType; +use vortex::encodings::alp::ALPArrayExt; use vortex::encodings::alp::alp_encode; -use vortex::encodings::datetime_parts::DateTimePartsArray; +use vortex::encodings::datetime_parts::DateTimePartsVTable; use vortex::encodings::datetime_parts::split_temporal; -use vortex::encodings::fastlanes::FoRArray; -use vortex::encodings::fsst::FSSTArray; +use vortex::encodings::fastlanes::FoRArrayExt; +use vortex::encodings::fastlanes::FoRVTable; +use vortex::encodings::fsst::FSSTArrayExt; +use vortex::encodings::fsst::FSSTVTable; use vortex::encodings::fsst::fsst_compress; use vortex::encodings::fsst::fsst_train_compressor; -use vortex::encodings::runend::RunEndArray; +use vortex::encodings::runend::RunEndArrayExt; +use vortex::encodings::runend::RunEndVTable; use vortex::extension::datetime::TimeUnit; -use vortex_fastlanes::BitPackedArray; +use vortex_fastlanes::BitPackedVTable; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; @@ -86,10 +90,10 @@ mod setup { /// Create FoR <- BitPacked encoding tree for u64 pub fn for_bp_u64() -> ArrayRef { let (uint_array, ..) = setup_primitive_arrays(); - let compressed = FoRArray::encode(uint_array).unwrap(); + let compressed = FoRVTable::encode(uint_array).unwrap(); let inner = compressed.encoded(); - let bp = BitPackedArray::encode(inner, 8).unwrap(); - FoRArray::try_new(bp.into_array(), compressed.reference_scalar().clone()) + let bp = BitPackedVTable::encode(inner, 8).unwrap(); + FoRVTable::try_new(bp.into_array(), compressed.reference_scalar().clone()) .unwrap() .into_array() } @@ -100,13 +104,13 @@ mod setup { let alp_compressed = alp_encode(&float_array, None).unwrap(); // Manually construct ALP <- FoR <- BitPacked tree - let for_array = FoRArray::encode(alp_compressed.encoded().to_primitive()).unwrap(); + let for_array = FoRVTable::encode(alp_compressed.encoded().to_primitive()).unwrap(); let inner = for_array.encoded(); - let bp = BitPackedArray::encode(inner, 8).unwrap(); + let bp = BitPackedVTable::encode(inner, 8).unwrap(); let for_with_bp = - FoRArray::try_new(bp.into_array(), for_array.reference_scalar().clone()).unwrap(); + FoRVTable::try_new(bp.into_array(), for_array.reference_scalar().clone()).unwrap(); - vortex::encodings::alp::ALPArray::try_new( + vortex::encodings::alp::ALPVTable::try_new( for_with_bp.into_array(), alp_compressed.exponents(), alp_compressed.patches().cloned(), @@ -137,7 +141,7 @@ mod setup { let codes_prim = PrimitiveArray::from_iter(codes); // Compress codes with BitPacked (6 bits should be enough for ~50 unique values) - let codes_bp = BitPackedArray::encode(&codes_prim.into_array(), 6) + let codes_bp = BitPackedVTable::encode(&codes_prim.into_array(), 6) .unwrap() .into_array(); @@ -168,25 +172,25 @@ mod setup { } let prim_array = PrimitiveArray::from_iter(values); - let runend = RunEndArray::encode(prim_array.into_array()).unwrap(); + let runend = RunEndVTable::encode(prim_array.into_array()).unwrap(); // Compress the ends with FoR <- BitPacked let ends_prim = runend.ends().to_primitive(); - let ends_for = FoRArray::encode(ends_prim).unwrap(); + let ends_for = FoRVTable::encode(ends_prim).unwrap(); let ends_inner = ends_for.encoded(); - let ends_bp = BitPackedArray::encode(ends_inner, 8).unwrap(); + let ends_bp = BitPackedVTable::encode(ends_inner, 8).unwrap(); let compressed_ends = - FoRArray::try_new(ends_bp.into_array(), ends_for.reference_scalar().clone()) + FoRVTable::try_new(ends_bp.into_array(), ends_for.reference_scalar().clone()) .unwrap() .into_array(); // Compress the values with BitPacked let values_prim = runend.values().to_primitive(); - let compressed_values = BitPackedArray::encode(&values_prim.into_array(), 8) + let compressed_values = BitPackedVTable::encode(&values_prim.into_array(), 8) .unwrap() .into_array(); - RunEndArray::try_new(compressed_ends, compressed_values) + RunEndVTable::try_new(compressed_ends, compressed_values) .unwrap() .into_array() } @@ -246,7 +250,7 @@ mod setup { // Compress the VarBin offsets with BitPacked let codes = fsst.codes(); let offsets_prim = codes.offsets().to_primitive(); - let offsets_bp = BitPackedArray::encode(&offsets_prim.into_array(), 20).unwrap(); + let offsets_bp = BitPackedVTable::encode(&offsets_prim.into_array(), 20).unwrap(); // Rebuild VarBin with compressed offsets let compressed_codes = VarBinArray::try_new( @@ -258,7 +262,7 @@ mod setup { .unwrap(); // Rebuild FSST with compressed codes - let compressed_fsst = FSSTArray::try_new( + let compressed_fsst = FSSTVTable::try_new( fsst.dtype().clone(), fsst.symbols().clone(), fsst.symbol_lengths().clone(), @@ -297,20 +301,20 @@ mod setup { // Compress days with FoR <- BitPacked let days_prim = parts.days.to_primitive(); - let days_for = FoRArray::encode(days_prim).unwrap(); + let days_for = FoRVTable::encode(days_prim).unwrap(); let days_inner = days_for.encoded(); - let days_bp = BitPackedArray::encode(days_inner, 16).unwrap(); + let days_bp = BitPackedVTable::encode(days_inner, 16).unwrap(); let compressed_days = - FoRArray::try_new(days_bp.into_array(), days_for.reference_scalar().clone()) + FoRVTable::try_new(days_bp.into_array(), days_for.reference_scalar().clone()) .unwrap() .into_array(); // Compress seconds with FoR <- BitPacked let seconds_prim = parts.seconds.to_primitive(); - let seconds_for = FoRArray::encode(seconds_prim).unwrap(); + let seconds_for = FoRVTable::encode(seconds_prim).unwrap(); let seconds_inner = seconds_for.encoded(); - let seconds_bp = BitPackedArray::encode(seconds_inner, 17).unwrap(); - let compressed_seconds = FoRArray::try_new( + let seconds_bp = BitPackedVTable::encode(seconds_inner, 17).unwrap(); + let compressed_seconds = FoRVTable::try_new( seconds_bp.into_array(), seconds_for.reference_scalar().clone(), ) @@ -319,17 +323,17 @@ mod setup { // Compress subseconds with FoR <- BitPacked let subseconds_prim = parts.subseconds.to_primitive(); - let subseconds_for = FoRArray::encode(subseconds_prim).unwrap(); + let subseconds_for = FoRVTable::encode(subseconds_prim).unwrap(); let subseconds_inner = subseconds_for.encoded(); - let subseconds_bp = BitPackedArray::encode(subseconds_inner, 20).unwrap(); - let compressed_subseconds = FoRArray::try_new( + let subseconds_bp = BitPackedVTable::encode(subseconds_inner, 20).unwrap(); + let compressed_subseconds = FoRVTable::try_new( subseconds_bp.into_array(), subseconds_for.reference_scalar().clone(), ) .unwrap() .into_array(); - DateTimePartsArray::try_new( + DateTimePartsVTable::try_new( DType::Extension(temporal_array.ext_dtype()), compressed_days, compressed_seconds, diff --git a/vortex/benches/single_encoding_throughput.rs b/vortex/benches/single_encoding_throughput.rs index dd22c30b996..5a64ea0ad05 100644 --- a/vortex/benches/single_encoding_throughput.rs +++ b/vortex/benches/single_encoding_throughput.rs @@ -21,18 +21,18 @@ use vortex::array::builtins::ArrayBuiltins; use vortex::dtype::PType; use vortex::encodings::alp::RDEncoder; use vortex::encodings::alp::alp_encode; -use vortex::encodings::fastlanes::DeltaArray; -use vortex::encodings::fastlanes::FoRArray; +use vortex::encodings::fastlanes::DeltaVTable; +use vortex::encodings::fastlanes::FoRVTable; use vortex::encodings::fastlanes::delta_compress; use vortex::encodings::fsst::fsst_compress; use vortex::encodings::fsst::fsst_train_compressor; -use vortex::encodings::pco::PcoArray; -use vortex::encodings::runend::RunEndArray; +use vortex::encodings::pco::PcoVTable; +use vortex::encodings::runend::RunEndVTable; use vortex::encodings::sequence::sequence_encode; use vortex::encodings::zigzag::zigzag_encode; -use vortex::encodings::zstd::ZstdArray; +use vortex::encodings::zstd::ZstdVTable; use vortex_array::dtype::Nullability; -use vortex_sequence::SequenceArray; +use vortex_sequence::SequenceVTable; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; @@ -122,13 +122,13 @@ fn bench_runend_compress_u32(bencher: Bencher) { with_byte_counter(bencher, NUM_VALUES * 4) .with_inputs(|| uint_array.clone()) - .bench_values(|a| RunEndArray::encode(a.into_array()).unwrap()); + .bench_values(|a| RunEndVTable::encode(a.into_array()).unwrap()); } #[divan::bench(name = "runend_decompress_u32")] fn bench_runend_decompress_u32(bencher: Bencher) { let (uint_array, ..) = setup_primitive_arrays(); - let compressed = RunEndArray::encode(uint_array.into_array()).unwrap(); + let compressed = RunEndVTable::encode(uint_array.into_array()).unwrap(); with_byte_counter(bencher, NUM_VALUES * 4) .with_inputs(|| &compressed) @@ -143,7 +143,7 @@ fn bench_delta_compress_u32(bencher: Bencher) { .with_inputs(|| &uint_array) .bench_refs(|a| { let (bases, deltas) = delta_compress(a).unwrap(); - DeltaArray::try_from_delta_compress_parts(bases.into_array(), deltas.into_array()) + DeltaVTable::try_from_delta_compress_parts(bases.into_array(), deltas.into_array()) .unwrap() }); } @@ -153,7 +153,8 @@ fn bench_delta_decompress_u32(bencher: Bencher) { let (uint_array, ..) = setup_primitive_arrays(); let (bases, deltas) = delta_compress(&uint_array).unwrap(); let compressed = - DeltaArray::try_from_delta_compress_parts(bases.into_array(), deltas.into_array()).unwrap(); + DeltaVTable::try_from_delta_compress_parts(bases.into_array(), deltas.into_array()) + .unwrap(); with_byte_counter(bencher, NUM_VALUES * 4) .with_inputs(|| &compressed) @@ -166,13 +167,13 @@ fn bench_for_compress_i32(bencher: Bencher) { with_byte_counter(bencher, NUM_VALUES * 4) .with_inputs(|| int_array.clone()) - .bench_values(|a| FoRArray::encode(a).unwrap()); + .bench_values(|a| FoRVTable::encode(a).unwrap()); } #[divan::bench(name = "for_decompress_i32")] fn bench_for_decompress_i32(bencher: Bencher) { let (_, int_array, _) = setup_primitive_arrays(); - let compressed = FoRArray::encode(int_array).unwrap(); + let compressed = FoRVTable::encode(int_array).unwrap(); with_byte_counter(bencher, NUM_VALUES * 4) .with_inputs(|| &compressed) @@ -231,7 +232,7 @@ fn bench_sequence_compress_u32(bencher: Bencher) { #[divan::bench(name = "sequence_decompress_u32")] fn bench_sequence_decompress_u32(bencher: Bencher) { let compressed = - SequenceArray::try_new_typed(0, 1, Nullability::NonNullable, NUM_VALUES as usize) + SequenceVTable::try_new_typed(0, 1, Nullability::NonNullable, NUM_VALUES as usize) .unwrap() .into_array(); @@ -288,13 +289,13 @@ fn bench_pcodec_compress_f64(bencher: Bencher) { with_byte_counter(bencher, NUM_VALUES * 8) .with_inputs(|| &float_array) - .bench_refs(|a| PcoArray::from_primitive(a, 3, 0).unwrap()); + .bench_refs(|a| PcoVTable::from_primitive(a, 3, 0).unwrap()); } #[divan::bench(name = "pcodec_decompress_f64")] fn bench_pcodec_decompress_f64(bencher: Bencher) { let (_, _, float_array) = setup_primitive_arrays(); - let compressed = PcoArray::from_primitive(&float_array, 3, 0).unwrap(); + let compressed = PcoVTable::from_primitive(&float_array, 3, 0).unwrap(); with_byte_counter(bencher, NUM_VALUES * 8) .with_inputs(|| &compressed) @@ -309,14 +310,14 @@ fn bench_zstd_compress_u32(bencher: Bencher) { with_byte_counter(bencher, NUM_VALUES * 4) .with_inputs(|| array.clone()) - .bench_values(|a| ZstdArray::from_array(a, 3, 8192).unwrap()); + .bench_values(|a| ZstdVTable::from_array(a, 3, 8192).unwrap()); } #[cfg(feature = "zstd")] #[divan::bench(name = "zstd_decompress_u32")] fn bench_zstd_decompress_u32(bencher: Bencher) { let (uint_array, ..) = setup_primitive_arrays(); - let compressed = ZstdArray::from_array(uint_array.into_array(), 3, 8192).unwrap(); + let compressed = ZstdVTable::from_array(uint_array.into_array(), 3, 8192).unwrap(); with_byte_counter(bencher, NUM_VALUES * 4) .with_inputs(|| &compressed) @@ -377,14 +378,14 @@ fn bench_zstd_compress_string(bencher: Bencher) { with_byte_counter(bencher, nbytes) .with_inputs(|| array.clone()) - .bench_values(|a| ZstdArray::from_array(a, 3, 8192).unwrap()); + .bench_values(|a| ZstdVTable::from_array(a, 3, 8192).unwrap()); } #[cfg(feature = "zstd")] #[divan::bench(name = "zstd_decompress_string")] fn bench_zstd_decompress_string(bencher: Bencher) { let varbinview_arr = VarBinViewArray::from_iter_str(gen_varbin_words(1_000_000, 0.00005)); - let compressed = ZstdArray::from_array(varbinview_arr.clone().into_array(), 3, 8192).unwrap(); + let compressed = ZstdVTable::from_array(varbinview_arr.clone().into_array(), 3, 8192).unwrap(); let nbytes = varbinview_arr.into_array().nbytes() as u64; with_byte_counter(bencher, nbytes)