feat: Support ZeRO-2 based on DistributedOptimizer#110
Conversation
4a2527f to
a82ad42
Compare
a82ad42 to
e5b4492
Compare
d5ffecd to
e5b4492
Compare
0d7053c to
457f3a7
Compare
457f3a7 to
e9610d9
Compare
| StartGradSync(); | ||
| } | ||
|
|
||
| if (params_with_grad_.empty()) { |
There was a problem hiding this comment.
这里没有强 wait ReduceScatter 完成?optimizer 可能读到还没同步完成的 grad_shard_buffer
| } | ||
| std::weak_ptr<ParamAndGradBucketGroup> weak_group = it->second; | ||
| param->SetGradAccumulateBypass( | ||
| [weak_group, param](const std::shared_ptr<Tensor> &grad_output, bool overwrite, float learning_rate) { |
There was a problem hiding this comment.
这里param是一个shared_ptr,param Tensor 持有 grad_accumulate_bypass,这个 lambda 函数里面又捕获 param,会不会形成引用环导致该tensor不析构?这里weak_group 应该就是解决这一问题的,param会不会有同样问题
| return; | ||
| } | ||
| // TODO(zbl): check this if sync is only done in last mircobatch | ||
| // if (!inserted) { |
|
|
||
| if (params_with_grad_.size() == params_.size()) { | ||
| // All param grads are ready in this group, trigger grad sync | ||
| StartGradSync(); |
There was a problem hiding this comment.
这里注释是说ready in this group,但我看实现里遍历了所有bucket group,这个是符合预期的吗
|
|
||
| // Only register grads as ready when processing the last microbatch | ||
| // TODO(zbl): Only register grads as ready and trigger grad sync when processing the last microbatch | ||
| // For now, is_last_microbatch_ is always true |
There was a problem hiding this comment.
为什么is_last_microbatch_ always true 呀
| } | ||
|
|
||
| void ParamAndGradBucket::ScaleGradients(float scaling_factor) { | ||
| if (!grad_data_ || scaling_factor == 1.f) { |
There was a problem hiding this comment.
grad_data_不一定有了,这里的判断还合理吗?
| // optimization | ||
| DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations"); | ||
| DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); | ||
| DEFINE_int32(zero_stage, 1, "ZeRO stage (1/2/3), default 1 (only take effects when use_distributed_optimizer=true)"); |
There was a problem hiding this comment.
那应该加个检查,如果没开 use_distributed_optimizer但设置了zero_stage直接报错,而不是静默
| temp_full_grad_initialized_[bucket_idx] = false; | ||
| } | ||
|
|
||
| if (!temp_full_grad_initialized_[bucket_idx]) { |
There was a problem hiding this comment.
有点怪,temp_full_grad_buffer_ 在第一次使用时必然需要清零,那在 AllocateFlatBuffer 后立即 Fill(0.0f)就可以,为什么还需要 temp_full_grad_initialized_ 这个flag。另外在并行状态下这个写入可靠吗
| void ResetAccumulator(); | ||
|
|
||
| // ZeRO-2: Use this function to take over AccumulateGrad::Backward | ||
| using GradAccumulateBypass |
There was a problem hiding this comment.
这里我认为不要让Tensor 持有 GradAccumulateBypass,用 post_accumulate_grad_hook_ :把 hook 接口扩展一个 TryBypassAccumulate(grad, overwrite, lr),默认返回 false。AccumulateGrad::Backward() 先问 hook 是否接管,ZeRO2 hook 返回 true。这样不污染 Tensor 基类太多,只复用已有 hook 生命周期。
class PostAccumulateGradHook {
public:
virtual void operator()(const std::shared_ptr<Tensor> &grad) = 0;
virtual bool TryBypassAccumulate(const std::shared_ptr<Tensor> &grad_output,
bool overwrite,
float learning_rate) {
return false;
}
virtual ~PostAccumulateGradHook() = default;
};然后 AccumulateGrad ::Backward() 里变成:
const bool overwrite = tensor_->ConsumeGradOverwriteFlag();
auto hook = tensor_->post_accumulate_grad_hook();
if (hook && hook->TryBypassAccumulate(grad_output, overwrite, learning_rate_)) {
tensor_->ResetAccumulator();
return {};
}
再给ZeRO2 定义一个 hook
class Zero2AccumulateGradHook final : public autograd::PostAccumulateGradHook {
public:
Zero2AccumulateGradHook(std::weak_ptr<ParamAndGradBucketGroup> group,
std::shared_ptr<Tensor> param)
: group_(std::move(group)), param_(std::move(param)) {}
bool TryBypassAccumulate(const std::shared_ptr<Tensor> &grad_output,
bool overwrite,
float learning_rate) override {
if (auto group = group_.lock()) {
group->AccumulateParamGrad(param_, grad_output, overwrite, learning_rate);
if (group->config().overlap_grad_reduce) {
group->RegisterGradReady(param_);
}
return true;
}
return false;
}
void operator()(const std::shared_ptr<Tensor> &) override {}
private:
std::weak_ptr<ParamAndGradBucketGroup> group_;
std::shared_ptr<Tensor> param_;
};
基于框架目前 DistOpt 的建设,实现的 ZeRO-2 的梯度分片显存优化策略。
用户接口修改:添加
zero_stage的 gflag,在启用--use_distributed_optimizer的同时可以指定 zero 级别(目前 zero3 为占位符);zero_stage 的信息也作为成员变量存在 DDPConfig 类里。实现上的修改
a. 核心逻辑:ZeRO-2 的核心是对模型参数所对应的梯度信息也按 dp 来分片存储,每个 rank 拿到自己负责的那部分;考虑到原先的 DistOpt 实现依赖于一个大的一维连续 ParamAndGradBuffer,所以为了实现 ZeRO-2,也就需要在初始化时不构造全量的 grad_buffer,仅构造每个 shard 大小的 grad_buffer。
b. grad 于 ParamAndGradBucketGroup 创建的时候构造(见 ParamAndGradBucketGroup 构造函数):每个 group 单独构造各自 rank 上面的 shard grad buffer,以 grad_shard_buffer_list_ 的成员变量存储(按 buckets 存成一个 list,但是实际上默认情况就是一个 group 一个 bucket,所以这里就是一个 size()==1 的 list)。
c. Autograd 反向流程中,按需临时分配内存创建 full grad,用完后释放。考虑到之前修改了 tensor->grad 的 lazy init,以及每轮可能存在的 ZeroGrad(set_to_none=true),此 full grad 创建时机位于 AccumulateGrad::Backward。
d. 补充上述 c 的细节:为了不在 AccumulateGrad::Backward 插入过多 zero2 相关的 if else 判断污染代码逻辑,直接定义一个 pre-accumulate-grad 的 bypass function,用于劫持原先的 AccumulateGrad::Backward,从而实现流程的重写,替换为:视情况创建 full grad;把其中与 tensor->grad 有关的操作,都改为 full grad 的操作;正常完成梯度更新。
局限性分析:目前的 ZeRO-2 实现上,相当于把 full grad 从原先“显存上的永久存储”的角色转换为了一个“autograd 反向流程中随用随分配并及时释放的激活内存”的角色,从而使得显存下降。但是最坏情况下,仍然可能存在一种情况导致显存优化效果不太明显:计算/通信太慢,full grad 的 cudaFree 操作排在流后面太久还没释放,full grad 作为激活值会较长时间占据内存。目前的做法是适当减小 bucket size,这样 full grad 的 reduce scatter 操作会变得更细粒度一些,更能贴近随用随释放的目标。