Use clean at::cuda::CUDAStream but not wrapped with std::optional#11
Use clean at::cuda::CUDAStream but not wrapped with std::optional#11SigureMo merged 13 commits intoPFCCLab:paddlefrom
at::cuda::CUDAStream but not wrapped with std::optional#11Conversation
There was a problem hiding this comment.
Pull request overview
This PR aims to remove the std::optional wrapper around at::cuda::CUDAStream by constructing comm_stream directly (via initializer-list + lambda), and updating stream usages accordingly.
Changes:
- Replace
std::optional<at::cuda::CUDAStream> comm_streamwith a plainat::cuda::CUDAStream. - Initialize
device_idandcomm_streaminBuffer’s constructor initializer list. - Update call sites from
comm_stream.value()tocomm_stream.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 15 comments.
| File | Description |
|---|---|
csrc/deep_ep.hpp |
Changes Buffer to store comm_stream directly and adjusts accessors. |
csrc/deep_ep.cpp |
Moves stream/context setup into initializer list and updates many stream call sites. |
Comments suppressed due to low confidence (2)
csrc/deep_ep.cpp:1447
- This
internode::cached_notifycall passescomm_stream(at::cuda::CUDAStream) where the API expectscudaStream_t(seecsrc/kernels/api.cuh). Please passcomm_stream.stream()/get_comm_stream_raw()instead.
rank,
comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
num_nvl_bytes,
false,
low_latency_mode);
csrc/deep_ep.cpp:1127
internode::cached_notify’s signature takescudaStream_t stream(seecsrc/kernels/api.cuh), but this call passescomm_stream(at::cuda::CUDAStream). This should becomm_stream.stream()/get_comm_stream_raw().
barrier_signal_ptrs_gpu,
rank,
comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
num_nvl_bytes,
true,
low_latency_mode);
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
at::cuda::CUDAStream but not wrapped with std::optionalat::cuda::CUDAStream but not wrapped with std::optional
…ng it to a CUDAStream and back to a phi::Stream. This change is made in the deep_ep.cpp file, specifically in the get_comm_stream function. The modified line now directly retrieves the comm stream as a cudaStream_t and constructs a phi::Stream from it, eliminating unnecessary conversions.
cee6d1e to
fc00d22
Compare
…converting it to a CUDAStream and back to a phi::Stream. This change is made in the deep_ep.cpp file, specifically in the get_comm_stream function. The modified line now directly retrieves the comm stream as a cudaStream_t and constructs a phi::Stream from it, eliminating unnecessary conversions." This reverts commit f6f5d0d.
ShigureNyako
left a comment
There was a problem hiding this comment.
这轮我先请求修改。
阻塞项:
csrc/deep_ep.cpp里把comm_ctx/calc_ctx的赋值放进了comm_stream的初始化 lambda,但csrc/deep_ep.hpp中这两个成员仍然声明在comm_stream后面。成员初始化顺序按声明顺序执行,因此这里是在它们生命周期开始前写入,存在未定义行为。
其余方向我看是合理的:把 std::optional<at::cuda::CUDAStream> 收敛为直接持有 stream、并改用 at::cuda::getStreamFromExternal 对齐 Paddle 兼容层,整体思路没有问题。仓库当前没有可参考的 CI checks,所以这次结论主要基于代码审查。
| comm_ctx = reinterpret_cast<paddle::distributed::ProcessGroupNCCL*>(pg)->GetOrCreateCommContext( | ||
| place, phi::distributed::CommType::ALLTOALL); | ||
| calc_ctx = reinterpret_cast<phi::GPUContext*>( | ||
| reinterpret_cast<paddle::distributed::ProcessGroupNCCL*>(pg)->GetDeviceContext(place, true)); |
There was a problem hiding this comment.
这里把 comm_ctx / calc_ctx 的赋值放进 comm_stream 的初始化 lambda 里,会踩到成员初始化顺序。
Buffer 里当前的声明顺序仍然是 comm_stream 在前、comm_ctx / calc_ctx 在后(csrc/deep_ep.hpp:89-92),而 C++ 的成员初始化顺序按声明顺序执行,不按 initializer list 的书写顺序执行。所以在构造 comm_stream 的这一步里,后面的两个成员生命周期还没开始,这里提前写入它们属于未定义行为。
建议改成下面两种方式之一:
- 调整成员声明顺序,先初始化
comm_ctx/calc_ctx,再初始化comm_stream; - 让这个 lambda 只负责构造 stream,本体里在构造函数体中再给
comm_ctx/calc_ctx赋值。
There was a problem hiding this comment.
看起来没啥问题,应该都是和之前逻辑一致的改动
不过我这里有一个问题是(与本 PR 关系不大),因为我们现在在 ATen 层自己维护了一套 tls_current_streams 的,这可能会和框架自身维护的 stream 状态有所冲突,比如如下情况:
with paddle.device.stream_guard(stream):
... # 如果内部代码使用了 at::cuda::getCurrentCUDAStream() 是感知不到 Paddle stream 的变化的DeepEP 里我看目前基本都是将框架自身维护的 stream 当作 external stream 处理的,避免维护两个状态,不过还是要检查下,是否有遗漏,不然 CUDA Graph 等对 stream 比较敏感的场景可能会出问题
(刚刚有人找我了,喵的为什么周末还有人找,就是这个不长眼的家伙 PaddlePaddle/FastDeploy#7344)
另外,脱离 DeepEP,目前毕竟 ATen 自己维护了一套 stream,这里可能会导致原本编译期能发现的问题延迟到了运行期,这里可能得考虑下是否有一些能够自动处理的方式了
Co-authored-by: Nyakku Shigure <sigure.qaq@gmail.com>
ShigureNyako
left a comment
There was a problem hiding this comment.
前一个阻塞问题已经修复:comm_ctx / calc_ctx 现在在 csrc/deep_ep.hpp 中声明在 comm_stream 前面,和构造函数里 lambda 的写法一致,不再有成员生命周期顺序问题。
我这边从代码审查角度看可以通过。
补充说明:@SigureMo 提到的 ATen tls_current_streams 与 Paddle 自身 stream 状态同步问题我也认同值得继续跟进,但它属于当前兼容层设计的更大范围议题,不构成这次 PR 的阻塞项;这个 PR 本身把这里涉及到的框架 stream 继续按 external stream 处理,方向是对的。仓库当前仍没有可参考的 CI checks,所以这次结论主要基于代码审查。
之前因为
at::cuda::CUDAStream没有 operator= 重载导致无法直接赋值,因此使用了std::optional容器,实际上应该像原来一样使用初始化列表,通过 lambda 表达式返回一个用于构造函数重载的 Stream ,以此消除引入std::optional带来的 diff同时参考 Paddle 修复编译错误

前置 pr : PaddlePaddle/Paddle#78576 PaddlePaddle/Paddle#78584