diff --git a/vortex-array/src/arrays/chunked/array.rs b/vortex-array/src/arrays/chunked/array.rs index 1f37712b991..98b1c76dd5e 100644 --- a/vortex-array/src/arrays/chunked/array.rs +++ b/vortex-array/src/arrays/chunked/array.rs @@ -107,8 +107,8 @@ impl ChunkedArray { #[inline] pub fn chunk(&self, idx: usize) -> &ArrayRef { assert!(idx < self.nchunks(), "chunk index {idx} out of bounds"); - - &self.chunks[idx] + // SAFETY: bounds checked by the assert above. + unsafe { self.chunks.get_unchecked(idx) } } pub fn nchunks(&self) -> usize { diff --git a/vortex-array/src/arrays/chunked/compute/zip.rs b/vortex-array/src/arrays/chunked/compute/zip.rs index f5ff1fd589d..bcb9fbe58d6 100644 --- a/vortex-array/src/arrays/chunked/compute/zip.rs +++ b/vortex-array/src/arrays/chunked/compute/zip.rs @@ -29,39 +29,10 @@ impl ZipKernel for Chunked { .union_nullability(if_false.dtype().nullability()); let mut out_chunks = Vec::with_capacity(if_true.nchunks() + if_false.nchunks()); - let mut lhs_idx = 0; - let mut rhs_idx = 0; - let mut lhs_offset = 0; - let mut rhs_offset = 0; - let mut pos = 0; - let total_len = if_true.len(); - - while pos < total_len { - let lhs_chunk = if_true.chunk(lhs_idx); - let rhs_chunk = if_false.chunk(rhs_idx); - - let lhs_rem = lhs_chunk.len() - lhs_offset; - let rhs_rem = rhs_chunk.len() - rhs_offset; - let take_until = lhs_rem.min(rhs_rem); - - let mask_slice = mask.slice(pos..pos + take_until)?; - let lhs_slice = lhs_chunk.slice(lhs_offset..lhs_offset + take_until)?; - let rhs_slice = rhs_chunk.slice(rhs_offset..rhs_offset + take_until)?; - - out_chunks.push(mask_slice.zip(lhs_slice, rhs_slice)?); - - pos += take_until; - lhs_offset += take_until; - rhs_offset += take_until; - - if lhs_offset == lhs_chunk.len() { - lhs_idx += 1; - lhs_offset = 0; - } - if rhs_offset == rhs_chunk.len() { - rhs_idx += 1; - rhs_offset = 0; - } + for pair in if_true.paired_chunks(if_false) { + let pair = pair?; + let mask_slice = mask.slice(pair.pos)?; + out_chunks.push(mask_slice.zip(pair.left, pair.right)?); } // SAFETY: chunks originate from zipping slices of inputs that share dtype/nullability. diff --git a/vortex-array/src/arrays/chunked/mod.rs b/vortex-array/src/arrays/chunked/mod.rs index e9cc33ae322..b9a24da320e 100644 --- a/vortex-array/src/arrays/chunked/mod.rs +++ b/vortex-array/src/arrays/chunked/mod.rs @@ -5,6 +5,7 @@ mod array; pub use array::ChunkedArray; pub(crate) mod compute; +mod paired_chunks; mod vtable; pub use vtable::Chunked; diff --git a/vortex-array/src/arrays/chunked/paired_chunks.rs b/vortex-array/src/arrays/chunked/paired_chunks.rs new file mode 100644 index 00000000000..1677a489070 --- /dev/null +++ b/vortex-array/src/arrays/chunked/paired_chunks.rs @@ -0,0 +1,265 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::ops::Range; + +use vortex_error::VortexResult; + +use crate::ArrayRef; +use crate::arrays::ChunkedArray; + +pub(crate) struct AlignedPair { + pub left: ArrayRef, + pub right: ArrayRef, + pub pos: Range, +} + +/// Cursor over a chunk slice that maintains the invariant: `idx` always +/// points at a non-empty chunk or is past the end. +struct ChunkCursor<'a> { + chunks: &'a [ArrayRef], + idx: usize, + offset: usize, +} + +impl<'a> ChunkCursor<'a> { + fn new(chunks: &'a [ArrayRef]) -> Self { + let mut cursor = Self { + chunks, + idx: 0, + offset: 0, + }; + cursor.skip_empty(); + cursor + } + + fn skip_empty(&mut self) { + while self.idx < self.chunks.len() + && unsafe { self.chunks.get_unchecked(self.idx) }.is_empty() + { + self.idx += 1; + } + } + + fn current_chunk(&self) -> Option<&'a ArrayRef> { + (self.idx < self.chunks.len()).then(|| unsafe { self.chunks.get_unchecked(self.idx) }) + } + + fn remaining(&self, chunk: &ArrayRef) -> usize { + chunk.len() - self.offset + } + + fn take(&mut self, chunk: &ArrayRef, n: usize) -> VortexResult { + let slice = chunk.slice(self.offset..self.offset + n)?; + self.offset += n; + if self.offset == chunk.len() { + self.idx += 1; + self.offset = 0; + self.skip_empty(); + } + Ok(slice) + } +} + +pub(crate) struct PairedChunks<'a> { + left: ChunkCursor<'a>, + right: ChunkCursor<'a>, + pos: usize, + total_len: usize, +} + +impl ChunkedArray { + pub(crate) fn paired_chunks<'a>(&'a self, other: &'a ChunkedArray) -> PairedChunks<'a> { + assert_eq!( + self.len(), + other.len(), + "paired_chunks requires arrays of equal length" + ); + PairedChunks { + left: ChunkCursor::new(&self.chunks), + right: ChunkCursor::new(&other.chunks), + pos: 0, + total_len: self.len(), + } + } +} + +impl Iterator for PairedChunks<'_> { + type Item = VortexResult; + + fn next(&mut self) -> Option { + if self.pos >= self.total_len { + return None; + } + + let lhs_chunk = self.left.current_chunk()?; + let rhs_chunk = self.right.current_chunk()?; + + let take = self + .left + .remaining(lhs_chunk) + .min(self.right.remaining(rhs_chunk)); + + let (lhs_slice, rhs_slice) = match self + .left + .take(lhs_chunk, take) + .and_then(|l| self.right.take(rhs_chunk, take).map(|r| (l, r))) + { + Ok(pair) => pair, + Err(e) => return Some(Err(e)), + }; + + let start = self.pos; + self.pos += take; + + Some(Ok(AlignedPair { + left: lhs_slice, + right: rhs_slice, + pos: start..self.pos, + })) + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_error::VortexResult; + + use crate::IntoArray; + use crate::arrays::ChunkedArray; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; + + fn i32_dtype() -> DType { + DType::Primitive(PType::I32, Nullability::NonNullable) + } + + #[allow(clippy::type_complexity)] + fn collect_pairs( + left: &ChunkedArray, + right: &ChunkedArray, + ) -> VortexResult, Vec, std::ops::Range)>> { + use crate::ToCanonical; + let mut result = Vec::new(); + for pair in left.paired_chunks(right) { + let pair = pair?; + let l: Vec = pair.left.to_primitive().as_slice::().to_vec(); + let r: Vec = pair.right.to_primitive().as_slice::().to_vec(); + result.push((l, r, pair.pos)); + } + Ok(result) + } + + #[test] + fn test_aligned_chunks() -> VortexResult<()> { + let left = ChunkedArray::try_new( + vec![buffer![1i32, 2].into_array(), buffer![3i32, 4].into_array()], + i32_dtype(), + )?; + let right = ChunkedArray::try_new( + vec![ + buffer![10i32, 20].into_array(), + buffer![30i32, 40].into_array(), + ], + i32_dtype(), + )?; + + let pairs = collect_pairs(&left, &right)?; + assert_eq!(pairs.len(), 2); + assert_eq!(pairs[0], (vec![1, 2], vec![10, 20], 0..2)); + assert_eq!(pairs[1], (vec![3, 4], vec![30, 40], 2..4)); + Ok(()) + } + + #[test] + fn test_misaligned_chunks() -> VortexResult<()> { + let left = ChunkedArray::try_new( + vec![ + buffer![1i32, 2].into_array(), + buffer![3i32].into_array(), + buffer![4i32, 5].into_array(), + ], + i32_dtype(), + )?; + let right = ChunkedArray::try_new( + vec![ + buffer![10i32].into_array(), + buffer![20i32, 30].into_array(), + buffer![40i32, 50].into_array(), + ], + i32_dtype(), + )?; + + let pairs = collect_pairs(&left, &right)?; + assert_eq!(pairs.len(), 4); + assert_eq!(pairs[0], (vec![1], vec![10], 0..1)); + assert_eq!(pairs[1], (vec![2], vec![20], 1..2)); + assert_eq!(pairs[2], (vec![3], vec![30], 2..3)); + assert_eq!(pairs[3], (vec![4, 5], vec![40, 50], 3..5)); + Ok(()) + } + + #[test] + fn test_empty_chunks() -> VortexResult<()> { + let left = ChunkedArray::try_new( + vec![ + buffer![0i32; 0].into_array(), + buffer![1i32, 2, 3].into_array(), + ], + i32_dtype(), + )?; + let right = ChunkedArray::try_new( + vec![ + buffer![10i32, 20, 30].into_array(), + buffer![0i32; 0].into_array(), + ], + i32_dtype(), + )?; + + let pairs = collect_pairs(&left, &right)?; + assert_eq!(pairs.len(), 1); + assert_eq!(pairs[0], (vec![1, 2, 3], vec![10, 20, 30], 0..3)); + Ok(()) + } + + #[test] + fn test_single_element_chunks() -> VortexResult<()> { + let left = ChunkedArray::try_new( + vec![ + buffer![1i32].into_array(), + buffer![2i32].into_array(), + buffer![3i32].into_array(), + ], + i32_dtype(), + )?; + let right = ChunkedArray::try_new(vec![buffer![10i32, 20, 30].into_array()], i32_dtype())?; + + let pairs = collect_pairs(&left, &right)?; + assert_eq!(pairs.len(), 3); + assert_eq!(pairs[0], (vec![1], vec![10], 0..1)); + assert_eq!(pairs[1], (vec![2], vec![20], 1..2)); + assert_eq!(pairs[2], (vec![3], vec![30], 2..3)); + Ok(()) + } + + #[test] + fn test_both_empty() -> VortexResult<()> { + let left = ChunkedArray::try_new(vec![], i32_dtype())?; + let right = ChunkedArray::try_new(vec![], i32_dtype())?; + + let pairs = collect_pairs(&left, &right)?; + assert!(pairs.is_empty()); + Ok(()) + } + + #[test] + #[should_panic(expected = "paired_chunks requires arrays of equal length")] + fn test_length_mismatch_panics() { + let left = ChunkedArray::try_new(vec![buffer![1i32, 2].into_array()], i32_dtype()).unwrap(); + let right = + ChunkedArray::try_new(vec![buffer![10i32, 20, 30].into_array()], i32_dtype()).unwrap(); + + drop(left.paired_chunks(&right).collect::>()); + } +} diff --git a/vortex-array/src/scalar_fn/fns/zip/mod.rs b/vortex-array/src/scalar_fn/fns/zip/mod.rs index c23b5373a28..e5f5f18d964 100644 --- a/vortex-array/src/scalar_fn/fns/zip/mod.rs +++ b/vortex-array/src/scalar_fn/fns/zip/mod.rs @@ -116,7 +116,9 @@ impl ScalarFnVTable for Zip { let if_false = args.get(1)?; let mask_array = args.get(2)?; - let mask = mask_array.execute::(ctx)?.to_mask(); + let mask = mask_array + .execute::(ctx)? + .to_mask_fill_null_false(); let return_dtype = if_true .dtype() @@ -228,6 +230,7 @@ mod tests { use super::zip_impl; use crate::ArrayRef; + use crate::DynArray; use crate::IntoArray; use crate::LEGACY_SESSION; use crate::VortexSessionExecute; @@ -235,13 +238,14 @@ mod tests { use crate::arrays::PrimitiveArray; use crate::arrays::Struct; use crate::arrays::StructArray; - use crate::arrays::VarBinViewArray; + use crate::arrays::VarBinView; use crate::arrow::IntoArrowArray; use crate::assert_arrays_eq; use crate::builders::ArrayBuilder; use crate::builders::BufferGrowthStrategy; use crate::builders::VarBinViewBuilder; use crate::builtins::ArrayBuiltins; + use crate::columnar::Columnar; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; @@ -372,10 +376,12 @@ mod tests { let mask = Mask::from_indices(len, indices); let mask_array = mask.into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); let result = mask_array .clone() .zip(const1.clone(), const2.clone())? - .execute::(&mut LEGACY_SESSION.create_execution_ctx())?; + .execute::(&mut ctx)? + .into_array(); insta::assert_snapshot!(result.display_tree(), @r" root: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%) [all_valid] @@ -390,7 +396,7 @@ mod tests { let wrapped_result = mask_array .zip(wrapped1, wrapped2)? - .execute::(&mut LEGACY_SESSION.create_execution_ctx())?; + .execute::(&mut ctx)?; assert!(wrapped_result.is::()); Ok(()) @@ -431,11 +437,13 @@ mod tests { let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0).collect()); let mask_array = mask.clone().into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); let zipped = mask_array .zip(if_true.clone(), if_false.clone()) .unwrap() - .execute::(&mut LEGACY_SESSION.create_execution_ctx()) + .execute::(&mut ctx) .unwrap(); + let zipped = zipped.as_opt::().unwrap(); assert_eq!(zipped.nbuffers(), 2); let expected = arrow_zip( @@ -448,7 +456,7 @@ mod tests { ) .unwrap(); - let actual = zipped.into_array().into_arrow_preferred().unwrap(); + let actual = zipped.clone().into_array().into_arrow_preferred().unwrap(); assert_eq!(actual.as_ref(), expected.as_ref()); } }