Skip to content

TorchComms Integration for MSCCL++#771

Open
michael-beebe wants to merge 5 commits intomicrosoft:mainfrom
michael-beebe:michaelbeebe/torchcomms
Open

TorchComms Integration for MSCCL++#771
michael-beebe wants to merge 5 commits intomicrosoft:mainfrom
michael-beebe:michaelbeebe/torchcomms

Conversation

@michael-beebe
Copy link
Copy Markdown

What This PR Does

This PR adds TorchComms support to MSCCL++, allowing PyTorch users to use MSCCL++ collectives through the TorchComms API with a single line:

comm = torchcomms.new_comm("mscclpp", device, name="grad_sync")
comm.all_reduce(tensor, torchcomms.ReduceOp.SUM, False)

This is valuable because it gives PyTorch training frameworks (torchtitan, FSDP2, etc.) a clean way to use MSCCL++ for high-performance collectives without LD_PRELOAD hacks or custom CUDA kernel code. Users can run MSCCL++ for the hot-path collectives (allreduce, allgather) and NCCL for everything else — mixed-backend training with no code changes.

Architecture

Communicator Lifecycle

When a user calls torchcomms.new_comm("mscclpp", device), TorchComms dlopen's our _comms_mscclpp.*.so module and calls init(), which:

  1. Bootstrap — discovers rank/world_size from torchrun environment, exchanges a UniqueId through c10d::Store (rank 0 generates, others read), creates the MSCCL++ Communicator with a TcpBootstrap
  2. Scratch buffer — allocates 128MB via GpuBuffer (cuMemMap) for native algorithms that need intermediate storage
  3. Executor — creates the DSL plan executor (used by DSL algorithms, ignored by native ones)
  4. Algorithm collection — calls AlgorithmCollectionBuilder::buildDefaultAlgorithms() which registers 12 native algorithms + 2 DSL plans, then wires up the topology-aware algorithm selector
  5. Event pool — pre-allocates a pool of 256 reusable CUDA events for async work tracking

What Happens When You Call a Collective

comm.all_reduce(tensor, torchcomms.ReduceOp.SUM, False)
    │
    ▼
TorchCommMSCCLPP::all_reduce()
    │  validates reduce op (SUM/MIN only)
    │  ensures tensor is contiguous
    │
    ▼
TorchCommMSCCLPP::executeCollective("allreduce", sendbuf, recvbuf, size, dtype, ...)
    │
    │  1. Builds a CollectiveRequest with world_size, nRanksPerNode,
    │     rank, buffer pointers, message size, stream, dtype
    │
    │  2. Calls algorithmCollection_.selectAlgorithm(request)
    │     → selector considers message size, NVLS support, compute
    │       capability, symmetric memory, CUDA graph capture mode
    │     → returns the best algorithm (e.g., nvls_warp_pipeline for 4MB)
    │
    │  3. Creates TorchWorkMSCCLPP handle, records start GPU event
    │
    │  4. Calls algo->execute(comm, input, output, size, dtype, op, stream, executor)
    │     → native algorithms launch a CUDA kernel directly
    │     → DSL algorithms use the executor to interpret a JSON plan
    │
    │  5. Records end GPU event, returns the work handle
    │
    ▼
TorchWorkMSCCLPP (returned to caller)
    │  wait() → cudaStreamWaitEvent on caller's stream (GPU-side, no CPU block)
    │  checkStatus() → polls GPU events for completion/timeout

Component Diagram

torchcomms.new_comm("mscclpp", device)
    │
    ▼
TorchCommMSCCLPPPy.cpp          ← pybind11 module + dynamic loader interface
    │
    ▼
TorchCommMSCCLPP.cpp/hpp        ← backend class (init, finalize, collective dispatch)
    │
    ├── TorchCommMSCCLPPBootstrap.cpp/hpp  ← rank discovery via c10d::Store
    ├── TorchWorkMSCCLPP.cpp/hpp           ← GPU event pool + async work tracking
    │
    ▼
AlgorithmCollection::selectAlgorithm()   ← MSCCL++ native algorithm selection
    │
    ▼
Algorithm::execute()                      ← GPU kernel launch (native or DSL)

The backend is a thin adapter. It does not implement any collective algorithms — it delegates entirely to MSCCL++'s AlgorithmCollection, which selects the optimal native algorithm based on message size, topology, NVLS support, and compute capability.

Files to Review

Core Backend (focus here)

File Lines What It Does
TorchCommMSCCLPP.hpp ~180 Backend class declaration. All public methods, private members with comments.
TorchCommMSCCLPP.cpp ~510 The main implementation. init() bootstraps and builds the AlgorithmCollection. executeCollective() is the central dispatch — builds a CollectiveRequest, calls selectAlgorithm(), executes. Unsupported ops throw with NCCL/RCCL guidance.
TorchCommMSCCLPPBootstrap.hpp/cpp ~130 Rank discovery. Rank 0 generates a UniqueId, writes to c10d::Store, other ranks read it. Same pattern as TorchComms' NCCL backend.
TorchWorkMSCCLPP.hpp/cpp ~230 GPU event pool (amortizes cudaEventCreate cost) + async work handle. wait() uses cudaStreamWaitEvent for GPU-side sync — no CPU blocking.
TorchCommMSCCLPPPy.cpp ~60 Minimal pybind11 module. Exposes the class + DynamicLoaderInterface for TorchComms' dlopen discovery.
CMakeLists.txt (torchcomm) ~115 FetchContent for torchcomms headers, links mscclpp + PyTorch + pybind11.

Build System

File Change
CMakeLists.txt (root) 2 lines: adds MSCCLPP_BUILD_EXT_TORCHCOMMS option (OFF by default) and add_subdirectory()

Tests, Benchmarks, Docs

Tests (6 files, ~960 lines), benchmarks (3 files, ~500 lines), and docs (quickstart.md) are straightforward and lower review priority.

Supported Collectives

Collective Native Algorithms Notes
AllReduce allpair_packet, nvls_packet, packet, nvls, nvls_warp_pipeline, nvls_block_pipeline, fullmesh, rsag, rsag_pipeline, rsag_zero_copy Auto-selected by message size + topology. SUM and MIN reduction ops.
AllGather fullmesh, fullmesh2 Auto-selected by message size.
ReduceScatter (none natively) Dispatched if a DSL algorithm is registered.
AllToAll (none natively) Dispatched if a DSL algorithm is registered.
Broadcast, Reduce, Send/Recv, Barrier, Scatter, Gather Not supported Each throws with explicit message suggesting NCCL/RCCL.

Key Design Decisions

1. Thin adapter, not a reimplementation.
The backend calls AlgorithmCollection::selectAlgorithm() and Algorithm::execute(). It does not contain any collective kernel code. Algorithm registration, selection logic, and kernel implementations all live in MSCCL++ core.

2. Same algorithm selector as the NCCL extension.
We reuse algorithm_selector.hpp from src/ext/nccl/ so the TorchComms path selects the same algorithms as the LD_PRELOAD NCCL shim. This avoids divergence and ensures consistent behavior.

3. Shared library linking (not static).
The module links against libmscclpp.so (not mscclpp_static.a) to avoid dual-singleton crashes. mscclpp_collectives.so links against the shared lib, so if we statically linked, there would be two copies of singletons like UnixSocketServer::instance().

4. GpuBuffer for scratch allocation.
Scratch memory is allocated via mscclpp::GpuBuffer (cuMemMap) instead of plain cudaMalloc. This registers POSIX file descriptors in the unix socket server, which is required for cross-rank IPC sharing. Plain cudaMalloc causes "Requested fd not found" crashes.

5. Build-gated behind MSCCLPP_BUILD_EXT_TORCHCOMMS=OFF.
No impact on existing builds. TorchComms headers are fetched on-demand via CMake FetchContent only when the option is enabled.

6. GPU event pooling.
Every collective call needs 2 CUDA events (start + end) for async tracking. Creating/destroying events costs ~5-10μs each. The pool amortizes this across thousands of collective calls in a training loop.

7. User-defined algorithms via AlgorithmCollectionBuilder singleton.
Custom algorithms (DSL or native) are configured on the builder before creating the TorchComms communicator. The backend picks them up during init(). No algorithm registration API lives on the backend itself.

Limitations

  • Only single-tensor collective variants are implemented. MSCCL++'s Algorithm::execute() operates on contiguous buffers (one input pointer, one output pointer), so the backend implements all_gather_single and reduce_scatter_single but not the tensor-list variants. The tensor-list variants throw with guidance to use the single-tensor variant instead.
  • Unsupported collectives throw at runtime. Broadcast, reduce, send/recv, barrier, scatter, and gather throw a RuntimeError with an explicit message naming the operation and suggesting the caller use a separate NCCL/RCCL communicator. This is the expected pattern for mixed-backend training.
  • Algorithm selector is duplicated from the NCCL extension. The backend includes algorithm_selector.hpp from src/ext/nccl/ rather than sharing it through a common path. A TODO in the code notes this should be consolidated into AlgorithmCollectionBuilder so all consumers get a default selector automatically.

How to Build and Test

# Build
mkdir -p build && cd build
cmake -DCMAKE_BUILD_TYPE=Release -DMSCCLPP_BUILD_EXT_TORCHCOMMS=ON ..
make -j$(nproc)

# Set env var
export TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP=$PWD/lib/_comms_mscclpp.cpython-*.so

# Run correctness tests
torchrun --nproc_per_node=8 test/torchcomms/test_correctness.py --all

# Run benchmarks
torchrun --nproc_per_node=8 test/torchcomms/bench_torchcomms.py --collective allreduce --warmup 100 --iters 200

@michael-beebe
Copy link
Copy Markdown
Author

@microsoft-github-policy-service agree company="Microsoft"

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a TorchComms backend module (_comms_mscclpp*.so) that adapts TorchComms collective calls onto MSCCL++’s AlgorithmCollection (native + DSL), enabling PyTorch users to select MSCCL++ as a TorchComms backend at runtime.

Changes:

  • Introduces a new C++ TorchComms backend implementation (bootstrap via c10d::Store, algorithm selection, executor + scratch, CUDA-event-based work tracking).
  • Adds TorchComms-focused tests/benchmarks and a quickstart section documenting build/run steps.
  • Adds a build option MSCCLPP_BUILD_EXT_TORCHCOMMS and a dedicated CMake target for building the backend module.

Reviewed changes

Copilot reviewed 21 out of 21 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
CMakeLists.txt Adds MSCCLPP_BUILD_EXT_TORCHCOMMS option and conditionally builds the TorchComms backend.
docs/quickstart.md Documents how to build, use, test, and benchmark TorchComms support.
python/mscclpp_torchcomm/CMakeLists.txt Fetches TorchComms sources/headers, builds _comms_mscclpp pybind module, links MSCCL++ + Torch.
python/mscclpp_torchcomm/__init__.py Package stub for TorchComms backend directory.
python/mscclpp_torchcomm/requirements_cuda12.txt Optional pip requirements for TorchComms backend environment.
python/mscclpp_torchcomm/csrc/TorchCommMSCCLPP.hpp Declares the TorchComms backend class and supported/unsupported collectives.
python/mscclpp_torchcomm/csrc/TorchCommMSCCLPP.cpp Implements init/finalize, algorithm selection wiring, and collective dispatch.
python/mscclpp_torchcomm/csrc/TorchCommMSCCLPPBootstrap.hpp Declares bootstrap helper for rank/size + UniqueId exchange via store.
python/mscclpp_torchcomm/csrc/TorchCommMSCCLPPBootstrap.cpp Implements UniqueId exchange and MSCCL++ communicator creation via TcpBootstrap.
python/mscclpp_torchcomm/csrc/TorchCommMSCCLPPPy.cpp Exposes minimal pybind module + TorchComms dynamic loader entrypoint.
python/mscclpp_torchcomm/csrc/TorchWorkMSCCLPP.hpp Declares CUDA event pool + TorchWork implementation.
python/mscclpp_torchcomm/csrc/TorchWorkMSCCLPP.cpp Implements pooled CUDA events and async work completion tracking.
test/torchcomms/test_correctness.py Correctness coverage for allreduce/allgather/reducescatter via TorchComms.
test/torchcomms/test_error_handling.py Verifies clear runtime errors for unsupported ops and invalid usage patterns.
test/torchcomms/test_sizes.py Sweeps message sizes to exercise selection boundaries and correctness.
test/torchcomms/test_training_loop.py Simulates multi-iteration training-loop allreduce pattern.
test/torchcomms/test_user_algorithms.py Validates user algorithm/selector registration via AlgorithmCollectionBuilder.
test/torchcomms/test_multicomm.py Documents current multi-communicator limitation (expected skip).
test/torchcomms/bench_torchcomms.py TorchComms benchmark driver for allreduce/allgather.
test/torchcomms/bench_report.py Generates a report/figures from benchmark JSON output.
test/torchcomms/run_benchmarks.sh Convenience runner to produce benchmark JSON + report + plots.

Comment thread python/mscclpp_torchcomm/csrc/TorchCommMSCCLPPBootstrap.cpp Outdated
Comment thread python/mscclpp_torchcomms/csrc/TorchCommMSCCLPP.cpp
Comment thread python/mscclpp_torchcomm/csrc/TorchCommMSCCLPP.cpp Outdated
Comment thread python/mscclpp_torchcomm/csrc/TorchCommMSCCLPP.cpp Outdated
Comment thread python/mscclpp_torchcomms/csrc/TorchCommMSCCLPP.cpp
Comment thread python/mscclpp_torchcomm/csrc/TorchCommMSCCLPP.cpp Outdated
Comment thread python/mscclpp_torchcomm/csrc/TorchCommMSCCLPP.cpp Outdated
Comment thread python/mscclpp_torchcomms/requirements_cuda12.txt
Comment thread python/mscclpp_torchcomm/CMakeLists.txt Outdated
Comment thread docs/quickstart.md Outdated
Comment thread docs/quickstart.md Outdated
Comment thread python/mscclpp_torchcomms/csrc/TorchCommMSCCLPP.cpp
Comment thread python/mscclpp_torchcomm/csrc/TorchCommMSCCLPP.cpp Outdated
Comment thread python/mscclpp_torchcomm/csrc/TorchCommMSCCLPP.cpp Outdated
michael-beebe added a commit to michael-beebe/mscclpp that referenced this pull request Apr 28, 2026
Review feedback (chhwang):
- TorchCommMSCCLPP::init(): replace raw cudaSetDevice with RAII
  CudaDeviceGuard to restore previous device on return/exception
- TorchCommMSCCLPP::init(): remove redundant cudaGetDevice call, use
  device_.index() directly for compute capability queries
- Add pip install support via separate mscclpp-torchcomms package with
  pyproject.toml, scikit-build-core, and auto-discovery of backend .so
- docs/quickstart.md: add tested version table

Review feedback (Copilot bot):
- TorchCommMSCCLPPBootstrap: add "_" delimiter between name and counter
  in store key to prevent collisions, make counter_ std::atomic<int>
- TorchCommMSCCLPP::finalize(): wrap cudaStreamSynchronize and
  cudaStreamDestroy with MSCCLPP_CUDATHROW for error surfacing
- All 4 supported collectives: replace tensor.contiguous() with
  TORCH_CHECK(tensor.is_contiguous()) to prevent silently dropping
  results for non-contiguous tensors
- CMakeLists.txt: replace manual glog search with find_package(glog
  REQUIRED) for consistency with codebase conventions

Rename and documentation:
- Rename python/mscclpp_torchcomm to python/mscclpp_torchcomms for
  consistency with the torchcomms library naming
- Add docs/torchcomms.md: standalone doc covering architecture,
  algorithm selection, user-defined algorithms, testing, benchmarks,
  limitations, and troubleshooting
- Slim down quickstart.md TorchComms section to brief snippet + link
- Add torchcomms entry to docs/index.rst
- Add import mscclpp_torchcomms to all test/benchmark files for
  automatic backend .so discovery (no env var needed)
- python/mscclpp_torchcomm/: TorchComms integration for MSCCL++
  - CMakeLists.txt: FetchContent torchcomms, links mscclpp + PyTorch
  - TorchCommMSCCLPP: backend class with init/finalize lifecycle,
    algorithm selection via AlgorithmCollection, GPU event-based
    async work tracking
  - TorchCommMSCCLPPBootstrap: rank discovery via c10d::Store
  - TorchWorkMSCCLPP: GPU event pool + async completion handles
  - TorchCommMSCCLPPPy: pybind11 module + dynamic loader interface
- CMakeLists.txt: add MSCCLPP_BUILD_EXT_TORCHCOMMS option (OFF default)
- Supported: allreduce (10 native algorithms), allgather (2 algorithms)
- Uses same algorithm selector as NCCL extension
- Links mscclpp shared lib (not static) to avoid dual-singleton crashes
- test_correctness.py: allreduce/allgather with --sweep mode for
  multi-size/dtype coverage, in-place and repeated variants
- test_sizes.py: message size sweep from 1 element to 32MB
- test_error_handling.py: unsupported ops, invalid reduce ops, metadata
- test_training_loop.py: simulated multi-iteration training loop
- test_multicomm.py: multiple communicators (known limitation)
- test_user_algorithms.py: DSL algorithm registration via builder
- bench_torchcomms.py: allreduce/allgather benchmark with CUDA event
  timing, curated sizes per native algorithm, JSON output
- bench_report.py: generates report + latency/bandwidth figures with
  algorithm region annotations
- run_benchmarks.sh: orchestrator script
- docs/quickstart.md: build instructions, usage example, supported
  collectives table, environment variables, test/benchmark commands
- Consistent with existing doc style (dollar prompts, MSCCLPP_BUILD var)
@michael-beebe michael-beebe force-pushed the michaelbeebe/torchcomms branch from e90ebe6 to a1f4b97 Compare April 28, 2026 18:54
michael-beebe added a commit to michael-beebe/mscclpp that referenced this pull request Apr 28, 2026
Review feedback (chhwang):
- TorchCommMSCCLPP::init(): replace raw cudaSetDevice with RAII
  CudaDeviceGuard to restore previous device on return/exception
- TorchCommMSCCLPP::init(): remove redundant cudaGetDevice call, use
  device_.index() directly for compute capability queries
- Add pip install support via separate mscclpp-torchcomms package with
  pyproject.toml, scikit-build-core, and auto-discovery of backend .so
- docs/quickstart.md: add tested version table

Review feedback (Copilot bot):
- TorchCommMSCCLPPBootstrap: add "_" delimiter between name and counter
  in store key to prevent collisions, make counter_ std::atomic<int>
- TorchCommMSCCLPP::finalize(): wrap cudaStreamSynchronize and
  cudaStreamDestroy with MSCCLPP_CUDATHROW for error surfacing
- All 4 supported collectives: replace tensor.contiguous() with
  TORCH_CHECK(tensor.is_contiguous()) to prevent silently dropping
  results for non-contiguous tensors
- CMakeLists.txt: replace manual glog search with find_package(glog
  REQUIRED) for consistency with codebase conventions

Rename and documentation:
- Rename python/mscclpp_torchcomm to python/mscclpp_torchcomms for
  consistency with the torchcomms library naming
- Add docs/torchcomms.md: standalone doc covering architecture,
  algorithm selection, user-defined algorithms, testing, benchmarks,
  limitations, and troubleshooting
- Slim down quickstart.md TorchComms section to brief snippet + link
- Add torchcomms entry to docs/index.rst
- Add import mscclpp_torchcomms to all test/benchmark files for
  automatic backend .so discovery (no env var needed)
Review feedback (chhwang):
- TorchCommMSCCLPP::init(): replace raw cudaSetDevice with RAII
  CudaDeviceGuard to restore previous device on return/exception
- TorchCommMSCCLPP::init(): remove redundant cudaGetDevice call, use
  device_.index() directly for compute capability queries
- Add pip install support via separate mscclpp-torchcomms package with
  pyproject.toml, scikit-build-core, and auto-discovery of backend .so
- docs/quickstart.md: add tested version table

Review feedback (Copilot bot):
- TorchCommMSCCLPPBootstrap: add "_" delimiter between name and counter
  in store key to prevent collisions, make counter_ std::atomic<int>
- TorchCommMSCCLPP::finalize(): wrap cudaStreamSynchronize and
  cudaStreamDestroy with MSCCLPP_CUDATHROW for error surfacing
- All 4 supported collectives: replace tensor.contiguous() with
  TORCH_CHECK(tensor.is_contiguous()) to prevent silently dropping
  results for non-contiguous tensors
- CMakeLists.txt: replace manual glog search with find_package(glog
  REQUIRED) for consistency with codebase conventions

Rename and documentation:
- Rename python/mscclpp_torchcomm to python/mscclpp_torchcomms for
  consistency with the torchcomms library naming
- Add docs/torchcomms.md: standalone doc covering architecture,
  algorithm selection, user-defined algorithms, testing, benchmarks,
  limitations, and troubleshooting
- Slim down quickstart.md TorchComms section to brief snippet + link
- Add torchcomms entry to docs/index.rst
- Add import mscclpp_torchcomms to all test/benchmark files for
  automatic backend .so discovery (no env var needed)
@michael-beebe michael-beebe force-pushed the michaelbeebe/torchcomms branch from a1f4b97 to 78af9ac Compare April 28, 2026 19:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants