Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ option(WITH_ASCEND "Enable Ascend backend" OFF)

option(WITH_TORCH "Enable PyTorch C++ backend" OFF)

option(WITH_NINETOOTHED "Enable NineToothed-generated kernels" OFF)

# Default OFF until CANN's `extract_host_stub.py` path handling is fixed for
# `scikit-build-core` temp-dir builds (triggers `KeyError` on the preprocessed
# object path). Enable explicitly with `-DBUILD_CUSTOM_KERNEL=ON` when the
Expand Down Expand Up @@ -290,6 +292,18 @@ if(_gpu_backend_count GREATER 1)
message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_HYGON`, `WITH_METAX`, `WITH_MOORE`, and `WITH_ASCEND` are mutually exclusive. Build one GPU backend at a time.")
endif()

if(WITH_NINETOOTHED AND NOT WITH_NVIDIA)
message(FATAL_ERROR "`WITH_NINETOOTHED` currently requires `WITH_NVIDIA=ON` because ninetoothed AOT uses caller=`cuda`.")
endif()

if(WITH_NINETOOTHED)
# NineToothed code generation configuration.
set(NINETOOTHED_PYTHON_EXECUTABLE "" CACHE FILEPATH "Python executable used to run ninetoothed code generation")
set(INFINIOPS_NINETOOTHED_OPS "rms_norm" CACHE STRING "Semicolon- or comma-separated NineToothed ops to generate")
set(INFINIOPS_NINETOOTHED_DTYPES "float32;float16;bfloat16" CACHE STRING "Semicolon- or comma-separated NineToothed dtypes to generate")
set(INFINIOPS_NINETOOTHED_RMS_NORM_NDIMS "2;3" CACHE STRING "Semicolon- or comma-separated RmsNorm input ranks to generate with NineToothed")
endif()

if(WITH_NVIDIA)
add_compile_definitions(WITH_NVIDIA=1)
enable_language(CUDA)
Expand Down
107 changes: 107 additions & 0 deletions scripts/generate_ninetoothed_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import argparse
import importlib.util
import pathlib
import shutil
import sys

import ninetoothed

_PROJECT_DIR = pathlib.Path(__file__).resolve().parents[1]
_DEFAULT_DTYPES = ("float32", "float16", "bfloat16")
_DEFAULT_RMS_NORM_NDIMS = (2, 3)
_OP_MODULES = {
"rms_norm": _PROJECT_DIR
/ "src"
/ "ninetoothed"
/ "ops"
/ "rms_norm"
/ "codegen.py",
}


def _build_manifest(output_dir):
return sorted(
str(path)
for path in pathlib.Path(output_dir).rglob("*.cpp")
if not path.name.endswith(".tmp.cpp")
)


def _write_cmake_manifest(output_dir, sources):
manifest_path = pathlib.Path(output_dir) / "manifest.cmake"
lines = ["set(INFINIOPS_NINETOOTHED_SOURCES"]
lines.extend(f' "{source}"' for source in sources)
lines.append(")")
lines.append("")
lines.append(f'set(INFINIOPS_NINETOOTHED_INCLUDE_DIRS "{output_dir}")')
lines.append("")
manifest_path.write_text("\n".join(lines) + "\n")


def _load_op_module(op):
path = _OP_MODULES[op]
sys.path.insert(0, str(path.parent))
spec = importlib.util.spec_from_file_location(path.stem, path)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
sys.modules[spec.name] = module
spec.loader.exec_module(module)

return module


def generate(
ops,
*,
output_dir,
dtypes=_DEFAULT_DTYPES,
rms_norm_ndims=_DEFAULT_RMS_NORM_NDIMS,
):
unknown_ops = tuple(op for op in ops if op not in _OP_MODULES)

if unknown_ops:
raise ValueError(f"unsupported ninetoothed ops: {', '.join(unknown_ops)}")

output_dir = pathlib.Path(output_dir)
shutil.rmtree(output_dir, ignore_errors=True)
output_dir.mkdir(parents=True, exist_ok=True)

for op in ops:
module = _load_op_module(op)
module.generate(
ninetoothed,
output_dir,
dtypes=dtypes,
rms_norm_ndims=rms_norm_ndims,
)

sources = _build_manifest(output_dir)
_write_cmake_manifest(output_dir, sources)

return sources


def _parse_args():
parser = argparse.ArgumentParser(
description="Generate ninetoothed operator sources for InfiniOps."
)
parser.add_argument("--output-dir", required=True)
parser.add_argument("--ops", nargs="+", default=tuple(_OP_MODULES))
parser.add_argument("--dtypes", nargs="+", default=_DEFAULT_DTYPES)
parser.add_argument("--rms-norm-ndims", nargs="+", default=_DEFAULT_RMS_NORM_NDIMS)

return parser.parse_args()


def main():
args = _parse_args()
generate(
args.ops,
output_dir=args.output_dir,
dtypes=tuple(args.dtypes),
rms_norm_ndims=tuple(args.rms_norm_ndims),
)


if __name__ == "__main__":
main()
15 changes: 13 additions & 2 deletions scripts/generate_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,11 +765,13 @@ def _index_impl_headers(impl_roots, scan_dirs):
return by_operator


def _get_all_ops(devices, with_torch=False):
def _get_all_ops(devices, with_torch=False, with_ninetoothed=False):
scan_dirs = set(devices)

if with_torch:
scan_dirs.add("torch")
if with_ninetoothed:
scan_dirs.add("ninetoothed")

ops = {}

Expand Down Expand Up @@ -883,6 +885,11 @@ def _dispatch_gen_batch_size():
action="store_true",
help="Include PyTorch C++ backend implementations.",
)
parser.add_argument(
"--with-ninetoothed",
action="store_true",
help="Include NineToothed backend implementations.",
)

args = parser.parse_args()

Expand All @@ -900,7 +907,11 @@ def _dispatch_gen_batch_size():
if ops_json.exists():
ops = json.loads(ops_json.read_text())
else:
ops = _get_all_ops(args.devices, with_torch=args.with_torch)
ops = _get_all_ops(
args.devices,
with_torch=args.with_torch,
with_ninetoothed=args.with_ninetoothed,
)

bind_func_names = []

Expand Down
43 changes: 43 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,46 @@ if(WITH_NVIDIA)
)
endif()

if(WITH_NINETOOTHED)
find_package(Python COMPONENTS Interpreter REQUIRED)

if(NINETOOTHED_PYTHON_EXECUTABLE)
set(_ninetoothed_python "${NINETOOTHED_PYTHON_EXECUTABLE}")
elseif(_TORCH_PYTHON)
set(_ninetoothed_python "${_TORCH_PYTHON}")
else()
set(_ninetoothed_python "${Python_EXECUTABLE}")
endif()
message(STATUS "NineToothed codegen Python: ${_ninetoothed_python}")

string(REPLACE "," ";" _ninetoothed_ops "${INFINIOPS_NINETOOTHED_OPS}")
string(REPLACE "," ";" _ninetoothed_dtypes "${INFINIOPS_NINETOOTHED_DTYPES}")
string(REPLACE "," ";" _ninetoothed_rms_norm_ndims "${INFINIOPS_NINETOOTHED_RMS_NORM_NDIMS}")

set(_ninetoothed_output_dir "${CMAKE_CURRENT_BINARY_DIR}/ninetoothed")
set(_ninetoothed_generator_args
"${PROJECT_SOURCE_DIR}/scripts/generate_ninetoothed_ops.py"
--output-dir "${_ninetoothed_output_dir}"
--ops ${_ninetoothed_ops}
--dtypes ${_ninetoothed_dtypes}
--rms-norm-ndims ${_ninetoothed_rms_norm_ndims})

execute_process(
COMMAND "${_ninetoothed_python}" ${_ninetoothed_generator_args}
WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}"
RESULT_VARIABLE _ninetoothed_generation_result
)

if(NOT _ninetoothed_generation_result EQUAL 0)
message(FATAL_ERROR "Generating NineToothed operator sources failed with `${_ninetoothed_python}`. Set `NINETOOTHED_PYTHON_EXECUTABLE` to a Python with `ninetoothed`, `ntops`, `triton`, `sympy`, and CUDA dependencies installed.")
endif()

include("${_ninetoothed_output_dir}/manifest.cmake")
target_include_directories(infiniops PUBLIC
${INFINIOPS_NINETOOTHED_INCLUDE_DIRS})
target_sources(infiniops PRIVATE ${INFINIOPS_NINETOOTHED_SOURCES})
endif()

if(WITH_ILUVATAR)
set(ILUVATAR_PATTERNS
"native/cuda/*.cc"
Expand Down Expand Up @@ -480,6 +520,9 @@ if(GENERATE_PYTHON_BINDINGS)
if(WITH_TORCH)
list(APPEND GENERATOR_ARGS --with-torch)
endif()
if(WITH_NINETOOTHED)
list(APPEND GENERATOR_ARGS --with-ninetoothed)
endif()

execute_process(
COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py ${GENERATOR_ARGS}
Expand Down
73 changes: 73 additions & 0 deletions src/ninetoothed/ops/rms_norm/codegen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
_BLOCK_SIZE = 256
_DEFAULT_NDIMS = (2, 3)


def _premake(
ndim,
num_normalized_dims,
input_dtype,
weight_dtype,
output_dtype,
):
import ntops

return ntops.kernels.rms_norm.premake(
ndim,
num_normalized_dims,
input_dtype=input_dtype,
weight_dtype=weight_dtype,
output_dtype=output_dtype,
block_size=_BLOCK_SIZE,
)


def _normalize_ndims(values):
ndims = []

for value in values:
ndim = int(value)

if ndim not in _DEFAULT_NDIMS:
raise ValueError(f"`RmsNorm` currently supports rank 2 and 3: {value!r}")

if ndim not in ndims:
ndims.append(ndim)

return tuple(ndims)


def _configs(ninetoothed, dtypes, ndims):
configs = []

for ndim in _normalize_ndims(ndims):
for dtype_name in dtypes:
dtype = getattr(ninetoothed, dtype_name)
configs.append(
(
(),
{
"ndim": ndim,
"num_normalized_dims": 1,
"input_dtype": dtype,
"weight_dtype": dtype,
"output_dtype": dtype,
},
{},
)
)

return tuple(configs)


def generate(ninetoothed, output_dir, *, dtypes, rms_norm_ndims):
variant_dir = output_dir / "rms_norm"
variant_dir.mkdir(parents=True, exist_ok=True)
ninetoothed.build(
_premake,
_configs(ninetoothed, dtypes, rms_norm_ndims),
meta_parameters=None,
caller="cuda",
kernel_name="infiniops_ninetoothed_rms_norm",
output_dir=variant_dir,
lazy=False,
)
70 changes: 70 additions & 0 deletions src/ninetoothed/ops/rms_norm/ninetoothed.h
Comment thread
voltjia marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#ifndef INFINI_OPS_NINETOOTHED_RMS_NORM_H_
#define INFINI_OPS_NINETOOTHED_RMS_NORM_H_

#include <cassert>
#include <cstdint>
#include <vector>

#include "base/rms_norm.h"
#include "data_type.h"
#include "ninetoothed/tensor.h"
#include "rms_norm/infiniops_ninetoothed_rms_norm.h"

namespace infini::ops {

template <>
class Operator<RmsNorm, Device::Type::kNvidia, 9> : public RmsNorm {
public:
using RmsNorm::RmsNorm;
using RmsNorm::operator();

void operator()(const Tensor input, const Tensor weight, float eps,
Tensor out) const override {
assert(input.dtype() == out.dtype() && out.dtype() == weight.dtype() &&
"operator `RmsNorm` requires all input and output tensors to have "
"the same dtype");
assert(input.shape() == out.shape() &&
"ninetoothed `RmsNorm` requires input and output tensors with the "
"same shape");
assert(weight.ndim() == 1 && weight.size(-1) == out.size(-1) &&
"ninetoothed `RmsNorm` requires a 1D weight matching the last "
"dimension");
assert((out.ndim() == 2 || out.ndim() == 3) &&
"ninetoothed `RmsNorm` currently supports rank-2 and rank-3 "
"tensors");

std::vector<std::uint64_t> weight_sizes;
std::vector<std::int64_t> weight_strides;
double eps_value = static_cast<double>(eps);
std::int64_t num_normalized_elements =
static_cast<std::int64_t>(out.size(-1));
std::uint64_t empty_shape[1] = {};
std::int64_t empty_strides[1] = {};

weight_sizes.assign(out.shape().begin(), out.shape().end());
weight_strides.assign(out.ndim(), 0);
weight_strides.back() =
weight.strides().empty() ? 1 : weight.strides().back();

const int dtype_index = ninetoothed::DataTypeIndex(out.dtype());
assert(
dtype_index >= 0 &&
"ninetoothed `RmsNorm` supports only float16, bfloat16, and float32");

auto result = launch_infiniops_ninetoothed_rms_norm(
static_cast<NineToothedStream>(stream_), ninetoothed::Tensor(input),
ninetoothed::Tensor(const_cast<void*>(weight.data()),
weight_sizes.data(), weight_strides.data()),
ninetoothed::Tensor(eps_value, empty_shape, empty_strides),
ninetoothed::Tensor(out),
ninetoothed::Tensor(num_normalized_elements, empty_shape,
empty_strides),
static_cast<int>(out.ndim()), 1, dtype_index, dtype_index, dtype_index);

assert(result == 0 && "ninetoothed `RmsNorm` launch failed");
}
};

} // namespace infini::ops

#endif
Loading
Loading