TorchComms Integration for MSCCL++#771
Open
michael-beebe wants to merge 5 commits intomicrosoft:mainfrom
Open
Conversation
Author
|
@microsoft-github-policy-service agree company="Microsoft" |
Contributor
There was a problem hiding this comment.
Pull request overview
Adds a TorchComms backend module (_comms_mscclpp*.so) that adapts TorchComms collective calls onto MSCCL++’s AlgorithmCollection (native + DSL), enabling PyTorch users to select MSCCL++ as a TorchComms backend at runtime.
Changes:
- Introduces a new C++ TorchComms backend implementation (bootstrap via
c10d::Store, algorithm selection, executor + scratch, CUDA-event-based work tracking). - Adds TorchComms-focused tests/benchmarks and a quickstart section documenting build/run steps.
- Adds a build option
MSCCLPP_BUILD_EXT_TORCHCOMMSand a dedicated CMake target for building the backend module.
Reviewed changes
Copilot reviewed 21 out of 21 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
CMakeLists.txt |
Adds MSCCLPP_BUILD_EXT_TORCHCOMMS option and conditionally builds the TorchComms backend. |
docs/quickstart.md |
Documents how to build, use, test, and benchmark TorchComms support. |
python/mscclpp_torchcomm/CMakeLists.txt |
Fetches TorchComms sources/headers, builds _comms_mscclpp pybind module, links MSCCL++ + Torch. |
python/mscclpp_torchcomm/__init__.py |
Package stub for TorchComms backend directory. |
python/mscclpp_torchcomm/requirements_cuda12.txt |
Optional pip requirements for TorchComms backend environment. |
python/mscclpp_torchcomm/csrc/TorchCommMSCCLPP.hpp |
Declares the TorchComms backend class and supported/unsupported collectives. |
python/mscclpp_torchcomm/csrc/TorchCommMSCCLPP.cpp |
Implements init/finalize, algorithm selection wiring, and collective dispatch. |
python/mscclpp_torchcomm/csrc/TorchCommMSCCLPPBootstrap.hpp |
Declares bootstrap helper for rank/size + UniqueId exchange via store. |
python/mscclpp_torchcomm/csrc/TorchCommMSCCLPPBootstrap.cpp |
Implements UniqueId exchange and MSCCL++ communicator creation via TcpBootstrap. |
python/mscclpp_torchcomm/csrc/TorchCommMSCCLPPPy.cpp |
Exposes minimal pybind module + TorchComms dynamic loader entrypoint. |
python/mscclpp_torchcomm/csrc/TorchWorkMSCCLPP.hpp |
Declares CUDA event pool + TorchWork implementation. |
python/mscclpp_torchcomm/csrc/TorchWorkMSCCLPP.cpp |
Implements pooled CUDA events and async work completion tracking. |
test/torchcomms/test_correctness.py |
Correctness coverage for allreduce/allgather/reducescatter via TorchComms. |
test/torchcomms/test_error_handling.py |
Verifies clear runtime errors for unsupported ops and invalid usage patterns. |
test/torchcomms/test_sizes.py |
Sweeps message sizes to exercise selection boundaries and correctness. |
test/torchcomms/test_training_loop.py |
Simulates multi-iteration training-loop allreduce pattern. |
test/torchcomms/test_user_algorithms.py |
Validates user algorithm/selector registration via AlgorithmCollectionBuilder. |
test/torchcomms/test_multicomm.py |
Documents current multi-communicator limitation (expected skip). |
test/torchcomms/bench_torchcomms.py |
TorchComms benchmark driver for allreduce/allgather. |
test/torchcomms/bench_report.py |
Generates a report/figures from benchmark JSON output. |
test/torchcomms/run_benchmarks.sh |
Convenience runner to produce benchmark JSON + report + plots. |
chhwang
requested changes
Apr 13, 2026
michael-beebe
added a commit
to michael-beebe/mscclpp
that referenced
this pull request
Apr 28, 2026
Review feedback (chhwang): - TorchCommMSCCLPP::init(): replace raw cudaSetDevice with RAII CudaDeviceGuard to restore previous device on return/exception - TorchCommMSCCLPP::init(): remove redundant cudaGetDevice call, use device_.index() directly for compute capability queries - Add pip install support via separate mscclpp-torchcomms package with pyproject.toml, scikit-build-core, and auto-discovery of backend .so - docs/quickstart.md: add tested version table Review feedback (Copilot bot): - TorchCommMSCCLPPBootstrap: add "_" delimiter between name and counter in store key to prevent collisions, make counter_ std::atomic<int> - TorchCommMSCCLPP::finalize(): wrap cudaStreamSynchronize and cudaStreamDestroy with MSCCLPP_CUDATHROW for error surfacing - All 4 supported collectives: replace tensor.contiguous() with TORCH_CHECK(tensor.is_contiguous()) to prevent silently dropping results for non-contiguous tensors - CMakeLists.txt: replace manual glog search with find_package(glog REQUIRED) for consistency with codebase conventions Rename and documentation: - Rename python/mscclpp_torchcomm to python/mscclpp_torchcomms for consistency with the torchcomms library naming - Add docs/torchcomms.md: standalone doc covering architecture, algorithm selection, user-defined algorithms, testing, benchmarks, limitations, and troubleshooting - Slim down quickstart.md TorchComms section to brief snippet + link - Add torchcomms entry to docs/index.rst - Add import mscclpp_torchcomms to all test/benchmark files for automatic backend .so discovery (no env var needed)
- python/mscclpp_torchcomm/: TorchComms integration for MSCCL++
- CMakeLists.txt: FetchContent torchcomms, links mscclpp + PyTorch
- TorchCommMSCCLPP: backend class with init/finalize lifecycle,
algorithm selection via AlgorithmCollection, GPU event-based
async work tracking
- TorchCommMSCCLPPBootstrap: rank discovery via c10d::Store
- TorchWorkMSCCLPP: GPU event pool + async completion handles
- TorchCommMSCCLPPPy: pybind11 module + dynamic loader interface
- CMakeLists.txt: add MSCCLPP_BUILD_EXT_TORCHCOMMS option (OFF default)
- Supported: allreduce (10 native algorithms), allgather (2 algorithms)
- Uses same algorithm selector as NCCL extension
- Links mscclpp shared lib (not static) to avoid dual-singleton crashes
- test_correctness.py: allreduce/allgather with --sweep mode for multi-size/dtype coverage, in-place and repeated variants - test_sizes.py: message size sweep from 1 element to 32MB - test_error_handling.py: unsupported ops, invalid reduce ops, metadata - test_training_loop.py: simulated multi-iteration training loop - test_multicomm.py: multiple communicators (known limitation) - test_user_algorithms.py: DSL algorithm registration via builder
- bench_torchcomms.py: allreduce/allgather benchmark with CUDA event timing, curated sizes per native algorithm, JSON output - bench_report.py: generates report + latency/bandwidth figures with algorithm region annotations - run_benchmarks.sh: orchestrator script
- docs/quickstart.md: build instructions, usage example, supported collectives table, environment variables, test/benchmark commands - Consistent with existing doc style (dollar prompts, MSCCLPP_BUILD var)
e90ebe6 to
a1f4b97
Compare
michael-beebe
added a commit
to michael-beebe/mscclpp
that referenced
this pull request
Apr 28, 2026
Review feedback (chhwang): - TorchCommMSCCLPP::init(): replace raw cudaSetDevice with RAII CudaDeviceGuard to restore previous device on return/exception - TorchCommMSCCLPP::init(): remove redundant cudaGetDevice call, use device_.index() directly for compute capability queries - Add pip install support via separate mscclpp-torchcomms package with pyproject.toml, scikit-build-core, and auto-discovery of backend .so - docs/quickstart.md: add tested version table Review feedback (Copilot bot): - TorchCommMSCCLPPBootstrap: add "_" delimiter between name and counter in store key to prevent collisions, make counter_ std::atomic<int> - TorchCommMSCCLPP::finalize(): wrap cudaStreamSynchronize and cudaStreamDestroy with MSCCLPP_CUDATHROW for error surfacing - All 4 supported collectives: replace tensor.contiguous() with TORCH_CHECK(tensor.is_contiguous()) to prevent silently dropping results for non-contiguous tensors - CMakeLists.txt: replace manual glog search with find_package(glog REQUIRED) for consistency with codebase conventions Rename and documentation: - Rename python/mscclpp_torchcomm to python/mscclpp_torchcomms for consistency with the torchcomms library naming - Add docs/torchcomms.md: standalone doc covering architecture, algorithm selection, user-defined algorithms, testing, benchmarks, limitations, and troubleshooting - Slim down quickstart.md TorchComms section to brief snippet + link - Add torchcomms entry to docs/index.rst - Add import mscclpp_torchcomms to all test/benchmark files for automatic backend .so discovery (no env var needed)
Review feedback (chhwang): - TorchCommMSCCLPP::init(): replace raw cudaSetDevice with RAII CudaDeviceGuard to restore previous device on return/exception - TorchCommMSCCLPP::init(): remove redundant cudaGetDevice call, use device_.index() directly for compute capability queries - Add pip install support via separate mscclpp-torchcomms package with pyproject.toml, scikit-build-core, and auto-discovery of backend .so - docs/quickstart.md: add tested version table Review feedback (Copilot bot): - TorchCommMSCCLPPBootstrap: add "_" delimiter between name and counter in store key to prevent collisions, make counter_ std::atomic<int> - TorchCommMSCCLPP::finalize(): wrap cudaStreamSynchronize and cudaStreamDestroy with MSCCLPP_CUDATHROW for error surfacing - All 4 supported collectives: replace tensor.contiguous() with TORCH_CHECK(tensor.is_contiguous()) to prevent silently dropping results for non-contiguous tensors - CMakeLists.txt: replace manual glog search with find_package(glog REQUIRED) for consistency with codebase conventions Rename and documentation: - Rename python/mscclpp_torchcomm to python/mscclpp_torchcomms for consistency with the torchcomms library naming - Add docs/torchcomms.md: standalone doc covering architecture, algorithm selection, user-defined algorithms, testing, benchmarks, limitations, and troubleshooting - Slim down quickstart.md TorchComms section to brief snippet + link - Add torchcomms entry to docs/index.rst - Add import mscclpp_torchcomms to all test/benchmark files for automatic backend .so discovery (no env var needed)
a1f4b97 to
78af9ac
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What This PR Does
This PR adds TorchComms support to MSCCL++, allowing PyTorch users to use MSCCL++ collectives through the TorchComms API with a single line:
This is valuable because it gives PyTorch training frameworks (torchtitan, FSDP2, etc.) a clean way to use MSCCL++ for high-performance collectives without LD_PRELOAD hacks or custom CUDA kernel code. Users can run MSCCL++ for the hot-path collectives (allreduce, allgather) and NCCL for everything else — mixed-backend training with no code changes.
Architecture
Communicator Lifecycle
When a user calls
torchcomms.new_comm("mscclpp", device), TorchComms dlopen's our_comms_mscclpp.*.somodule and callsinit(), which:UniqueIdthrough c10d::Store (rank 0 generates, others read), creates the MSCCL++Communicatorwith aTcpBootstrapGpuBuffer(cuMemMap) for native algorithms that need intermediate storageAlgorithmCollectionBuilder::buildDefaultAlgorithms()which registers 12 native algorithms + 2 DSL plans, then wires up the topology-aware algorithm selectorWhat Happens When You Call a Collective
Component Diagram
The backend is a thin adapter. It does not implement any collective algorithms — it delegates entirely to MSCCL++'s
AlgorithmCollection, which selects the optimal native algorithm based on message size, topology, NVLS support, and compute capability.Files to Review
Core Backend (focus here)
TorchCommMSCCLPP.hppTorchCommMSCCLPP.cppinit()bootstraps and builds the AlgorithmCollection.executeCollective()is the central dispatch — builds aCollectiveRequest, callsselectAlgorithm(), executes. Unsupported ops throw with NCCL/RCCL guidance.TorchCommMSCCLPPBootstrap.hpp/cppUniqueId, writes to c10d::Store, other ranks read it. Same pattern as TorchComms' NCCL backend.TorchWorkMSCCLPP.hpp/cppwait()usescudaStreamWaitEventfor GPU-side sync — no CPU blocking.TorchCommMSCCLPPPy.cppDynamicLoaderInterfacefor TorchComms' dlopen discovery.CMakeLists.txt(torchcomm)Build System
CMakeLists.txt(root)MSCCLPP_BUILD_EXT_TORCHCOMMSoption (OFF by default) andadd_subdirectory()Tests, Benchmarks, Docs
Tests (6 files, ~960 lines), benchmarks (3 files, ~500 lines), and docs (quickstart.md) are straightforward and lower review priority.
Supported Collectives
Key Design Decisions
1. Thin adapter, not a reimplementation.
The backend calls
AlgorithmCollection::selectAlgorithm()andAlgorithm::execute(). It does not contain any collective kernel code. Algorithm registration, selection logic, and kernel implementations all live in MSCCL++ core.2. Same algorithm selector as the NCCL extension.
We reuse
algorithm_selector.hppfromsrc/ext/nccl/so the TorchComms path selects the same algorithms as the LD_PRELOAD NCCL shim. This avoids divergence and ensures consistent behavior.3. Shared library linking (not static).
The module links against
libmscclpp.so(notmscclpp_static.a) to avoid dual-singleton crashes.mscclpp_collectives.solinks against the shared lib, so if we statically linked, there would be two copies of singletons likeUnixSocketServer::instance().4. GpuBuffer for scratch allocation.
Scratch memory is allocated via
mscclpp::GpuBuffer(cuMemMap) instead of plaincudaMalloc. This registers POSIX file descriptors in the unix socket server, which is required for cross-rank IPC sharing. Plain cudaMalloc causes "Requested fd not found" crashes.5. Build-gated behind
MSCCLPP_BUILD_EXT_TORCHCOMMS=OFF.No impact on existing builds. TorchComms headers are fetched on-demand via CMake FetchContent only when the option is enabled.
6. GPU event pooling.
Every collective call needs 2 CUDA events (start + end) for async tracking. Creating/destroying events costs ~5-10μs each. The pool amortizes this across thousands of collective calls in a training loop.
7. User-defined algorithms via AlgorithmCollectionBuilder singleton.
Custom algorithms (DSL or native) are configured on the builder before creating the TorchComms communicator. The backend picks them up during
init(). No algorithm registration API lives on the backend itself.Limitations
Algorithm::execute()operates on contiguous buffers (one input pointer, one output pointer), so the backend implementsall_gather_singleandreduce_scatter_singlebut not the tensor-list variants. The tensor-list variants throw with guidance to use the single-tensor variant instead.RuntimeErrorwith an explicit message naming the operation and suggesting the caller use a separate NCCL/RCCL communicator. This is the expected pattern for mixed-backend training.algorithm_selector.hppfromsrc/ext/nccl/rather than sharing it through a common path. A TODO in the code notes this should be consolidated intoAlgorithmCollectionBuilderso all consumers get a default selector automatically.How to Build and Test