diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 046ccf0b1c..fa7f081f17 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -534,6 +534,16 @@ object CometConf extends ShimCometConf { .checkValue(v => v > 0, "Write buffer size must be positive") .createWithDefault(1) + val COMET_SHUFFLE_SORT_BASED: ConfigEntry[Boolean] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.sort_based") + .category(CATEGORY_SHUFFLE) + .doc( + "When enabled, uses sort-based repartitioning for native shuffle. " + + "This avoids per-partition memory overhead from builders, making it more " + + "memory-efficient for large partition counts. Default is false (uses buffered mode).") + .booleanConf + .createWithDefault(false) + val COMET_SHUFFLE_PREFER_DICTIONARY_RATIO: ConfigEntry[Double] = conf( "spark.comet.shuffle.preferDictionary.ratio") .category(CATEGORY_SHUFFLE) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index ac35925ace..8d2dc58f60 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1371,6 +1371,7 @@ impl PhysicalPlanner { }?; let write_buffer_size = writer.write_buffer_size as usize; + let sort_based = writer.sort_based; let shuffle_writer = Arc::new(ShuffleWriterExec::try_new( Arc::clone(&child.native_plan), partitioning, @@ -1379,6 +1380,7 @@ impl PhysicalPlanner { writer.output_index_file.clone(), writer.tracing_enabled, write_buffer_size, + sort_based, )?); Ok(( diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index fb438b26a4..9d2b29c1b8 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -294,6 +294,10 @@ message ShuffleWriter { // Size of the write buffer in bytes used when writing shuffle data to disk. // Larger values may improve write performance but use more memory. int32 write_buffer_size = 8; + // When true, uses sort-based repartitioning for native shuffle. + // This avoids per-partition memory overhead from builders, making it more + // memory-efficient for large partition counts. + bool sort_based = 9; } message ParquetWriter { diff --git a/native/shuffle/benches/shuffle_writer.rs b/native/shuffle/benches/shuffle_writer.rs index 27abd919fa..e4e7940630 100644 --- a/native/shuffle/benches/shuffle_writer.rs +++ b/native/shuffle/benches/shuffle_writer.rs @@ -153,6 +153,7 @@ fn create_shuffle_writer_exec( "/tmp/index.out".to_string(), false, 1024 * 1024, + false, ) .unwrap() } diff --git a/native/shuffle/src/bin/shuffle_bench.rs b/native/shuffle/src/bin/shuffle_bench.rs index bb8c2a0380..e7856a6a64 100644 --- a/native/shuffle/src/bin/shuffle_bench.rs +++ b/native/shuffle/src/bin/shuffle_bench.rs @@ -114,6 +114,11 @@ struct Args { /// Each task reads the same input and writes to its own output files. #[arg(long, default_value_t = 1)] concurrent_tasks: usize, + + /// Shuffle mode: 'buffered' buffers all rows before writing (default), + /// 'sort' sorts each batch by partition ID and writes immediately. + #[arg(long, default_value = "buffered")] + mode: String, } fn main() { @@ -141,6 +146,7 @@ fn main() { println!("Partitioning: {}", args.partitioning); println!("Partitions: {}", args.partitions); println!("Codec: {:?}", codec); + println!("Mode: {}", args.mode); println!("Hash columns: {:?}", hash_col_indices); if let Some(mem_limit) = args.memory_limit { println!("Memory limit: {}", format_bytes(mem_limit)); @@ -403,6 +409,7 @@ fn run_shuffle_write( let rt = tokio::runtime::Runtime::new().unwrap(); rt.block_on(async { let start = Instant::now(); + let sort_based = args.mode == "sort"; let (shuffle_metrics, input_metrics) = execute_shuffle_write( input_path.to_str().unwrap(), codec.clone(), @@ -413,6 +420,7 @@ fn run_shuffle_write( args.limit, data_file.to_string(), index_file.to_string(), + sort_based, ) .await .unwrap(); @@ -436,6 +444,7 @@ async fn execute_shuffle_write( limit: usize, data_file: String, index_file: String, + sort_based: bool, ) -> datafusion::common::Result<(MetricsSet, MetricsSet)> { let config = SessionConfig::new().with_batch_size(batch_size); let mut runtime_builder = RuntimeEnvBuilder::new(); @@ -477,6 +486,7 @@ async fn execute_shuffle_write( index_file, false, write_buffer_size, + sort_based, ) .expect("Failed to create ShuffleWriterExec"); @@ -541,6 +551,7 @@ fn run_concurrent_shuffle_writes( let memory_limit = args.memory_limit; let write_buffer_size = args.write_buffer_size; let limit = args.limit; + let sort_based = args.mode == "sort"; handles.push(tokio::spawn(async move { execute_shuffle_write( @@ -553,6 +564,7 @@ fn run_concurrent_shuffle_writes( limit, data_file, index_file, + sort_based, ) .await .unwrap() diff --git a/native/shuffle/src/partitioners/mod.rs b/native/shuffle/src/partitioners/mod.rs index a0bc652b4b..c15e44f07a 100644 --- a/native/shuffle/src/partitioners/mod.rs +++ b/native/shuffle/src/partitioners/mod.rs @@ -19,10 +19,12 @@ mod empty_schema; mod multi_partition; mod partitioned_batch_iterator; mod single_partition; +mod sort_based; mod traits; pub(crate) use empty_schema::EmptySchemaShufflePartitioner; pub(crate) use multi_partition::MultiPartitionShuffleRepartitioner; pub(crate) use partitioned_batch_iterator::PartitionedBatchIterator; pub(crate) use single_partition::SinglePartitionShufflePartitioner; +pub(crate) use sort_based::SortBasedPartitioner; pub(crate) use traits::ShufflePartitioner; diff --git a/native/shuffle/src/partitioners/sort_based.rs b/native/shuffle/src/partitioners/sort_based.rs new file mode 100644 index 0000000000..dfd2d68ff4 --- /dev/null +++ b/native/shuffle/src/partitioners/sort_based.rs @@ -0,0 +1,418 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::metrics::ShufflePartitionerMetrics; +use crate::partitioners::ShufflePartitioner; +use crate::writers::BufBatchWriter; +use crate::{comet_partitioning, CometPartitioning, CompressionCodec, ShuffleBlockWriter}; +use arrow::array::{ArrayRef, RecordBatch, UInt32Array}; +use arrow::compute::take; +use arrow::datatypes::SchemaRef; +use datafusion::common::DataFusionError; +use datafusion::execution::disk_manager::RefCountedTempFile; +use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion_comet_spark_expr::murmur3::create_murmur3_hashes; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::fs::{File, OpenOptions}; +use std::io::{BufWriter, Seek, Write}; +use std::sync::Arc; +use tokio::time::Instant; + +/// Per-partition writer that owns a persistent BufBatchWriter with BatchCoalescer, +/// so small batches are accumulated to batch_size before encoding. +struct PartitionSpillWriter { + /// The BufBatchWriter that coalesces and encodes batches. + /// None until the first batch is written to this partition. + writer: Option>, + /// Temp file handle — kept alive so the file isn't deleted until we're done + _temp_file: Option, + /// Path to the spill file for copying in shuffle_write + spill_path: Option, +} + +impl PartitionSpillWriter { + fn new() -> Self { + Self { + writer: None, + _temp_file: None, + spill_path: None, + } + } + + fn ensure_writer( + &mut self, + runtime: &RuntimeEnv, + shuffle_block_writer: &ShuffleBlockWriter, + write_buffer_size: usize, + batch_size: usize, + ) -> datafusion::common::Result<()> { + if self.writer.is_none() { + let temp_file = runtime + .disk_manager + .create_tmp_file("sort shuffle spill")?; + let path = temp_file.path().to_path_buf(); + let file = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(&path) + .map_err(|e| { + DataFusionError::Execution(format!("Error creating spill file: {e}")) + })?; + self.writer = Some(BufBatchWriter::new( + shuffle_block_writer.clone(), + file, + write_buffer_size, + batch_size, + )); + self.spill_path = Some(path); + self._temp_file = Some(temp_file); + } + Ok(()) + } + + #[allow(clippy::too_many_arguments)] + fn write_batch( + &mut self, + batch: &RecordBatch, + runtime: &RuntimeEnv, + shuffle_block_writer: &ShuffleBlockWriter, + write_buffer_size: usize, + batch_size: usize, + encode_time: &datafusion::physical_plan::metrics::Time, + write_time: &datafusion::physical_plan::metrics::Time, + ) -> datafusion::common::Result<()> { + self.ensure_writer(runtime, shuffle_block_writer, write_buffer_size, batch_size)?; + self.writer + .as_mut() + .unwrap() + .write(batch, encode_time, write_time)?; + Ok(()) + } + + fn flush( + &mut self, + encode_time: &datafusion::physical_plan::metrics::Time, + write_time: &datafusion::physical_plan::metrics::Time, + ) -> datafusion::common::Result<()> { + if let Some(writer) = &mut self.writer { + writer.flush(encode_time, write_time)?; + } + Ok(()) + } + + fn path(&self) -> Option<&std::path::Path> { + self.spill_path.as_deref() + } +} + +/// A shuffle repartitioner that sorts each batch by partition ID using counting sort, +/// then slices and writes per-partition sub-batches immediately. This avoids +/// per-partition Arrow builders, so memory usage is O(batch_size) regardless of +/// partition count. +/// +/// Each partition has a persistent BufBatchWriter with a BatchCoalescer that accumulates +/// small slices to batch_size before encoding, avoiding per-slice IPC schema overhead. +pub(crate) struct SortBasedPartitioner { + output_data_file: String, + output_index_file: String, + partition_writers: Vec, + shuffle_block_writer: ShuffleBlockWriter, + partitioning: CometPartitioning, + runtime: Arc, + metrics: ShufflePartitionerMetrics, + batch_size: usize, + /// Held to keep the memory reservation alive for the lifetime of this partitioner + _reservation: MemoryReservation, + write_buffer_size: usize, + hashes_buf: Vec, + partition_ids: Vec, + sorted_indices: Vec, + partition_starts: Vec, +} + +impl SortBasedPartitioner { + #[allow(clippy::too_many_arguments)] + pub(crate) fn try_new( + partition: usize, + output_data_file: String, + output_index_file: String, + schema: SchemaRef, + partitioning: CometPartitioning, + metrics: ShufflePartitionerMetrics, + runtime: Arc, + batch_size: usize, + codec: CompressionCodec, + write_buffer_size: usize, + ) -> datafusion::common::Result { + let num_output_partitions = partitioning.partition_count(); + let shuffle_block_writer = ShuffleBlockWriter::try_new(schema.as_ref(), codec.clone())?; + let partition_writers = (0..num_output_partitions) + .map(|_| PartitionSpillWriter::new()) + .collect(); + let reservation = MemoryConsumer::new(format!("SortBasedPartitioner[{partition}]")) + .register(&runtime.memory_pool); + let hashes_buf = match partitioning { + CometPartitioning::Hash(_, _) | CometPartitioning::RoundRobin(_, _) => { + vec![0u32; batch_size] + } + _ => vec![], + }; + Ok(Self { + output_data_file, + output_index_file, + partition_writers, + shuffle_block_writer, + partitioning, + runtime, + metrics, + batch_size, + _reservation: reservation, + write_buffer_size, + hashes_buf, + partition_ids: vec![0u32; batch_size], + sorted_indices: vec![0u32; batch_size], + partition_starts: vec![0usize; num_output_partitions + 1], + }) + } + + fn compute_partition_ids_and_sort( + &mut self, + input: &RecordBatch, + ) -> datafusion::common::Result<()> { + let num_rows = input.num_rows(); + let num_partitions = self.partitioning.partition_count(); + + match &self.partitioning { + CometPartitioning::Hash(exprs, num_output_partitions) => { + let arrays = exprs + .iter() + .map(|expr| expr.evaluate(input)?.into_array(num_rows)) + .collect::>>()?; + let hashes_buf = &mut self.hashes_buf[..num_rows]; + hashes_buf.fill(42_u32); + create_murmur3_hashes(&arrays, hashes_buf)?; + let partition_ids = &mut self.partition_ids[..num_rows]; + for (idx, hash) in hashes_buf.iter().enumerate() { + partition_ids[idx] = + comet_partitioning::pmod(*hash, *num_output_partitions) as u32; + } + } + CometPartitioning::RangePartitioning( + lex_ordering, + _num_output_partitions, + row_converter, + bounds, + ) => { + let arrays = lex_ordering + .iter() + .map(|expr| expr.expr.evaluate(input)?.into_array(num_rows)) + .collect::>>()?; + let row_batch = row_converter.convert_columns(arrays.as_slice())?; + let partition_ids = &mut self.partition_ids[..num_rows]; + row_batch.iter().enumerate().for_each(|(row_idx, row)| { + partition_ids[row_idx] = bounds + .as_slice() + .partition_point(|bound| bound.row() <= row) + as u32; + }); + } + CometPartitioning::RoundRobin(num_output_partitions, max_hash_columns) => { + let num_columns_to_hash = if *max_hash_columns == 0 { + input.num_columns() + } else { + (*max_hash_columns).min(input.num_columns()) + }; + let columns_to_hash: Vec = (0..num_columns_to_hash) + .map(|i| Arc::clone(input.column(i))) + .collect(); + let hashes_buf = &mut self.hashes_buf[..num_rows]; + hashes_buf.fill(42_u32); + create_murmur3_hashes(&columns_to_hash, hashes_buf)?; + let partition_ids = &mut self.partition_ids[..num_rows]; + for (idx, hash) in hashes_buf.iter().enumerate() { + partition_ids[idx] = + comet_partitioning::pmod(*hash, *num_output_partitions) as u32; + } + } + other => { + return Err(DataFusionError::NotImplemented(format!( + "Unsupported shuffle partitioning scheme {other:?}" + ))); + } + } + + // Counting sort + let partition_starts = &mut self.partition_starts[..num_partitions + 1]; + partition_starts.fill(0); + let partition_ids = &self.partition_ids[..num_rows]; + for &pid in partition_ids.iter() { + partition_starts[pid as usize + 1] += 1; + } + for i in 1..=num_partitions { + partition_starts[i] += partition_starts[i - 1]; + } + let sorted_indices = &mut self.sorted_indices[..num_rows]; + let mut cursors = partition_starts.to_vec(); + for (row_idx, &pid) in partition_ids.iter().enumerate() { + let pos = cursors[pid as usize]; + sorted_indices[pos] = row_idx as u32; + cursors[pid as usize] += 1; + } + Ok(()) + } + + fn process_batch(&mut self, input: RecordBatch) -> datafusion::common::Result<()> { + if input.num_rows() == 0 { + return Ok(()); + } + let num_rows = input.num_rows(); + let num_partitions = self.partitioning.partition_count(); + + self.metrics.data_size.add(input.get_array_memory_size()); + self.metrics.baseline.record_output(num_rows); + + { + let repart_start = Instant::now(); + self.compute_partition_ids_and_sort(&input)?; + self.metrics + .repart_time + .add_duration(repart_start.elapsed()); + } + + let sorted_indices = &self.sorted_indices[..num_rows]; + let partition_starts = &self.partition_starts[..num_partitions + 1]; + let indices_array = UInt32Array::from_iter_values(sorted_indices.iter().copied()); + let sorted_batch = RecordBatch::try_new( + input.schema(), + input + .columns() + .iter() + .map(|col| take(col, &indices_array, None)) + .collect::, _>>()?, + )?; + + for partition_id in 0..num_partitions { + let start = partition_starts[partition_id]; + let end = partition_starts[partition_id + 1]; + let len = end - start; + if len == 0 { + continue; + } + let partition_batch = sorted_batch.slice(start, len); + self.partition_writers[partition_id].write_batch( + &partition_batch, + &self.runtime, + &self.shuffle_block_writer, + self.write_buffer_size, + self.batch_size, + &self.metrics.encode_time, + &self.metrics.write_time, + )?; + } + Ok(()) + } +} + +#[async_trait::async_trait] +impl ShufflePartitioner for SortBasedPartitioner { + async fn insert_batch(&mut self, batch: RecordBatch) -> datafusion::common::Result<()> { + let start_time = Instant::now(); + let mut start = 0; + while start < batch.num_rows() { + let end = (start + self.batch_size).min(batch.num_rows()); + let slice = batch.slice(start, end - start); + self.process_batch(slice)?; + start = end; + } + self.metrics.input_batches.add(1); + self.metrics + .baseline + .elapsed_compute() + .add_duration(start_time.elapsed()); + Ok(()) + } + + fn shuffle_write(&mut self) -> datafusion::common::Result<()> { + let start_time = Instant::now(); + let num_output_partitions = self.partition_writers.len(); + let mut offsets = vec![0i64; num_output_partitions + 1]; + let data_file = self.output_data_file.clone(); + let index_file = self.output_index_file.clone(); + + // Flush all partition writers to ensure all data is written to spill files + for writer in &mut self.partition_writers { + writer.flush(&self.metrics.encode_time, &self.metrics.write_time)?; + } + + let output_data = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(data_file) + .map_err(|e| DataFusionError::Execution(format!("shuffle write error: {e:?}")))?; + let mut output_data = BufWriter::new(output_data); + + for (i, partition_writer) in self + .partition_writers + .iter() + .enumerate() + .take(num_output_partitions) + { + offsets[i] = output_data.stream_position()? as i64; + if let Some(spill_path) = partition_writer.path() { + let mut spill_file = File::open(spill_path)?; + let mut write_timer = self.metrics.write_time.timer(); + std::io::copy(&mut spill_file, &mut output_data)?; + write_timer.stop(); + } + } + + let mut write_timer = self.metrics.write_time.timer(); + output_data.flush()?; + write_timer.stop(); + + offsets[num_output_partitions] = output_data.stream_position()? as i64; + + let mut write_timer = self.metrics.write_time.timer(); + let mut output_index = BufWriter::new( + File::create(index_file) + .map_err(|e| DataFusionError::Execution(format!("shuffle write error: {e:?}")))?, + ); + for offset in offsets { + output_index.write_all(&offset.to_le_bytes()[..])?; + } + output_index.flush()?; + write_timer.stop(); + + self.metrics + .baseline + .elapsed_compute() + .add_duration(start_time.elapsed()); + Ok(()) + } +} + +impl Debug for SortBasedPartitioner { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("SortBasedPartitioner") + .field("partitions", &self.partition_writers.len()) + .finish() + } +} diff --git a/native/shuffle/src/shuffle_writer.rs b/native/shuffle/src/shuffle_writer.rs index 8502c79624..91a8e46502 100644 --- a/native/shuffle/src/shuffle_writer.rs +++ b/native/shuffle/src/shuffle_writer.rs @@ -20,7 +20,7 @@ use crate::metrics::ShufflePartitionerMetrics; use crate::partitioners::{ EmptySchemaShufflePartitioner, MultiPartitionShuffleRepartitioner, ShufflePartitioner, - SinglePartitionShufflePartitioner, + SinglePartitionShufflePartitioner, SortBasedPartitioner, }; use crate::{CometPartitioning, CompressionCodec}; use async_trait::async_trait; @@ -67,6 +67,8 @@ pub struct ShuffleWriterExec { tracing_enabled: bool, /// Size of the write buffer in bytes write_buffer_size: usize, + /// Whether to use sort-based partitioning + sort_based: bool, } impl ShuffleWriterExec { @@ -80,6 +82,7 @@ impl ShuffleWriterExec { output_index_file: String, tracing_enabled: bool, write_buffer_size: usize, + sort_based: bool, ) -> Result { let cache = Arc::new(PlanProperties::new( EquivalenceProperties::new(Arc::clone(&input.schema())), @@ -98,6 +101,7 @@ impl ShuffleWriterExec { codec, tracing_enabled, write_buffer_size, + sort_based, }) } } @@ -158,6 +162,7 @@ impl ExecutionPlan for ShuffleWriterExec { self.output_index_file.clone(), self.tracing_enabled, self.write_buffer_size, + self.sort_based, )?)), _ => panic!("ShuffleWriterExec wrong number of children"), } @@ -185,6 +190,7 @@ impl ExecutionPlan for ShuffleWriterExec { self.codec.clone(), self.tracing_enabled, self.write_buffer_size, + self.sort_based, ) .map_err(|e| ArrowError::ExternalError(Box::new(e))), ) @@ -205,6 +211,7 @@ async fn external_shuffle( codec: CompressionCodec, tracing_enabled: bool, write_buffer_size: usize, + sort_based: bool, ) -> Result { let schema = input.schema(); @@ -229,6 +236,18 @@ async fn external_shuffle( codec, write_buffer_size, )?), + _ if sort_based => Box::new(SortBasedPartitioner::try_new( + partition, + output_data_file, + output_index_file, + Arc::clone(&schema), + partitioning, + metrics, + context.runtime_env(), + context.session_config().batch_size(), + codec, + write_buffer_size, + )?), _ => Box::new(MultiPartitionShuffleRepartitioner::try_new( partition, output_data_file, @@ -312,34 +331,61 @@ mod test { #[test] #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` fn test_single_partition_shuffle_writer() { - shuffle_write_test(1000, 100, 1, None); - shuffle_write_test(10000, 10, 1, None); + shuffle_write_test(1000, 100, 1, None, false); + shuffle_write_test(10000, 10, 1, None, false); } #[test] #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` fn test_insert_larger_batch() { - shuffle_write_test(10000, 1, 16, None); + shuffle_write_test(10000, 1, 16, None, false); } #[test] #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` fn test_insert_smaller_batch() { - shuffle_write_test(1000, 1, 16, None); - shuffle_write_test(1000, 10, 16, None); + shuffle_write_test(1000, 1, 16, None, false); + shuffle_write_test(1000, 10, 16, None, false); } #[test] #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` fn test_large_number_of_partitions() { - shuffle_write_test(10000, 10, 200, Some(10 * 1024 * 1024)); - shuffle_write_test(10000, 10, 2000, Some(10 * 1024 * 1024)); + shuffle_write_test(10000, 10, 200, Some(10 * 1024 * 1024), false); + shuffle_write_test(10000, 10, 2000, Some(10 * 1024 * 1024), false); } #[test] #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` fn test_large_number_of_partitions_spilling() { - shuffle_write_test(10000, 100, 200, Some(10 * 1024 * 1024)); + shuffle_write_test(10000, 100, 200, Some(10 * 1024 * 1024), false); + } + + #[test] + #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` + fn test_sort_based_basic() { + shuffle_write_test(1000, 100, 1, None, true); + shuffle_write_test(10000, 10, 1, None, true); + } + + #[test] + #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` + fn test_sort_based_insert_larger_batch() { + shuffle_write_test(10000, 1, 16, None, true); + } + + #[test] + #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` + fn test_sort_based_insert_smaller_batch() { + shuffle_write_test(1000, 1, 16, None, true); + shuffle_write_test(1000, 10, 16, None, true); + } + + #[test] + #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` + fn test_sort_based_large_number_of_partitions() { + shuffle_write_test(10000, 10, 200, Some(10 * 1024 * 1024), true); + shuffle_write_test(10000, 10, 2000, Some(10 * 1024 * 1024), true); } #[tokio::test] @@ -403,6 +449,7 @@ mod test { num_batches: usize, num_partitions: usize, memory_limit: Option, + sort_based: bool, ) { let batch = create_batch(batch_size); @@ -467,6 +514,7 @@ mod test { "/tmp/index.out".to_string(), false, 1024 * 1024, // write_buffer_size: 1MB default + sort_based, ) .unwrap(); @@ -526,6 +574,7 @@ mod test { index_file.clone(), false, 1024 * 1024, + false, ) .unwrap(); @@ -730,6 +779,7 @@ mod test { index_file.to_str().unwrap().to_string(), false, 1024 * 1024, + false, ) .unwrap(); @@ -818,6 +868,7 @@ mod test { index_file.to_str().unwrap().to_string(), false, 1024 * 1024, + false, ) .unwrap(); diff --git a/native/shuffle/src/writers/spill.rs b/native/shuffle/src/writers/spill.rs index c16caddbf9..46f8588cf9 100644 --- a/native/shuffle/src/writers/spill.rs +++ b/native/shuffle/src/writers/spill.rs @@ -19,6 +19,7 @@ use super::ShuffleBlockWriter; use crate::metrics::ShufflePartitionerMetrics; use crate::partitioners::PartitionedBatchIterator; use crate::writers::buf_batch_writer::BufBatchWriter; +use arrow::array::RecordBatch; use datafusion::common::DataFusionError; use datafusion::execution::disk_manager::RefCountedTempFile; use datafusion::execution::runtime_env::RuntimeEnv; @@ -113,6 +114,32 @@ impl PartitionWriter { } } + /// Write a single batch to this partition's spill file. + #[allow(dead_code)] // TODO: Remove once SortBasedPartitioner is wired in + pub(crate) fn spill_batch( + &mut self, + batch: &RecordBatch, + runtime: &RuntimeEnv, + metrics: &ShufflePartitionerMetrics, + write_buffer_size: usize, + batch_size: usize, + ) -> datafusion::common::Result { + if batch.num_rows() == 0 { + return Ok(0); + } + self.ensure_spill_file_created(runtime)?; + let mut buf_batch_writer = BufBatchWriter::new( + &mut self.shuffle_block_writer, + &mut self.spill_file.as_mut().unwrap().file, + write_buffer_size, + batch_size, + ); + let bytes_written = + buf_batch_writer.write(batch, &metrics.encode_time, &metrics.write_time)?; + buf_batch_writer.flush(&metrics.encode_time, &metrics.write_time)?; + Ok(bytes_written) + } + pub(crate) fn path(&self) -> Option<&std::path::Path> { self.spill_file .as_ref() diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index f27d021ac4..0d7091eae0 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -192,6 +192,7 @@ class CometNativeShuffleWriter[K, V]( CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get) shuffleWriterBuilder.setWriteBufferSize( CometConf.COMET_SHUFFLE_WRITE_BUFFER_SIZE.get().min(Int.MaxValue).toInt) + shuffleWriterBuilder.setSortBased(CometConf.COMET_SHUFFLE_SORT_BASED.get()) outputPartitioning match { case p if isSinglePartitioning(p) =>