From 1ad360b1ad16f819d5ec2f2b3805dc7b9056f543 Mon Sep 17 00:00:00 2001 From: halfman510 Date: Mon, 18 May 2026 12:20:48 +0000 Subject: [PATCH 1/3] feat: support `reducescatter` with OpenMPI backend implementation Modified file: - `include/comm.h` Added files: - `src/base/reduce_scatter.h` - `src/ompi/impl/reduce_scatter.h` - `examples/reduce_scatter.cc` --- examples/reduce_scatter.cc | 155 +++++++++++++++++++++++++++++++++ include/comm.h | 7 +- src/base/reduce_scatter.h | 58 ++++++++++++ src/ompi/impl/reduce_scatter.h | 102 ++++++++++++++++++++++ 4 files changed, 321 insertions(+), 1 deletion(-) create mode 100644 examples/reduce_scatter.cc create mode 100644 src/base/reduce_scatter.h create mode 100644 src/ompi/impl/reduce_scatter.h diff --git a/examples/reduce_scatter.cc b/examples/reduce_scatter.cc new file mode 100644 index 0000000..30a8f61 --- /dev/null +++ b/examples/reduce_scatter.cc @@ -0,0 +1,155 @@ +/** + * InfiniCCL Example: ReduceScatter + * * This example demonstrates the planned API for performing a + * collective sum-reduction across multiple GPUs and nodes. + */ + +#include + +#include +#include + +// Public API +#include "infiniccl.h" + +// Example-Specific Utilities +#include "utils.h" + +// Internal Headers (Accessible via example-specific include paths, technically +// not public APIs) +#include "backend_manifest.h" +#include "device.h" +#include "runtime.h" +#include "traits.h" + +using namespace infini::ccl; + +void RunReduceScatterExample(int argc, char **argv, int warmup_iter, + int profile_iter, const size_t kRecvCount) { + constexpr Device::Type kDevType = + ListGetBest(EnabledDevices{}); + using Rt = Runtime; + + CHECK_INFINI(infiniInit(&argc, &argv)); + + int rank, size; + CHECK_INFINI(infiniGetRank(&rank)); + CHECK_INFINI(infiniGetSize(&size)); + + char hostname[256]; + gethostname(hostname, sizeof(hostname)); + + // Map local rank to GPU device. + // Note: this is just for info printing. In practice, this part is not needed. + const char *local_rank_str = std::getenv("OMPI_COMM_WORLD_LOCAL_RANK"); + int local_rank = 0; + if (local_rank_str != nullptr) { + local_rank = std::atoi(local_rank_str); + } + + std::cout << "[Rank " << rank << "] Host: " << hostname + << " | GPU: " << Device::StringFromType(kDevType) << " " + << " | Device " << local_rank << std::endl; + + // Setup Communicator + infiniComm_t comm = nullptr; + CHECK_INFINI(infiniCommInitAll(&comm, size, nullptr)); + + // ReduceScatter requires send_count = recv_count * world_size. + const size_t kSendCount = kRecvCount * static_cast(size); + + // Prepare Data + std::vector h_send(kSendCount); + std::vector h_recv(kRecvCount, 0.0f); + + // Initialize: each rank provides its (rank + 1) as data. + for (size_t i = 0; i < kSendCount; ++i) { + h_send[i] = static_cast(rank + 1); + } + + float *d_send, *d_recv; + size_t send_bytes = kSendCount * sizeof(*d_send); + size_t recv_bytes = kRecvCount * sizeof(*d_recv); + + CHECK_RT(Rt, Rt::Malloc(&d_send, send_bytes)); + CHECK_RT(Rt, Rt::Malloc(&d_recv, recv_bytes)); + CHECK_RT(Rt, Rt::Memcpy(d_send, h_send.data(), send_bytes, + Rt::MemcpyHostToDevice)); + CHECK_RT(Rt, Rt::Memcpy(d_recv, h_recv.data(), recv_bytes, + Rt::MemcpyHostToDevice)); + + if (rank == 0) { + std::cout << "\n=== Performing ReduceScatter on GPU Memory ===" + << std::endl; + std::cout << "Recv data size per rank: " << kRecvCount << " floats (" + << recv_bytes / 1024 / 1024 << " MB)" << std::endl; + std::cout << "Send data size per rank: " << kSendCount << " floats (" + << send_bytes / 1024 / 1024 << " MB)" << std::endl; + std::cout << "Operation: Sum" << std::endl; + std::cout << "Warm-up iterations: " << warmup_iter << std::endl; + std::cout << "Profile iterations: " << profile_iter << std::endl; + } + + CHECK_RT(Rt, Rt::StreamSynchronize(nullptr)); + + // Warm-up and D2H transfer the answer. + CHECK_INFINI(infiniReduceScatter(d_send, d_recv, kRecvCount, infiniFloat32, + infiniSum, comm, nullptr)); + CHECK_RT(Rt, Rt::Memcpy(h_recv.data(), d_recv, recv_bytes, + Rt::MemcpyDeviceToHost)); + + for (int i = 1; i < warmup_iter; ++i) { + CHECK_INFINI(infiniReduceScatter(d_send, d_recv, kRecvCount, infiniFloat32, + infiniSum, comm, nullptr)); + } + CHECK_RT(Rt, Rt::StreamSynchronize(nullptr)); + + // Profiling + Timer timer; + + for (int i = 0; i < profile_iter; ++i) { + CHECK_INFINI(infiniReduceScatter(d_send, d_recv, kRecvCount, infiniFloat32, + infiniSum, comm, nullptr)); + } + + CHECK_RT(Rt, Rt::StreamSynchronize(nullptr)); + CHECK_RT(Rt, Rt::Memcpy(h_recv.data(), d_recv, recv_bytes, + Rt::MemcpyDeviceToHost)); + + double elapsed = timer.elapsed_ms() / static_cast(profile_iter); + + // Result Validation: + float expected = 0.0f; + for (int r = 0; r < size; ++r) { + expected += static_cast(r + 1); + } + + Validator::ValidateResult(h_recv.data(), kRecvCount, expected, rank); + + // Metrics Reporting (Only from rank 0 for cleaner output) + if (rank == 0) { + Metrics metrics{elapsed, recv_bytes, size}; + metrics.Print(); + } + + // Cleanup + CHECK_RT(Rt, Rt::Free(d_send)); + CHECK_RT(Rt, Rt::Free(d_recv)); + + CHECK_INFINI(infiniCommDestroy(comm)); + CHECK_INFINI(infiniFinalize()); + + if (rank == 0) { + std::cout << "InfiniCCL finalized." << std::endl; + } +} + +int main(int argc, char **argv) { + int warmup_iters = 2; + int profile_iters = 20; + size_t recv_count = 1 << 20; + + RunReduceScatterExample(argc, argv, warmup_iters, profile_iters, recv_count); + + return EXIT_SUCCESS; +} diff --git a/include/comm.h b/include/comm.h index 0029345..b278063 100644 --- a/include/comm.h +++ b/include/comm.h @@ -45,8 +45,13 @@ infiniResult_t infiniAllGather(const void *sendbuff, void *recvbuff, size_t count, infiniDataType_t datatype, infiniComm_t comm, void *stream); +infiniResult_t infiniReduceScatter(const void *sendbuff, void *recvbuff, + size_t recvcount, infiniDataType_t datatype, + infiniRedOp_t op, infiniComm_t comm, + void *stream); + #ifdef __cplusplus } #endif -#endif // INFINI_CCL_COMM_H_ +#endif // INFINI_CCL_COMM_H_ diff --git a/src/base/reduce_scatter.h b/src/base/reduce_scatter.h new file mode 100644 index 0000000..c06150b --- /dev/null +++ b/src/base/reduce_scatter.h @@ -0,0 +1,58 @@ +#ifndef INFINI_CCL_BASE_REDUCE_SCATTER_H_ +#define INFINI_CCL_BASE_REDUCE_SCATTER_H_ + +#include "comm_impl.h" +#include "communicator.h" +#include "logging.h" +#include "operation.h" +#include "return_status_impl.h" + +namespace infini::ccl { + +template +struct ReduceScatterImpl; + +class ReduceScatter : public Operation { + public: + template + static ReturnStatus Execute(const void *send_buff, void *recv_buff, + size_t recv_count, DataType datatype, + ReductionOpType op, void *comm_handle, + void *stream) { + if (HasInvalidArgs(send_buff, recv_buff, datatype, op, comm_handle)) { + return ReturnStatus::kInvalidArgument; + } + auto *comm = static_cast(comm_handle); + return ReduceScatterImpl::Apply( + send_buff, recv_buff, recv_count, datatype, op, comm, stream); + } + + private: + static bool HasInvalidArgs(const void *send_buff, void *recv_buff, + DataType datatype, ReductionOpType op, + void *comm_handle) { + if (!comm_handle) { + // TODO(lzm): change to use `glog`. + LOG("Invalid communicator handle for `ReduceScatter`."); + return true; + } + if (!send_buff || !recv_buff) { + LOG("Invalid buffer pointer for `ReduceScatter`."); + return true; + } + if (op < ReductionOpType::kSum || op >= ReductionOpType::kNumRedOps) { + LOG("Invalid reduction operation for `ReduceScatter`."); + return true; + } + if (datatype < DataType::kChar || datatype >= DataType::kNumTypes) { + LOG("Invalid data type for `ReduceScatter`."); + return true; + } + return false; + } +}; + +} // namespace infini::ccl + +#endif // INFINI_CCL_BASE_REDUCE_SCATTER_H_ diff --git a/src/ompi/impl/reduce_scatter.h b/src/ompi/impl/reduce_scatter.h new file mode 100644 index 0000000..57fc2ce --- /dev/null +++ b/src/ompi/impl/reduce_scatter.h @@ -0,0 +1,102 @@ +#ifndef INFINI_CCL_OMPI_IMPL_REDUCE_SCATTER_H_ +#define INFINI_CCL_OMPI_IMPL_REDUCE_SCATTER_H_ + +#include + +#include "base/reduce_scatter.h" +#include "communicator.h" +#include "dispatcher.h" +#include "logging.h" +#include "ompi/checks.h" +#include "ompi/comm_instance.h" +#include "ompi/type_map.h" + +namespace infini::ccl { + +template +class ReduceScatterImpl { + public: + static ReturnStatus Apply(const void *send_buff, void *recv_buff, + size_t recv_count, DataType data_type, + ReductionOpType op, Communicator *comm, + void *stream) { + constexpr Device::Type kDev = + ListGetBest(ActiveDevices{}); + using Rt = Runtime; + + auto *inst = static_cast(comm->inter_comm()); + + if (!inst || inst->handle == MPI_COMM_NULL) { + LOG("Invalid OpenMPI communicator instance for `ReduceScatter`."); + return ReturnStatus::kInternalError; + } + + MPI_Datatype mpi_type = DataTypeToOmpiType(data_type); + MPI_Op mpi_op = RedOpToOmpiOp(op); + + size_t world_size = static_cast(comm->size()); + size_t type_size = kDataTypeToSize.at(data_type); + size_t send_count = recv_count * world_size; + size_t send_bytes = send_count * type_size; + size_t recv_bytes = recv_count * type_size; + + // Handle GPU Memory (Staging Pattern) + // Note: we simply use host-staging for now. + void *host_sendbuf = malloc(send_bytes); + void *host_recvbuf = malloc(recv_bytes); + if (!host_sendbuf || !host_recvbuf) { + free(host_sendbuf); + free(host_recvbuf); + LOG("Failed to allocate host buffers for `ReduceScatter` staging."); + return ReturnStatus::kSystemError; + } + + CHECK_STATUS(Rt, Rt::Memcpy(host_sendbuf, send_buff, send_bytes, + Rt::MemcpyDeviceToHost)); + CHECK_STATUS(Rt, Rt::StreamSynchronize(static_cast(stream))); + + if (recv_count > static_cast(std::numeric_limits::max())) { + LOG("recv_count exceeds MPI int range for ReduceScatter."); + free(host_sendbuf); + free(host_recvbuf); + return ReturnStatus::kInvalidArgument; + } + int mpi_recv_count = static_cast(recv_count); + + INFINI_CHECK_MPI(MPI_Reduce_scatter_block(host_sendbuf, host_recvbuf, + mpi_recv_count, mpi_type, mpi_op, + inst->handle)); + + if (op == ReductionOpType::kAvg) { + float scale = 1.0f / static_cast(world_size); + + DispatchFunc(data_type, [&](auto dtype) { + using T = typename decltype(dtype)::type; + + T *typed_buf = static_cast(host_recvbuf); + + // Simply do the averaging on the CPU before the H2D copy. + for (size_t i = 0; i < recv_count; ++i) { + // TODO(lzm): should later use the unified `Cast` function instead of + // static_cast to support CPU custom types. + typed_buf[i] *= static_cast(scale); + } + }); + } + + CHECK_STATUS(Rt, Rt::Memcpy(recv_buff, host_recvbuf, recv_bytes, + Rt::MemcpyHostToDevice)); + + free(host_sendbuf); + free(host_recvbuf); + + return ReturnStatus::kSuccess; + } +}; + +template <> +struct BackendEnabled : std::true_type {}; + +} // namespace infini::ccl + +#endif // INFINI_CCL_OMPI_IMPL_REDUCE_SCATTER_H_ From 6761744b5eae9d68c6c7455a6eb65fbe80248274 Mon Sep 17 00:00:00 2001 From: halfman510 Date: Mon, 18 May 2026 12:38:00 +0000 Subject: [PATCH 2/3] fix: correct the comment style for the output logs in `src/ompi/impl/reduce_scatter.h` --- src/ompi/impl/reduce_scatter.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ompi/impl/reduce_scatter.h b/src/ompi/impl/reduce_scatter.h index 57fc2ce..88854fc 100644 --- a/src/ompi/impl/reduce_scatter.h +++ b/src/ompi/impl/reduce_scatter.h @@ -56,7 +56,7 @@ class ReduceScatterImpl { CHECK_STATUS(Rt, Rt::StreamSynchronize(static_cast(stream))); if (recv_count > static_cast(std::numeric_limits::max())) { - LOG("recv_count exceeds MPI int range for ReduceScatter."); + LOG("recv_count exceeds MPI int range for `ReduceScatter`."); free(host_sendbuf); free(host_recvbuf); return ReturnStatus::kInvalidArgument; From af38416469cad0bb79211989b0f612d4ebd5430b Mon Sep 17 00:00:00 2001 From: halfman510 Date: Tue, 19 May 2026 05:59:38 +0000 Subject: [PATCH 3/3] fix: correct the code formatting in the comments of `examples/reduce_scatter.cc`. --- examples/reduce_scatter.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/reduce_scatter.cc b/examples/reduce_scatter.cc index 30a8f61..f8b086a 100644 --- a/examples/reduce_scatter.cc +++ b/examples/reduce_scatter.cc @@ -55,7 +55,7 @@ void RunReduceScatterExample(int argc, char **argv, int warmup_iter, infiniComm_t comm = nullptr; CHECK_INFINI(infiniCommInitAll(&comm, size, nullptr)); - // ReduceScatter requires send_count = recv_count * world_size. + // ReduceScatter requires `send_count = recv_count * world_size`. const size_t kSendCount = kRecvCount * static_cast(size); // Prepare Data