diff --git a/infini_train/include/autograd/comm.h b/infini_train/include/autograd/comm.h index ec3cfe4a..c67372ee 100644 --- a/infini_train/include/autograd/comm.h +++ b/infini_train/include/autograd/comm.h @@ -15,7 +15,7 @@ class ProcessGroup; } // namespace nn::parallel } // namespace infini_train -namespace infini_train::autograd { +namespace infini_train::autograd::comm { class Scatter : public autograd::Function { public: static constexpr char kType[] = "ScatterFunction"; @@ -99,4 +99,4 @@ class ReduceAddCoalesced : public autograd::Function { std::vector target_gpus_; int64_t num_inputs_ = 0; }; -} // namespace infini_train::autograd +} // namespace infini_train::autograd::comm diff --git a/infini_train/include/autograd/gather.h b/infini_train/include/autograd/gather.h new file mode 100644 index 00000000..0fb44c51 --- /dev/null +++ b/infini_train/include/autograd/gather.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +class Gather : public Function { +public: + static constexpr char kType[] = "GatherFunction"; + + Gather(int64_t dim = 0) : Function(kType), dim_(dim) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + const int64_t dim_ = 0; + std::vector input_dims_; +}; + +} // namespace infini_train::autograd diff --git a/infini_train/include/autograd/misc.h b/infini_train/include/autograd/misc.h deleted file mode 100644 index ccfca22d..00000000 --- a/infini_train/include/autograd/misc.h +++ /dev/null @@ -1,113 +0,0 @@ -#pragma once - -#include -#include - -#include "infini_train/include/autograd/function.h" - -namespace infini_train { -class Tensor; -} - -namespace infini_train::autograd { -class Split : public Function { -public: - static constexpr char kType[] = "SplitFunction"; - - Split(int64_t split_size, int dim = 0) : Function(kType), split_size_(split_size), dim_(dim) {} - - std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; - -private: - const int64_t split_size_ = 0; - const int dim_ = 0; - std::vector input_dims_; -}; - -// FIXME(zbl): This function aligns with torch.gather -// Currently named IndexGather to avoid conflict with communication operators -// Should be renamed to Gather later for interface consistency -class IndexGather : public Function { -public: - static constexpr char kType[] = "IndexGatherFunction"; - - IndexGather(int64_t dim = 0) : Function(kType), dim_(dim) {} - - std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; - -private: - const int64_t dim_ = 0; - std::vector input_dims_; -}; - -class NoOp : public Function { -public: - static constexpr char kType[] = "NoOpFunction"; - - explicit NoOp(const std::vector &output_dims) : Function(kType), output_dims_(output_dims) {} - - std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; - -private: - const std::vector output_dims_; - std::vector input_dims_; -}; - -class Slice : public Function { -public: - static constexpr char kType[] = "SliceFunction"; - - Slice(const std::vector &starts, const std::vector &ends, const std::vector &steps) - : Function(kType), starts_(starts), ends_(ends), steps_(steps) {} - std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; - -private: - const std::vector starts_; - const std::vector ends_; - const std::vector steps_; -}; - -class Stack : public Function { -public: - static constexpr char kType[] = "StackFunction"; - - Stack(int64_t dim) : Function(kType), dim_(dim) {} - - std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; - -private: - int64_t dim_ = 0; - std::vector input_dims_; -}; - -class Concat : public Function { -public: - static constexpr char kType[] = "ConcatFunction"; - - Concat(int64_t dim) : Function(kType), dim_(dim) {} - - std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; - -private: - const int64_t dim_ = 0; - std::vector> input_dims_list_; -}; -} // namespace infini_train::autograd diff --git a/infini_train/include/autograd/no_op.h b/infini_train/include/autograd/no_op.h new file mode 100644 index 00000000..a097393d --- /dev/null +++ b/infini_train/include/autograd/no_op.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +class NoOp : public Function { +public: + static constexpr char kType[] = "NoOpFunction"; + + explicit NoOp(const std::vector &output_dims) : Function(kType), output_dims_(output_dims) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + const std::vector output_dims_; + std::vector input_dims_; +}; + +} // namespace infini_train::autograd diff --git a/infini_train/include/autograd/scatter.h b/infini_train/include/autograd/scatter.h new file mode 100644 index 00000000..3d6f830a --- /dev/null +++ b/infini_train/include/autograd/scatter.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +class Scatter : public Function { +public: + static constexpr char kType[] = "ScatterFunction"; + + explicit Scatter(const std::vector &output_dims) : Function(kType), output_dims_(output_dims) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + std::vector output_dims_; +}; + +} // namespace infini_train::autograd diff --git a/infini_train/include/autograd/topk.h b/infini_train/include/autograd/topk.h new file mode 100644 index 00000000..7752efca --- /dev/null +++ b/infini_train/include/autograd/topk.h @@ -0,0 +1,40 @@ +#pragma once + +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +// FIXME(dcj): Align this API with torch.topk and return both values and indices from Forward once +// InfiniTrain autograd supports marking individual outputs as non-differentiable. Today indices +// are exposed through TopIndices() to avoid waiting for gradients on metadata outputs. +class TopK : public Function { +public: + static constexpr char kType[] = "TopKFunction"; + + explicit TopK(int64_t topk, int64_t dim = -1, bool largest = true, bool sorted = true) + : Function(kType), topk_(topk), dim_(dim), largest_(largest), sorted_(sorted) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + + std::shared_ptr TopIndices() const; + +private: + int64_t topk_ = 1; + int64_t dim_ = -1; + bool largest_ = true; + bool sorted_ = true; + std::shared_ptr top_indices_; + std::vector input_dims_; +}; + +} // namespace infini_train::autograd diff --git a/infini_train/include/autograd/transform.h b/infini_train/include/autograd/transform.h index 92ce71ea..88b7d56e 100644 --- a/infini_train/include/autograd/transform.h +++ b/infini_train/include/autograd/transform.h @@ -78,4 +78,70 @@ class RepeatInterleave : public Function { std::vector input_dims_; }; +class Split : public Function { +public: + static constexpr char kType[] = "SplitFunction"; + + Split(int64_t split_size, int dim = 0) : Function(kType), split_size_(split_size), dim_(dim) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + const int64_t split_size_ = 0; + const int dim_ = 0; + std::vector input_dims_; +}; + +class Stack : public Function { +public: + static constexpr char kType[] = "StackFunction"; + + Stack(int64_t dim) : Function(kType), dim_(dim) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + int64_t dim_ = 0; + std::vector input_dims_; +}; + +class Concat : public Function { +public: + static constexpr char kType[] = "ConcatFunction"; + + Concat(int64_t dim) : Function(kType), dim_(dim) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + const int64_t dim_ = 0; + std::vector> input_dims_list_; +}; + +class Slice : public Function { +public: + static constexpr char kType[] = "SliceFunction"; + + Slice(const std::vector &starts, const std::vector &ends, const std::vector &steps) + : Function(kType), starts_(starts), ends_(ends), steps_(steps) {} + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + const std::vector starts_; + const std::vector ends_; + const std::vector steps_; +}; + } // namespace infini_train::autograd diff --git a/infini_train/include/core/backend_type_map.h b/infini_train/include/core/backend_type_map.h index f67b8da7..38fea110 100644 --- a/infini_train/include/core/backend_type_map.h +++ b/infini_train/include/core/backend_type_map.h @@ -48,6 +48,9 @@ template struct BackendTypeMap; // ----------------------------------------------------------------------------- #define INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV) \ namespace infini_train::core { \ + template <> struct BackendTypeMap { \ + using type = bool; \ + }; \ template <> struct BackendTypeMap { \ using type = uint8_t; \ }; \ diff --git a/infini_train/include/datatype.h b/infini_train/include/datatype.h index cf637300..6efa849c 100644 --- a/infini_train/include/datatype.h +++ b/infini_train/include/datatype.h @@ -84,6 +84,7 @@ struct alignas(2) BF16 { // DataType enum and metadata tables // ----------------------------------------------------------------------------- enum class DataType : int8_t { + kBOOL, kUINT8, kINT8, kUINT16, @@ -99,16 +100,18 @@ enum class DataType : int8_t { }; inline const std::unordered_map kDataTypeToSize = { - {DataType::kUINT8, 1}, {DataType::kINT8, 1}, {DataType::kUINT16, 2}, {DataType::kINT16, 2}, - {DataType::kUINT32, 4}, {DataType::kINT32, 4}, {DataType::kUINT64, 8}, {DataType::kINT64, 8}, - {DataType::kBFLOAT16, 2}, {DataType::kFLOAT16, 2}, {DataType::kFLOAT32, 4}, {DataType::kFLOAT64, 8}, + {DataType::kBOOL, 1}, {DataType::kUINT8, 1}, {DataType::kINT8, 1}, {DataType::kUINT16, 2}, + {DataType::kINT16, 2}, {DataType::kUINT32, 4}, {DataType::kINT32, 4}, {DataType::kUINT64, 8}, + {DataType::kINT64, 8}, {DataType::kBFLOAT16, 2}, {DataType::kFLOAT16, 2}, {DataType::kFLOAT32, 4}, + {DataType::kFLOAT64, 8}, }; inline const std::unordered_map kDataTypeToDesc = { - {DataType::kUINT8, "uint8"}, {DataType::kINT8, "int8"}, {DataType::kUINT16, "uint16"}, - {DataType::kINT16, "int16"}, {DataType::kUINT32, "uint32"}, {DataType::kINT32, "int32"}, - {DataType::kUINT64, "uint64"}, {DataType::kINT64, "int64"}, {DataType::kBFLOAT16, "bf16"}, - {DataType::kFLOAT16, "fp16"}, {DataType::kFLOAT32, "fp32"}, {DataType::kFLOAT64, "fp64"}, + {DataType::kBOOL, "bool"}, {DataType::kUINT8, "uint8"}, {DataType::kINT8, "int8"}, + {DataType::kUINT16, "uint16"}, {DataType::kINT16, "int16"}, {DataType::kUINT32, "uint32"}, + {DataType::kINT32, "int32"}, {DataType::kUINT64, "uint64"}, {DataType::kINT64, "int64"}, + {DataType::kBFLOAT16, "bf16"}, {DataType::kFLOAT16, "fp16"}, {DataType::kFLOAT32, "fp32"}, + {DataType::kFLOAT64, "fp64"}, }; // ============================================================================= diff --git a/infini_train/include/dtype_dispatch.h b/infini_train/include/dtype_dispatch.h index e3db38b8..8bd5054b 100644 --- a/infini_train/include/dtype_dispatch.h +++ b/infini_train/include/dtype_dispatch.h @@ -180,10 +180,11 @@ namespace infini_train { #define INFINI_FLOATING_TYPES DataType::kFLOAT32, DataType::kFLOAT64 #define INFINI_REDUCED_FLOATING_TYPES DataType::kFLOAT16, DataType::kBFLOAT16 #define INFINI_ALL_FLOATING_TYPES INFINI_FLOATING_TYPES, INFINI_REDUCED_FLOATING_TYPES +#define INFINI_LOGICAL_TYPES DataType::kBOOL #define INFINI_SIGNED_INTEGRAL_TYPES DataType::kINT8, DataType::kINT16, DataType::kINT32, DataType::kINT64 #define INFINI_UNSIGNED_INTEGRAL_TYPES DataType::kUINT8, DataType::kUINT16, DataType::kUINT32, DataType::kUINT64 #define INFINI_ALL_INTEGRAL_TYPES INFINI_SIGNED_INTEGRAL_TYPES, INFINI_UNSIGNED_INTEGRAL_TYPES -#define INFINI_ALL_TYPES INFINI_ALL_FLOATING_TYPES, INFINI_ALL_INTEGRAL_TYPES +#define INFINI_ALL_NUMERIC_TYPES INFINI_ALL_FLOATING_TYPES, INFINI_ALL_INTEGRAL_TYPES #define INFINI_8_BIT_TYPES DataType::kINT8, DataType::kUINT8 #define INFINI_16_BIT_TYPES DataType::kINT16, DataType::kUINT16, DataType::kFLOAT16, DataType::kBFLOAT16 #define INFINI_32_BIT_TYPES DataType::kINT32, DataType::kUINT32, DataType::kFLOAT32 @@ -242,6 +243,7 @@ auto DispatchByTypeMap(DataType dtype, Functor &&func, std::string_view context_ } \ } + CASE_FOR_TYPE(DataType::kBOOL) CASE_FOR_TYPE(DataType::kUINT8) CASE_FOR_TYPE(DataType::kINT8) CASE_FOR_TYPE(DataType::kUINT16) @@ -290,6 +292,7 @@ struct TypeMapDispatcher { break; \ } + CASE_FOR_TYPE(DataType::kBOOL) CASE_FOR_TYPE(DataType::kUINT8) CASE_FOR_TYPE(DataType::kINT8) CASE_FOR_TYPE(DataType::kUINT16) diff --git a/infini_train/include/nn/modules/transformer/moe/experts.h b/infini_train/include/nn/modules/transformer/moe/experts.h new file mode 100644 index 00000000..a3dda7f0 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/experts.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +class SequentialMLP : public CloneableModule { +public: + static constexpr char kType[] = "SequentialMLP"; + static constexpr char kExpertNamePrefix[] = "expert_"; + + explicit SequentialMLP(const TransformerConfig &config); + + std::vector> Forward(const std::vector> &input_tensors) override; + +private: + TransformerConfig config_; + int64_t num_local_experts_ = 0; +}; + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/moe/moe_layer.h b/infini_train/include/nn/modules/transformer/moe/moe_layer.h new file mode 100644 index 00000000..e5fdb3ab --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/moe_layer.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +class MoELayer : public CloneableModule { +public: + static constexpr char kType[] = "MoELayer"; + static constexpr char kRouterLayerName[] = "router"; + static constexpr char kExpertsLayerName[] = "experts"; + + explicit MoELayer(const TransformerConfig &config); + + std::vector> Forward(const std::vector> &input_tensors) override; + +private: + TransformerConfig config_; +}; + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/moe/moe_utils.h b/infini_train/include/nn/modules/transformer/moe/moe_utils.h new file mode 100644 index 00000000..6ce26f44 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/moe_utils.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include +#include + +#include "infini_train/include/nn/modules/transformer/transformer_config.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +std::vector> TopkRoutingWithScoreFunction(const std::shared_ptr &logits, int64_t topk, + bool use_pre_softmax, + std::optional scaling_factor, + const MoEConfig::RouterScoreFunction &score_function); + +const MoEConfig &RequireMoEConfig(const TransformerConfig &config); + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/moe/router.h b/infini_train/include/nn/modules/transformer/moe/router.h new file mode 100644 index 00000000..1279c217 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/router.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +class TopKRouter : public CloneableModule { +public: + static constexpr char kType[] = "TopKRouter"; + static constexpr char kParamWeightName[] = "weight"; + static constexpr char kParamBiasName[] = "bias"; + + explicit TopKRouter(const TransformerConfig &config); + + std::vector> Forward(const std::vector> &input_tensors) override; + +private: + TransformerConfig config_; +}; + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/transformer_config.h b/infini_train/include/nn/modules/transformer/transformer_config.h index 62379666..8c440d16 100644 --- a/infini_train/include/nn/modules/transformer/transformer_config.h +++ b/infini_train/include/nn/modules/transformer/transformer_config.h @@ -20,11 +20,45 @@ enum class MLPType { kSwiGLU // SwiGLU activation }; +enum class FFNType { + kDense, // Standard dense MLP + kMoE // Mixture-of-Experts MLP +}; + enum class NormType { kLayerNorm, // LayerNorm kRMSNorm // RMSNorm }; +struct MoEConfig { + enum class RouterScoreFunction { + kSoftmax, + kSigmoid, + }; + + enum class DispatcherType { + kAllGather, // Megatron-style AllGather dispatcher. Degenerates to local dispatch when TP=EP=1. + kAllToAll // Megatron-style AllToAll dispatcher for expert parallel MoE. + }; + + enum class ExpertImpl { + kSequential // Run local experts sequentially + }; + + int64_t num_experts = 0; + int64_t expert_parallel_size = 1; + int64_t router_topk = 1; + bool router_pre_softmax = false; + std::optional router_topk_scaling_factor = std::nullopt; + RouterScoreFunction router_score_function = RouterScoreFunction::kSoftmax; + float aux_loss_coeff = 0.0f; + std::optional expert_capacity_factor = std::nullopt; + bool pad_expert_input_to_capacity = false; + int64_t moe_ffn_hidden_size = 0; + DispatcherType dispatcher_type = DispatcherType::kAllGather; + ExpertImpl expert_impl = ExpertImpl::kSequential; +}; + struct TransformerConfig { int64_t block_size = 1024; // Max seq_len int64_t vocab_size = 50304; // Vocab size @@ -36,6 +70,7 @@ struct TransformerConfig { AttentionType attention_type = AttentionType::kStandard; // Attention mechanism type MLPType activation_type = MLPType::kGELU; // MLP activation type + FFNType ffn_type = FFNType::kDense; // Feed-forward module type NormType norm_type = NormType::kLayerNorm; // Normalization type bool add_bias_linear = true; // Whether to add learnable bias to all Linear layers in the Transformer block, @@ -48,6 +83,7 @@ struct TransformerConfig { float ffn_expansion_ratio = 4.0f; // MLP output: n_embd * ffn_expansion_ratio std::optional ffn_dim_multiplier = 1.5f; // FFN dim multiplier int64_t multiple_of = 256; // FFN dims must be multiple of this number + std::optional moe_config = std::nullopt; // RoPE config float rope_theta = 500000.0f; // theta in RoPE diff --git a/infini_train/src/autograd/comm.cc b/infini_train/src/autograd/comm.cc index d524088a..325422b3 100644 --- a/infini_train/src/autograd/comm.cc +++ b/infini_train/src/autograd/comm.cc @@ -8,7 +8,7 @@ #include "infini_train/include/nn/parallel/process_group.h" #include "infini_train/include/tensor.h" -namespace infini_train::autograd { +namespace infini_train::autograd::comm { Scatter::Scatter(const std::vector &target_gpus, int64_t dim, const infini_train::nn::parallel::ProcessGroup *pg) @@ -122,4 +122,4 @@ std::vector> ReduceAddCoalesced::Backward(const std::vector> &grad_outputs) { return std::make_shared(target_gpus_)->Apply(grad_outputs); } -} // namespace infini_train::autograd +} // namespace infini_train::autograd::comm diff --git a/infini_train/src/autograd/gather.cc b/infini_train/src/autograd/gather.cc new file mode 100644 index 00000000..a30cb013 --- /dev/null +++ b/infini_train/src/autograd/gather.cc @@ -0,0 +1,37 @@ +#include "infini_train/include/autograd/gather.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { +std::vector> Gather::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 2); + const auto &input = input_tensors[0]; + const auto &index = input_tensors[1]; + + auto device = input->GetDevice().type(); + auto kernel = Dispatcher::Instance().GetKernel({device, "GatherForward"}); + return {kernel.Call>(input, index, dim_)}; +} + +void Gather::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + const auto &input = input_tensors[0]; + const auto &index = input_tensors[1]; + input_dims_ = input->Dims(); + saved_tensors_ = {index}; +} + +std::vector> Gather::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(grad_outputs.size(), 1); + const auto &grad_output = grad_outputs[0]; + const auto &index = saved_tensors_[0]; + + auto device = grad_outputs[0]->GetDevice(); + auto kernel = Dispatcher::Instance().GetKernel({device.type(), "GatherBackward"}); + return {kernel.Call>(grad_output, index, dim_, input_dims_), nullptr}; +} + +} // namespace infini_train::autograd diff --git a/infini_train/src/autograd/misc.cc b/infini_train/src/autograd/misc.cc deleted file mode 100644 index 601258eb..00000000 --- a/infini_train/src/autograd/misc.cc +++ /dev/null @@ -1,147 +0,0 @@ -#include "infini_train/include/autograd/misc.h" - -#include "glog/logging.h" - -#include "infini_train/include/dispatcher.h" -#include "infini_train/include/tensor.h" - -namespace infini_train::autograd { -std::vector> Split::Forward(const std::vector> &input_tensors) { - CHECK_EQ(input_tensors.size(), 1); - const auto &input = input_tensors[0]; - - auto device = input->GetDevice().type(); - return {Dispatcher::Instance().Call>>({device, "SplitForward"}, input, - split_size_, dim_)}; -} - -void Split::SetupContext(const std::vector> &input_tensors, - const std::vector> &) { - const auto &input = input_tensors[0]; - input_dims_ = input->Dims(); -} - -std::vector> Split::Backward(const std::vector> &grad_outputs) { - auto device = grad_outputs[0]->GetDevice(); - return {Dispatcher::Instance().Call>({device.type(), "SplitBackward"}, input_dims_, - split_size_, dim_, grad_outputs)}; -} - -std::vector> IndexGather::Forward(const std::vector> &input_tensors) { - CHECK_EQ(input_tensors.size(), 2); - const auto &input = input_tensors[0]; - const auto &index = input_tensors[1]; - - auto device = input->GetDevice().type(); - auto kernel = Dispatcher::Instance().GetKernel({device, "IndexGatherForward"}); - return {kernel.Call>(input, index, dim_)}; -} - -void IndexGather::SetupContext(const std::vector> &input_tensors, - const std::vector> &) { - const auto &input = input_tensors[0]; - const auto &index = input_tensors[1]; - input_dims_ = input->Dims(); - saved_tensors_ = {index}; -} - -std::vector> IndexGather::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(grad_outputs.size(), 1); - const auto &grad_output = grad_outputs[0]; - const auto &index = saved_tensors_[0]; - - auto device = grad_outputs[0]->GetDevice(); - auto kernel = Dispatcher::Instance().GetKernel({device.type(), "IndexGatherBackward"}); - return {kernel.Call>(grad_output, index, dim_, input_dims_)}; -} - -std::vector> NoOp::Forward(const std::vector> &input_tensors) { - CHECK_EQ(input_tensors.size(), 1); - const auto &input = input_tensors[0]; - - auto device = input->GetDevice().type(); - return {Dispatcher::Instance().Call>({device, "NoOpForward"}, input, output_dims_)}; -} - -void NoOp::SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) { - const auto &input = input_tensors[0]; - input_dims_ = input->Dims(); -} - -std::vector> NoOp::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(grad_outputs.size(), 1); - const auto &grad_output = grad_outputs[0]; - - auto device = grad_output->GetDevice().type(); - return {Dispatcher::Instance().Call>({device, "NoOpBackward"}, input_dims_, grad_output)}; -} - -std::vector> Slice::Forward(const std::vector> &input_tensors) { - CHECK_EQ(input_tensors.size(), 1); - const auto &input = input_tensors[0]; - - auto device = input->GetDevice().type(); - return { - Dispatcher::Instance().Call>({device, "SliceForward"}, input, starts_, ends_, steps_)}; -} - -void Slice::SetupContext(const std::vector> &input_tensors, - const std::vector> &) { - // FIXME(dcj): only input's dim need to be saved - const auto &input = input_tensors[0]; - saved_tensors_ = {input}; -} - -std::vector> Slice::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &input = saved_tensors_[0]; - const auto &grad_output = grad_outputs[0]; - - auto device = input->GetDevice().type(); - return {Dispatcher::Instance().Call>({device, "SliceBackward"}, grad_output, input, starts_, - ends_, steps_)}; -} - -std::vector> Stack::Forward(const std::vector> &input_tensors) { - CHECK_GE(input_tensors.size(), 2); - const auto device = input_tensors[0]->GetDevice().type(); - - return {Dispatcher::Instance().Call>({device, "StackForward"}, input_tensors, dim_)}; -} - -void Stack::SetupContext(const std::vector> &input_tensors, - const std::vector> &) { - const auto &input = input_tensors[0]; - input_dims_ = input->Dims(); -} - -std::vector> Stack::Backward(const std::vector> &grad_outputs) { - const auto &grad_output = grad_outputs[0]; - - auto device = grad_output->GetDevice().type(); - return {Dispatcher::Instance().Call>>({device, "StackBackward"}, input_dims_, - dim_, grad_output)}; -} - -std::vector> Concat::Forward(const std::vector> &input_tensors) { - CHECK_GE(input_tensors.size(), 2); - const auto device = input_tensors[0]->GetDevice().type(); - - auto kernel = Dispatcher::Instance().GetKernel({device, "ConcatForward"}); - return {kernel.Call>(input_tensors, dim_)}; -} - -void Concat::SetupContext(const std::vector> &input_tensors, - const std::vector> &) { - for (auto input : input_tensors) { input_dims_list_.push_back(input->Dims()); } -} - -std::vector> Concat::Backward(const std::vector> &grad_outputs) { - const auto &grad_output = grad_outputs[0]; - - auto device = grad_output->GetDevice().type(); - auto kernel = Dispatcher::Instance().GetKernel({device, "ConcatBackward"}); - return kernel.Call>>(grad_output, input_dims_list_, dim_); -} -} // namespace infini_train::autograd diff --git a/infini_train/src/autograd/no_op.cc b/infini_train/src/autograd/no_op.cc new file mode 100644 index 00000000..b4247dec --- /dev/null +++ b/infini_train/src/autograd/no_op.cc @@ -0,0 +1,31 @@ +#include "infini_train/include/autograd/no_op.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { +std::vector> NoOp::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + const auto &input = input_tensors[0]; + + auto device = input->GetDevice().type(); + return {Dispatcher::Instance().Call>({device, "NoOpForward"}, input, output_dims_)}; +} + +void NoOp::SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) { + const auto &input = input_tensors[0]; + input_dims_ = input->Dims(); +} + +std::vector> NoOp::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(grad_outputs.size(), 1); + const auto &grad_output = grad_outputs[0]; + + auto device = grad_output->GetDevice().type(); + return {Dispatcher::Instance().Call>({device, "NoOpBackward"}, input_dims_, grad_output)}; +} + +} // namespace infini_train::autograd diff --git a/infini_train/src/autograd/scatter.cc b/infini_train/src/autograd/scatter.cc new file mode 100644 index 00000000..472fd543 --- /dev/null +++ b/infini_train/src/autograd/scatter.cc @@ -0,0 +1,34 @@ +#include "infini_train/include/autograd/scatter.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { + +std::vector> Scatter::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 2); + const auto &values = input_tensors[0]; + const auto &indices = input_tensors[1]; + auto device = values->GetDevice().type(); + return {Dispatcher::Instance().Call>({device, "ScatterForward"}, values, indices, + output_dims_)}; +} + +void Scatter::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + saved_tensors_ = {input_tensors[1]}; +} + +std::vector> Scatter::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(grad_outputs.size(), 1); + const auto &grad_output = grad_outputs[0]; + const auto &indices = saved_tensors_[0]; + auto device = grad_output->GetDevice().type(); + auto grad_values + = Dispatcher::Instance().Call>({device, "ScatterBackward"}, grad_output, indices); + return {grad_values, nullptr}; +} + +} // namespace infini_train::autograd diff --git a/infini_train/src/autograd/topk.cc b/infini_train/src/autograd/topk.cc new file mode 100644 index 00000000..4e0420b8 --- /dev/null +++ b/infini_train/src/autograd/topk.cc @@ -0,0 +1,39 @@ +#include "infini_train/include/autograd/topk.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { + +std::vector> TopK::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + CHECK_GT(topk_, 0); + const auto &input = input_tensors[0]; + auto device = input->GetDevice().type(); + auto topk_outputs = Dispatcher::Instance().Call>>( + {device, "TopKForward"}, input, topk_, dim_, largest_, sorted_); + CHECK_EQ(topk_outputs.size(), 2); + top_indices_ = topk_outputs[1]; + return {topk_outputs[0]}; +} + +void TopK::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + input_dims_ = input_tensors[0]->Dims(); + saved_tensors_ = {top_indices_}; +} + +std::vector> TopK::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(grad_outputs.size(), 1); + const auto &top_grad = grad_outputs[0]; + const auto &top_indices = saved_tensors_[0]; + auto device = top_grad->GetDevice().type(); + return {Dispatcher::Instance().Call>({device, "TopKBackward"}, top_grad, top_indices, + input_dims_, dim_)}; +} + +std::shared_ptr TopK::TopIndices() const { return top_indices_; } + +} // namespace infini_train::autograd diff --git a/infini_train/src/autograd/transform.cc b/infini_train/src/autograd/transform.cc index 4fae05bb..e38d5616 100644 --- a/infini_train/src/autograd/transform.cc +++ b/infini_train/src/autograd/transform.cc @@ -89,4 +89,94 @@ RepeatInterleave::Backward(const std::vector> &grad_outp return {Dispatcher::Instance().Call>({device, "RepeatInterleaveBackward"}, grad_output, input_dims_, dim_)}; } + +std::vector> Split::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + const auto &input = input_tensors[0]; + + auto device = input->GetDevice().type(); + return {Dispatcher::Instance().Call>>({device, "SplitForward"}, input, + split_size_, dim_)}; +} + +void Split::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + const auto &input = input_tensors[0]; + input_dims_ = input->Dims(); +} + +std::vector> Split::Backward(const std::vector> &grad_outputs) { + auto device = grad_outputs[0]->GetDevice(); + return {Dispatcher::Instance().Call>({device.type(), "SplitBackward"}, input_dims_, + split_size_, dim_, grad_outputs)}; +} + +std::vector> Stack::Forward(const std::vector> &input_tensors) { + CHECK_GE(input_tensors.size(), 2); + const auto device = input_tensors[0]->GetDevice().type(); + + return {Dispatcher::Instance().Call>({device, "StackForward"}, input_tensors, dim_)}; +} + +void Stack::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + const auto &input = input_tensors[0]; + input_dims_ = input->Dims(); +} + +std::vector> Stack::Backward(const std::vector> &grad_outputs) { + const auto &grad_output = grad_outputs[0]; + + auto device = grad_output->GetDevice().type(); + return {Dispatcher::Instance().Call>>({device, "StackBackward"}, input_dims_, + dim_, grad_output)}; +} + +std::vector> Concat::Forward(const std::vector> &input_tensors) { + CHECK_GE(input_tensors.size(), 2); + const auto device = input_tensors[0]->GetDevice().type(); + + auto kernel = Dispatcher::Instance().GetKernel({device, "ConcatForward"}); + return {kernel.Call>(input_tensors, dim_)}; +} + +void Concat::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + for (auto input : input_tensors) { input_dims_list_.push_back(input->Dims()); } +} + +std::vector> Concat::Backward(const std::vector> &grad_outputs) { + const auto &grad_output = grad_outputs[0]; + + auto device = grad_output->GetDevice().type(); + auto kernel = Dispatcher::Instance().GetKernel({device, "ConcatBackward"}); + return kernel.Call>>(grad_output, input_dims_list_, dim_); +} + +std::vector> Slice::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + const auto &input = input_tensors[0]; + + auto device = input->GetDevice().type(); + return { + Dispatcher::Instance().Call>({device, "SliceForward"}, input, starts_, ends_, steps_)}; +} + +void Slice::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + // FIXME(dcj): only input's dim need to be saved + const auto &input = input_tensors[0]; + saved_tensors_ = {input}; +} + +std::vector> Slice::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(saved_tensors_.size(), 1); + const auto &input = saved_tensors_[0]; + const auto &grad_output = grad_outputs[0]; + + auto device = input->GetDevice().type(); + return {Dispatcher::Instance().Call>({device, "SliceBackward"}, grad_output, input, starts_, + ends_, steps_)}; +} + } // namespace infini_train::autograd diff --git a/infini_train/src/kernels/cpu/cast.cc b/infini_train/src/kernels/cpu/cast.cc index 114a5597..c3a2e595 100644 --- a/infini_train/src/kernels/cpu/cast.cc +++ b/infini_train/src/kernels/cpu/cast.cc @@ -13,7 +13,8 @@ std::shared_ptr Cast(std::shared_ptr input, DataType dtype) { auto device = input->GetDevice(); auto dst_tensor = std::make_shared(input->Dims(), dtype, device); - core::cpu::DispatchCpuFunc, DataTypeList>( + core::cpu::DispatchCpuFunc, + DataTypeList>( {dtype, input->Dtype()}, [=]() { auto dst = static_cast(dst_tensor->DataPtr()); diff --git a/infini_train/src/kernels/cpu/fill.cc b/infini_train/src/kernels/cpu/fill.cc index 5f8b7cd3..7bcda6bd 100644 --- a/infini_train/src/kernels/cpu/fill.cc +++ b/infini_train/src/kernels/cpu/fill.cc @@ -8,7 +8,7 @@ namespace infini_train::kernels::cpu { void Fill(std::shared_ptr tensor, Scalar scalar) { - core::cpu::DispatchCpuFunc( + core::cpu::DispatchCpuFunc( tensor->Dtype(), [=]() { auto data = reinterpret_cast(tensor->DataPtr()); diff --git a/infini_train/src/kernels/cpu/gather.cc b/infini_train/src/kernels/cpu/gather.cc index b59fd45f..af39fc0f 100644 --- a/infini_train/src/kernels/cpu/gather.cc +++ b/infini_train/src/kernels/cpu/gather.cc @@ -8,11 +8,8 @@ #include "infini_train/include/tensor.h" namespace infini_train::kernels::cpu { -// FIXME(zbl): This kernel aligns with torch.gather -// Currently named IndexGather to avoid conflict with communication operators -// Should be renamed to Gather later for interface consistency -std::shared_ptr IndexGatherForward(const std::shared_ptr &input, const std::shared_ptr &index, - int64_t dim) { +std::shared_ptr GatherForward(const std::shared_ptr &input, const std::shared_ptr &index, + int64_t dim) { const auto &in_dims = input->Dims(); const auto &idx_dims = index->Dims(); CHECK_EQ(in_dims.size(), idx_dims.size()); @@ -103,9 +100,8 @@ std::shared_ptr IndexGatherForward(const std::shared_ptr &input, return out; } -std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_output, - const std::shared_ptr &index, int64_t dim, - const std::vector &input_dims) { +std::shared_ptr GatherBackward(const std::shared_ptr &grad_output, const std::shared_ptr &index, + int64_t dim, const std::vector &input_dims) { const auto &in_dims = input_dims; const auto &idx_dims = index->Dims(); CHECK_EQ(in_dims.size(), idx_dims.size()); @@ -199,7 +195,7 @@ std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_ #define REGISTER_CPU_GATHER_KERNEL(kernel_name) \ REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) -REGISTER_CPU_GATHER_KERNEL(IndexGatherForward) -REGISTER_CPU_GATHER_KERNEL(IndexGatherBackward) +REGISTER_CPU_GATHER_KERNEL(GatherForward) +REGISTER_CPU_GATHER_KERNEL(GatherBackward) #undef REGISTER_CPU_GATHER_KERNEL diff --git a/infini_train/src/kernels/cpu/scatter.cc b/infini_train/src/kernels/cpu/scatter.cc new file mode 100644 index 00000000..1a9cf62e --- /dev/null +++ b/infini_train/src/kernels/cpu/scatter.cc @@ -0,0 +1,86 @@ +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::kernels::cpu { + +std::shared_ptr ScatterForward(const std::shared_ptr &values, const std::shared_ptr &indices, + const std::vector &output_dims) { + CHECK(indices->Dtype() == DataType::kINT64) << "CPU ScatterForward expects int64 indices"; + CHECK(values->Dims() == indices->Dims()); + CHECK(!output_dims.empty()); + CHECK_EQ(values->Dims().size(), output_dims.size()); + CHECK_GT(values->Dims().back(), 0); + CHECK_GT(output_dims.back(), 0); + + const int64_t topk = values->Dims().back(); + const int64_t num_experts = output_dims.back(); + const int64_t rows = values->NumElements() / topk; + size_t output_numel = 1; + for (const auto dim : output_dims) { output_numel *= static_cast(dim); } + CHECK_EQ(output_numel, static_cast(rows * num_experts)); + + auto output = std::make_shared(output_dims, values->Dtype(), values->GetDevice()); + std::memset(output->DataPtr(), 0, output->SizeInBytes()); + + const size_t elem_size = kDataTypeToSize.at(values->Dtype()); + const auto *src = static_cast(values->DataPtr()); + auto *dst = static_cast(output->DataPtr()); + const auto *idx = static_cast(indices->DataPtr()); + for (int64_t row = 0; row < rows; ++row) { + for (int64_t selected = 0; selected < topk; ++selected) { + const int64_t expert_idx = idx[row * topk + selected]; + CHECK_GE(expert_idx, 0); + CHECK_LT(expert_idx, num_experts); + std::memcpy(dst + (row * num_experts + expert_idx) * elem_size, src + (row * topk + selected) * elem_size, + elem_size); + } + } + + return output; +} + +std::shared_ptr ScatterBackward(const std::shared_ptr &grad_output, + const std::shared_ptr &indices) { + CHECK(indices->Dtype() == DataType::kINT64) << "CPU ScatterBackward expects int64 indices"; + CHECK_GE(grad_output->Dims().size(), 1); + CHECK_GE(indices->Dims().size(), 1); + + const int64_t num_experts = grad_output->Dims().back(); + const int64_t topk = indices->Dims().back(); + const int64_t rows = indices->NumElements() / topk; + CHECK_EQ(grad_output->NumElements(), static_cast(rows * num_experts)); + + auto grad_values = std::make_shared(indices->Dims(), grad_output->Dtype(), grad_output->GetDevice()); + const size_t elem_size = kDataTypeToSize.at(grad_output->Dtype()); + const auto *src = static_cast(grad_output->DataPtr()); + auto *dst = static_cast(grad_values->DataPtr()); + const auto *idx = static_cast(indices->DataPtr()); + for (int64_t row = 0; row < rows; ++row) { + for (int64_t selected = 0; selected < topk; ++selected) { + const int64_t expert_idx = idx[row * topk + selected]; + CHECK_GE(expert_idx, 0); + CHECK_LT(expert_idx, num_experts); + std::memcpy(dst + (row * topk + selected) * elem_size, src + (row * num_experts + expert_idx) * elem_size, + elem_size); + } + } + + return grad_values; +} + +} // namespace infini_train::kernels::cpu + +#define REGISTER_CPU_SCATTER_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + +REGISTER_CPU_SCATTER_KERNEL(ScatterForward) +REGISTER_CPU_SCATTER_KERNEL(ScatterBackward) + +#undef REGISTER_CPU_SCATTER_KERNEL diff --git a/infini_train/src/kernels/cpu/topk.cc b/infini_train/src/kernels/cpu/topk.cc new file mode 100644 index 00000000..9e191143 --- /dev/null +++ b/infini_train/src/kernels/cpu/topk.cc @@ -0,0 +1,124 @@ +#include +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::kernels::cpu { + +std::vector> TopKForward(const std::shared_ptr &input, int64_t topk, int64_t dim, + bool largest, bool sorted) { + CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU TopKForward currently supports float32 only"; + CHECK_GE(input->Dims().size(), 1); + (void)sorted; + + const auto &dims = input->Dims(); + if (dim < 0) { + dim += static_cast(dims.size()); + } + CHECK_GE(dim, 0); + CHECK_LT(dim, static_cast(dims.size())); + + const int64_t dim_size = dims[dim]; + CHECK_GT(dim_size, 0); + CHECK_GT(topk, 0); + CHECK_LE(topk, dim_size); + + int64_t outer_size = 1; + for (int64_t idx = 0; idx < dim; ++idx) { outer_size *= dims[idx]; } + int64_t inner_size = 1; + for (size_t idx = static_cast(dim) + 1; idx < dims.size(); ++idx) { inner_size *= dims[idx]; } + + auto topk_dims = dims; + topk_dims[dim] = topk; + auto top_values = std::make_shared(topk_dims, input->Dtype(), input->GetDevice()); + auto top_indices = std::make_shared(topk_dims, DataType::kINT64, input->GetDevice()); + + const float *in = static_cast(input->DataPtr()); + float *values = static_cast(top_values->DataPtr()); + int64_t *indices = static_cast(top_indices->DataPtr()); + for (int64_t outer = 0; outer < outer_size; ++outer) { + for (int64_t inner = 0; inner < inner_size; ++inner) { + std::vector selected_indices(dim_size, false); + for (int64_t selected = 0; selected < topk; ++selected) { + int64_t best_idx = -1; + float best_value + = largest ? -std::numeric_limits::infinity() : std::numeric_limits::infinity(); + for (int64_t idx = 0; idx < dim_size; ++idx) { + if (selected_indices[idx]) { + continue; + } + const float value = in[outer * dim_size * inner_size + idx * inner_size + inner]; + const bool better = largest ? value > best_value : value < best_value; + if (better) { + best_value = value; + best_idx = idx; + } + } + CHECK_GE(best_idx, 0); + selected_indices[best_idx] = true; + const int64_t out_offset = outer * topk * inner_size + selected * inner_size + inner; + values[out_offset] = best_value; + indices[out_offset] = best_idx; + } + } + } + + return {top_values, top_indices}; +} + +std::shared_ptr TopKBackward(const std::shared_ptr &grad_values, const std::shared_ptr &indices, + const std::vector &input_dims, int64_t dim) { + CHECK(indices->Dtype() == DataType::kINT64) << "CPU TopKBackward expects int64 indices"; + CHECK(grad_values->Dims() == indices->Dims()); + CHECK(!input_dims.empty()); + if (dim < 0) { + dim += static_cast(input_dims.size()); + } + CHECK_GE(dim, 0); + CHECK_LT(dim, static_cast(input_dims.size())); + + const int64_t dim_size = input_dims[dim]; + const int64_t topk = indices->Dims()[dim]; + int64_t outer_size = 1; + for (int64_t idx = 0; idx < dim; ++idx) { outer_size *= input_dims[idx]; } + int64_t inner_size = 1; + for (size_t idx = static_cast(dim) + 1; idx < input_dims.size(); ++idx) { inner_size *= input_dims[idx]; } + + auto grad_input = std::make_shared(input_dims, grad_values->Dtype(), grad_values->GetDevice()); + std::memset(grad_input->DataPtr(), 0, grad_input->SizeInBytes()); + + const size_t elem_size = kDataTypeToSize.at(grad_values->Dtype()); + const auto *src = static_cast(grad_values->DataPtr()); + auto *dst = static_cast(grad_input->DataPtr()); + const auto *idx_ptr = static_cast(indices->DataPtr()); + for (int64_t outer = 0; outer < outer_size; ++outer) { + for (int64_t inner = 0; inner < inner_size; ++inner) { + for (int64_t selected = 0; selected < topk; ++selected) { + const int64_t out_offset = outer * topk * inner_size + selected * inner_size + inner; + const int64_t selected_idx = idx_ptr[out_offset]; + CHECK_GE(selected_idx, 0); + CHECK_LT(selected_idx, dim_size); + std::memcpy(dst + (outer * dim_size * inner_size + selected_idx * inner_size + inner) * elem_size, + src + out_offset * elem_size, elem_size); + } + } + } + + return grad_input; +} + +} // namespace infini_train::kernels::cpu + +#define REGISTER_CPU_TOPK_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + +REGISTER_CPU_TOPK_KERNEL(TopKForward) +REGISTER_CPU_TOPK_KERNEL(TopKBackward) + +#undef REGISTER_CPU_TOPK_KERNEL diff --git a/infini_train/src/kernels/cuda/cast.cu b/infini_train/src/kernels/cuda/cast.cu index 16190912..96a70ae2 100644 --- a/infini_train/src/kernels/cuda/cast.cu +++ b/infini_train/src/kernels/cuda/cast.cu @@ -34,7 +34,8 @@ std::shared_ptr Cast(std::shared_ptr input, DataType dtype) { dim3 grid_dims(CEIL_DIV(num_elements, block_dims.x)); const size_t step = grid_dims.x * block_dims.x; - core::cuda::DispatchCudaFunc, DataTypeList>( + core::cuda::DispatchCudaFunc, + DataTypeList>( {dtype, input->Dtype()}, [=]() { auto dst = static_cast(dst_tensor->DataPtr()); diff --git a/infini_train/src/kernels/cuda/comm.cu b/infini_train/src/kernels/cuda/comm.cu index b4bdafd8..6300ffdb 100644 --- a/infini_train/src/kernels/cuda/comm.cu +++ b/infini_train/src/kernels/cuda/comm.cu @@ -9,7 +9,7 @@ #include "infini_train/include/nn/functional.h" #include "infini_train/include/tensor.h" -namespace infini_train::kernels::cuda { +namespace infini_train::kernels::cuda::comm { std::vector> Broadcast(const std::vector> &input_tensors, const std::vector &devices) { @@ -69,11 +69,11 @@ std::shared_ptr Gather(const std::vector> &tenso auto view_kernel = Dispatcher::Instance().GetKernel({destination.type(), "NoOpForward"}); return view_kernel.Call>(gathered_tensor, new_dims); } -} // namespace infini_train::kernels::cuda +} // namespace infini_train::kernels::cuda::comm #define REGISTER_CUDA_COMM_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, Comm##kernel_name, \ - infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, \ + infini_train::kernels::cuda::comm::kernel_name) REGISTER_CUDA_COMM_KERNEL(Broadcast) REGISTER_CUDA_COMM_KERNEL(Scatter) diff --git a/infini_train/src/kernels/cuda/concat.cu b/infini_train/src/kernels/cuda/concat.cu index c158a5c3..a7fa7490 100644 --- a/infini_train/src/kernels/cuda/concat.cu +++ b/infini_train/src/kernels/cuda/concat.cu @@ -103,7 +103,7 @@ std::shared_ptr ConcatForward(const std::vector> int threads_per_block = 256; int num_blocks = static_cast((total + threads_per_block - 1) / threads_per_block); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=, &inputs, &host_offsets]() { std::vector host_input_ptrs; @@ -208,7 +208,7 @@ std::vector> ConcatBackward(const std::shared_ptr((total + threads_per_block - 1) / threads_per_block); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=, &grads, &host_offsets]() { std::vector host_ptrs; diff --git a/infini_train/src/kernels/cuda/elementwise.cu b/infini_train/src/kernels/cuda/elementwise.cu index fe63e0b2..92ce9915 100644 --- a/infini_train/src/kernels/cuda/elementwise.cu +++ b/infini_train/src/kernels/cuda/elementwise.cu @@ -1018,7 +1018,7 @@ std::shared_ptr EqualsForward(const std::shared_ptr &a, const st DISPATCH(a->Dtype(), return BinaryForward(a, b, [] __device__(auto x, auto y) { return (x == y) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr EqualsScalarForward(const std::shared_ptr &a, float scalar) { @@ -1033,7 +1033,7 @@ std::shared_ptr EqualsScalarForward(const std::shared_ptr &a, fl std::shared_ptr LtForward(const std::shared_ptr &a, const std::shared_ptr &b) { DISPATCH(a->Dtype(), return BinaryForward( a, b, [] __device__(auto x, auto y) { return x < y ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr LtScalarForward(const std::shared_ptr &a, float scalar) { @@ -1042,14 +1042,14 @@ std::shared_ptr LtScalarForward(const std::shared_ptr &a, float return (x < static_cast(scalar)) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr LeForward(const std::shared_ptr &a, const std::shared_ptr &b) { DISPATCH(a->Dtype(), return BinaryForward(a, b, [] __device__(auto x, auto y) { return (x <= y) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr LeScalarForward(const std::shared_ptr &a, float scalar) { @@ -1058,13 +1058,13 @@ std::shared_ptr LeScalarForward(const std::shared_ptr &a, float return (x <= static_cast(scalar)) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr GtForward(const std::shared_ptr &a, const std::shared_ptr &b) { DISPATCH(a->Dtype(), return BinaryForward( a, b, [] __device__(auto x, auto y) { return x > y ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr GtScalarForward(const std::shared_ptr &a, float scalar) { @@ -1073,14 +1073,14 @@ std::shared_ptr GtScalarForward(const std::shared_ptr &a, float return (x > static_cast(scalar)) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr GeForward(const std::shared_ptr &a, const std::shared_ptr &b) { DISPATCH(a->Dtype(), return BinaryForward(a, b, [] __device__(auto x, auto y) { return (x >= y) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr GeScalarForward(const std::shared_ptr &a, float scalar) { @@ -1089,7 +1089,7 @@ std::shared_ptr GeScalarForward(const std::shared_ptr &a, float return (x >= static_cast(scalar)) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr OrForward(const std::shared_ptr &a, const std::shared_ptr &b) { @@ -1098,7 +1098,7 @@ std::shared_ptr OrForward(const std::shared_ptr &a, const std::s return (x != decltype(x){0} || y != decltype(y){0}) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr AndForward(const std::shared_ptr &a, const std::shared_ptr &b) { @@ -1107,7 +1107,7 @@ std::shared_ptr AndForward(const std::shared_ptr &a, const std:: return (x != decltype(x){0} && y != decltype(y){0}) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr AddForward(const std::shared_ptr &a, const std::shared_ptr &b) { @@ -1125,19 +1125,19 @@ std::pair, std::shared_ptr> AddBackward(const st std::shared_ptr AddScalarForward(const std::shared_ptr &a, float scalar) { DISPATCH(a->Dtype(), return UnaryForward(a, [scalar] __device__(auto x) { return Add(x, static_cast(scalar)); }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr AddScalarBackward(const std::shared_ptr &grad_output) { DISPATCH(grad_output->Dtype(), return UnaryBackward(grad_output, nullptr, [] __device__(auto x) { return common::cuda::Cast(1); }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr SubForward(const std::shared_ptr &a, const std::shared_ptr &b) { DISPATCH(a->Dtype(), return BinaryForward(a, b, [] __device__(auto x, auto y) { return Sub(x, y); }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::pair, std::shared_ptr> SubBackward(const std::shared_ptr &grad_output, diff --git a/infini_train/src/kernels/cuda/fill.cu b/infini_train/src/kernels/cuda/fill.cu index f5532779..3ddead5c 100644 --- a/infini_train/src/kernels/cuda/fill.cu +++ b/infini_train/src/kernels/cuda/fill.cu @@ -28,7 +28,7 @@ void Fill(std::shared_ptr tensor, Scalar scalar) { infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( tensor->Dtype(), [=]() { const T casted_value = scalar.to(); diff --git a/infini_train/src/kernels/cuda/gather.cu b/infini_train/src/kernels/cuda/gather.cu index 12d0567d..5216f28e 100644 --- a/infini_train/src/kernels/cuda/gather.cu +++ b/infini_train/src/kernels/cuda/gather.cu @@ -9,15 +9,11 @@ #include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" namespace infini_train::kernels::cuda { -// FIXME(zbl): This kernel aligns with torch.gather -// Currently named IndexGather to avoid conflict with communication operators -// Should be renamed to Gather later for interface consistency template -__global__ void IndexGatherForwardKernel(const T *__restrict__ input, const int64_t *__restrict__ norm_index, - T *__restrict__ output, const int64_t *__restrict__ out_dims, - const int64_t *__restrict__ in_strides, - const int64_t *__restrict__ out_strides, int num_dims, int gather_dim, - int64_t dim_size_gather, int64_t total_elements) { +__global__ void GatherForwardKernel(const T *__restrict__ input, const int64_t *__restrict__ norm_index, + T *__restrict__ output, const int64_t *__restrict__ out_dims, + const int64_t *__restrict__ in_strides, const int64_t *__restrict__ out_strides, + int num_dims, int gather_dim, int64_t dim_size_gather, int64_t total_elements) { int64_t out_idx = blockIdx.x * blockDim.x + threadIdx.x; if (out_idx >= total_elements) { return; @@ -43,8 +39,8 @@ __global__ void IndexGatherForwardKernel(const T *__restrict__ input, const int6 output[out_idx] = input[in_linear]; } -std::shared_ptr IndexGatherForward(const std::shared_ptr &input, const std::shared_ptr &index, - int64_t dim) { +std::shared_ptr GatherForward(const std::shared_ptr &input, const std::shared_ptr &index, + int64_t dim) { const auto &in_dims = input->Dims(); const auto &idx_dims = index->Dims(); CHECK_EQ(in_dims.size(), idx_dims.size()); @@ -103,23 +99,22 @@ std::shared_ptr IndexGatherForward(const std::shared_ptr &input, core::cuda::DispatchCudaFunc( dtype, [=]() { - IndexGatherForwardKernel<<>>( + GatherForwardKernel<<>>( static_cast(input->DataPtr()), static_cast(index->DataPtr()), static_cast(out->DataPtr()), out_dims_dev, in_strides_dev, out_strides_dev, (int)num_dims, (int)dim, gather_dim_size, total_elements); }, - "CUDA IndexGatherForward"); + "CUDA GatherForward"); CUDA_CHECK(cudaFreeAsync(dev_buf, stream)); return out; } template -__global__ void IndexGatherBackwardKernel(const T *__restrict__ grad_output, const int64_t *__restrict__ index, - T *__restrict__ grad_input, const int64_t *__restrict__ out_dims, - const int64_t *__restrict__ in_strides, - const int64_t *__restrict__ out_strides, int num_dims, int gather_dim, - int64_t dim_size_gather, int64_t total_elements) { +__global__ void GatherBackwardKernel(const T *__restrict__ grad_output, const int64_t *__restrict__ index, + T *__restrict__ grad_input, const int64_t *__restrict__ out_dims, + const int64_t *__restrict__ in_strides, const int64_t *__restrict__ out_strides, + int num_dims, int gather_dim, int64_t dim_size_gather, int64_t total_elements) { int64_t out_idx = blockIdx.x * blockDim.x + threadIdx.x; if (out_idx >= total_elements) { return; @@ -149,9 +144,8 @@ __global__ void IndexGatherBackwardKernel(const T *__restrict__ grad_output, con atomicAdd(&grad_input[in_linear], grad_output[out_idx]); } -std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_output, - const std::shared_ptr &index, int64_t dim, - const std::vector &input_dims) { +std::shared_ptr GatherBackward(const std::shared_ptr &grad_output, const std::shared_ptr &index, + int64_t dim, const std::vector &input_dims) { const auto &in_dims = input_dims; const auto &idx_dims = index->Dims(); CHECK_EQ(in_dims.size(), idx_dims.size()); @@ -210,12 +204,12 @@ std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_ core::cuda::DispatchCudaFunc( dtype, [=]() { - IndexGatherBackwardKernel<<>>( + GatherBackwardKernel<<>>( static_cast(grad_output->DataPtr()), static_cast(index->DataPtr()), static_cast(grad_input->DataPtr()), out_dims_dev, in_strides_dev, out_strides_dev, (int)num_dims, (int)dim, gather_dim_size, total_elements); }, - "CUDA IndexGatherBackward"); + "CUDA GatherBackward"); CUDA_CHECK(cudaFreeAsync(dev_buf, stream)); return grad_input; @@ -226,7 +220,7 @@ std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_ #define REGISTER_CUDA_GATHER_KERNEL(kernel_name) \ REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) -REGISTER_CUDA_GATHER_KERNEL(IndexGatherForward) -REGISTER_CUDA_GATHER_KERNEL(IndexGatherBackward) +REGISTER_CUDA_GATHER_KERNEL(GatherForward) +REGISTER_CUDA_GATHER_KERNEL(GatherBackward) #undef REGISTER_CUDA_GATHER_KERNEL diff --git a/infini_train/src/kernels/cuda/scatter.cu b/infini_train/src/kernels/cuda/scatter.cu new file mode 100644 index 00000000..9ebb173a --- /dev/null +++ b/infini_train/src/kernels/cuda/scatter.cu @@ -0,0 +1,120 @@ +#include "glog/logging.h" + +#include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +#include "infini_train/src/core/runtime/cuda/cuda_dispatch.h" +#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" + +namespace infini_train::kernels::cuda { + +template +__global__ void ScatterForwardKernel(const T *__restrict__ values, const int64_t *__restrict__ indices, + T *__restrict__ output, int64_t rows, int64_t topk, int64_t num_experts) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t total = rows * topk; + if (idx >= total) { + return; + } + + const int64_t row = idx / topk; + const int64_t expert_idx = indices[idx]; + output[row * num_experts + expert_idx] = values[idx]; +} + +std::shared_ptr ScatterForward(const std::shared_ptr &values, const std::shared_ptr &indices, + const std::vector &output_dims) { + CHECK(indices->Dtype() == DataType::kINT64) << "CUDA ScatterForward expects int64 indices"; + CHECK(values->Dims() == indices->Dims()); + CHECK(!output_dims.empty()); + CHECK_EQ(values->Dims().size(), output_dims.size()); + CHECK_GT(values->Dims().back(), 0); + CHECK_GT(output_dims.back(), 0); + + const int64_t topk = values->Dims().back(); + const int64_t num_experts = output_dims.back(); + CHECK_GT(num_experts, 0); + const int64_t rows = values->NumElements() / topk; + + auto output = std::make_shared(output_dims, values->Dtype(), values->GetDevice()); + + auto device = values->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + CUDA_CHECK(cudaMemsetAsync(output->DataPtr(), 0, output->SizeInBytes(), stream)); + const int threads = 256; + const int blocks = static_cast(((rows * topk) + threads - 1) / threads); + if (values->Dtype() == DataType::kBOOL) { + ScatterForwardKernel<<>>( + static_cast(values->DataPtr()), static_cast(indices->DataPtr()), + static_cast(output->DataPtr()), rows, topk, num_experts); + CUDA_CHECK(cudaGetLastError()); + } else { + core::cuda::DispatchCudaFunc( + values->Dtype(), + [=]() { + ScatterForwardKernel<<>>( + static_cast(values->DataPtr()), static_cast(indices->DataPtr()), + static_cast(output->DataPtr()), rows, topk, num_experts); + }, + "CUDA ScatterForward"); + } + return output; +} + +template +__global__ void ScatterBackwardKernel(const T *__restrict__ grad_output, const int64_t *__restrict__ indices, + T *__restrict__ grad_values, int64_t rows, int64_t topk, int64_t num_experts) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t total = rows * topk; + if (idx >= total) { + return; + } + const int64_t row = idx / topk; + const int64_t expert_idx = indices[idx]; + grad_values[idx] = grad_output[row * num_experts + expert_idx]; +} + +std::shared_ptr ScatterBackward(const std::shared_ptr &grad_output, + const std::shared_ptr &indices) { + CHECK(indices->Dtype() == DataType::kINT64) << "CUDA ScatterBackward expects int64 indices"; + CHECK_GE(grad_output->Dims().size(), 1); + CHECK_GE(indices->Dims().size(), 1); + const int64_t num_experts = grad_output->Dims().back(); + const int64_t topk = indices->Dims().back(); + const int64_t rows = indices->NumElements() / topk; + CHECK_EQ(grad_output->NumElements(), static_cast(rows * num_experts)); + + auto grad_values = std::make_shared(indices->Dims(), grad_output->Dtype(), grad_output->GetDevice()); + + auto device = grad_output->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + const int threads = 256; + const int blocks = static_cast(((rows * topk) + threads - 1) / threads); + + core::cuda::DispatchCudaFunc( + grad_output->Dtype(), + [=]() { + ScatterBackwardKernel<<>>( + static_cast(grad_output->DataPtr()), static_cast(indices->DataPtr()), + static_cast(grad_values->DataPtr()), rows, topk, num_experts); + }, + "CUDA ScatterBackward"); + + return grad_values; +} + +} // namespace infini_train::kernels::cuda + +#define REGISTER_CUDA_SCATTER_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + +REGISTER_CUDA_SCATTER_KERNEL(ScatterForward) +REGISTER_CUDA_SCATTER_KERNEL(ScatterBackward) + +#undef REGISTER_CUDA_SCATTER_KERNEL diff --git a/infini_train/src/kernels/cuda/slice.cu b/infini_train/src/kernels/cuda/slice.cu index 35bd2ac5..d030d73a 100644 --- a/infini_train/src/kernels/cuda/slice.cu +++ b/infini_train/src/kernels/cuda/slice.cu @@ -92,7 +92,7 @@ std::shared_ptr SliceForward(const std::shared_ptr &input, const int threads_per_block = 256; int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block; - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { SliceForwardKernel<<>>( @@ -185,7 +185,7 @@ std::shared_ptr SliceBackward(const std::shared_ptr &grad_output int threads_per_block = 256; int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block; - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( grad_output_dtype, [=]() { SliceBackwardKernel<<>>( diff --git a/infini_train/src/kernels/cuda/split.cu b/infini_train/src/kernels/cuda/split.cu index f208695f..bda0dd70 100644 --- a/infini_train/src/kernels/cuda/split.cu +++ b/infini_train/src/kernels/cuda/split.cu @@ -59,7 +59,7 @@ std::vector> SplitForward(const std::shared_ptr infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { SplitForwardKernel<<>>( @@ -166,7 +166,7 @@ std::shared_ptr SplitBackward(const std::vector &input_dims, in CHECK_GE(dim, 0) << "Currently we do not support negative dimension"; CHECK_LT(dim, input_dims.size()); - return core::cuda::DispatchCudaFunc( + return core::cuda::DispatchCudaFunc( grad_outputs[0]->Dtype(), [=]() { return LaunchSplitBackward(input_dims, split_size, dim, grad_outputs); }, "CUDA SplitBackward"); diff --git a/infini_train/src/kernels/cuda/stack.cu b/infini_train/src/kernels/cuda/stack.cu index 562fa5ec..841940ea 100644 --- a/infini_train/src/kernels/cuda/stack.cu +++ b/infini_train/src/kernels/cuda/stack.cu @@ -61,7 +61,7 @@ std::shared_ptr StackForward(const std::vector> int threads_per_block = 256; int num_blocks = (total + threads_per_block - 1) / threads_per_block; - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { std::vector host_input_ptrs; @@ -129,7 +129,7 @@ std::vector> StackBackward(const std::vector &i int threads_per_block = 256; int num_blocks = (total + threads_per_block - 1) / threads_per_block; - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { std::vector host_ptrs; diff --git a/infini_train/src/kernels/cuda/topk.cu b/infini_train/src/kernels/cuda/topk.cu new file mode 100644 index 00000000..32044c3f --- /dev/null +++ b/infini_train/src/kernels/cuda/topk.cu @@ -0,0 +1,155 @@ +#include "glog/logging.h" + +#include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +#include "infini_train/src/core/runtime/cuda/cuda_dispatch.h" +#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" + +namespace infini_train::kernels::cuda { + +template +__global__ void TopKForwardKernel(const T *__restrict__ input, T *__restrict__ top_values, + int64_t *__restrict__ top_indices, int64_t rows, int64_t dim_size, int64_t inner_size, + int64_t topk, bool largest) { + int64_t row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= rows) { + return; + } + + const int64_t outer = row / inner_size; + const int64_t inner = row % inner_size; + for (int64_t idx = 0; idx < dim_size; ++idx) { + const float value = static_cast(input[outer * dim_size * inner_size + idx * inner_size + inner]); + int64_t rank = 0; + for (int64_t other_idx = 0; other_idx < dim_size; ++other_idx) { + const float other_value + = static_cast(input[outer * dim_size * inner_size + other_idx * inner_size + inner]); + const bool ranks_before = largest ? (other_value > value || (other_value == value && other_idx < idx)) + : (other_value < value || (other_value == value && other_idx < idx)); + if (ranks_before) { + ++rank; + } + } + if (rank < topk) { + const int64_t out_offset = outer * topk * inner_size + rank * inner_size + inner; + top_values[out_offset] = input[outer * dim_size * inner_size + idx * inner_size + inner]; + top_indices[out_offset] = idx; + } + } +} + +std::vector> TopKForward(const std::shared_ptr &input, int64_t topk, int64_t dim, + bool largest, bool sorted) { + CHECK_GE(input->Dims().size(), 1); + (void)sorted; + const auto &dims = input->Dims(); + if (dim < 0) { + dim += static_cast(dims.size()); + } + CHECK_GE(dim, 0); + CHECK_LT(dim, static_cast(dims.size())); + + const int64_t dim_size = dims[dim]; + CHECK_GT(dim_size, 0); + CHECK_GT(topk, 0); + CHECK_LE(topk, dim_size); + int64_t outer_size = 1; + for (int64_t idx = 0; idx < dim; ++idx) { outer_size *= dims[idx]; } + int64_t inner_size = 1; + for (size_t idx = static_cast(dim) + 1; idx < dims.size(); ++idx) { inner_size *= dims[idx]; } + const int64_t rows = outer_size * inner_size; + + auto topk_dims = dims; + topk_dims[dim] = topk; + auto top_values = std::make_shared(topk_dims, input->Dtype(), input->GetDevice()); + auto top_indices = std::make_shared(topk_dims, DataType::kINT64, input->GetDevice()); + + auto device = input->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + const int threads = 256; + const int blocks = static_cast((rows + threads - 1) / threads); + + core::cuda::DispatchCudaFunc( + input->Dtype(), + [=]() { + TopKForwardKernel<<>>( + static_cast(input->DataPtr()), static_cast(top_values->DataPtr()), + static_cast(top_indices->DataPtr()), rows, dim_size, inner_size, topk, largest); + }, + "CUDA TopKForward"); + + return {top_values, top_indices}; +} + +template +__global__ void TopKBackwardKernel(const T *__restrict__ grad_values, const int64_t *__restrict__ indices, + T *__restrict__ grad_input, int64_t rows, int64_t dim_size, int64_t inner_size, + int64_t topk) { + int64_t row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= rows) { + return; + } + + const int64_t outer = row / inner_size; + const int64_t inner = row % inner_size; + for (int64_t selected = 0; selected < topk; ++selected) { + const int64_t out_offset = outer * topk * inner_size + selected * inner_size + inner; + const int64_t selected_idx = indices[out_offset]; + grad_input[outer * dim_size * inner_size + selected_idx * inner_size + inner] = grad_values[out_offset]; + } +} + +std::shared_ptr TopKBackward(const std::shared_ptr &grad_values, const std::shared_ptr &indices, + const std::vector &input_dims, int64_t dim) { + CHECK(indices->Dtype() == DataType::kINT64) << "CUDA TopKBackward expects int64 indices"; + CHECK(grad_values->Dims() == indices->Dims()); + CHECK(!input_dims.empty()); + if (dim < 0) { + dim += static_cast(input_dims.size()); + } + CHECK_GE(dim, 0); + CHECK_LT(dim, static_cast(input_dims.size())); + + const int64_t dim_size = input_dims[dim]; + const int64_t topk = indices->Dims()[dim]; + int64_t outer_size = 1; + for (int64_t idx = 0; idx < dim; ++idx) { outer_size *= input_dims[idx]; } + int64_t inner_size = 1; + for (size_t idx = static_cast(dim) + 1; idx < input_dims.size(); ++idx) { inner_size *= input_dims[idx]; } + const int64_t rows = outer_size * inner_size; + + auto grad_input = std::make_shared(input_dims, grad_values->Dtype(), grad_values->GetDevice()); + auto device = grad_values->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + CUDA_CHECK(cudaMemsetAsync(grad_input->DataPtr(), 0, grad_input->SizeInBytes(), stream)); + + const int threads = 256; + const int blocks = static_cast((rows + threads - 1) / threads); + core::cuda::DispatchCudaFunc( + grad_values->Dtype(), + [=]() { + TopKBackwardKernel<<>>( + static_cast(grad_values->DataPtr()), static_cast(indices->DataPtr()), + static_cast(grad_input->DataPtr()), rows, dim_size, inner_size, topk); + }, + "CUDA TopKBackward"); + + return grad_input; +} + +} // namespace infini_train::kernels::cuda + +#define REGISTER_CUDA_TOPK_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + +REGISTER_CUDA_TOPK_KERNEL(TopKForward) +REGISTER_CUDA_TOPK_KERNEL(TopKBackward) + +#undef REGISTER_CUDA_TOPK_KERNEL diff --git a/infini_train/src/kernels/cuda/transform.cu b/infini_train/src/kernels/cuda/transform.cu index 2bb35598..88f0e10f 100644 --- a/infini_train/src/kernels/cuda/transform.cu +++ b/infini_train/src/kernels/cuda/transform.cu @@ -47,7 +47,7 @@ std::shared_ptr TrilForward(const std::shared_ptr &input, int64_ infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( input->Dtype(), [=]() { TrilForwardKernel<<>>( @@ -90,7 +90,7 @@ std::shared_ptr TrilBackward(const std::shared_ptr &grad_output, infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { grad_input->Fill(0.0); @@ -135,7 +135,7 @@ std::shared_ptr TriuForward(const std::shared_ptr &input, int64_ infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( input->Dtype(), [=]() { TriuForwardKernel<<>>( @@ -177,7 +177,7 @@ std::shared_ptr TriuBackward(const std::shared_ptr &grad_output, infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { grad_input->Fill(0.0); @@ -269,7 +269,7 @@ std::shared_ptr TransposeForward(const std::shared_ptr &input, i int threads_per_block = 256; int num_blocks = (num_elements + threads_per_block - 1) / threads_per_block; - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { output->Fill(0.0); @@ -371,7 +371,7 @@ std::shared_ptr MaskForward(const std::shared_ptr &input, const int64_t inner = input->NumElements() / rows; int num_blocks = static_cast((input->NumElements() + threads_per_block - 1) / threads_per_block); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { MaskLeadsForwardKernel<<>>( @@ -384,7 +384,7 @@ std::shared_ptr MaskForward(const std::shared_ptr &input, const int64_t batch_size = input->NumElements() / mask_size; int num_blocks = static_cast((input->NumElements() + threads_per_block - 1) / threads_per_block); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { MaskForwardKernel<<>>( @@ -435,7 +435,7 @@ std::shared_ptr MaskBackward(const std::shared_ptr &grad_output, int64_t inner = grad_output->NumElements() / rows; int num_blocks = static_cast((grad_output->NumElements() + threads_per_block - 1) / threads_per_block); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { grad_input->Fill(0.0); @@ -449,7 +449,7 @@ std::shared_ptr MaskBackward(const std::shared_ptr &grad_output, int64_t batch_size = grad_output->NumElements() / mask_size; int num_blocks = static_cast((grad_output->NumElements() + threads_per_block - 1) / threads_per_block); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { grad_input->Fill(0.0); @@ -504,7 +504,7 @@ std::shared_ptr RepeatInterleaveForward(const std::shared_ptr &i infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( input->Dtype(), [=]() { RepeatInterleaveForwardKernel<<>>( @@ -562,7 +562,7 @@ std::shared_ptr RepeatInterleaveBackward(const std::shared_ptr & infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( grad_output->Dtype(), [=]() { grad_input->Fill(0.0); diff --git a/infini_train/src/nn/functional.cc b/infini_train/src/nn/functional.cc index b02f185a..c33e2368 100644 --- a/infini_train/src/nn/functional.cc +++ b/infini_train/src/nn/functional.cc @@ -6,7 +6,6 @@ #include "infini_train/include/autograd/activations.h" #include "infini_train/include/autograd/elementwise.h" -#include "infini_train/include/autograd/misc.h" #include "infini_train/include/autograd/reduction.h" #include "infini_train/include/autograd/softmax.h" #include "infini_train/include/autograd/transform.h" diff --git a/infini_train/src/nn/modules/transformer/mlp.cc b/infini_train/src/nn/modules/transformer/mlp.cc index 3af341b2..ac35d144 100644 --- a/infini_train/src/nn/modules/transformer/mlp.cc +++ b/infini_train/src/nn/modules/transformer/mlp.cc @@ -35,9 +35,14 @@ MLP::MLP(const TransformerConfig &config) : CloneableModule(kType) { } // Round up to multiple_of - int64_t before_round = ffn_hidden; ffn_hidden = (ffn_hidden + config.multiple_of - 1) / config.multiple_of * config.multiple_of; + if (config.ffn_type == FFNType::kMoE && config.moe_config.has_value() + && config.moe_config->moe_ffn_hidden_size > 0) { + ffn_hidden = config.moe_config->moe_ffn_hidden_size; + } + CHECK_GT(ffn_hidden, 0); + // c_fc: ColumnParallel (input full, output parallel) modules_[kCFcLayerName] = std::make_shared( /*in_features=*/config.n_embd, /*out_features=*/ffn_hidden, diff --git a/infini_train/src/nn/modules/transformer/moe/experts.cc b/infini_train/src/nn/modules/transformer/moe/experts.cc new file mode 100644 index 00000000..8f3b1be8 --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/experts.cc @@ -0,0 +1,50 @@ +#include "infini_train/include/nn/modules/transformer/moe/experts.h" + +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/nn/modules/transformer/mlp.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +SequentialMLP::SequentialMLP(const TransformerConfig &config) : CloneableModule(kType), config_(config) { + const auto &moe_config = RequireMoEConfig(config_); + CHECK(moe_config.expert_impl == MoEExpertImpl::kSequential); + CHECK_EQ(moe_config.expert_parallel_size, 1) + << "Current InfiniTrain MoE implementation supports expert_parallel_size=1 only"; + CHECK(moe_config.dispatcher_type == MoEDispatcherType::kLocal) + << "Current InfiniTrain MoE implementation supports local dispatch only"; + + num_local_experts_ = moe_config.num_experts; + CHECK_GT(num_local_experts_, 0); + + for (int64_t expert_idx = 0; expert_idx < num_local_experts_; ++expert_idx) { + modules_[std::string(kExpertNamePrefix) + std::to_string(expert_idx)] = std::make_shared(config_); + } +} + +std::vector> SequentialMLP::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 2); + auto hidden_states = input_tensors[0]; + auto routing_probs = input_tensors[1]; + CHECK_EQ(routing_probs->Dims().back(), num_local_experts_); + + std::shared_ptr output = nullptr; + const int64_t expert_dim = static_cast(routing_probs->Dims().size()) - 1; + for (int64_t expert_idx = 0; expert_idx < num_local_experts_; ++expert_idx) { + auto expert_name = std::string(kExpertNamePrefix) + std::to_string(expert_idx); + auto expert_output = (*modules_.at(expert_name))({hidden_states})[0]; + auto expert_prob = routing_probs->Slice(expert_dim, expert_idx, expert_idx + 1); + auto weighted_output = expert_output * expert_prob; + output = output == nullptr ? weighted_output : output + weighted_output; + } + + return {output}; +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/moe_layer.cc b/infini_train/src/nn/modules/transformer/moe/moe_layer.cc new file mode 100644 index 00000000..6add37ef --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/moe_layer.cc @@ -0,0 +1,33 @@ +#include "infini_train/include/nn/modules/transformer/moe/moe_layer.h" + +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/nn/modules/transformer/moe/experts.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/nn/modules/transformer/moe/router.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +MoELayer::MoELayer(const TransformerConfig &config) : CloneableModule(kType), config_(config) { + const auto &moe_config = RequireMoEConfig(config_); + CHECK(config_.ffn_type == FFNType::kMoE); + CHECK(moe_config.dispatcher_type == MoEConfig::DispatcherType::kAllGather) + << "Current InfiniTrain MoE implementation supports AllGather dispatcher only"; + + modules_[kRouterLayerName] = std::make_shared(config_); + modules_[kExpertsLayerName] = std::make_shared(config_); +} + +std::vector> MoELayer::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + auto hidden_states = input_tensors[0]; + auto router_output = (*modules_.at(kRouterLayerName))({hidden_states}); + CHECK_EQ(router_output.size(), 2); + return (*modules_.at(kExpertsLayerName))({hidden_states, router_output[0], router_output[1]}); +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/moe_utils.cc b/infini_train/src/nn/modules/transformer/moe/moe_utils.cc new file mode 100644 index 00000000..976e9eff --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/moe_utils.cc @@ -0,0 +1,64 @@ +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" + +#include "glog/logging.h" + +#include "infini_train/include/autograd/local_token_dispatcher.h" +#include "infini_train/include/autograd/scatter.h" +#include "infini_train/include/autograd/topk.h" +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/nn/functional.h" + +namespace infini_train::nn::moe { + +std::vector> +TopkRoutingWithScoreFunction(const std::shared_ptr &logits, int64_t topk, bool use_pre_softmax, + std::optional scaling_factor, + const MoEConfig::RouterScoreFunction &score_function) { + + // Megatron TopKRouter returns dense tensors: + // routing_probs: [num_tokens, num_experts] + // routing_map: [num_tokens, num_experts], bool + std::shared_ptr top_probs; + std::shared_ptr top_indices; + + if (score_function == MoEConfig::RouterScoreFunction::kSoftmax) { + if (use_pre_softmax) { + auto scores = function::Softmax(logits, -1); + auto topk_function = std::make_shared(topk); + top_probs = topk_function->Apply({scores})[0]; + top_indices = topk_function->TopIndices(); + } else { + auto topk_function = std::make_shared(topk); + auto top_scores = topk_function->Apply({logits})[0]; + top_indices = topk_function->TopIndices(); + top_probs = function::Softmax(top_scores, -1); + } + } else if (score_function == MoEConfig::RouterScoreFunction::kSigmoid) { + auto sigmoid_scores = function::Sigmoid(logits); + auto topk_function = std::make_shared(topk); + top_probs = topk_function->Apply({sigmoid_scores})[0]; + top_indices = topk_function->TopIndices(); + if (topk > 1) { + top_probs = top_probs / (top_probs->Sum(-1, true) + 1e-20f); + } + } else { + LOG(FATAL) << "Unsupported MoE router score function"; + } + + if (scaling_factor.has_value()) { + top_probs = top_probs * scaling_factor.value(); + } + + auto routing_probs = std::make_shared(logits->Dims())->Apply({top_probs, top_indices})[0]; + auto routing_map_values = std::make_shared(top_indices->Equals(top_indices)->To(DataType::kBOOL)); + auto routing_map = Dispatcher::Instance().Call>( + {logits->GetDevice().type(), "ScatterForward"}, routing_map_values, top_indices, logits->Dims()); + return {routing_probs, routing_map}; +} + +const MoEConfig &RequireMoEConfig(const TransformerConfig &config) { + CHECK(config.moe_config.has_value()) << "MoE layer requires TransformerConfig::moe_config"; + return config.moe_config.value(); +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/router.cc b/infini_train/src/nn/modules/transformer/moe/router.cc new file mode 100644 index 00000000..25208684 --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/router.cc @@ -0,0 +1,57 @@ +#include "infini_train/include/nn/modules/transformer/moe/router.h" + +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/autograd/linear.h" +#include "infini_train/include/autograd/scatter.h" +#include "infini_train/include/autograd/topk.h" +#include "infini_train/include/nn/functional.h" +#include "infini_train/include/nn/init.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +TopKRouter::TopKRouter(const TransformerConfig &config) : CloneableModule(kType), config_(config) { + const auto &moe_config = RequireMoEConfig(config_); + CHECK_GT(moe_config.num_experts, 0); + CHECK_GT(moe_config.router_topk, 0); + CHECK_LE(moe_config.router_topk, moe_config.num_experts); + parameters_[kParamWeightName] + = std::make_shared(std::vector{moe_config.num_experts, config_.n_embd}, DataType::kFLOAT32, + device_) + ->RequiresGrad(); + init::KaimingUniform(parameters_[kParamWeightName]); + + if (config_.add_bias_linear) { + parameters_[kParamBiasName] + = std::make_shared(std::vector{moe_config.num_experts}, DataType::kFLOAT32, device_) + ->RequiresGrad(); + parameters_[kParamBiasName]->Fill(0.0f); + } +} + +std::vector> TopKRouter::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + std::vector> linear_inputs{input_tensors[0], parameters_.at(kParamWeightName)}; + if (parameters_.contains(kParamBiasName)) { + linear_inputs.push_back(parameters_.at(kParamBiasName)); + } + + auto logits = std::make_shared()->Apply(linear_inputs)[0]; + + const auto &moe_config = RequireMoEConfig(config_); + + auto routing_results + = TopkRoutingWithScoreFunction(logits, moe_config.router_topk, moe_config.router_pre_softmax, + moe_config.router_topk_scaling_factor, moe_config.router_score_function); + + auto routing_probs = routing_results[0]; + auto routing_map = routing_results[1]; + return {routing_probs, routing_map}; +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/transformer.cc b/infini_train/src/nn/modules/transformer/transformer.cc index c7e0f28c..bdcde449 100644 --- a/infini_train/src/nn/modules/transformer/transformer.cc +++ b/infini_train/src/nn/modules/transformer/transformer.cc @@ -15,6 +15,7 @@ #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" #include "infini_train/include/nn/modules/transformer/mlp.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_layer.h" #include "infini_train/include/nn/modules/transformer/utils.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" @@ -86,7 +87,11 @@ TransformerLayer::TransformerLayer(const nn::TransformerConfig &config) : Clonea } modules_[kAttnLayerName] = std::make_shared(config); - modules_[kMlpLayerName] = std::make_shared(config); + if (config.ffn_type == FFNType::kMoE) { + modules_[kMlpLayerName] = std::make_shared(config); + } else { + modules_[kMlpLayerName] = std::make_shared(config); + } } std::vector> TransformerLayer::Forward(const std::vector> &x) { diff --git a/infini_train/src/nn/parallel/parallel_functional.cc b/infini_train/src/nn/parallel/parallel_functional.cc index 31db11ec..ffd218d7 100644 --- a/infini_train/src/nn/parallel/parallel_functional.cc +++ b/infini_train/src/nn/parallel/parallel_functional.cc @@ -44,7 +44,7 @@ std::vector>> Scatter(const std::vector &devices, int dim) { std::vector>> output_tensors; for (const auto &tensor : input_tensors) { - output_tensors.emplace_back(std::make_shared(devices, dim)->Apply({tensor})); + output_tensors.emplace_back(std::make_shared(devices, dim)->Apply({tensor})); } std::vector>> transposed_output_tensors; transposed_output_tensors.resize(devices.size()); @@ -59,7 +59,7 @@ std::vector> Gather(const std::vector> gather_tensors; for (const auto &tensor : tensors) { gather_tensors.push_back(tensor[0]); } - return std::make_shared(target_device, dim)->Apply(gather_tensors); + return std::make_shared(target_device, dim)->Apply(gather_tensors); } std::vector>> @@ -67,7 +67,7 @@ BroadcastCoalescedReshape(const std::vector> &tensors, c if (tensors.empty()) { return {}; } - auto tensor_copies = std::make_shared(devices)->Apply(tensors); + auto tensor_copies = std::make_shared(devices)->Apply(tensors); std::vector>> tensor_copies_reshaped(devices.size()); for (int replica_idx = 0; replica_idx < devices.size(); ++replica_idx) { tensor_copies_reshaped[replica_idx].resize(tensors.size()); diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index f7947030..44860a0f 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -13,8 +13,9 @@ #include "infini_train/include/autograd/elementwise.h" #include "infini_train/include/autograd/function.h" #include "infini_train/include/autograd/function_hook.h" +#include "infini_train/include/autograd/gather.h" #include "infini_train/include/autograd/matmul.h" -#include "infini_train/include/autograd/misc.h" +#include "infini_train/include/autograd/no_op.h" #include "infini_train/include/autograd/outer.h" #include "infini_train/include/autograd/reduction.h" #include "infini_train/include/autograd/transform.h" @@ -356,7 +357,7 @@ std::vector> Tensor::Split(int split_size, int dim) { std::shared_ptr Tensor::Gather(int dim, const std::shared_ptr &index) { CHECK(GetDevice() == index->GetDevice()) << "index must be on the same device as input."; - return std::make_shared(dim)->Apply({shared_from_this(), index})[0]; + return std::make_shared(dim)->Apply({shared_from_this(), index})[0]; } std::shared_ptr Tensor::RepeatInterleave(int64_t repeat, int64_t dim) { diff --git a/infini_train/src/utils/precision_checker.cc b/infini_train/src/utils/precision_checker.cc index d2cbd16a..2965284e 100644 --- a/infini_train/src/utils/precision_checker.cc +++ b/infini_train/src/utils/precision_checker.cc @@ -193,6 +193,8 @@ std::string FormatShape(const std::vector &shape) { std::string DataTypeToString(DataType dtype) { switch (dtype) { + case DataType::kBOOL: + return "bool"; case DataType::kFLOAT32: return "float32"; case DataType::kFLOAT16: diff --git a/tests/transformer/test_transformer_architecture.cc b/tests/transformer/test_transformer_architecture.cc index ba62e1e3..ad7a9da3 100644 --- a/tests/transformer/test_transformer_architecture.cc +++ b/tests/transformer/test_transformer_architecture.cc @@ -4,10 +4,13 @@ #include "gtest/gtest.h" +#include "infini_train/include/autograd/topk.h" #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" #include "infini_train/include/nn/modules/transformer/mlp.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_layer.h" +#include "infini_train/include/nn/modules/transformer/moe/router.h" #include "infini_train/include/nn/modules/transformer/transformer.h" #include "infini_train/include/nn/modules/transformer/transformer_config.h" #include "infini_train/include/nn/modules/transformer/utils.h" @@ -189,4 +192,160 @@ TEST_P(TransformerModuleTest, StateDict) { EXPECT_GE(state_dict.size(), params.size()); } + +TEST_P(TransformerModuleTest, MoELayerTop1) { + nn::TransformerConfig config; + config.n_embd = 32; + config.n_head = 2; + config.n_kv_head = 2; + config.activation_type = nn::MLPType::kGELU; + config.add_bias_linear = true; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 2; + config.moe_config->router_topk = 1; + config.moe_config->router_pre_softmax = true; + + auto moe = std::make_shared(config); + moe->To(GetDevice()); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32, GetDevice()); + input->Uniform(); + + auto output = (*moe)({input}); + ASSERT_EQ(output.size(), 1); + EXPECT_EQ(output[0]->Dims(), input->Dims()); + EXPECT_FALSE(moe->Parameters().empty()); +} + +TEST_P(TransformerModuleTest, MoELayerTop2SwiGLU) { + nn::TransformerConfig config; + config.n_embd = 32; + config.n_head = 2; + config.n_kv_head = 2; + config.activation_type = nn::MLPType::kSwiGLU; + config.add_bias_linear = false; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 4; + config.moe_config->router_topk = 2; + config.moe_config->moe_ffn_hidden_size = 48; + + auto moe = std::make_shared(config); + moe->To(GetDevice()); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32, GetDevice()); + input->Uniform(); + + auto output = (*moe)({input}); + ASSERT_EQ(output.size(), 1); + EXPECT_EQ(output[0]->Dims(), input->Dims()); + + auto state = moe->StateDict(); + ASSERT_TRUE(state.contains("experts.expert_0.c_fc.weight")); + ASSERT_TRUE(state.contains("experts.expert_0.c_fc2.weight")); + ASSERT_TRUE(state.contains("experts.expert_0.c_proj.weight")); + EXPECT_EQ(state.at("experts.expert_0.c_fc.weight")->Dims(), (std::vector{48, config.n_embd})); + EXPECT_EQ(state.at("experts.expert_0.c_fc2.weight")->Dims(), (std::vector{48, config.n_embd})); + EXPECT_EQ(state.at("experts.expert_0.c_proj.weight")->Dims(), (std::vector{config.n_embd, 48})); +} + +TEST_P(TransformerModuleTest, TopKRouterMegatronOutputs) { + nn::TransformerConfig config; + config.n_embd = 32; + config.add_bias_linear = false; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 4; + config.moe_config->router_topk = 2; + + auto router = std::make_shared(config); + router->To(GetDevice()); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32, GetDevice()); + input->Uniform(); + + auto output = (*router)({input}); + ASSERT_EQ(output.size(), 2); + EXPECT_EQ(output[0]->Dims(), (std::vector{2, 4, 4})); + EXPECT_EQ(output[1]->Dims(), (std::vector{2, 4, 4})); + EXPECT_EQ(output[0]->Dtype(), DataType::kFLOAT32); + EXPECT_EQ(output[1]->Dtype(), DataType::kBOOL); +} + +TEST_P(TransformerModuleTest, TopKTorchInterface) { + ONLY_CPU(); + const float data[] = {1.0f, 5.0f, 2.0f, 4.0f, 3.0f, 0.0f}; + auto input = std::make_shared(data, std::vector{2, 3}, DataType::kFLOAT32); + + auto largest_topk = std::make_shared(2, 1, true, true); + auto largest_values = largest_topk->Apply({input})[0]; + auto largest_indices = largest_topk->TopIndices(); + ASSERT_EQ(largest_values->Dims(), (std::vector{2, 2})); + ASSERT_EQ(largest_indices->Dims(), (std::vector{2, 2})); + const auto *largest_values_ptr = static_cast(largest_values->DataPtr()); + const auto *largest_indices_ptr = static_cast(largest_indices->DataPtr()); + EXPECT_FLOAT_EQ(largest_values_ptr[0], 5.0f); + EXPECT_FLOAT_EQ(largest_values_ptr[1], 2.0f); + EXPECT_FLOAT_EQ(largest_values_ptr[2], 4.0f); + EXPECT_FLOAT_EQ(largest_values_ptr[3], 3.0f); + EXPECT_EQ(largest_indices_ptr[0], 1); + EXPECT_EQ(largest_indices_ptr[1], 2); + EXPECT_EQ(largest_indices_ptr[2], 0); + EXPECT_EQ(largest_indices_ptr[3], 1); + + auto smallest_topk = std::make_shared(1, 0, false, true); + auto smallest_values = smallest_topk->Apply({input})[0]; + auto smallest_indices = smallest_topk->TopIndices(); + ASSERT_EQ(smallest_values->Dims(), (std::vector{1, 3})); + ASSERT_EQ(smallest_indices->Dims(), (std::vector{1, 3})); + const auto *smallest_values_ptr = static_cast(smallest_values->DataPtr()); + const auto *smallest_indices_ptr = static_cast(smallest_indices->DataPtr()); + EXPECT_FLOAT_EQ(smallest_values_ptr[0], 1.0f); + EXPECT_FLOAT_EQ(smallest_values_ptr[1], 3.0f); + EXPECT_FLOAT_EQ(smallest_values_ptr[2], 0.0f); + EXPECT_EQ(smallest_indices_ptr[0], 0); + EXPECT_EQ(smallest_indices_ptr[1], 1); + EXPECT_EQ(smallest_indices_ptr[2], 1); +} + +TEST_P(TransformerModuleTest, TopKRouterNormalization) { + ONLY_CPU(); + auto make_router = [](nn::MoEConfig::RouterScoreFunction score_function, bool pre_softmax) { + nn::TransformerConfig config; + config.n_embd = 2; + config.add_bias_linear = false; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 3; + config.moe_config->router_topk = 2; + config.moe_config->router_score_function = score_function; + config.moe_config->router_pre_softmax = pre_softmax; + auto router = std::make_shared(config); + auto weight = router->parameter(nn::moe::TopKRouter::kParamWeightName); + auto *weight_ptr = static_cast(weight->DataPtr()); + weight_ptr[0] = 1.0f; + weight_ptr[1] = 0.0f; + weight_ptr[2] = 2.0f; + weight_ptr[3] = 0.0f; + weight_ptr[4] = 0.0f; + weight_ptr[5] = 0.0f; + return router; + }; + + const float input_data[] = {1.0f, 1.0f}; + auto input = std::make_shared(input_data, std::vector{1, 1, 2}, DataType::kFLOAT32); + + auto softmax_router = make_router(nn::MoEConfig::RouterScoreFunction::kSoftmax, false); + auto softmax_output = (*softmax_router)({input}); + const auto *softmax_probs = static_cast(softmax_output[0]->DataPtr()); + EXPECT_NEAR(softmax_probs[0] + softmax_probs[1] + softmax_probs[2], 1.0f, 1e-5f); + EXPECT_GT(softmax_probs[1], softmax_probs[0]); + EXPECT_FLOAT_EQ(softmax_probs[2], 0.0f); + + auto sigmoid_router = make_router(nn::MoEConfig::RouterScoreFunction::kSigmoid, true); + auto sigmoid_output = (*sigmoid_router)({input}); + const auto *sigmoid_probs = static_cast(sigmoid_output[0]->DataPtr()); + EXPECT_NEAR(sigmoid_probs[0] + sigmoid_probs[1] + sigmoid_probs[2], 1.0f, 1e-5f); + EXPECT_GT(sigmoid_probs[1], sigmoid_probs[0]); + EXPECT_FLOAT_EQ(sigmoid_probs[2], 0.0f); +} + INFINI_TRAIN_REGISTER_TEST(TransformerModuleTest);