From 7c93692d368c0e657f08840932495919a865b1b9 Mon Sep 17 00:00:00 2001 From: bolunz Date: Wed, 4 Feb 2026 15:11:29 +0800 Subject: [PATCH 01/12] feat: Support ZeRO-2 based on DistributedOptimizer --- example/gpt2/main.cc | 9 +- example/llama3/main.cc | 9 +- .../ddp/distributed_data_parallel_config.h | 8 +- .../nn/parallel/ddp/param_and_grad_buffer.h | 26 ++- infini_train/include/tensor.h | 8 + infini_train/src/autograd/accumulate.cc | 10 +- .../parallel/ddp/distributed_data_parallel.cc | 39 +++- .../nn/parallel/ddp/distributed_optimizer.cc | 17 +- .../nn/parallel/ddp/param_and_grad_buffer.cc | 179 ++++++++++++++++-- infini_train/src/tensor.cc | 7 + 10 files changed, 280 insertions(+), 32 deletions(-) diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index c12b5a28..f427a61b 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -58,6 +58,7 @@ DEFINE_uint32(text_length, 64, "the length of the generated text"); // optimization DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations"); DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); +DEFINE_int32(zero_stage, 1, "ZeRO stage (1/2/3), default 1 (only take effects when use_distributed_optimizer=true)"); // evaluation DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); @@ -114,6 +115,7 @@ const std::unordered_map kModelToConfigs = { DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); }); DEFINE_validator(device, [](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; }); +DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 1 && value <= 3; }); void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; @@ -253,8 +255,8 @@ void Train(const nn::parallel::Rank &rank) { model = std::make_shared(model, pp_world_size, num_micro_batches, shapes, pp_rank, device, model_config.GetChunkSize()); if (ddp_world_size > 1) { - auto ddp_config - = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer}; + auto ddp_config = DistributedDataParallelConfig{ + .use_distributed_optimizer = FLAGS_use_distributed_optimizer, .zero_stage = FLAGS_zero_stage}; auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { (*mutable_chunks)[chunk_id] @@ -266,7 +268,8 @@ void Train(const nn::parallel::Rank &rank) { // before wrapping the model with DistributedDataParallel (DDP). // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors // are created during the conversion. - auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer}; + auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer, + .zero_stage = FLAGS_zero_stage}; model = std::make_shared(model, rank, ddp_config); } diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 117551d5..8edf74e5 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -57,6 +57,7 @@ DEFINE_uint32(text_length, 64, "the length of the generated text"); // optimization DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations"); DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); +DEFINE_int32(zero_stage, 1, "ZeRO stage (1/2/3), default 1 (only take effects when use_distributed_optimizer=true)"); // evaluation DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); @@ -100,6 +101,7 @@ constexpr char kDtypeBF16[] = "bfloat16"; DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); }); DEFINE_validator(device, [](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; }); +DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 1 && value <= 3; }); void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; @@ -223,8 +225,8 @@ void Train(const nn::parallel::Rank &rank) { model = std::make_shared(model, pp_world_size, num_micro_batches, shapes, pp_rank, device, model_config.GetChunkSize()); if (ddp_world_size > 1) { - auto ddp_config - = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer}; + auto ddp_config = DistributedDataParallelConfig{ + .use_distributed_optimizer = FLAGS_use_distributed_optimizer, .zero_stage = FLAGS_zero_stage}; auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { (*mutable_chunks)[chunk_id] @@ -237,7 +239,8 @@ void Train(const nn::parallel::Rank &rank) { // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors // are created during the conversion. - auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer}; + auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer, + .zero_stage = FLAGS_zero_stage}; model = std::make_shared(model, rank, ddp_config); } diff --git a/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h b/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h index 99d30703..4af2ce01 100644 --- a/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h +++ b/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h @@ -40,6 +40,12 @@ class DistributedDataParallelConfig { // In this case, grad reduce is triggered immediately when a grad is ready or till all grads are ready. bool overlap_grad_reduce = true; + // ZeRO-DP Stage for memory optimization (Only take effects when use_distributed_optimizer=true) + // ZeRO-1: Optimizer states partitioning, by default + // ZeRO-2: Gradients partitioning + // ZeRO-3: Parameters partitioning + int zero_stage = 1; + // Whether to overlap parameter all-gather with forward compute. bool overlap_param_gather = true; @@ -59,7 +65,7 @@ class DistributedDataParallelConfig { // Maximum number of parameters in each ParamAndGradBucket. // NOTE(zbl): This is distinct from DDP Reducer's MB-based bucket caps. // TODO(zbl): To unify the definition of bucket_size argument for users - size_t bucket_size_in_elements = 40000000; + size_t bucket_size_in_elements = 1000000; // Whether to pad bucket sizes to improve NCCL bus bandwidth utilization. bool pad_buckets_for_high_nccl_busbw = false; diff --git a/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h b/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h index c83fe9a5..8ae86678 100644 --- a/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h +++ b/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h @@ -22,8 +22,8 @@ namespace infini_train::nn::parallel { class ParamAndGradBucket { public: ParamAndGradBucket(const std::vector> ¶ms, const std::shared_ptr ¶m_data, - const std::shared_ptr &grad_data, size_t offset, size_t num_elements_unpadded, - float gradient_scaling_factor, size_t bucket_id); + DataType param_dtype, const std::shared_ptr &grad_data, DataType grad_dtype, + size_t offset, size_t num_elements_unpadded, float gradient_scaling_factor, size_t bucket_id); size_t bucket_id() const { return bucket_id_; } @@ -33,6 +33,10 @@ class ParamAndGradBucket { const std::shared_ptr &grad_data() const { return grad_data_; } + DataType param_dtype() const { return param_dtype_; } + + DataType grad_dtype() const { return grad_dtype_; } + size_t offset() const { return offset_; } size_t num_elements_unpadded() const { return num_elements_unpadded_; } @@ -49,6 +53,8 @@ class ParamAndGradBucket { std::vector> params_; std::shared_ptr param_data_; std::shared_ptr grad_data_; + DataType param_dtype_; + DataType grad_dtype_; size_t offset_ = 0; size_t num_elements_unpadded_ = 0; @@ -73,6 +79,11 @@ class ParamAndGradBucketGroup { // Start grad reduce void StartGradSync(); + // Accumulate a parameter grad into bucket buffer + // ZeRO-2: Use this funtion to take over autograd::AccumulateGrad::Backward + void AccumulateParamGrad(const std::shared_ptr ¶meter, const std::shared_ptr &grad, + bool overwrite, float learning_rate); + // Wait for gradient reduce to complete void FinishGradSync(); @@ -87,6 +98,9 @@ class ParamAndGradBucketGroup { const std::vector> &buckets() const { return buckets_; } + // ZeRO-2: Get a bucket's local grad shard buffer + std::shared_ptr GetLocalGradShardBuffer(size_t bucket_idx) const; + const DistributedDataParallelConfig &config() const { return ddp_config_; } private: @@ -98,12 +112,20 @@ class ParamAndGradBucketGroup { std::unordered_set params_; std::unordered_set params_with_grad_; + // Tensor -> (Bucket, Bucket Index) + std::unordered_map, size_t>> param_to_bucket_; // TODO(zbl): Implement CoalescedWork for aggregate works // According to Megatron-LM's _coalescing_manager std::vector> grad_reduce_work_list_; + std::vector grad_reduce_bucket_indices_; std::vector> param_gather_work_list_; + // ZeRO-2: persistent grad shard buffers and temporary full grad buffers + std::vector> grad_shard_buffer_list_; + std::vector> temp_full_grad_buffer_list_; + std::vector temp_full_grad_initialized_; + std::shared_ptr next_param_gather_bucket_group_ = nullptr; std::vector>> param_buffer_shard_list_; diff --git a/infini_train/include/tensor.h b/infini_train/include/tensor.h index 12f45f57..156fde78 100644 --- a/infini_train/include/tensor.h +++ b/infini_train/include/tensor.h @@ -230,6 +230,12 @@ class Tensor : public std::enable_shared_from_this { std::shared_ptr grad_accumulator(); void ResetAccumulator(); + // ZeRO-2: Use this function to take over AccumulateGrad::Backward + using GradAccumulateBypass + = std::function &grad_output, bool overwrite, float learning_rate)>; + GradAccumulateBypass grad_accumulate_bypass(); + void SetGradAccumulateBypass(GradAccumulateBypass); + void RegisterPostAccumulateGradHook(std::shared_ptr hook); autograd::PostAccumulateGradHook *post_accumulate_grad_hook() const; @@ -244,6 +250,8 @@ class Tensor : public std::enable_shared_from_this { // a strong reference to the accumulator to manage its lifetime. std::shared_ptr grad_accumulator_ = nullptr; std::shared_ptr post_accumulate_grad_hook_ = nullptr; + // ZeRO-2: Use this function to take over AccumulateGrad::Backward + GradAccumulateBypass grad_accumulate_bypass_ = nullptr; bool grad_overwrite_once_ = false; }; diff --git a/infini_train/src/autograd/accumulate.cc b/infini_train/src/autograd/accumulate.cc index d9b70bc1..c3558b47 100644 --- a/infini_train/src/autograd/accumulate.cc +++ b/infini_train/src/autograd/accumulate.cc @@ -33,8 +33,16 @@ AccumulateGrad::Backward(const std::vector> &grad_output "running before autograd). The grad is not cast and will be used as-is."; } + const bool overwrite = tensor_->ConsumeGradOverwriteFlag(); + // ZeRO-2: Use a bypass function to perform grad accumulation in temp full grad buffer + auto bypass = tensor_->grad_accumulate_bypass(); + if (bypass && bypass(grad_output, overwrite, learning_rate_)) { + tensor_->ResetAccumulator(); + return {}; + } + if (grad) { - if (tensor_->ConsumeGradOverwriteFlag()) { + if (overwrite) { // If the tensor is marked to overrite its current grad on next grad update // See notes in `infini_train::nn::parallel::Reducer::PrepareForBackward()` // NOTE(zbl): must copy, cannot change grad buffer address diff --git a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc index 82f143a9..10c085ad 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc @@ -24,6 +24,16 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr mod const DistributedDataParallelConfig ddp_config) : ddp_config_(ddp_config), ddp_pg_(ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(rank.GlobalRank()))) { + CHECK(ddp_config_.zero_stage >= 1 && ddp_config_.zero_stage <= 3) + << "DistributedDataParallel: zero_stage must be in 1/2/3."; + if (ddp_config_.zero_stage >= 3) { + LOG(FATAL) << "DistributedDataParallel: ZeRO-3 is not implemented yet."; + } + if (!ddp_config_.use_distributed_optimizer && ddp_config_.zero_stage >= 1) { + LOG(WARNING) << "DistributedDataParallel: zero_stage is ignored because " + "use_distributed_optimizer is false."; + ddp_config_.zero_stage = 1; + } for (auto ¶m : module->Parameters()) { if (!param->requires_grad()) { continue; @@ -83,6 +93,7 @@ void DistributedDataParallel::BuildParamAndGradBuffers() { continue; } + // At the point, zero_stage is already aligned with use_distributed_optimizer. auto buffer = std::make_shared(param_list, param_dtype, grad_dtype, ddp_pg_, ddp_config_); param_grad_buffers_.push_back(buffer); @@ -116,6 +127,32 @@ void DistributedDataParallel::BuildParamAndGradBuffers() { } void DistributedDataParallel::RegisterBackwardHooks() { + if (ddp_config_.zero_stage >= 2) { + auto &module = modules_.at(kModuleName); + for (auto ¶m : module->Parameters()) { + if (!param->requires_grad()) { + continue; + } + auto it = param_to_bucket_group_.find(param.get()); + if (it == param_to_bucket_group_.end()) { + continue; + } + std::weak_ptr weak_group = it->second; + param->SetGradAccumulateBypass( + [weak_group, param](const std::shared_ptr &grad_output, bool overwrite, float learning_rate) { + if (auto group = weak_group.lock()) { + group->AccumulateParamGrad(param, grad_output, overwrite, learning_rate); + if (group->config().overlap_grad_reduce) { + group->RegisterGradReady(param); + } + return true; + } + return false; + }); + } + return; + } + class DDPPostAccumulateHook final : public autograd::PostAccumulateGradHook { public: DDPPostAccumulateHook(DistributedDataParallel *ddp, const std::weak_ptr param) @@ -147,7 +184,7 @@ void DistributedDataParallel::OnGradReady(const std::shared_ptr ¶m) auto it = param_to_bucket_group_.find(param.get()); if (it != param_to_bucket_group_.end()) { CHECK(param->requires_grad()); - if (ddp_config_.overlap_grad_reduce) { + if (ddp_config_.overlap_grad_reduce && (ddp_config_.zero_stage < 2)) { CHECK(param->grad()) << "param.grad being None is not safe when overlap_grad_reduce is True"; } diff --git a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc index 55e5800b..48fd7103 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc @@ -35,10 +35,13 @@ void DistributedOptimizer::BuildShardParamsAndBindGrads() { shard_params_.clear(); for (const auto &group : bucket_groups_) { - for (const auto &bucket : group->buckets()) { + const bool use_grad_shard = group->config().zero_stage >= 2; + const auto &buckets = group->buckets(); + for (size_t bucket_idx = 0; bucket_idx < buckets.size(); ++bucket_idx) { + const auto &bucket = buckets[bucket_idx]; auto bucket_param = bucket->param_data(); - auto bucket_grad = bucket->grad_data(); + auto bucket_grad = use_grad_shard ? group->GetLocalGradShardBuffer(bucket_idx) : bucket->grad_data(); CHECK(bucket_param) << "DistributedOptimizer requires param buffer."; CHECK(bucket_grad) << "DistributedOptimizer requires grad buffer."; @@ -65,7 +68,9 @@ void DistributedOptimizer::BuildShardParamsAndBindGrads() { CHECK_GT(piece_numel, 0); const size_t param_piece_offset_bytes = local_start * kDataTypeToSize.at(bucket_param->Dtype()); - const size_t grad_piece_offset_bytes = local_start * kDataTypeToSize.at(bucket_grad->Dtype()); + // Adjust the offset since bucket_grad is already the shard of grad under ZeRO-2. + auto offset = use_grad_shard ? (local_start - bucket_shard_start) : local_start; + size_t grad_piece_offset_bytes = offset * kDataTypeToSize.at(bucket_grad->Dtype()); auto param_piece = std::make_shared(*bucket_param, param_piece_offset_bytes, std::vector{static_cast(piece_numel)}); @@ -74,6 +79,12 @@ void DistributedOptimizer::BuildShardParamsAndBindGrads() { std::vector{static_cast(piece_numel)}); param_piece->set_grad(grad_piece); + // if (use_grad_shard) { + // // NOTE(zbl): Under ZeRO-2, param->grad() is the shard of grad, not the full grad. + // // The binding is done in the construnctor of DistributedOptimizer. + // // Not until backward is finished, the value of param->grad() will be updated. + // param->set_grad(grad_piece); + // } shard_params_.push_back(param_piece); } } diff --git a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc index 75a21f63..1ab17810 100644 --- a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc +++ b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc @@ -6,6 +6,7 @@ #include "glog/logging.h" +#include "infini_train/include/dispatcher.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h" #include "infini_train/include/nn/parallel/global.h" @@ -53,12 +54,12 @@ std::vector> ShardBuffer(const std::shared_ptr b } // namespace ParamAndGradBucket::ParamAndGradBucket(const std::vector> ¶ms, - const std::shared_ptr ¶m_data, - const std::shared_ptr &grad_data, size_t offset, + const std::shared_ptr ¶m_data, DataType param_dtype, + const std::shared_ptr &grad_data, DataType grad_dtype, size_t offset, size_t num_elements_unpadded, float gradient_scaling_factor, size_t bucket_id) - : bucket_id_(bucket_id), params_(std::move(params)), param_data_(std::move(param_data)), - grad_data_(std::move(grad_data)), offset_(offset), num_elements_unpadded_(num_elements_unpadded), - gradient_scaling_factor_(gradient_scaling_factor) { + : bucket_id_(bucket_id), params_(std::move(params)), param_data_(std::move(param_data)), param_dtype_(param_dtype), + grad_data_(std::move(grad_data)), grad_dtype_(grad_dtype), offset_(offset), + num_elements_unpadded_(num_elements_unpadded), gradient_scaling_factor_(gradient_scaling_factor) { size_t current_offset = 0; for (const auto ¶m : params_) { auto numel = param->NumElements(); @@ -97,8 +98,12 @@ ParamAndGradBucketGroup::ParamAndGradBucketGroup(const std::vectorparams()) { params_.insert(param.get()); } + for (size_t bucket_idx = 0; bucket_idx < buckets_.size(); ++bucket_idx) { + const auto &bucket = buckets_[bucket_idx]; + for (const auto ¶m : bucket->params()) { + params_.insert(param.get()); + param_to_bucket_[param.get()] = {bucket, bucket_idx}; + } } if (rank_in_collective_pg_ == -1) { auto param = *params_.begin(); @@ -108,15 +113,40 @@ ParamAndGradBucketGroup::ParamAndGradBucketGroup(const std::vector= 2) { + for (size_t i = 0; i < buckets_.size(); ++i) { + auto bucket = buckets_[i]; + CHECK(bucket->param_data()) << "ParamAndGradBucketGroup: param buffer required for ZeRO-2."; + const size_t bucket_numel = bucket->param_data()->NumElements(); + if (bucket_numel == 0) { + continue; + } + CHECK_EQ(bucket_numel % collective_pg_size_, 0); + const size_t shard_numel = bucket_numel / collective_pg_size_; + auto param = bucket->params().front(); + grad_shard_buffer_list_[i] = AllocateFlatBuffer(shard_numel, bucket->grad_dtype(), param->GetDevice()); + } + } } void ParamAndGradBucketGroup::Reset() { params_with_grad_.clear(); grad_reduce_work_list_.clear(); + grad_reduce_bucket_indices_.clear(); param_gather_work_list_.clear(); is_last_microbatch_ = true; grad_reduce_dispatched_ = false; param_gather_dispatched_ = false; + + if (ddp_config_.zero_stage >= 2) { + std::fill(temp_full_grad_buffer_list_.begin(), temp_full_grad_buffer_list_.end(), nullptr); + std::fill(temp_full_grad_initialized_.begin(), temp_full_grad_initialized_.end(), false); + } } void ParamAndGradBucketGroup::RegisterGradReady(const std::shared_ptr ¶meter) { @@ -147,6 +177,69 @@ void ParamAndGradBucketGroup::RegisterGradReady(const std::shared_ptr &p } } +void ParamAndGradBucketGroup::AccumulateParamGrad(const std::shared_ptr ¶meter, + const std::shared_ptr &grad, bool overwrite, + float learning_rate) { + if (ddp_config_.zero_stage < 2) { + LOG(FATAL) << "ParamAndGradBucketGroup: AccumulateParamGrad called when ZeRO-2 is disabled."; + return; + } + if (!grad || !parameter) { + return; + } + + auto it = param_to_bucket_.find(parameter.get()); + if (it == param_to_bucket_.end()) { + return; + } + auto bucket = it->second.first; + const size_t bucket_idx = it->second.second; + + size_t param_start_in_bucket = 0, param_end_in_bucket = 0; + auto found = bucket->GetTensorLocInBucket(parameter, param_start_in_bucket, param_end_in_bucket); + if (!found) { + return; + } + + if (!temp_full_grad_buffer_list_[bucket_idx]) { + CHECK(bucket->param_data()) << "ParamAndGradBucketGroup: param buffer required for ZeRO-2."; + const size_t bucket_numel = bucket->param_data()->NumElements(); + if (bucket_numel == 0) { + return; + } + temp_full_grad_buffer_list_[bucket_idx] + = AllocateFlatBuffer(bucket_numel, bucket->grad_dtype(), parameter->GetDevice()); + temp_full_grad_initialized_[bucket_idx] = false; + } + + if (!temp_full_grad_initialized_[bucket_idx]) { + temp_full_grad_buffer_list_[bucket_idx]->Fill(0.0f); + temp_full_grad_initialized_[bucket_idx] = true; + } + + const size_t offset_bytes = param_start_in_bucket * kDataTypeToSize.at(bucket->grad_dtype()); + auto bucket_grad_view + = std::make_shared(*temp_full_grad_buffer_list_[bucket_idx], offset_bytes, parameter->Dims()); + + if (overwrite) { + bucket_grad_view->CopyFrom(*grad); + } else { + auto kernel = Dispatcher::Instance().GetKernel({parameter->GetDevice()->Type(), "AccumulateGrad"}); + kernel.Call(grad, learning_rate, bucket_grad_view); + } +} + +std::shared_ptr ParamAndGradBucketGroup::GetLocalGradShardBuffer(size_t bucket_idx) const { + if (ddp_config_.zero_stage < 2) { + LOG(WARNING) << "ParamAndGradBucketGroup: GetLocalGradShardBuffer called when ZeRO-2 is disabled."; + return nullptr; + } + if (bucket_idx >= grad_shard_buffer_list_.size()) { + return nullptr; + } + return grad_shard_buffer_list_[bucket_idx]; +} + void ParamAndGradBucketGroup::StartGradSync() { if (!collective_pg_) { LOG(FATAL) << "ParamAndGradBucketGroup: StartGradSync() called with null collective_pg_."; @@ -174,6 +267,20 @@ void ParamAndGradBucketGroup::StartGradSync() { for (auto i = 0; i < buckets_.size(); ++i) { auto bucket = buckets_[i]; + + if (ddp_config_.zero_stage >= 2) { + auto full_grad_buffer = temp_full_grad_buffer_list_[i]; + if (!full_grad_buffer) { + continue; + } + CHECK(grad_shard_buffer_list_[i]) << "ParamAndGradBucketGroup: grad shard buffer missing."; + auto local_data_view = grad_shard_buffer_list_[i]; + grad_reduce_work_list_.push_back( + collective_pg_->ReduceScatter(local_data_view, full_grad_buffer, reduce_op, async_op)); + grad_reduce_bucket_indices_.push_back(i); + continue; + } + std::shared_ptr grad_buffer = bucket->grad_data(); if (!grad_buffer) { continue; @@ -200,6 +307,10 @@ void ParamAndGradBucketGroup::FinishGradSync() { StartGradSync(); } + if (params_with_grad_.empty()) { + return; + } + if (!ddp_config_.overlap_grad_reduce) { // Assume reduce ops are synced and no work needs to be resolved grad_reduce_work_list_.clear(); @@ -211,6 +322,20 @@ void ParamAndGradBucketGroup::FinishGradSync() { << "ParamAndGradBucketGroup: Communication call has not been issued for this bucket(" << params_with_grad_.size() << "/" << params_.size() << " params have grad available)"; + if (ddp_config_.zero_stage >= 2) { + for (size_t idx = 0; idx < grad_reduce_work_list_.size(); ++idx) { + auto &work = grad_reduce_work_list_[idx]; + work->WaitNonBlocking(); + const size_t bucket_idx = grad_reduce_bucket_indices_[idx]; + temp_full_grad_buffer_list_[bucket_idx].reset(); + temp_full_grad_initialized_[bucket_idx] = false; + } + grad_reduce_work_list_.clear(); + grad_reduce_bucket_indices_.clear(); + grad_reduce_dispatched_ = false; + return; + } + for (auto work : grad_reduce_work_list_) { work->WaitNonBlocking(); } grad_reduce_work_list_.clear(); grad_reduce_dispatched_ = false; @@ -399,7 +524,11 @@ void ParamAndGradBuffer::BuildBuckets(DataType param_dtype, DataType grad_dtype) // No param buffer needed if optimzer is not distributed param_buffer_.reset(); } - grad_buffer_ = AllocateFlatBuffer(numel_, grad_dtype, device); + if (ddp_config_.zero_stage >= 2) { + grad_buffer_.reset(); + } else { + grad_buffer_ = AllocateFlatBuffer(numel_, grad_dtype, device); + } LOG(INFO) << "ParamAndGradBuffer: numel_unpadded=" << numel_unpadded_ << ", numel (padded)=" << numel_; @@ -424,14 +553,17 @@ void ParamAndGradBuffer::BuildBuckets(DataType param_dtype, DataType grad_dtype) bucket_param_view = GetBufferView(param_buffer_, start_index, std::vector{static_cast(end_index - start_index)}); } - std::shared_ptr bucket_grad_view = GetBufferView( - grad_buffer_, start_index, std::vector{static_cast(end_index - start_index)}); + std::shared_ptr bucket_grad_view; + if (grad_buffer_) { + bucket_grad_view = GetBufferView(grad_buffer_, start_index, + std::vector{static_cast(end_index - start_index)}); + } // FIXME(zbl): Use default for now float gradient_scaling_factor = 1.0f; - auto bucket - = std::make_shared(bucket_params, bucket_param_view, bucket_grad_view, start_index, - num_elements_unpadded, gradient_scaling_factor, bucket_id); + auto bucket = std::make_shared(bucket_params, bucket_param_view, param_dtype, + bucket_grad_view, grad_dtype, start_index, + num_elements_unpadded, gradient_scaling_factor, bucket_id); for (auto param : bucket_params) { CHECK(param_bucket_map_.find(param.get()) == param_bucket_map_.end()) @@ -454,8 +586,11 @@ void ParamAndGradBuffer::BuildBuckets(DataType param_dtype, DataType grad_dtype) param->SetData(*param_buffer_, param_start_index * kDataTypeToSize.at(param_buffer_->Dtype()), true); } - auto grad_view = GetBufferView(grad_buffer_, param_start_index, param->Dims()); - param->set_grad(grad_view); + std::shared_ptr grad_view; + if (grad_buffer_) { + grad_view = GetBufferView(grad_buffer_, param_start_index, param->Dims()); + param->set_grad(grad_view); + } // Save grad view for each params --i; grads_[i] = grad_view; @@ -506,7 +641,9 @@ void ParamAndGradBuffer::Reset(bool need_rebind) { if (!need_rebind) { grad_buffer_->Fill(0.f); } - need_rebind_grad_views_ = need_rebind; + // NOTE(zbl): Under ZeRO-2, param->grad() is the shard of grad, not the full grad. + // It is constantly pointed to the shard of grad, so no need to rebind. + need_rebind_grad_views_ = need_rebind && (ddp_config_.zero_stage < 2); } void ParamAndGradBuffer::RebindGradViews() { @@ -514,10 +651,16 @@ void ParamAndGradBuffer::RebindGradViews() { return; } + if (!grad_buffer_) { + return; + } + CHECK_EQ(params_.size(), grads_.size()); for (size_t i = 0; i < params_.size(); ++i) { - params_[i]->set_grad(grads_[i]); - params_[i]->MarkGradOverwriteOnNextAccum(); + if (grads_[i]) { + params_[i]->set_grad(grads_[i]); + params_[i]->MarkGradOverwriteOnNextAccum(); + } } need_rebind_grad_views_ = false; diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index f7947030..1f485e68 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -559,6 +559,13 @@ void Tensor::ResetAccumulator() { } } +Tensor::GradAccumulateBypass Tensor::grad_accumulate_bypass() { + CHECK(grad_accumulator_) << "grad_accumulate_bypass() should only be called on leaf tensors"; + return grad_accumulate_bypass_; +} + +void Tensor::SetGradAccumulateBypass(GradAccumulateBypass bypass) { grad_accumulate_bypass_ = std::move(bypass); } + void Tensor::RegisterPostAccumulateGradHook(std::shared_ptr hook) { CHECK(requires_grad_) << "cannot register a hook on a tensor that doesn't require gradient"; From cb60411c246edb2ca93ba9f584a990b8ff9d1b13 Mon Sep 17 00:00:00 2001 From: bolunz Date: Wed, 11 Mar 2026 15:54:28 +0800 Subject: [PATCH 02/12] fix: adapt kernel call under new Device/DeviceGuard impl --- infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc index 1ab17810..8399ec04 100644 --- a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc +++ b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc @@ -224,7 +224,8 @@ void ParamAndGradBucketGroup::AccumulateParamGrad(const std::shared_ptr if (overwrite) { bucket_grad_view->CopyFrom(*grad); } else { - auto kernel = Dispatcher::Instance().GetKernel({parameter->GetDevice()->Type(), "AccumulateGrad"}); + auto device = parameter->GetDevice(); + auto kernel = Dispatcher::Instance().GetKernel({device.type(), "AccumulateGrad"}); kernel.Call(grad, learning_rate, bucket_grad_view); } } From 42e457ad4dd551f239fde815f013be5324b3f2ef Mon Sep 17 00:00:00 2001 From: bolunz Date: Thu, 12 Mar 2026 13:14:22 +0800 Subject: [PATCH 03/12] fix: adapt test_config.json to new format --- scripts/test_config.json | 78 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/scripts/test_config.json b/scripts/test_config.json index 2f061528..6a4f1d5d 100644 --- a/scripts/test_config.json +++ b/scripts/test_config.json @@ -213,6 +213,18 @@ "use_distributed_optimizer": true } }, + { + "id": "3_distopt_zero2", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "use_distributed_optimizer": true, + "zero_stage": 2 + } + }, { "id": "3_bfloat16_distopt", "args": { @@ -224,6 +236,18 @@ "use_distributed_optimizer": true } }, + { + "id": "3_bfloat16_distopt_zero2", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "use_distributed_optimizer": true, + "zero_stage": 2 + } + }, { "id": "4_distopt", "args": { @@ -236,6 +260,19 @@ "use_distributed_optimizer": true } }, + { + "id": "4_distopt_zero2", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "use_distributed_optimizer": true, + "zero_stage": 2 + } + }, { "id": "4_bfloat16_distopt", "args": { @@ -248,6 +285,19 @@ "use_distributed_optimizer": true } }, + { + "id": "4_bfloat16_distopt_zero2", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "use_distributed_optimizer": true, + "zero_stage": 2 + } + }, { "id": "5_distopt", "args": { @@ -261,6 +311,20 @@ "use_distributed_optimizer": true } }, + { + "id": "5_distopt_zero2", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "sequence_parallel": true, + "use_distributed_optimizer": true, + "zero_stage": 2 + } + }, { "id": "5_bfloat16_distopt", "args": { @@ -274,6 +338,20 @@ "use_distributed_optimizer": true } }, + { + "id": "5_bfloat16_distopt_zero2", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "sequence_parallel": true, + "use_distributed_optimizer": true, + "zero_stage": 2 + } + }, { "id": "8_distopt", "args": { From 2762da8fd8bca47497d86aeb0639cb0048fe33ed Mon Sep 17 00:00:00 2001 From: bolunz Date: Wed, 13 May 2026 09:45:49 +0000 Subject: [PATCH 04/12] fix: fix distopt behavior on gradient accumulation cases --- .../nn/parallel/ddp/param_and_grad_buffer.cc | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc index 8399ec04..4f987da3 100644 --- a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc +++ b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc @@ -157,18 +157,21 @@ void ParamAndGradBucketGroup::RegisterGradReady(const std::shared_ptr &p return; } - // Only register grads as ready when processing the last microbatch + // TODO(zbl): Only register grads as ready and trigger grad sync when processing the last microbatch + // For now, is_last_microbatch_ is always true if (is_last_microbatch_) { if (!parameter || params_.find(parameter.get()) == params_.end()) { return; } const bool inserted = params_with_grad_.insert(parameter.get()).second; - if (!inserted) { - LOG(FATAL) << "ParamAndGradBucketGroup: RegisterGradReady() was called twice for the same parameter in a " - "bucket group."; - return; - } + // TODO(zbl): check this if sync is only done in last mircobatch + // if (!inserted) { + // LOG(FATAL) << "ParamAndGradBucketGroup: RegisterGradReady() was called twice for the same parameter in a + // " + // "bucket group."; + // return; + // } if (params_with_grad_.size() == params_.size()) { // All param grads are ready in this group, trigger grad sync @@ -301,6 +304,8 @@ void ParamAndGradBucketGroup::StartGradSync() { } grad_reduce_dispatched_ = true; + // FIXME(zbl): no need to clear params_with_grad_ here if grad sync is only done on last microbatch + params_with_grad_.clear(); } void ParamAndGradBucketGroup::FinishGradSync() { From 7cd51e4be60b2d610dd73a937641088676b9fefc Mon Sep 17 00:00:00 2001 From: bolunz Date: Thu, 14 May 2026 01:22:01 +0000 Subject: [PATCH 05/12] fix: fix some descriptions in comments --- .../src/nn/parallel/ddp/distributed_optimizer.cc | 11 ++++------- .../src/nn/parallel/ddp/param_and_grad_buffer.cc | 13 +++++-------- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc index 48fd7103..3b86106b 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc @@ -79,12 +79,9 @@ void DistributedOptimizer::BuildShardParamsAndBindGrads() { std::vector{static_cast(piece_numel)}); param_piece->set_grad(grad_piece); - // if (use_grad_shard) { - // // NOTE(zbl): Under ZeRO-2, param->grad() is the shard of grad, not the full grad. - // // The binding is done in the construnctor of DistributedOptimizer. - // // Not until backward is finished, the value of param->grad() will be updated. - // param->set_grad(grad_piece); - // } + // NOTE(zbl): Do not call `param->set_grad(grad_piece);` under ZeRO-2. + // The base optimizer updates param_piece views only; original param->grad() + // would be a partial flattened shard and does not represent the full parameter grad. shard_params_.push_back(param_piece); } } @@ -135,7 +132,7 @@ void DistributedOptimizer::Step() { // 3. Gather updated param shards back to full params StartParamSync(/*force_sync=*/false); - // FIXME(zbl): Call sync before param is actually used in next step + // TODO(zbl): Delay sync call until param is actually used in next step FinishParamSync(/*skip_next_bucket_dispatch=*/true); } diff --git a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc index 4f987da3..9916312c 100644 --- a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc +++ b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc @@ -86,7 +86,7 @@ void ParamAndGradBucket::ScaleGradients(float scaling_factor) { // FIXME(zbl): should perform in-place multiply // grad_data_ *= scaling_factor; - LOG(FATAL) << "ParamAndGradBucket: Should not arrive here"; + LOG(FATAL) << "ParamAndGradBuffer::ScaleGradients(): Inplace multiply not implemented yet."; } ParamAndGradBucketGroup::ParamAndGradBucketGroup(const std::vector> &buckets, @@ -107,8 +107,7 @@ ParamAndGradBucketGroup::ParamAndGradBucketGroup(const std::vectorGetGroupRank(param->GetDevice().Rank().thread_rank()); + rank_in_collective_pg_ = collective_pg_->GetGroupRank(param->GetDevice().Rank().GlobalRank()); } param_buffer_shard_list_.resize(buckets_.size()); @@ -168,9 +167,7 @@ void ParamAndGradBucketGroup::RegisterGradReady(const std::shared_ptr &p // TODO(zbl): check this if sync is only done in last mircobatch // if (!inserted) { // LOG(FATAL) << "ParamAndGradBucketGroup: RegisterGradReady() was called twice for the same parameter in a - // " - // "bucket group."; - // return; + // bucket group."; return; // } if (params_with_grad_.size() == params_.size()) { @@ -304,7 +301,7 @@ void ParamAndGradBucketGroup::StartGradSync() { } grad_reduce_dispatched_ = true; - // FIXME(zbl): no need to clear params_with_grad_ here if grad sync is only done on last microbatch + // TODO(zbl): no need to clear params_with_grad_ here if grad sync is only done on last microbatch params_with_grad_.clear(); } @@ -637,7 +634,7 @@ void ParamAndGradBuffer::ScaleGradients(float scaling_factor) { // FIXME(zbl): should perform in-place multiply // grad_data_ *= scaling_factor; - LOG(FATAL) << "Should not arrive here"; + LOG(FATAL) << "ParamAndGradBuffer::ScaleGradients(): Inplace multiply not implemented yet."; } void ParamAndGradBuffer::Reset(bool need_rebind) { From 3a416dd5544d5d78ff156be765271e8f26636d95 Mon Sep 17 00:00:00 2001 From: bolunz Date: Wed, 20 May 2026 08:32:20 +0000 Subject: [PATCH 06/12] fix: resolve comments --- example/common/utils.cc | 10 +++++ example/common/utils.h | 2 + example/gpt2/main.cc | 2 + example/llama3/main.cc | 2 + .../nn/parallel/ddp/param_and_grad_buffer.h | 1 - .../parallel/ddp/distributed_data_parallel.cc | 25 ++++++----- .../nn/parallel/ddp/param_and_grad_buffer.cc | 41 +++++++++---------- 7 files changed, 50 insertions(+), 33 deletions(-) diff --git a/example/common/utils.cc b/example/common/utils.cc index 03cc7aa0..0d7b966e 100644 --- a/example/common/utils.cc +++ b/example/common/utils.cc @@ -1,5 +1,8 @@ #include "example/common/utils.h" +#include "gflags/gflags.h" +#include "glog/logging.h" + namespace infini_train { float ConvertBF16ToFloat(void *ptr) { @@ -61,4 +64,11 @@ void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t s ifs.seekg(base + std::streamoff(len * sizeof(float))); } +void ValidateDistributedOptimizerFlags(bool use_distributed_optimizer) { + gflags::CommandLineFlagInfo zero_stage_info; + CHECK(gflags::GetCommandLineFlagInfo("zero_stage", &zero_stage_info)); + CHECK(use_distributed_optimizer || zero_stage_info.is_default) + << "--zero_stage requires --use_distributed_optimizer=true."; +} + } // namespace infini_train diff --git a/example/common/utils.h b/example/common/utils.h index 5bab3e97..19e65ff1 100644 --- a/example/common/utils.h +++ b/example/common/utils.h @@ -30,4 +30,6 @@ void ReadVectorAllFloat(std::ifstream &ifs, float *dst, int64_t len); void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t start, int64_t cnt); +void ValidateDistributedOptimizerFlags(bool use_distributed_optimizer); + } // namespace infini_train diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index f427a61b..0106dbc7 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -36,6 +36,7 @@ #include "example/common/tiny_shakespeare_dataset.h" #include "example/common/tokenizer.h" +#include "example/common/utils.h" #include "example/gpt2/checkpoint_loader.h" #include "example/gpt2/config.h" @@ -451,6 +452,7 @@ void Train(const nn::parallel::Rank &rank) { int main(int argc, char *argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, true); google::InitGoogleLogging(argv[0]); + ValidateDistributedOptimizerFlags(FLAGS_use_distributed_optimizer); auto precision_config = utils::PrecisionCheckConfig::Parse(FLAGS_precision_check); nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel, diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 8edf74e5..e4e0b09a 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -35,6 +35,7 @@ #include "example/common/tiny_shakespeare_dataset.h" #include "example/common/tokenizer.h" +#include "example/common/utils.h" #include "example/llama3/checkpoint_loader.h" #include "example/llama3/config.h" @@ -426,6 +427,7 @@ void Train(const nn::parallel::Rank &rank) { int main(int argc, char *argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, true); google::InitGoogleLogging(argv[0]); + ValidateDistributedOptimizerFlags(FLAGS_use_distributed_optimizer); auto precision_config = utils::PrecisionCheckConfig::Parse(FLAGS_precision_check); nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel, diff --git a/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h b/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h index 8ae86678..11b63828 100644 --- a/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h +++ b/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h @@ -124,7 +124,6 @@ class ParamAndGradBucketGroup { // ZeRO-2: persistent grad shard buffers and temporary full grad buffers std::vector> grad_shard_buffer_list_; std::vector> temp_full_grad_buffer_list_; - std::vector temp_full_grad_initialized_; std::shared_ptr next_param_gather_bucket_group_ = nullptr; diff --git a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc index 10c085ad..dc37200c 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc @@ -138,17 +138,22 @@ void DistributedDataParallel::RegisterBackwardHooks() { continue; } std::weak_ptr weak_group = it->second; - param->SetGradAccumulateBypass( - [weak_group, param](const std::shared_ptr &grad_output, bool overwrite, float learning_rate) { - if (auto group = weak_group.lock()) { - group->AccumulateParamGrad(param, grad_output, overwrite, learning_rate); - if (group->config().overlap_grad_reduce) { - group->RegisterGradReady(param); - } - return true; + std::weak_ptr weak_param = param; + param->SetGradAccumulateBypass([weak_group, weak_param](const std::shared_ptr &grad_output, + bool overwrite, float learning_rate) { + if (auto group = weak_group.lock(); group) { + auto param = weak_param.lock(); + if (!param) { + return false; } - return false; - }); + group->AccumulateParamGrad(param, grad_output, overwrite, learning_rate); + if (group->config().overlap_grad_reduce) { + group->RegisterGradReady(param); + } + return true; + } + return false; + }); } return; } diff --git a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc index 9916312c..691602bc 100644 --- a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc +++ b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc @@ -86,7 +86,7 @@ void ParamAndGradBucket::ScaleGradients(float scaling_factor) { // FIXME(zbl): should perform in-place multiply // grad_data_ *= scaling_factor; - LOG(FATAL) << "ParamAndGradBuffer::ScaleGradients(): Inplace multiply not implemented yet."; + LOG(FATAL) << "ParamAndGradBucket::ScaleGradients(): Inplace multiply not implemented yet."; } ParamAndGradBucketGroup::ParamAndGradBucketGroup(const std::vector> &buckets, @@ -115,7 +115,6 @@ ParamAndGradBucketGroup::ParamAndGradBucketGroup(const std::vector= 2) { for (size_t i = 0; i < buckets_.size(); ++i) { @@ -144,7 +143,6 @@ void ParamAndGradBucketGroup::Reset() { if (ddp_config_.zero_stage >= 2) { std::fill(temp_full_grad_buffer_list_.begin(), temp_full_grad_buffer_list_.end(), nullptr); - std::fill(temp_full_grad_initialized_.begin(), temp_full_grad_initialized_.end(), false); } } @@ -209,12 +207,7 @@ void ParamAndGradBucketGroup::AccumulateParamGrad(const std::shared_ptr } temp_full_grad_buffer_list_[bucket_idx] = AllocateFlatBuffer(bucket_numel, bucket->grad_dtype(), parameter->GetDevice()); - temp_full_grad_initialized_[bucket_idx] = false; - } - - if (!temp_full_grad_initialized_[bucket_idx]) { temp_full_grad_buffer_list_[bucket_idx]->Fill(0.0f); - temp_full_grad_initialized_[bucket_idx] = true; } const size_t offset_bytes = param_start_in_bucket * kDataTypeToSize.at(bucket->grad_dtype()); @@ -257,12 +250,6 @@ void ParamAndGradBucketGroup::StartGradSync() { // TODO(zbl): Check NaN/Inf/too large in grad (options in DistributedDataParallelConfig) - for (auto bucket : buckets_) { - if (bucket->gradient_scaling_factor() != 1.f) { - bucket->ScaleGradients(bucket->gradient_scaling_factor()); - } - } - auto reduce_op = ddp_config_.average_in_collective ? function::ReduceOpType::kAvg : function::ReduceOpType::kSum; auto async_op = ddp_config_.overlap_grad_reduce && (ddp_config_.num_distributed_optimizer_instances == 1); @@ -274,6 +261,11 @@ void ParamAndGradBucketGroup::StartGradSync() { if (!full_grad_buffer) { continue; } + if (bucket->gradient_scaling_factor() != 1.f) { + // FIXME(zbl): should perform in-place multiply + // full_grad_buffer *= bucket->gradient_scaling_factor(); + LOG(FATAL) << "ParamAndGradBucketGroup::StartGradSync(): Inplace multiply not implemented yet."; + } CHECK(grad_shard_buffer_list_[i]) << "ParamAndGradBucketGroup: grad shard buffer missing."; auto local_data_view = grad_shard_buffer_list_[i]; grad_reduce_work_list_.push_back( @@ -282,6 +274,10 @@ void ParamAndGradBucketGroup::StartGradSync() { continue; } + if (bucket->gradient_scaling_factor() != 1.f) { + bucket->ScaleGradients(bucket->gradient_scaling_factor()); + } + std::shared_ptr grad_buffer = bucket->grad_data(); if (!grad_buffer) { continue; @@ -310,28 +306,28 @@ void ParamAndGradBucketGroup::FinishGradSync() { StartGradSync(); } - if (params_with_grad_.empty()) { - return; - } - if (!ddp_config_.overlap_grad_reduce) { // Assume reduce ops are synced and no work needs to be resolved grad_reduce_work_list_.clear(); + grad_reduce_bucket_indices_.clear(); grad_reduce_dispatched_ = false; return; } - CHECK(!grad_reduce_work_list_.empty()) - << "ParamAndGradBucketGroup: Communication call has not been issued for this bucket(" - << params_with_grad_.size() << "/" << params_.size() << " params have grad available)"; + if (grad_reduce_work_list_.empty()) { + grad_reduce_bucket_indices_.clear(); + grad_reduce_dispatched_ = false; + return; + } if (ddp_config_.zero_stage >= 2) { + CHECK_EQ(grad_reduce_work_list_.size(), grad_reduce_bucket_indices_.size()) + << "ParamAndGradBucketGroup: grad reduce works and bucket indices are out of sync."; for (size_t idx = 0; idx < grad_reduce_work_list_.size(); ++idx) { auto &work = grad_reduce_work_list_[idx]; work->WaitNonBlocking(); const size_t bucket_idx = grad_reduce_bucket_indices_[idx]; temp_full_grad_buffer_list_[bucket_idx].reset(); - temp_full_grad_initialized_[bucket_idx] = false; } grad_reduce_work_list_.clear(); grad_reduce_bucket_indices_.clear(); @@ -341,6 +337,7 @@ void ParamAndGradBucketGroup::FinishGradSync() { for (auto work : grad_reduce_work_list_) { work->WaitNonBlocking(); } grad_reduce_work_list_.clear(); + grad_reduce_bucket_indices_.clear(); grad_reduce_dispatched_ = false; } From d0258bd2cc634cddefcf206a40f7eaf773bceeb0 Mon Sep 17 00:00:00 2001 From: bolunz Date: Thu, 21 May 2026 01:37:59 +0000 Subject: [PATCH 07/12] fix: refactor zero-2 AccumulateGrad bypasser --- infini_train/include/autograd/function_hook.h | 6 +++ infini_train/include/tensor.h | 8 ---- infini_train/src/autograd/accumulate.cc | 6 +-- .../parallel/ddp/distributed_data_parallel.cc | 42 ++++++++++++------- infini_train/src/tensor.cc | 7 ---- 5 files changed, 34 insertions(+), 35 deletions(-) diff --git a/infini_train/include/autograd/function_hook.h b/infini_train/include/autograd/function_hook.h index 0cdd4170..b90e6172 100644 --- a/infini_train/include/autograd/function_hook.h +++ b/infini_train/include/autograd/function_hook.h @@ -17,6 +17,12 @@ namespace infini_train::autograd { class PostAccumulateGradHook { public: virtual void operator()(const std::shared_ptr &tensor) = 0; + + // ZeRO-2: Use this function to take over AccumulateGrad::Backward + virtual bool TryBypassAccumulate(const std::shared_ptr &, const std::shared_ptr &, bool, float) { + return false; + } + virtual ~PostAccumulateGradHook() = default; }; diff --git a/infini_train/include/tensor.h b/infini_train/include/tensor.h index 156fde78..12f45f57 100644 --- a/infini_train/include/tensor.h +++ b/infini_train/include/tensor.h @@ -230,12 +230,6 @@ class Tensor : public std::enable_shared_from_this { std::shared_ptr grad_accumulator(); void ResetAccumulator(); - // ZeRO-2: Use this function to take over AccumulateGrad::Backward - using GradAccumulateBypass - = std::function &grad_output, bool overwrite, float learning_rate)>; - GradAccumulateBypass grad_accumulate_bypass(); - void SetGradAccumulateBypass(GradAccumulateBypass); - void RegisterPostAccumulateGradHook(std::shared_ptr hook); autograd::PostAccumulateGradHook *post_accumulate_grad_hook() const; @@ -250,8 +244,6 @@ class Tensor : public std::enable_shared_from_this { // a strong reference to the accumulator to manage its lifetime. std::shared_ptr grad_accumulator_ = nullptr; std::shared_ptr post_accumulate_grad_hook_ = nullptr; - // ZeRO-2: Use this function to take over AccumulateGrad::Backward - GradAccumulateBypass grad_accumulate_bypass_ = nullptr; bool grad_overwrite_once_ = false; }; diff --git a/infini_train/src/autograd/accumulate.cc b/infini_train/src/autograd/accumulate.cc index c3558b47..474fea4a 100644 --- a/infini_train/src/autograd/accumulate.cc +++ b/infini_train/src/autograd/accumulate.cc @@ -34,9 +34,8 @@ AccumulateGrad::Backward(const std::vector> &grad_output } const bool overwrite = tensor_->ConsumeGradOverwriteFlag(); - // ZeRO-2: Use a bypass function to perform grad accumulation in temp full grad buffer - auto bypass = tensor_->grad_accumulate_bypass(); - if (bypass && bypass(grad_output, overwrite, learning_rate_)) { + auto hook = tensor_->post_accumulate_grad_hook(); + if (hook && hook->TryBypassAccumulate(tensor_, grad_output, overwrite, learning_rate_)) { tensor_->ResetAccumulator(); return {}; } @@ -56,7 +55,6 @@ AccumulateGrad::Backward(const std::vector> &grad_output auto new_grad = std::make_shared(*grad_output.get(), 0, grad_output->Dims()); tensor_->set_grad(new_grad); } - auto hook = tensor_->post_accumulate_grad_hook(); if (hook != nullptr) { (*hook)(tensor_->grad()); } diff --git a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc index dc37200c..02ea1d37 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc @@ -128,6 +128,30 @@ void DistributedDataParallel::BuildParamAndGradBuffers() { void DistributedDataParallel::RegisterBackwardHooks() { if (ddp_config_.zero_stage >= 2) { + // NOTE(zbl): ZeRO-2 bypasses Tensor::grad accumulation: stash grads in the bucket group's + // temporary full-grad buffer, then mark the bucket ready for reduce-scatter. + class Zero2AccumulateGradHook final : public autograd::PostAccumulateGradHook { + public: + explicit Zero2AccumulateGradHook(std::weak_ptr group) : group_(std::move(group)) {} + + bool TryBypassAccumulate(const std::shared_ptr ¶m, const std::shared_ptr &grad_output, + bool overwrite, float learning_rate) override { + if (auto group = group_.lock(); group) { + group->AccumulateParamGrad(param, grad_output, overwrite, learning_rate); + if (group->config().overlap_grad_reduce) { + group->RegisterGradReady(param); + } + return true; + } + return false; + } + + void operator()(const std::shared_ptr &) override {} + + private: + std::weak_ptr group_; + }; + auto &module = modules_.at(kModuleName); for (auto ¶m : module->Parameters()) { if (!param->requires_grad()) { @@ -138,22 +162,8 @@ void DistributedDataParallel::RegisterBackwardHooks() { continue; } std::weak_ptr weak_group = it->second; - std::weak_ptr weak_param = param; - param->SetGradAccumulateBypass([weak_group, weak_param](const std::shared_ptr &grad_output, - bool overwrite, float learning_rate) { - if (auto group = weak_group.lock(); group) { - auto param = weak_param.lock(); - if (!param) { - return false; - } - group->AccumulateParamGrad(param, grad_output, overwrite, learning_rate); - if (group->config().overlap_grad_reduce) { - group->RegisterGradReady(param); - } - return true; - } - return false; - }); + auto hook = std::make_unique(weak_group); + param->RegisterPostAccumulateGradHook(std::move(hook)); } return; } diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index 1f485e68..f7947030 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -559,13 +559,6 @@ void Tensor::ResetAccumulator() { } } -Tensor::GradAccumulateBypass Tensor::grad_accumulate_bypass() { - CHECK(grad_accumulator_) << "grad_accumulate_bypass() should only be called on leaf tensors"; - return grad_accumulate_bypass_; -} - -void Tensor::SetGradAccumulateBypass(GradAccumulateBypass bypass) { grad_accumulate_bypass_ = std::move(bypass); } - void Tensor::RegisterPostAccumulateGradHook(std::shared_ptr hook) { CHECK(requires_grad_) << "cannot register a hook on a tensor that doesn't require gradient"; From fe53cf3d75fa36698d4387bd359f91d43b067e7b Mon Sep 17 00:00:00 2001 From: bolunz Date: Mon, 25 May 2026 02:51:50 +0000 Subject: [PATCH 08/12] fix: move hook definition to anon namespace --- .../parallel/ddp/distributed_data_parallel.cc | 86 ++++++++++--------- 1 file changed, 45 insertions(+), 41 deletions(-) diff --git a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc index 02ea1d37..37fd5821 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc @@ -1,5 +1,6 @@ #include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h" +#include #include #include #include @@ -18,6 +19,48 @@ namespace infini_train::nn::parallel { namespace { constexpr char kModuleName[] = "module"; + +// NOTE(zbl): ZeRO-2 bypasses Tensor::grad accumulation: stash grads in the bucket group's +// temporary full-grad buffer, then mark the bucket ready for reduce-scatter. +class Zero2AccumulateGradHook final : public autograd::PostAccumulateGradHook { +public: + explicit Zero2AccumulateGradHook(std::weak_ptr group) : group_(std::move(group)) {} + + bool TryBypassAccumulate(const std::shared_ptr ¶m, const std::shared_ptr &grad_output, + bool overwrite, float learning_rate) override { + if (auto group = group_.lock(); group) { + group->AccumulateParamGrad(param, grad_output, overwrite, learning_rate); + if (group->config().overlap_grad_reduce) { + group->RegisterGradReady(param); + } + return true; + } + return false; + } + + void operator()(const std::shared_ptr &) override {} + +private: + std::weak_ptr group_; +}; + +class DDPPostAccumulateHook final : public autograd::PostAccumulateGradHook { +public: + using Callback = std::function &)>; + + DDPPostAccumulateHook(const std::weak_ptr param, Callback callback) + : param_(param), callback_(std::move(callback)) {} + + void operator()(const std::shared_ptr &) override { + if (auto param = param_.lock()) { + callback_(param); + } + } + +private: + std::weak_ptr param_; + Callback callback_; +}; } // namespace DistributedDataParallel::DistributedDataParallel(std::shared_ptr module, const Rank &rank, @@ -128,30 +171,6 @@ void DistributedDataParallel::BuildParamAndGradBuffers() { void DistributedDataParallel::RegisterBackwardHooks() { if (ddp_config_.zero_stage >= 2) { - // NOTE(zbl): ZeRO-2 bypasses Tensor::grad accumulation: stash grads in the bucket group's - // temporary full-grad buffer, then mark the bucket ready for reduce-scatter. - class Zero2AccumulateGradHook final : public autograd::PostAccumulateGradHook { - public: - explicit Zero2AccumulateGradHook(std::weak_ptr group) : group_(std::move(group)) {} - - bool TryBypassAccumulate(const std::shared_ptr ¶m, const std::shared_ptr &grad_output, - bool overwrite, float learning_rate) override { - if (auto group = group_.lock(); group) { - group->AccumulateParamGrad(param, grad_output, overwrite, learning_rate); - if (group->config().overlap_grad_reduce) { - group->RegisterGradReady(param); - } - return true; - } - return false; - } - - void operator()(const std::shared_ptr &) override {} - - private: - std::weak_ptr group_; - }; - auto &module = modules_.at(kModuleName); for (auto ¶m : module->Parameters()) { if (!param->requires_grad()) { @@ -168,29 +187,14 @@ void DistributedDataParallel::RegisterBackwardHooks() { return; } - class DDPPostAccumulateHook final : public autograd::PostAccumulateGradHook { - public: - DDPPostAccumulateHook(DistributedDataParallel *ddp, const std::weak_ptr param) - : ddp_(ddp), param_(param) {} - - void operator()(const std::shared_ptr &) override { - if (auto param = param_.lock()) { - ddp_->OnGradReady(param); - } - } - - private: - DistributedDataParallel *ddp_; - std::weak_ptr param_; - }; - auto &module = modules_.at(kModuleName); for (auto ¶m : module->Parameters()) { if (!param->requires_grad()) { continue; } - auto hook = std::make_unique(this, param); + auto hook = std::make_unique( + param, [this](const std::shared_ptr ¶m) { OnGradReady(param); }); param->RegisterPostAccumulateGradHook(std::move(hook)); } } From d64cbdda16a12092e7e3755511ee7837c43d02d3 Mon Sep 17 00:00:00 2001 From: bolunz Date: Tue, 2 Jun 2026 02:29:01 +0000 Subject: [PATCH 09/12] fix: remove use_distributed_optimizer flag, add PreAccumulateGradHook, add comments --- example/common/utils.cc | 10 ------ example/common/utils.h | 2 -- example/gpt2/main.cc | 14 +++----- example/llama3/main.cc | 14 +++----- infini_train/include/autograd/function_hook.h | 13 ++++++-- .../ddp/distributed_data_parallel_config.h | 13 +++----- .../nn/parallel/ddp/param_and_grad_buffer.h | 33 +++++++++++++++++-- infini_train/include/tensor.h | 7 +++- infini_train/src/autograd/accumulate.cc | 18 ++++++---- .../parallel/ddp/distributed_data_parallel.cc | 33 ++++++++----------- .../nn/parallel/ddp/param_and_grad_buffer.cc | 18 +++++----- infini_train/src/tensor.cc | 10 ++++++ scripts/test_config.json | 22 +++++-------- 13 files changed, 112 insertions(+), 95 deletions(-) diff --git a/example/common/utils.cc b/example/common/utils.cc index 0d7b966e..03cc7aa0 100644 --- a/example/common/utils.cc +++ b/example/common/utils.cc @@ -1,8 +1,5 @@ #include "example/common/utils.h" -#include "gflags/gflags.h" -#include "glog/logging.h" - namespace infini_train { float ConvertBF16ToFloat(void *ptr) { @@ -64,11 +61,4 @@ void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t s ifs.seekg(base + std::streamoff(len * sizeof(float))); } -void ValidateDistributedOptimizerFlags(bool use_distributed_optimizer) { - gflags::CommandLineFlagInfo zero_stage_info; - CHECK(gflags::GetCommandLineFlagInfo("zero_stage", &zero_stage_info)); - CHECK(use_distributed_optimizer || zero_stage_info.is_default) - << "--zero_stage requires --use_distributed_optimizer=true."; -} - } // namespace infini_train diff --git a/example/common/utils.h b/example/common/utils.h index 19e65ff1..5bab3e97 100644 --- a/example/common/utils.h +++ b/example/common/utils.h @@ -30,6 +30,4 @@ void ReadVectorAllFloat(std::ifstream &ifs, float *dst, int64_t len); void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t start, int64_t cnt); -void ValidateDistributedOptimizerFlags(bool use_distributed_optimizer); - } // namespace infini_train diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 0106dbc7..4e034894 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -58,8 +58,7 @@ DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation"); DEFINE_uint32(text_length, 64, "the length of the generated text"); // optimization DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations"); -DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); -DEFINE_int32(zero_stage, 1, "ZeRO stage (1/2/3), default 1 (only take effects when use_distributed_optimizer=true)"); +DEFINE_int32(zero_stage, 0, "ZeRO stage (0/1/2/3); 0 disables DistributedOptimizer"); // evaluation DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); @@ -116,7 +115,7 @@ const std::unordered_map kModelToConfigs = { DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); }); DEFINE_validator(device, [](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; }); -DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 1 && value <= 3; }); +DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 0 && value <= 3; }); void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; @@ -256,8 +255,7 @@ void Train(const nn::parallel::Rank &rank) { model = std::make_shared(model, pp_world_size, num_micro_batches, shapes, pp_rank, device, model_config.GetChunkSize()); if (ddp_world_size > 1) { - auto ddp_config = DistributedDataParallelConfig{ - .use_distributed_optimizer = FLAGS_use_distributed_optimizer, .zero_stage = FLAGS_zero_stage}; + auto ddp_config = DistributedDataParallelConfig{.zero_stage = FLAGS_zero_stage}; auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { (*mutable_chunks)[chunk_id] @@ -269,8 +267,7 @@ void Train(const nn::parallel::Rank &rank) { // before wrapping the model with DistributedDataParallel (DDP). // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors // are created during the conversion. - auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer, - .zero_stage = FLAGS_zero_stage}; + auto ddp_config = DistributedDataParallelConfig{.zero_stage = FLAGS_zero_stage}; model = std::make_shared(model, rank, ddp_config); } @@ -299,7 +296,7 @@ void Train(const nn::parallel::Rank &rank) { auto optimizer_creator = optimizers::SGD::Create(FLAGS_learning_rate); std::shared_ptr optimizer = nullptr; - if (FLAGS_use_distributed_optimizer) { + if (FLAGS_zero_stage >= 1) { auto model_chunks = (pp_world_size > 1) ? *(dynamic_cast(model.get())->mutable_chunks()) : std::vector>{model}; @@ -452,7 +449,6 @@ void Train(const nn::parallel::Rank &rank) { int main(int argc, char *argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, true); google::InitGoogleLogging(argv[0]); - ValidateDistributedOptimizerFlags(FLAGS_use_distributed_optimizer); auto precision_config = utils::PrecisionCheckConfig::Parse(FLAGS_precision_check); nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel, diff --git a/example/llama3/main.cc b/example/llama3/main.cc index e4e0b09a..945367d8 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -57,8 +57,7 @@ DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation"); DEFINE_uint32(text_length, 64, "the length of the generated text"); // optimization DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations"); -DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); -DEFINE_int32(zero_stage, 1, "ZeRO stage (1/2/3), default 1 (only take effects when use_distributed_optimizer=true)"); +DEFINE_int32(zero_stage, 0, "ZeRO stage (0/1/2/3); 0 disables DistributedOptimizer"); // evaluation DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); @@ -102,7 +101,7 @@ constexpr char kDtypeBF16[] = "bfloat16"; DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); }); DEFINE_validator(device, [](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; }); -DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 1 && value <= 3; }); +DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 0 && value <= 3; }); void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; @@ -226,8 +225,7 @@ void Train(const nn::parallel::Rank &rank) { model = std::make_shared(model, pp_world_size, num_micro_batches, shapes, pp_rank, device, model_config.GetChunkSize()); if (ddp_world_size > 1) { - auto ddp_config = DistributedDataParallelConfig{ - .use_distributed_optimizer = FLAGS_use_distributed_optimizer, .zero_stage = FLAGS_zero_stage}; + auto ddp_config = DistributedDataParallelConfig{.zero_stage = FLAGS_zero_stage}; auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { (*mutable_chunks)[chunk_id] @@ -240,8 +238,7 @@ void Train(const nn::parallel::Rank &rank) { // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors // are created during the conversion. - auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer, - .zero_stage = FLAGS_zero_stage}; + auto ddp_config = DistributedDataParallelConfig{.zero_stage = FLAGS_zero_stage}; model = std::make_shared(model, rank, ddp_config); } @@ -278,7 +275,7 @@ void Train(const nn::parallel::Rank &rank) { LOG(INFO) << "Optimizing " << params_to_optimize.size() << " model parameters"; } - if (FLAGS_use_distributed_optimizer) { + if (FLAGS_zero_stage >= 1) { auto model_chunks = (pp_world_size > 1) ? *(dynamic_cast(model.get())->mutable_chunks()) : std::vector>{model}; @@ -427,7 +424,6 @@ void Train(const nn::parallel::Rank &rank) { int main(int argc, char *argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, true); google::InitGoogleLogging(argv[0]); - ValidateDistributedOptimizerFlags(FLAGS_use_distributed_optimizer); auto precision_config = utils::PrecisionCheckConfig::Parse(FLAGS_precision_check); nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel, diff --git a/infini_train/include/autograd/function_hook.h b/infini_train/include/autograd/function_hook.h index b90e6172..734a0930 100644 --- a/infini_train/include/autograd/function_hook.h +++ b/infini_train/include/autograd/function_hook.h @@ -14,15 +14,22 @@ class ProcessGroup; namespace infini_train::autograd { -class PostAccumulateGradHook { +class PreAccumulateGradHook { public: - virtual void operator()(const std::shared_ptr &tensor) = 0; + virtual void operator()(const std::shared_ptr &grad_output) = 0; - // ZeRO-2: Use this function to take over AccumulateGrad::Backward + // Return true if this hook has handled the current gradient accumulation. virtual bool TryBypassAccumulate(const std::shared_ptr &, const std::shared_ptr &, bool, float) { return false; } + virtual ~PreAccumulateGradHook() = default; +}; + +class PostAccumulateGradHook { +public: + virtual void operator()(const std::shared_ptr &tensor) = 0; + virtual ~PostAccumulateGradHook() = default; }; diff --git a/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h b/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h index 4af2ce01..9223631f 100644 --- a/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h +++ b/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h @@ -30,21 +30,16 @@ class DistributedDataParallelConfig { // Ref: // https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/distributed/distributed_data_parallel_config.py // ====================================================== - // Whether to enable DistributedOptimizer (ZeRO-1 equivalent). - // When set true: - // 1) Gradients/params are managed by ParamAndGradBuffer and reduced in groups. - // 2) The classic DDP reducer path is not used (i.e., disable reducer/bucketing in the DDP sense). - bool use_distributed_optimizer = false; - // Whether to overlap gradient reduce-scatter/all-reduce with backward compute. // In this case, grad reduce is triggered immediately when a grad is ready or till all grads are ready. bool overlap_grad_reduce = true; - // ZeRO-DP Stage for memory optimization (Only take effects when use_distributed_optimizer=true) - // ZeRO-1: Optimizer states partitioning, by default + // ZeRO-DP stage for memory optimization. + // ZeRO-0: Disabled; use the classic DDP reducer path. + // ZeRO-1: Optimizer states partitioning // ZeRO-2: Gradients partitioning // ZeRO-3: Parameters partitioning - int zero_stage = 1; + int zero_stage = 0; // Whether to overlap parameter all-gather with forward compute. bool overlap_param_gather = true; diff --git a/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h b/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h index 11b63828..4af99d81 100644 --- a/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h +++ b/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h @@ -21,6 +21,19 @@ class Work; namespace infini_train::nn::parallel { class ParamAndGradBucket { public: + /** + * @brief Create bucket metadata and flat-buffer views. + * + * @param params Parameters in bucket-local order. + * @param param_data View of this bucket in the flat parameter buffer, or nullptr if unused. + * @param param_dtype Parameter storage dtype. + * @param grad_data View of this bucket in the flat gradient buffer; nullptr for ZeRO-2. + * @param grad_dtype Gradient storage dtype. + * @param offset Bucket start offset in the owning flat buffer. + * @param num_elements_unpadded Bucket element count before padding. + * @param gradient_scaling_factor Pre-collective gradient scale factor. + * @param bucket_id Bucket index in the owning ParamAndGradBuffer. + */ ParamAndGradBucket(const std::vector> ¶ms, const std::shared_ptr ¶m_data, DataType param_dtype, const std::shared_ptr &grad_data, DataType grad_dtype, size_t offset, size_t num_elements_unpadded, float gradient_scaling_factor, size_t bucket_id); @@ -65,6 +78,14 @@ class ParamAndGradBucket { class ParamAndGradBucketGroup { public: + /** + * @brief Group buckets that synchronize gradients and parameters together. + * + * @param buckets Buckets owned by this group. + * @param collective_pg Process group for gradient and parameter collectives. + * @param process_group_size Number of ranks in collective_pg. + * @param ddp_config DDP/DistributedOptimizer behavior config. + */ ParamAndGradBucketGroup(const std::vector> &buckets, const ProcessGroup *collective_pg, size_t process_group_size, DistributedDataParallelConfig ddp_config); @@ -79,8 +100,7 @@ class ParamAndGradBucketGroup { // Start grad reduce void StartGradSync(); - // Accumulate a parameter grad into bucket buffer - // ZeRO-2: Use this funtion to take over autograd::AccumulateGrad::Backward + // Accumulate a parameter grad into bucket storage for the ZeRO-2 pre-accumulate hook. void AccumulateParamGrad(const std::shared_ptr ¶meter, const std::shared_ptr &grad, bool overwrite, float learning_rate); @@ -138,6 +158,15 @@ class ParamAndGradBucketGroup { class ParamAndGradBuffer { public: + /** + * @brief Own flat buffers and bucket metadata for one dtype group. + * + * @param params Parameters with the same parameter/gradient dtype pair. + * @param param_dtype Flat parameter-buffer dtype. + * @param grad_dtype Gradient storage dtype. + * @param ddp_pg Data-parallel process group used by derived bucket groups. + * @param ddp_config DDP/DistributedOptimizer bucketing and padding config. + */ ParamAndGradBuffer(const std::vector> ¶ms, DataType ¶m_dtype, DataType &grad_dtype, const ProcessGroup *ddp_pg, DistributedDataParallelConfig ddp_config); diff --git a/infini_train/include/tensor.h b/infini_train/include/tensor.h index 12f45f57..58011762 100644 --- a/infini_train/include/tensor.h +++ b/infini_train/include/tensor.h @@ -18,6 +18,7 @@ namespace infini_train { namespace autograd { class Function; class AccumulateGrad; +class PreAccumulateGradHook; class PostAccumulateGradHook; } // namespace autograd @@ -230,8 +231,11 @@ class Tensor : public std::enable_shared_from_this { std::shared_ptr grad_accumulator(); void ResetAccumulator(); - void RegisterPostAccumulateGradHook(std::shared_ptr hook); + void RegisterPreAccumulateGradHook(std::shared_ptr hook); + + autograd::PreAccumulateGradHook *pre_accumulate_grad_hook() const; + void RegisterPostAccumulateGradHook(std::shared_ptr hook); autograd::PostAccumulateGradHook *post_accumulate_grad_hook() const; private: @@ -243,6 +247,7 @@ class Tensor : public std::enable_shared_from_this { // FIXME(dcj): This should be a weak_ptr. The autograd graph should hold // a strong reference to the accumulator to manage its lifetime. std::shared_ptr grad_accumulator_ = nullptr; + std::shared_ptr pre_accumulate_grad_hook_ = nullptr; std::shared_ptr post_accumulate_grad_hook_ = nullptr; bool grad_overwrite_once_ = false; diff --git a/infini_train/src/autograd/accumulate.cc b/infini_train/src/autograd/accumulate.cc index 474fea4a..0c34819f 100644 --- a/infini_train/src/autograd/accumulate.cc +++ b/infini_train/src/autograd/accumulate.cc @@ -21,7 +21,6 @@ AccumulateGrad::Backward(const std::vector> &grad_output CHECK_EQ(grad_outputs.size(), 1); auto grad_output = grad_outputs[0]; - auto grad = tensor_->grad(); auto device = tensor_->GetDevice(); core::DeviceGuard guard(device); @@ -34,12 +33,16 @@ AccumulateGrad::Backward(const std::vector> &grad_output } const bool overwrite = tensor_->ConsumeGradOverwriteFlag(); - auto hook = tensor_->post_accumulate_grad_hook(); - if (hook && hook->TryBypassAccumulate(tensor_, grad_output, overwrite, learning_rate_)) { - tensor_->ResetAccumulator(); - return {}; + auto pre_hook = tensor_->pre_accumulate_grad_hook(); + if (pre_hook) { + if (pre_hook->TryBypassAccumulate(tensor_, grad_output, overwrite, learning_rate_)) { + tensor_->ResetAccumulator(); + return {}; + } + (*pre_hook)(grad_output); } + auto grad = tensor_->grad(); if (grad) { if (overwrite) { // If the tensor is marked to overrite its current grad on next grad update @@ -55,8 +58,9 @@ AccumulateGrad::Backward(const std::vector> &grad_output auto new_grad = std::make_shared(*grad_output.get(), 0, grad_output->Dims()); tensor_->set_grad(new_grad); } - if (hook != nullptr) { - (*hook)(tensor_->grad()); + auto post_hook = tensor_->post_accumulate_grad_hook(); + if (post_hook != nullptr) { + (*post_hook)(tensor_->grad()); } tensor_->ResetAccumulator(); } diff --git a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc index 37fd5821..b149a690 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc @@ -22,9 +22,9 @@ constexpr char kModuleName[] = "module"; // NOTE(zbl): ZeRO-2 bypasses Tensor::grad accumulation: stash grads in the bucket group's // temporary full-grad buffer, then mark the bucket ready for reduce-scatter. -class Zero2AccumulateGradHook final : public autograd::PostAccumulateGradHook { +class Zero2PreAccumulateGradHook final : public autograd::PreAccumulateGradHook { public: - explicit Zero2AccumulateGradHook(std::weak_ptr group) : group_(std::move(group)) {} + explicit Zero2PreAccumulateGradHook(std::weak_ptr group) : group_(std::move(group)) {} bool TryBypassAccumulate(const std::shared_ptr ¶m, const std::shared_ptr &grad_output, bool overwrite, float learning_rate) override { @@ -67,23 +67,18 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr mod const DistributedDataParallelConfig ddp_config) : ddp_config_(ddp_config), ddp_pg_(ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(rank.GlobalRank()))) { - CHECK(ddp_config_.zero_stage >= 1 && ddp_config_.zero_stage <= 3) - << "DistributedDataParallel: zero_stage must be in 1/2/3."; - if (ddp_config_.zero_stage >= 3) { + CHECK(ddp_config_.zero_stage >= 0 && ddp_config_.zero_stage <= 3) + << "DistributedDataParallel: zero_stage must be in 0/1/2/3."; + if (ddp_config_.zero_stage == 3) { LOG(FATAL) << "DistributedDataParallel: ZeRO-3 is not implemented yet."; } - if (!ddp_config_.use_distributed_optimizer && ddp_config_.zero_stage >= 1) { - LOG(WARNING) << "DistributedDataParallel: zero_stage is ignored because " - "use_distributed_optimizer is false."; - ddp_config_.zero_stage = 1; - } for (auto ¶m : module->Parameters()) { if (!param->requires_grad()) { continue; } auto device = param->GetDevice(); CHECK_EQ(device.index(), rank.thread_rank()) << "All parameters must be on the same device as the module"; - if (!ddp_config.gradient_bucketing_enabled && !ddp_config.use_distributed_optimizer) { + if (!ddp_config.gradient_bucketing_enabled && ddp_config.zero_stage < 1) { auto hook = std::make_unique( function::ReduceOpType::kAvg, ddp_pg_); param->RegisterPostAccumulateGradHook(std::move(hook)); @@ -95,7 +90,7 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr mod } modules_[kModuleName] = std::move(module); - if (ddp_config.use_distributed_optimizer) { + if (ddp_config.zero_stage >= 1) { BuildParamAndGradBuffers(); RegisterBackwardHooks(); } else if (ddp_config.gradient_bucketing_enabled) { @@ -136,7 +131,6 @@ void DistributedDataParallel::BuildParamAndGradBuffers() { continue; } - // At the point, zero_stage is already aligned with use_distributed_optimizer. auto buffer = std::make_shared(param_list, param_dtype, grad_dtype, ddp_pg_, ddp_config_); param_grad_buffers_.push_back(buffer); @@ -145,7 +139,7 @@ void DistributedDataParallel::BuildParamAndGradBuffers() { // TODO(zbl): option for disable bucketing bucket_groups_ = PartitionBuckets(param_grad_buffers_, /*force_single_bucket_group=*/false); - if (ddp_config_.use_distributed_optimizer && ddp_config_.overlap_param_gather) { + if (ddp_config_.zero_stage >= 1 && ddp_config_.overlap_param_gather) { auto num_bucket_groups = bucket_groups_.size(); for (auto i = num_bucket_groups - 1; i > 0; --i) { bucket_groups_[i]->SetNextParamGatherBucketGroup(bucket_groups_[i - 1]); @@ -177,12 +171,11 @@ void DistributedDataParallel::RegisterBackwardHooks() { continue; } auto it = param_to_bucket_group_.find(param.get()); - if (it == param_to_bucket_group_.end()) { - continue; - } + CHECK(it != param_to_bucket_group_.end()); + std::weak_ptr weak_group = it->second; - auto hook = std::make_unique(weak_group); - param->RegisterPostAccumulateGradHook(std::move(hook)); + auto hook = std::make_unique(weak_group); + param->RegisterPreAccumulateGradHook(std::move(hook)); } return; } @@ -219,7 +212,7 @@ DistributedDataParallel::Forward(const std::vector> &inp if (reducer_) { reducer_->PrepareForBackward(); } - if (ddp_config_.use_distributed_optimizer) { + if (ddp_config_.zero_stage >= 1) { for (auto buffer : param_grad_buffers_) { buffer->RebindGradViews(); } } return outputs; diff --git a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc index 691602bc..6771654f 100644 --- a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc +++ b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc @@ -283,7 +283,7 @@ void ParamAndGradBucketGroup::StartGradSync() { continue; } - if (ddp_config_.use_distributed_optimizer) { + if (ddp_config_.zero_stage >= 1) { if (grad_buffer_shard_list_[i].empty()) { grad_buffer_shard_list_[i] = ShardBuffer(grad_buffer, collective_pg_size_); } @@ -342,7 +342,7 @@ void ParamAndGradBucketGroup::FinishGradSync() { } void ParamAndGradBucketGroup::StartParamSync(bool force_sync) { - CHECK(ddp_config_.use_distributed_optimizer); + CHECK(ddp_config_.zero_stage >= 1); if (!collective_pg_) { LOG(ERROR) << "ParamAndGradBucketGroup: StartParamSync called with null collective_pg_."; @@ -378,7 +378,7 @@ void ParamAndGradBucketGroup::StartParamSync(bool force_sync) { } void ParamAndGradBucketGroup::FinishParamSync(bool skip_next_bucket_dispatch) { - if (!ddp_config_.use_distributed_optimizer || !ddp_config_.overlap_param_gather) { + if (ddp_config_.zero_stage < 1 || !ddp_config_.overlap_param_gather) { return; } @@ -427,7 +427,7 @@ void ParamAndGradBuffer::BuildBuckets(DataType param_dtype, DataType grad_dtype) // Param start must be multiple of 64 auto PadParamStartIfNeeded = [&](size_t start) -> size_t { - if (ddp_config_.use_distributed_optimizer) { + if (ddp_config_.zero_stage >= 1) { // According to Megatron-LM, make sure each param starts at 128B aligned address (by default align to 64 // elements for precision >=16-bit) return PadTo(start, kParamStartAlignElements); @@ -437,7 +437,7 @@ void ParamAndGradBuffer::BuildBuckets(DataType param_dtype, DataType grad_dtype) // Bucket size shoule be multiple of ddp size and 128 (sweet spot for NCCL) auto PadBucketEndIfNeeded = [&](size_t bucket_end_index) -> size_t { - if (ddp_config_.use_distributed_optimizer) { + if (ddp_config_.zero_stage >= 1) { // According to Megatron-LM, ensure that all buckets start at a memory address that is 256B // aligned(128 values since params and grads use >= 16-bit precision) size_t lcm_val = std::lcm(ddp_world_size_, kBucketEndAlignElements); @@ -509,7 +509,7 @@ void ParamAndGradBuffer::BuildBuckets(DataType param_dtype, DataType grad_dtype) static_cast(0), std::plus()); CHECK(numel_unpadded_ <= numel_); - if (ddp_config_.use_distributed_optimizer) { + if (ddp_config_.zero_stage >= 1) { // numel must be multiple of ddp size (so that reduce-scatter could easily shard the buffer among ranks) CHECK_EQ(numel_ % ddp_world_size_, 0); } else { @@ -518,10 +518,10 @@ void ParamAndGradBuffer::BuildBuckets(DataType param_dtype, DataType grad_dtype) // 2. Allocate buffer auto device = params_.front()->GetDevice(); - if (ddp_config_.use_distributed_optimizer) { + if (ddp_config_.zero_stage >= 1) { param_buffer_ = AllocateFlatBuffer(numel_, param_dtype, device); } else { - // No param buffer needed if optimzer is not distributed + // No param buffer needed if optimizer is not distributed param_buffer_.reset(); } if (ddp_config_.zero_stage >= 2) { @@ -541,7 +541,7 @@ void ParamAndGradBuffer::BuildBuckets(DataType param_dtype, DataType grad_dtype) auto NewBucket = [&](const std::vector> &bucket_params, size_t start_index, size_t end_index, size_t num_elements_unpadded, size_t bucket_id) -> std::shared_ptr { - if (ddp_config_.use_distributed_optimizer) { + if (ddp_config_.zero_stage >= 1) { CHECK_EQ(start_index % ddp_world_size_, 0); CHECK_EQ(end_index % ddp_world_size_, 0); CHECK_EQ(bucket_indices_.at(bucket_id).first, start_index); diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index f7947030..3c2ae69b 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -559,6 +559,16 @@ void Tensor::ResetAccumulator() { } } +void Tensor::RegisterPreAccumulateGradHook(std::shared_ptr hook) { + CHECK(requires_grad_) << "cannot register a hook on a tensor that doesn't require gradient"; + + CHECK_EQ(grad_fn_, nullptr) << "pre accumulate grad hooks cannot be registered on non-leaf tensors"; + + pre_accumulate_grad_hook_ = hook; +} + +autograd::PreAccumulateGradHook *Tensor::pre_accumulate_grad_hook() const { return pre_accumulate_grad_hook_.get(); } + void Tensor::RegisterPostAccumulateGradHook(std::shared_ptr hook) { CHECK(requires_grad_) << "cannot register a hook on a tensor that doesn't require gradient"; diff --git a/scripts/test_config.json b/scripts/test_config.json index 6a4f1d5d..74581649 100644 --- a/scripts/test_config.json +++ b/scripts/test_config.json @@ -210,7 +210,7 @@ "num_iteration": 10, "batch_size": 10, "total_batch_size": 5120, - "use_distributed_optimizer": true + "zero_stage": 1 } }, { @@ -221,7 +221,6 @@ "num_iteration": 10, "batch_size": 10, "total_batch_size": 5120, - "use_distributed_optimizer": true, "zero_stage": 2 } }, @@ -233,7 +232,7 @@ "num_iteration": 10, "batch_size": 10, "total_batch_size": 5120, - "use_distributed_optimizer": true + "zero_stage": 1 } }, { @@ -244,7 +243,6 @@ "num_iteration": 10, "batch_size": 10, "total_batch_size": 5120, - "use_distributed_optimizer": true, "zero_stage": 2 } }, @@ -257,7 +255,7 @@ "batch_size": 40, "total_batch_size": 5120, "tensor_parallel": 4, - "use_distributed_optimizer": true + "zero_stage": 1 } }, { @@ -269,7 +267,6 @@ "batch_size": 40, "total_batch_size": 5120, "tensor_parallel": 4, - "use_distributed_optimizer": true, "zero_stage": 2 } }, @@ -282,7 +279,7 @@ "batch_size": 40, "total_batch_size": 5120, "tensor_parallel": 4, - "use_distributed_optimizer": true + "zero_stage": 1 } }, { @@ -294,7 +291,6 @@ "batch_size": 40, "total_batch_size": 5120, "tensor_parallel": 4, - "use_distributed_optimizer": true, "zero_stage": 2 } }, @@ -308,7 +304,7 @@ "total_batch_size": 5120, "tensor_parallel": 4, "sequence_parallel": true, - "use_distributed_optimizer": true + "zero_stage": 1 } }, { @@ -321,7 +317,6 @@ "total_batch_size": 5120, "tensor_parallel": 4, "sequence_parallel": true, - "use_distributed_optimizer": true, "zero_stage": 2 } }, @@ -335,7 +330,7 @@ "total_batch_size": 5120, "tensor_parallel": 4, "sequence_parallel": true, - "use_distributed_optimizer": true + "zero_stage": 1 } }, { @@ -348,7 +343,6 @@ "total_batch_size": 5120, "tensor_parallel": 4, "sequence_parallel": true, - "use_distributed_optimizer": true, "zero_stage": 2 } }, @@ -364,7 +358,7 @@ "sequence_parallel": true, "pipeline_parallel": 2, "virtual_pipeline_parallel": 2, - "use_distributed_optimizer": true + "zero_stage": 1 } }, { @@ -379,7 +373,7 @@ "sequence_parallel": true, "pipeline_parallel": 2, "virtual_pipeline_parallel": 2, - "use_distributed_optimizer": true + "zero_stage": 1 } } ] From d4f7e7a735e620ac63f428d5305e0611fac0e82a Mon Sep 17 00:00:00 2001 From: bolunz Date: Tue, 2 Jun 2026 02:37:26 +0000 Subject: [PATCH 10/12] fix: remove unnecessary include --- example/gpt2/main.cc | 1 - example/llama3/main.cc | 1 - 2 files changed, 2 deletions(-) diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 4e034894..67738e14 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -36,7 +36,6 @@ #include "example/common/tiny_shakespeare_dataset.h" #include "example/common/tokenizer.h" -#include "example/common/utils.h" #include "example/gpt2/checkpoint_loader.h" #include "example/gpt2/config.h" diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 945367d8..fadf205e 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -35,7 +35,6 @@ #include "example/common/tiny_shakespeare_dataset.h" #include "example/common/tokenizer.h" -#include "example/common/utils.h" #include "example/llama3/checkpoint_loader.h" #include "example/llama3/config.h" From 77a4a09a451e5a733e0347d904307ea7e91d8d64 Mon Sep 17 00:00:00 2001 From: bolunz Date: Wed, 3 Jun 2026 16:52:09 +0800 Subject: [PATCH 11/12] fix: resolve comments --- .../ddp/distributed_data_parallel_config.h | 1 - infini_train/include/tensor.h | 1 - .../parallel/ddp/distributed_data_parallel.cc | 85 +++++++++---------- 3 files changed, 42 insertions(+), 45 deletions(-) diff --git a/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h b/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h index 9223631f..729456ce 100644 --- a/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h +++ b/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h @@ -59,7 +59,6 @@ class DistributedDataParallelConfig { // Maximum number of parameters in each ParamAndGradBucket. // NOTE(zbl): This is distinct from DDP Reducer's MB-based bucket caps. - // TODO(zbl): To unify the definition of bucket_size argument for users size_t bucket_size_in_elements = 1000000; // Whether to pad bucket sizes to improve NCCL bus bandwidth utilization. diff --git a/infini_train/include/tensor.h b/infini_train/include/tensor.h index 58011762..b6de3340 100644 --- a/infini_train/include/tensor.h +++ b/infini_train/include/tensor.h @@ -232,7 +232,6 @@ class Tensor : public std::enable_shared_from_this { void ResetAccumulator(); void RegisterPreAccumulateGradHook(std::shared_ptr hook); - autograd::PreAccumulateGradHook *pre_accumulate_grad_hook() const; void RegisterPostAccumulateGradHook(std::shared_ptr hook); diff --git a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc index b149a690..a3bfe008 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc @@ -20,47 +20,6 @@ namespace infini_train::nn::parallel { namespace { constexpr char kModuleName[] = "module"; -// NOTE(zbl): ZeRO-2 bypasses Tensor::grad accumulation: stash grads in the bucket group's -// temporary full-grad buffer, then mark the bucket ready for reduce-scatter. -class Zero2PreAccumulateGradHook final : public autograd::PreAccumulateGradHook { -public: - explicit Zero2PreAccumulateGradHook(std::weak_ptr group) : group_(std::move(group)) {} - - bool TryBypassAccumulate(const std::shared_ptr ¶m, const std::shared_ptr &grad_output, - bool overwrite, float learning_rate) override { - if (auto group = group_.lock(); group) { - group->AccumulateParamGrad(param, grad_output, overwrite, learning_rate); - if (group->config().overlap_grad_reduce) { - group->RegisterGradReady(param); - } - return true; - } - return false; - } - - void operator()(const std::shared_ptr &) override {} - -private: - std::weak_ptr group_; -}; - -class DDPPostAccumulateHook final : public autograd::PostAccumulateGradHook { -public: - using Callback = std::function &)>; - - DDPPostAccumulateHook(const std::weak_ptr param, Callback callback) - : param_(param), callback_(std::move(callback)) {} - - void operator()(const std::shared_ptr &) override { - if (auto param = param_.lock()) { - callback_(param); - } - } - -private: - std::weak_ptr param_; - Callback callback_; -}; } // namespace DistributedDataParallel::DistributedDataParallel(std::shared_ptr module, const Rank &rank, @@ -165,6 +124,31 @@ void DistributedDataParallel::BuildParamAndGradBuffers() { void DistributedDataParallel::RegisterBackwardHooks() { if (ddp_config_.zero_stage >= 2) { + // NOTE(zbl): ZeRO-2 bypasses Tensor::grad accumulation: stash grads in the bucket group's + // temporary full-grad buffer, then mark the bucket ready for reduce-scatter. + class Zero2PreAccumulateGradHook final : public autograd::PreAccumulateGradHook { + public: + explicit Zero2PreAccumulateGradHook(std::weak_ptr group) + : group_(std::move(group)) {} + + bool TryBypassAccumulate(const std::shared_ptr ¶m, const std::shared_ptr &grad_output, + bool overwrite, float learning_rate) override { + if (auto group = group_.lock(); group) { + group->AccumulateParamGrad(param, grad_output, overwrite, learning_rate); + if (group->config().overlap_grad_reduce) { + group->RegisterGradReady(param); + } + return true; + } + return false; + } + + void operator()(const std::shared_ptr &) override {} + + private: + std::weak_ptr group_; + }; + auto &module = modules_.at(kModuleName); for (auto ¶m : module->Parameters()) { if (!param->requires_grad()) { @@ -180,14 +164,29 @@ void DistributedDataParallel::RegisterBackwardHooks() { return; } + class DDPPostAccumulateHook final : public autograd::PostAccumulateGradHook { + public: + DDPPostAccumulateHook(DistributedDataParallel *ddp, const std::weak_ptr param) + : ddp_(ddp), param_(param) {} + + void operator()(const std::shared_ptr &) override { + if (auto param = param_.lock()) { + ddp_->OnGradReady(param); + } + } + + private: + DistributedDataParallel *ddp_; + std::weak_ptr param_; + }; + auto &module = modules_.at(kModuleName); for (auto ¶m : module->Parameters()) { if (!param->requires_grad()) { continue; } - auto hook = std::make_unique( - param, [this](const std::shared_ptr ¶m) { OnGradReady(param); }); + auto hook = std::make_unique(this, param); param->RegisterPostAccumulateGradHook(std::move(hook)); } } From f590947fab4707426a713f94f32d90ff2d144d89 Mon Sep 17 00:00:00 2001 From: bolunz Date: Thu, 4 Jun 2026 11:39:01 +0800 Subject: [PATCH 12/12] fix: rename zero test cases --- scripts/test_config.json | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scripts/test_config.json b/scripts/test_config.json index 74581649..32c3e202 100644 --- a/scripts/test_config.json +++ b/scripts/test_config.json @@ -214,7 +214,7 @@ } }, { - "id": "3_distopt_zero2", + "id": "3_zero2", "args": { "dtype": "float32", "nthread_per_process": 8, @@ -236,7 +236,7 @@ } }, { - "id": "3_bfloat16_distopt_zero2", + "id": "3_bfloat16_zero2", "args": { "dtype": "bfloat16", "nthread_per_process": 8, @@ -259,7 +259,7 @@ } }, { - "id": "4_distopt_zero2", + "id": "4_zero2", "args": { "dtype": "float32", "nthread_per_process": 8, @@ -283,7 +283,7 @@ } }, { - "id": "4_bfloat16_distopt_zero2", + "id": "4_bfloat16_zero2", "args": { "dtype": "bfloat16", "nthread_per_process": 8, @@ -308,7 +308,7 @@ } }, { - "id": "5_distopt_zero2", + "id": "5_zero2", "args": { "dtype": "float32", "nthread_per_process": 8, @@ -334,7 +334,7 @@ } }, { - "id": "5_bfloat16_distopt_zero2", + "id": "5_bfloat16_zero2", "args": { "dtype": "bfloat16", "nthread_per_process": 8,