diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 70ff659eda1..965a8a72406 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -58,7 +58,7 @@ pub type vortex_array::aggregate_fn::fns::sum::Sum::Options = vortex_array::aggr pub type vortex_array::aggregate_fn::fns::sum::Sum::Partial = vortex_array::aggregate_fn::fns::sum::SumPartial -pub fn vortex_array::aggregate_fn::fns::sum::Sum::accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::Canonical, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> +pub fn vortex_array::aggregate_fn::fns::sum::Sum::accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> pub fn vortex_array::aggregate_fn::fns::sum::Sum::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> @@ -286,7 +286,7 @@ pub type vortex_array::aggregate_fn::AggregateFnVTable::Options: 'static + core: pub type vortex_array::aggregate_fn::AggregateFnVTable::Partial: 'static + core::marker::Send -pub fn vortex_array::aggregate_fn::AggregateFnVTable::accumulate(&self, state: &mut Self::Partial, batch: &vortex_array::Canonical, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> +pub fn vortex_array::aggregate_fn::AggregateFnVTable::accumulate(&self, state: &mut Self::Partial, batch: &vortex_array::Columnar, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> pub fn vortex_array::aggregate_fn::AggregateFnVTable::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> @@ -316,7 +316,7 @@ pub type vortex_array::aggregate_fn::fns::sum::Sum::Options = vortex_array::aggr pub type vortex_array::aggregate_fn::fns::sum::Sum::Partial = vortex_array::aggregate_fn::fns::sum::SumPartial -pub fn vortex_array::aggregate_fn::fns::sum::Sum::accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::Canonical, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> +pub fn vortex_array::aggregate_fn::fns::sum::Sum::accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> pub fn vortex_array::aggregate_fn::fns::sum::Sum::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> diff --git a/vortex-array/src/aggregate_fn/accumulator.rs b/vortex-array/src/aggregate_fn/accumulator.rs index d3fab024739..5e9a12e53fb 100644 --- a/vortex-array/src/aggregate_fn/accumulator.rs +++ b/vortex-array/src/aggregate_fn/accumulator.rs @@ -7,7 +7,7 @@ use vortex_session::VortexSession; use crate::AnyCanonical; use crate::ArrayRef; -use crate::Canonical; +use crate::Columnar; use crate::DynArray; use crate::VortexSessionExecute; use crate::aggregate_fn::AggregateFn; @@ -131,11 +131,11 @@ impl DynAccumulator for Accumulator { batch = batch.execute(&mut ctx)?; } - // Otherwise, execute the batch until it is canonical and accumulate it into the state. - let canonical = batch.execute::(&mut ctx)?; + // Otherwise, execute the batch until it is columnar and accumulate it into the state. + let columnar = batch.execute::(&mut ctx)?; self.vtable - .accumulate(&mut self.partial, &canonical, &mut ctx) + .accumulate(&mut self.partial, &columnar, &mut ctx) } fn is_saturated(&self) -> bool { diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index ff85da637b9..17d165e3746 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -14,6 +14,7 @@ use vortex_session::VortexSession; use crate::AnyCanonical; use crate::ArrayRef; use crate::Canonical; +use crate::Columnar; use crate::DynArray; use crate::ExecutionCtx; use crate::IntoArray; @@ -121,7 +122,11 @@ impl DynGroupedAccumulator for GroupedAccumulator { // We first execute the groups until it is a ListView or FixedSizeList, since we only // dispatch the aggregate kernel over the elements of these arrays. - match groups.clone().execute::(&mut ctx)? { + let canonical = match groups.clone().execute::(&mut ctx)? { + Columnar::Canonical(c) => c, + Columnar::Constant(c) => c.into_array().execute::(&mut ctx)?, + }; + match canonical { Canonical::List(groups) => self.accumulate_list_view(&groups, &mut ctx), Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, &mut ctx), _ => vortex_panic!("We checked the DType above, so this should never happen"), @@ -192,7 +197,7 @@ impl GroupedAccumulator { } // Otherwise, we iterate the offsets and sizes and accumulate each group one by one. - let elements = elements.execute::(ctx)?.into_array(); + let elements = elements.execute::(ctx)?.into_array(); let offsets = groups.offsets(); let sizes = groups.sizes().cast(offsets.dtype().clone())?; let validity = groups.validity().to_mask(offsets.len()); @@ -279,7 +284,7 @@ impl GroupedAccumulator { } // Otherwise, we iterate the offsets and sizes and accumulate each group one by one. - let elements = elements.execute::(ctx)?.into_array(); + let elements = elements.execute::(ctx)?.into_array(); let validity = groups.validity().to_mask(groups.len()); let mut accumulator = Accumulator::try_new( diff --git a/vortex-array/src/aggregate_fn/fns/sum.rs b/vortex-array/src/aggregate_fn/fns/sum.rs index 5b16b585e37..52af2f3c4fb 100644 --- a/vortex-array/src/aggregate_fn/fns/sum.rs +++ b/vortex-array/src/aggregate_fn/fns/sum.rs @@ -14,11 +14,13 @@ use vortex_mask::AllOr; use crate::ArrayRef; use crate::Canonical; +use crate::Columnar; use crate::ExecutionCtx; use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::EmptyOptions; use crate::arrays::BoolArray; +use crate::arrays::ConstantArray; use crate::arrays::DecimalArray; use crate::arrays::PrimitiveArray; use crate::dtype::DType; @@ -149,7 +151,7 @@ impl AggregateFnVTable for Sum { fn accumulate( &self, partial: &mut Self::Partial, - batch: &Canonical, + batch: &Columnar, _ctx: &mut ExecutionCtx, ) -> VortexResult<()> { let mut inner = match partial.current.take() { @@ -158,10 +160,13 @@ impl AggregateFnVTable for Sum { }; let result = match batch { - Canonical::Primitive(p) => accumulate_primitive(&mut inner, p), - Canonical::Bool(b) => accumulate_bool(&mut inner, b), - Canonical::Decimal(d) => accumulate_decimal(&mut inner, d), - _ => vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()), + Columnar::Canonical(c) => match c { + Canonical::Primitive(p) => accumulate_primitive(&mut inner, p), + Canonical::Bool(b) => accumulate_bool(&mut inner, b), + Canonical::Decimal(d) => accumulate_decimal(&mut inner, d), + _ => vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()), + }, + Columnar::Constant(c) => accumulate_constant(&mut inner, c), }; match result { @@ -349,6 +354,85 @@ fn accumulate_bool(inner: &mut SumState, b: &BoolArray) -> VortexResult { Ok(checked_add_u64(acc, true_count)) } +/// Accumulate a constant array into the sum state. +/// Computes `scalar * len` and adds to the accumulator. +/// Returns Ok(true) if saturated (overflow), Ok(false) if not. +fn accumulate_constant(inner: &mut SumState, c: &ConstantArray) -> VortexResult { + let scalar = c.scalar(); + if scalar.is_null() || c.is_empty() { + return Ok(false); + } + let len = c.len(); + + match scalar.dtype() { + DType::Bool(_) => { + let SumState::Unsigned(acc) = inner else { + vortex_panic!("expected unsigned sum state for bool input"); + }; + let val = scalar + .as_bool() + .value() + .ok_or_else(|| vortex_err!("Expected non-null bool scalar for sum"))?; + if val { + Ok(checked_add_u64(acc, len as u64)) + } else { + Ok(false) + } + } + DType::Primitive(..) => { + let pvalue = scalar + .as_primitive() + .pvalue() + .ok_or_else(|| vortex_err!("Expected non-null primitive scalar for sum"))?; + match inner { + SumState::Unsigned(acc) => { + let val = pvalue.cast::()?; + match val.checked_mul(len as u64) { + Some(product) => Ok(checked_add_u64(acc, product)), + None => Ok(true), + } + } + SumState::Signed(acc) => { + let val = pvalue.cast::()?; + match i64::try_from(len).ok().and_then(|l| val.checked_mul(l)) { + Some(product) => Ok(checked_add_i64(acc, product)), + None => Ok(true), + } + } + SumState::Float(acc) => { + let val = pvalue.cast::()?; + *acc += val * len as f64; + Ok(false) + } + SumState::Decimal(_) => { + vortex_panic!("decimal sum state with primitive input") + } + } + } + DType::Decimal(..) => { + let SumState::Decimal(acc) = inner else { + vortex_panic!("expected decimal sum state for decimal input"); + }; + let val = scalar + .as_decimal() + .decimal_value() + .ok_or_else(|| vortex_err!("Expected non-null decimal scalar for sum"))?; + let len_decimal = DecimalValue::from(len as i128); + match val.checked_mul(&len_decimal) { + Some(product) => match acc.checked_add(&product) { + Some(r) => { + *acc = r; + Ok(false) + } + None => Ok(true), + }, + None => Ok(true), + } + } + _ => vortex_bail!("Unsupported constant type for sum: {}", scalar.dtype()), + } +} + /// Accumulate a decimal array into the sum state. /// Returns Ok(true) if saturated (overflow), Ok(false) if not. fn accumulate_decimal(inner: &mut SumState, d: &DecimalArray) -> VortexResult { diff --git a/vortex-array/src/aggregate_fn/vtable.rs b/vortex-array/src/aggregate_fn/vtable.rs index 7c53387b56b..0e45e8a54fd 100644 --- a/vortex-array/src/aggregate_fn/vtable.rs +++ b/vortex-array/src/aggregate_fn/vtable.rs @@ -12,7 +12,7 @@ use vortex_error::vortex_bail; use vortex_session::VortexSession; use crate::ArrayRef; -use crate::Canonical; +use crate::Columnar; use crate::DynArray; use crate::ExecutionCtx; use crate::IntoArray; @@ -95,7 +95,7 @@ pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync { fn accumulate( &self, state: &mut Self::Partial, - batch: &Canonical, + batch: &Columnar, ctx: &mut ExecutionCtx, ) -> VortexResult<()>;