Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,16 @@ object CometConf extends ShimCometConf {
.checkValue(v => v > 0, "Write buffer size must be positive")
.createWithDefault(1)

val COMET_SHUFFLE_SORT_BASED: ConfigEntry[Boolean] =
conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.sort_based")
.category(CATEGORY_SHUFFLE)
.doc(
"When enabled, uses sort-based repartitioning for native shuffle. " +
"This avoids per-partition memory overhead from builders, making it more " +
"memory-efficient for large partition counts. Default is false (uses buffered mode).")
.booleanConf
.createWithDefault(false)

val COMET_SHUFFLE_PREFER_DICTIONARY_RATIO: ConfigEntry[Double] = conf(
"spark.comet.shuffle.preferDictionary.ratio")
.category(CATEGORY_SHUFFLE)
Expand Down
2 changes: 2 additions & 0 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,7 @@ impl PhysicalPlanner {
}?;

let write_buffer_size = writer.write_buffer_size as usize;
let sort_based = writer.sort_based;
let shuffle_writer = Arc::new(ShuffleWriterExec::try_new(
Arc::clone(&child.native_plan),
partitioning,
Expand All @@ -1379,6 +1380,7 @@ impl PhysicalPlanner {
writer.output_index_file.clone(),
writer.tracing_enabled,
write_buffer_size,
sort_based,
)?);

Ok((
Expand Down
4 changes: 4 additions & 0 deletions native/proto/src/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,10 @@ message ShuffleWriter {
// Size of the write buffer in bytes used when writing shuffle data to disk.
// Larger values may improve write performance but use more memory.
int32 write_buffer_size = 8;
// When true, uses sort-based repartitioning for native shuffle.
// This avoids per-partition memory overhead from builders, making it more
// memory-efficient for large partition counts.
bool sort_based = 9;
}

message ParquetWriter {
Expand Down
1 change: 1 addition & 0 deletions native/shuffle/benches/shuffle_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ fn create_shuffle_writer_exec(
"/tmp/index.out".to_string(),
false,
1024 * 1024,
false,
)
.unwrap()
}
Expand Down
12 changes: 12 additions & 0 deletions native/shuffle/src/bin/shuffle_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ struct Args {
/// Each task reads the same input and writes to its own output files.
#[arg(long, default_value_t = 1)]
concurrent_tasks: usize,

/// Shuffle mode: 'buffered' buffers all rows before writing (default),
/// 'sort' sorts each batch by partition ID and writes immediately.
#[arg(long, default_value = "buffered")]
mode: String,
}

fn main() {
Expand Down Expand Up @@ -141,6 +146,7 @@ fn main() {
println!("Partitioning: {}", args.partitioning);
println!("Partitions: {}", args.partitions);
println!("Codec: {:?}", codec);
println!("Mode: {}", args.mode);
println!("Hash columns: {:?}", hash_col_indices);
if let Some(mem_limit) = args.memory_limit {
println!("Memory limit: {}", format_bytes(mem_limit));
Expand Down Expand Up @@ -403,6 +409,7 @@ fn run_shuffle_write(
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let start = Instant::now();
let sort_based = args.mode == "sort";
let (shuffle_metrics, input_metrics) = execute_shuffle_write(
input_path.to_str().unwrap(),
codec.clone(),
Expand All @@ -413,6 +420,7 @@ fn run_shuffle_write(
args.limit,
data_file.to_string(),
index_file.to_string(),
sort_based,
)
.await
.unwrap();
Expand All @@ -436,6 +444,7 @@ async fn execute_shuffle_write(
limit: usize,
data_file: String,
index_file: String,
sort_based: bool,
) -> datafusion::common::Result<(MetricsSet, MetricsSet)> {
let config = SessionConfig::new().with_batch_size(batch_size);
let mut runtime_builder = RuntimeEnvBuilder::new();
Expand Down Expand Up @@ -477,6 +486,7 @@ async fn execute_shuffle_write(
index_file,
false,
write_buffer_size,
sort_based,
)
.expect("Failed to create ShuffleWriterExec");

Expand Down Expand Up @@ -541,6 +551,7 @@ fn run_concurrent_shuffle_writes(
let memory_limit = args.memory_limit;
let write_buffer_size = args.write_buffer_size;
let limit = args.limit;
let sort_based = args.mode == "sort";

handles.push(tokio::spawn(async move {
execute_shuffle_write(
Expand All @@ -553,6 +564,7 @@ fn run_concurrent_shuffle_writes(
limit,
data_file,
index_file,
sort_based,
)
.await
.unwrap()
Expand Down
2 changes: 2 additions & 0 deletions native/shuffle/src/partitioners/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ mod empty_schema;
mod multi_partition;
mod partitioned_batch_iterator;
mod single_partition;
mod sort_based;
mod traits;

pub(crate) use empty_schema::EmptySchemaShufflePartitioner;
pub(crate) use multi_partition::MultiPartitionShuffleRepartitioner;
pub(crate) use partitioned_batch_iterator::PartitionedBatchIterator;
pub(crate) use single_partition::SinglePartitionShufflePartitioner;
pub(crate) use sort_based::SortBasedPartitioner;
pub(crate) use traits::ShufflePartitioner;
Loading
Loading