Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions infini_train/include/autograd/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class ProcessGroup;
} // namespace nn::parallel
} // namespace infini_train

namespace infini_train::autograd {
namespace infini_train::autograd::comm {
class Scatter : public autograd::Function {
public:
static constexpr char kType[] = "ScatterFunction";
Expand Down Expand Up @@ -99,4 +99,4 @@ class ReduceAddCoalesced : public autograd::Function {
std::vector<Device> target_gpus_;
int64_t num_inputs_ = 0;
};
} // namespace infini_train::autograd
} // namespace infini_train::autograd::comm
30 changes: 30 additions & 0 deletions infini_train/include/autograd/gather.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include <memory>
#include <vector>

#include "infini_train/include/autograd/function.h"

namespace infini_train {
class Tensor;
}

namespace infini_train::autograd {

class Gather : public Function {
public:
static constexpr char kType[] = "GatherFunction";

Gather(int64_t dim = 0) : Function(kType), dim_(dim) {}

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

private:
const int64_t dim_ = 0;
std::vector<int64_t> input_dims_;
};

} // namespace infini_train::autograd
113 changes: 0 additions & 113 deletions infini_train/include/autograd/misc.h

This file was deleted.

30 changes: 30 additions & 0 deletions infini_train/include/autograd/no_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include <memory>
#include <vector>

#include "infini_train/include/autograd/function.h"

namespace infini_train {
class Tensor;
}

namespace infini_train::autograd {

class NoOp : public Function {
public:
static constexpr char kType[] = "NoOpFunction";

explicit NoOp(const std::vector<int64_t> &output_dims) : Function(kType), output_dims_(output_dims) {}

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

private:
const std::vector<int64_t> output_dims_;
std::vector<int64_t> input_dims_;
};

} // namespace infini_train::autograd
29 changes: 29 additions & 0 deletions infini_train/include/autograd/scatter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include <memory>
#include <vector>

#include "infini_train/include/autograd/function.h"

namespace infini_train {
class Tensor;
}

namespace infini_train::autograd {

class Scatter : public Function {
public:
static constexpr char kType[] = "ScatterFunction";

explicit Scatter(const std::vector<int64_t> &output_dims) : Function(kType), output_dims_(output_dims) {}

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

private:
std::vector<int64_t> output_dims_;
};

} // namespace infini_train::autograd
40 changes: 40 additions & 0 deletions infini_train/include/autograd/topk.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#pragma once

#include <memory>
#include <vector>

#include "infini_train/include/autograd/function.h"

namespace infini_train {
class Tensor;
}

namespace infini_train::autograd {

// FIXME(dcj): Align this API with torch.topk and return both values and indices from Forward once
// InfiniTrain autograd supports marking individual outputs as non-differentiable. Today indices
// are exposed through TopIndices() to avoid waiting for gradients on metadata outputs.
class TopK : public Function {
public:
static constexpr char kType[] = "TopKFunction";

explicit TopK(int64_t topk, int64_t dim = -1, bool largest = true, bool sorted = true)
: Function(kType), topk_(topk), dim_(dim), largest_(largest), sorted_(sorted) {}

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

std::shared_ptr<Tensor> TopIndices() const;

private:
int64_t topk_ = 1;
int64_t dim_ = -1;
bool largest_ = true;
bool sorted_ = true;
std::shared_ptr<Tensor> top_indices_;
std::vector<int64_t> input_dims_;
};

} // namespace infini_train::autograd
66 changes: 66 additions & 0 deletions infini_train/include/autograd/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,70 @@ class RepeatInterleave : public Function {
std::vector<int64_t> input_dims_;
};

class Split : public Function {
public:
static constexpr char kType[] = "SplitFunction";

Split(int64_t split_size, int dim = 0) : Function(kType), split_size_(split_size), dim_(dim) {}

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

private:
const int64_t split_size_ = 0;
const int dim_ = 0;
std::vector<int64_t> input_dims_;
};

class Stack : public Function {
public:
static constexpr char kType[] = "StackFunction";

Stack(int64_t dim) : Function(kType), dim_(dim) {}

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

private:
int64_t dim_ = 0;
std::vector<int64_t> input_dims_;
};

class Concat : public Function {
public:
static constexpr char kType[] = "ConcatFunction";

Concat(int64_t dim) : Function(kType), dim_(dim) {}

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

private:
const int64_t dim_ = 0;
std::vector<std::vector<int64_t>> input_dims_list_;
};

class Slice : public Function {
public:
static constexpr char kType[] = "SliceFunction";

Slice(const std::vector<int64_t> &starts, const std::vector<int64_t> &ends, const std::vector<int64_t> &steps)
: Function(kType), starts_(starts), ends_(ends), steps_(steps) {}
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

private:
const std::vector<int64_t> starts_;
const std::vector<int64_t> ends_;
const std::vector<int64_t> steps_;
};

} // namespace infini_train::autograd
3 changes: 3 additions & 0 deletions infini_train/include/core/backend_type_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ template <Device::DeviceType Dev, DataType DType> struct BackendTypeMap;
// -----------------------------------------------------------------------------
#define INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV) \
namespace infini_train::core { \
template <> struct BackendTypeMap<DEV, DataType::kBOOL> { \
using type = bool; \
}; \
template <> struct BackendTypeMap<DEV, DataType::kUINT8> { \
using type = uint8_t; \
}; \
Expand Down
17 changes: 10 additions & 7 deletions infini_train/include/datatype.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ struct alignas(2) BF16 {
// DataType enum and metadata tables
// -----------------------------------------------------------------------------
enum class DataType : int8_t {
kBOOL,
kUINT8,
kINT8,
kUINT16,
Expand All @@ -99,16 +100,18 @@ enum class DataType : int8_t {
};

inline const std::unordered_map<DataType, size_t> kDataTypeToSize = {
{DataType::kUINT8, 1}, {DataType::kINT8, 1}, {DataType::kUINT16, 2}, {DataType::kINT16, 2},
{DataType::kUINT32, 4}, {DataType::kINT32, 4}, {DataType::kUINT64, 8}, {DataType::kINT64, 8},
{DataType::kBFLOAT16, 2}, {DataType::kFLOAT16, 2}, {DataType::kFLOAT32, 4}, {DataType::kFLOAT64, 8},
{DataType::kBOOL, 1}, {DataType::kUINT8, 1}, {DataType::kINT8, 1}, {DataType::kUINT16, 2},
{DataType::kINT16, 2}, {DataType::kUINT32, 4}, {DataType::kINT32, 4}, {DataType::kUINT64, 8},
{DataType::kINT64, 8}, {DataType::kBFLOAT16, 2}, {DataType::kFLOAT16, 2}, {DataType::kFLOAT32, 4},
{DataType::kFLOAT64, 8},
};

inline const std::unordered_map<DataType, std::string> kDataTypeToDesc = {
{DataType::kUINT8, "uint8"}, {DataType::kINT8, "int8"}, {DataType::kUINT16, "uint16"},
{DataType::kINT16, "int16"}, {DataType::kUINT32, "uint32"}, {DataType::kINT32, "int32"},
{DataType::kUINT64, "uint64"}, {DataType::kINT64, "int64"}, {DataType::kBFLOAT16, "bf16"},
{DataType::kFLOAT16, "fp16"}, {DataType::kFLOAT32, "fp32"}, {DataType::kFLOAT64, "fp64"},
{DataType::kBOOL, "bool"}, {DataType::kUINT8, "uint8"}, {DataType::kINT8, "int8"},
{DataType::kUINT16, "uint16"}, {DataType::kINT16, "int16"}, {DataType::kUINT32, "uint32"},
{DataType::kINT32, "int32"}, {DataType::kUINT64, "uint64"}, {DataType::kINT64, "int64"},
{DataType::kBFLOAT16, "bf16"}, {DataType::kFLOAT16, "fp16"}, {DataType::kFLOAT32, "fp32"},
{DataType::kFLOAT64, "fp64"},
};

// =============================================================================
Expand Down
Loading
Loading