Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions examples/reduce_scatter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/**
* InfiniCCL Example: ReduceScatter
* * This example demonstrates the planned API for performing a
* collective sum-reduction across multiple GPUs and nodes.
*/

#include <unistd.h>

#include <iostream>
#include <vector>

// Public API
#include "infiniccl.h"

// Example-Specific Utilities
#include "utils.h"

// Internal Headers (Accessible via example-specific include paths, technically
// not public APIs)
#include "backend_manifest.h"
#include "device.h"
#include "runtime.h"
#include "traits.h"

using namespace infini::ccl;

void RunReduceScatterExample(int argc, char **argv, int warmup_iter,
int profile_iter, const size_t kRecvCount) {
constexpr Device::Type kDevType =
ListGetBest<DevicePriority>(EnabledDevices{});
using Rt = Runtime<kDevType>;

CHECK_INFINI(infiniInit(&argc, &argv));

int rank, size;
CHECK_INFINI(infiniGetRank(&rank));
CHECK_INFINI(infiniGetSize(&size));

char hostname[256];
gethostname(hostname, sizeof(hostname));

// Map local rank to GPU device.
// Note: this is just for info printing. In practice, this part is not needed.
const char *local_rank_str = std::getenv("OMPI_COMM_WORLD_LOCAL_RANK");
int local_rank = 0;
if (local_rank_str != nullptr) {
local_rank = std::atoi(local_rank_str);
}

std::cout << "[Rank " << rank << "] Host: " << hostname
<< " | GPU: " << Device::StringFromType(kDevType) << " "
<< " | Device " << local_rank << std::endl;

// Setup Communicator
infiniComm_t comm = nullptr;
CHECK_INFINI(infiniCommInitAll(&comm, size, nullptr));

// ReduceScatter requires `send_count = recv_count * world_size`.
const size_t kSendCount = kRecvCount * static_cast<size_t>(size);

// Prepare Data
std::vector<float> h_send(kSendCount);
std::vector<float> h_recv(kRecvCount, 0.0f);

// Initialize: each rank provides its (rank + 1) as data.
for (size_t i = 0; i < kSendCount; ++i) {
h_send[i] = static_cast<float>(rank + 1);
}

float *d_send, *d_recv;
size_t send_bytes = kSendCount * sizeof(*d_send);
size_t recv_bytes = kRecvCount * sizeof(*d_recv);

CHECK_RT(Rt, Rt::Malloc(&d_send, send_bytes));
CHECK_RT(Rt, Rt::Malloc(&d_recv, recv_bytes));
CHECK_RT(Rt, Rt::Memcpy(d_send, h_send.data(), send_bytes,
Rt::MemcpyHostToDevice));
CHECK_RT(Rt, Rt::Memcpy(d_recv, h_recv.data(), recv_bytes,
Rt::MemcpyHostToDevice));

if (rank == 0) {
std::cout << "\n=== Performing ReduceScatter on GPU Memory ==="
<< std::endl;
std::cout << "Recv data size per rank: " << kRecvCount << " floats ("
<< recv_bytes / 1024 / 1024 << " MB)" << std::endl;
std::cout << "Send data size per rank: " << kSendCount << " floats ("
<< send_bytes / 1024 / 1024 << " MB)" << std::endl;
std::cout << "Operation: Sum" << std::endl;
std::cout << "Warm-up iterations: " << warmup_iter << std::endl;
std::cout << "Profile iterations: " << profile_iter << std::endl;
}

CHECK_RT(Rt, Rt::StreamSynchronize(nullptr));

// Warm-up and D2H transfer the answer.
CHECK_INFINI(infiniReduceScatter(d_send, d_recv, kRecvCount, infiniFloat32,
infiniSum, comm, nullptr));
CHECK_RT(Rt, Rt::Memcpy(h_recv.data(), d_recv, recv_bytes,
Rt::MemcpyDeviceToHost));

for (int i = 1; i < warmup_iter; ++i) {
CHECK_INFINI(infiniReduceScatter(d_send, d_recv, kRecvCount, infiniFloat32,
infiniSum, comm, nullptr));
}
CHECK_RT(Rt, Rt::StreamSynchronize(nullptr));

// Profiling
Timer timer;

for (int i = 0; i < profile_iter; ++i) {
CHECK_INFINI(infiniReduceScatter(d_send, d_recv, kRecvCount, infiniFloat32,
infiniSum, comm, nullptr));
}

CHECK_RT(Rt, Rt::StreamSynchronize(nullptr));
CHECK_RT(Rt, Rt::Memcpy(h_recv.data(), d_recv, recv_bytes,
Rt::MemcpyDeviceToHost));

double elapsed = timer.elapsed_ms() / static_cast<double>(profile_iter);

// Result Validation:
float expected = 0.0f;
for (int r = 0; r < size; ++r) {
expected += static_cast<float>(r + 1);
}

Validator::ValidateResult(h_recv.data(), kRecvCount, expected, rank);

// Metrics Reporting (Only from rank 0 for cleaner output)
if (rank == 0) {
Metrics metrics{elapsed, recv_bytes, size};
metrics.Print();
}

// Cleanup
CHECK_RT(Rt, Rt::Free(d_send));
CHECK_RT(Rt, Rt::Free(d_recv));

CHECK_INFINI(infiniCommDestroy(comm));
CHECK_INFINI(infiniFinalize());

if (rank == 0) {
std::cout << "InfiniCCL finalized." << std::endl;
}
}

int main(int argc, char **argv) {
int warmup_iters = 2;
int profile_iters = 20;
size_t recv_count = 1 << 20;

RunReduceScatterExample(argc, argv, warmup_iters, profile_iters, recv_count);

return EXIT_SUCCESS;
}
7 changes: 6 additions & 1 deletion include/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,13 @@ infiniResult_t infiniAllGather(const void *sendbuff, void *recvbuff,
size_t count, infiniDataType_t datatype,
infiniComm_t comm, void *stream);

infiniResult_t infiniReduceScatter(const void *sendbuff, void *recvbuff,
size_t recvcount, infiniDataType_t datatype,
infiniRedOp_t op, infiniComm_t comm,
void *stream);

#ifdef __cplusplus
}
#endif

#endif // INFINI_CCL_COMM_H_
#endif // INFINI_CCL_COMM_H_
58 changes: 58 additions & 0 deletions src/base/reduce_scatter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#ifndef INFINI_CCL_BASE_REDUCE_SCATTER_H_
#define INFINI_CCL_BASE_REDUCE_SCATTER_H_

#include "comm_impl.h"
#include "communicator.h"
#include "logging.h"
#include "operation.h"
#include "return_status_impl.h"

namespace infini::ccl {

template <BackendType backend_type, Device::Type device_type>
struct ReduceScatterImpl;

class ReduceScatter : public Operation<ReduceScatter> {
public:
template <BackendType backend_type, Device::Type device_type,
typename... Args>
static ReturnStatus Execute(const void *send_buff, void *recv_buff,
size_t recv_count, DataType datatype,
ReductionOpType op, void *comm_handle,
void *stream) {
if (HasInvalidArgs(send_buff, recv_buff, datatype, op, comm_handle)) {
return ReturnStatus::kInvalidArgument;
}
auto *comm = static_cast<Communicator *>(comm_handle);
return ReduceScatterImpl<backend_type, device_type>::Apply(
send_buff, recv_buff, recv_count, datatype, op, comm, stream);
}

private:
static bool HasInvalidArgs(const void *send_buff, void *recv_buff,
DataType datatype, ReductionOpType op,
void *comm_handle) {
if (!comm_handle) {
// TODO(lzm): change to use `glog`.
LOG("Invalid communicator handle for `ReduceScatter`.");
return true;
}
if (!send_buff || !recv_buff) {
LOG("Invalid buffer pointer for `ReduceScatter`.");
return true;
}
if (op < ReductionOpType::kSum || op >= ReductionOpType::kNumRedOps) {
LOG("Invalid reduction operation for `ReduceScatter`.");
return true;
}
if (datatype < DataType::kChar || datatype >= DataType::kNumTypes) {
LOG("Invalid data type for `ReduceScatter`.");
return true;
}
return false;
}
};

} // namespace infini::ccl

#endif // INFINI_CCL_BASE_REDUCE_SCATTER_H_
102 changes: 102 additions & 0 deletions src/ompi/impl/reduce_scatter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#ifndef INFINI_CCL_OMPI_IMPL_REDUCE_SCATTER_H_
#define INFINI_CCL_OMPI_IMPL_REDUCE_SCATTER_H_

#include <limits>

#include "base/reduce_scatter.h"
#include "communicator.h"
#include "dispatcher.h"
#include "logging.h"
#include "ompi/checks.h"
#include "ompi/comm_instance.h"
#include "ompi/type_map.h"

namespace infini::ccl {

template <Device::Type device_type>
class ReduceScatterImpl<BackendType::kOmpi, device_type> {
public:
static ReturnStatus Apply(const void *send_buff, void *recv_buff,
size_t recv_count, DataType data_type,
ReductionOpType op, Communicator *comm,
void *stream) {
constexpr Device::Type kDev =
ListGetBest<DevicePriority>(ActiveDevices<ReduceScatter>{});
using Rt = Runtime<kDev>;

auto *inst = static_cast<OmpiInstance *>(comm->inter_comm());

if (!inst || inst->handle == MPI_COMM_NULL) {
LOG("Invalid OpenMPI communicator instance for `ReduceScatter`.");
return ReturnStatus::kInternalError;
}

MPI_Datatype mpi_type = DataTypeToOmpiType(data_type);
MPI_Op mpi_op = RedOpToOmpiOp(op);

size_t world_size = static_cast<size_t>(comm->size());
size_t type_size = kDataTypeToSize.at(data_type);
size_t send_count = recv_count * world_size;
size_t send_bytes = send_count * type_size;
size_t recv_bytes = recv_count * type_size;

// Handle GPU Memory (Staging Pattern)
// Note: we simply use host-staging for now.
void *host_sendbuf = malloc(send_bytes);
void *host_recvbuf = malloc(recv_bytes);
if (!host_sendbuf || !host_recvbuf) {
free(host_sendbuf);
free(host_recvbuf);
LOG("Failed to allocate host buffers for `ReduceScatter` staging.");
return ReturnStatus::kSystemError;
}

CHECK_STATUS(Rt, Rt::Memcpy(host_sendbuf, send_buff, send_bytes,
Rt::MemcpyDeviceToHost));
CHECK_STATUS(Rt, Rt::StreamSynchronize(static_cast<Rt::Stream>(stream)));

if (recv_count > static_cast<size_t>(std::numeric_limits<int>::max())) {
LOG("recv_count exceeds MPI int range for `ReduceScatter`.");
free(host_sendbuf);
free(host_recvbuf);
return ReturnStatus::kInvalidArgument;
}
int mpi_recv_count = static_cast<int>(recv_count);

INFINI_CHECK_MPI(MPI_Reduce_scatter_block(host_sendbuf, host_recvbuf,
mpi_recv_count, mpi_type, mpi_op,
inst->handle));

if (op == ReductionOpType::kAvg) {
float scale = 1.0f / static_cast<float>(world_size);

DispatchFunc<kDev, AllTypes>(data_type, [&](auto dtype) {
using T = typename decltype(dtype)::type;

T *typed_buf = static_cast<T *>(host_recvbuf);

// Simply do the averaging on the CPU before the H2D copy.
for (size_t i = 0; i < recv_count; ++i) {
// TODO(lzm): should later use the unified `Cast` function instead of
// static_cast to support CPU custom types.
typed_buf[i] *= static_cast<T>(scale);
}
});
}

CHECK_STATUS(Rt, Rt::Memcpy(recv_buff, host_recvbuf, recv_bytes,
Rt::MemcpyHostToDevice));

free(host_sendbuf);
free(host_recvbuf);

return ReturnStatus::kSuccess;
}
};

template <>
struct BackendEnabled<ReduceScatter, BackendType::kOmpi> : std::true_type {};

} // namespace infini::ccl

#endif // INFINI_CCL_OMPI_IMPL_REDUCE_SCATTER_H_