From ea231e85dcc5da7577e30dbaf8a1ff75acbc0aa5 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 20 May 2026 09:59:36 +0800 Subject: [PATCH] feat(nvidia): add ntops rms norm backend --- CMakeLists.txt | 14 +++ scripts/generate_ninetoothed_ops.py | 107 +++++++++++++++++++++ scripts/generate_wrappers.py | 15 ++- src/CMakeLists.txt | 43 +++++++++ src/ninetoothed/ops/rms_norm/codegen.py | 73 ++++++++++++++ src/ninetoothed/ops/rms_norm/ninetoothed.h | 70 ++++++++++++++ src/ninetoothed/tensor.h | 62 ++++++++++++ tests/test_generate_ninetoothed_ops.py | 107 +++++++++++++++++++++ 8 files changed, 489 insertions(+), 2 deletions(-) create mode 100644 scripts/generate_ninetoothed_ops.py create mode 100644 src/ninetoothed/ops/rms_norm/codegen.py create mode 100644 src/ninetoothed/ops/rms_norm/ninetoothed.h create mode 100644 src/ninetoothed/tensor.h create mode 100644 tests/test_generate_ninetoothed_ops.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 3ac4bd40..0baa3bdc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,6 +23,8 @@ option(WITH_ASCEND "Enable Ascend backend" OFF) option(WITH_TORCH "Enable PyTorch C++ backend" OFF) +option(WITH_NINETOOTHED "Enable NineToothed-generated kernels" OFF) + # Default OFF until CANN's `extract_host_stub.py` path handling is fixed for # `scikit-build-core` temp-dir builds (triggers `KeyError` on the preprocessed # object path). Enable explicitly with `-DBUILD_CUSTOM_KERNEL=ON` when the @@ -290,6 +292,18 @@ if(_gpu_backend_count GREATER 1) message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_HYGON`, `WITH_METAX`, `WITH_MOORE`, and `WITH_ASCEND` are mutually exclusive. Build one GPU backend at a time.") endif() +if(WITH_NINETOOTHED AND NOT WITH_NVIDIA) + message(FATAL_ERROR "`WITH_NINETOOTHED` currently requires `WITH_NVIDIA=ON` because ninetoothed AOT uses caller=`cuda`.") +endif() + +if(WITH_NINETOOTHED) + # NineToothed code generation configuration. + set(NINETOOTHED_PYTHON_EXECUTABLE "" CACHE FILEPATH "Python executable used to run ninetoothed code generation") + set(INFINIOPS_NINETOOTHED_OPS "rms_norm" CACHE STRING "Semicolon- or comma-separated NineToothed ops to generate") + set(INFINIOPS_NINETOOTHED_DTYPES "float32;float16;bfloat16" CACHE STRING "Semicolon- or comma-separated NineToothed dtypes to generate") + set(INFINIOPS_NINETOOTHED_RMS_NORM_NDIMS "2;3" CACHE STRING "Semicolon- or comma-separated RmsNorm input ranks to generate with NineToothed") +endif() + if(WITH_NVIDIA) add_compile_definitions(WITH_NVIDIA=1) enable_language(CUDA) diff --git a/scripts/generate_ninetoothed_ops.py b/scripts/generate_ninetoothed_ops.py new file mode 100644 index 00000000..61887cd6 --- /dev/null +++ b/scripts/generate_ninetoothed_ops.py @@ -0,0 +1,107 @@ +import argparse +import importlib.util +import pathlib +import shutil +import sys + +import ninetoothed + +_PROJECT_DIR = pathlib.Path(__file__).resolve().parents[1] +_DEFAULT_DTYPES = ("float32", "float16", "bfloat16") +_DEFAULT_RMS_NORM_NDIMS = (2, 3) +_OP_MODULES = { + "rms_norm": _PROJECT_DIR + / "src" + / "ninetoothed" + / "ops" + / "rms_norm" + / "codegen.py", +} + + +def _build_manifest(output_dir): + return sorted( + str(path) + for path in pathlib.Path(output_dir).rglob("*.cpp") + if not path.name.endswith(".tmp.cpp") + ) + + +def _write_cmake_manifest(output_dir, sources): + manifest_path = pathlib.Path(output_dir) / "manifest.cmake" + lines = ["set(INFINIOPS_NINETOOTHED_SOURCES"] + lines.extend(f' "{source}"' for source in sources) + lines.append(")") + lines.append("") + lines.append(f'set(INFINIOPS_NINETOOTHED_INCLUDE_DIRS "{output_dir}")') + lines.append("") + manifest_path.write_text("\n".join(lines) + "\n") + + +def _load_op_module(op): + path = _OP_MODULES[op] + sys.path.insert(0, str(path.parent)) + spec = importlib.util.spec_from_file_location(path.stem, path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + + return module + + +def generate( + ops, + *, + output_dir, + dtypes=_DEFAULT_DTYPES, + rms_norm_ndims=_DEFAULT_RMS_NORM_NDIMS, +): + unknown_ops = tuple(op for op in ops if op not in _OP_MODULES) + + if unknown_ops: + raise ValueError(f"unsupported ninetoothed ops: {', '.join(unknown_ops)}") + + output_dir = pathlib.Path(output_dir) + shutil.rmtree(output_dir, ignore_errors=True) + output_dir.mkdir(parents=True, exist_ok=True) + + for op in ops: + module = _load_op_module(op) + module.generate( + ninetoothed, + output_dir, + dtypes=dtypes, + rms_norm_ndims=rms_norm_ndims, + ) + + sources = _build_manifest(output_dir) + _write_cmake_manifest(output_dir, sources) + + return sources + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Generate ninetoothed operator sources for InfiniOps." + ) + parser.add_argument("--output-dir", required=True) + parser.add_argument("--ops", nargs="+", default=tuple(_OP_MODULES)) + parser.add_argument("--dtypes", nargs="+", default=_DEFAULT_DTYPES) + parser.add_argument("--rms-norm-ndims", nargs="+", default=_DEFAULT_RMS_NORM_NDIMS) + + return parser.parse_args() + + +def main(): + args = _parse_args() + generate( + args.ops, + output_dir=args.output_dir, + dtypes=tuple(args.dtypes), + rms_norm_ndims=tuple(args.rms_norm_ndims), + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 4eaa3474..0302dd55 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -765,11 +765,13 @@ def _index_impl_headers(impl_roots, scan_dirs): return by_operator -def _get_all_ops(devices, with_torch=False): +def _get_all_ops(devices, with_torch=False, with_ninetoothed=False): scan_dirs = set(devices) if with_torch: scan_dirs.add("torch") + if with_ninetoothed: + scan_dirs.add("ninetoothed") ops = {} @@ -883,6 +885,11 @@ def _dispatch_gen_batch_size(): action="store_true", help="Include PyTorch C++ backend implementations.", ) + parser.add_argument( + "--with-ninetoothed", + action="store_true", + help="Include NineToothed backend implementations.", + ) args = parser.parse_args() @@ -900,7 +907,11 @@ def _dispatch_gen_batch_size(): if ops_json.exists(): ops = json.loads(ops_json.read_text()) else: - ops = _get_all_ops(args.devices, with_torch=args.with_torch) + ops = _get_all_ops( + args.devices, + with_torch=args.with_torch, + with_ninetoothed=args.with_ninetoothed, + ) bind_func_names = [] diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4361ba38..36aaf224 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -49,6 +49,46 @@ if(WITH_NVIDIA) ) endif() +if(WITH_NINETOOTHED) + find_package(Python COMPONENTS Interpreter REQUIRED) + + if(NINETOOTHED_PYTHON_EXECUTABLE) + set(_ninetoothed_python "${NINETOOTHED_PYTHON_EXECUTABLE}") + elseif(_TORCH_PYTHON) + set(_ninetoothed_python "${_TORCH_PYTHON}") + else() + set(_ninetoothed_python "${Python_EXECUTABLE}") + endif() + message(STATUS "NineToothed codegen Python: ${_ninetoothed_python}") + + string(REPLACE "," ";" _ninetoothed_ops "${INFINIOPS_NINETOOTHED_OPS}") + string(REPLACE "," ";" _ninetoothed_dtypes "${INFINIOPS_NINETOOTHED_DTYPES}") + string(REPLACE "," ";" _ninetoothed_rms_norm_ndims "${INFINIOPS_NINETOOTHED_RMS_NORM_NDIMS}") + + set(_ninetoothed_output_dir "${CMAKE_CURRENT_BINARY_DIR}/ninetoothed") + set(_ninetoothed_generator_args + "${PROJECT_SOURCE_DIR}/scripts/generate_ninetoothed_ops.py" + --output-dir "${_ninetoothed_output_dir}" + --ops ${_ninetoothed_ops} + --dtypes ${_ninetoothed_dtypes} + --rms-norm-ndims ${_ninetoothed_rms_norm_ndims}) + + execute_process( + COMMAND "${_ninetoothed_python}" ${_ninetoothed_generator_args} + WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}" + RESULT_VARIABLE _ninetoothed_generation_result + ) + + if(NOT _ninetoothed_generation_result EQUAL 0) + message(FATAL_ERROR "Generating NineToothed operator sources failed with `${_ninetoothed_python}`. Set `NINETOOTHED_PYTHON_EXECUTABLE` to a Python with `ninetoothed`, `ntops`, `triton`, `sympy`, and CUDA dependencies installed.") + endif() + + include("${_ninetoothed_output_dir}/manifest.cmake") + target_include_directories(infiniops PUBLIC + ${INFINIOPS_NINETOOTHED_INCLUDE_DIRS}) + target_sources(infiniops PRIVATE ${INFINIOPS_NINETOOTHED_SOURCES}) +endif() + if(WITH_ILUVATAR) set(ILUVATAR_PATTERNS "native/cuda/*.cc" @@ -480,6 +520,9 @@ if(GENERATE_PYTHON_BINDINGS) if(WITH_TORCH) list(APPEND GENERATOR_ARGS --with-torch) endif() + if(WITH_NINETOOTHED) + list(APPEND GENERATOR_ARGS --with-ninetoothed) + endif() execute_process( COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py ${GENERATOR_ARGS} diff --git a/src/ninetoothed/ops/rms_norm/codegen.py b/src/ninetoothed/ops/rms_norm/codegen.py new file mode 100644 index 00000000..e623da31 --- /dev/null +++ b/src/ninetoothed/ops/rms_norm/codegen.py @@ -0,0 +1,73 @@ +_BLOCK_SIZE = 256 +_DEFAULT_NDIMS = (2, 3) + + +def _premake( + ndim, + num_normalized_dims, + input_dtype, + weight_dtype, + output_dtype, +): + import ntops + + return ntops.kernels.rms_norm.premake( + ndim, + num_normalized_dims, + input_dtype=input_dtype, + weight_dtype=weight_dtype, + output_dtype=output_dtype, + block_size=_BLOCK_SIZE, + ) + + +def _normalize_ndims(values): + ndims = [] + + for value in values: + ndim = int(value) + + if ndim not in _DEFAULT_NDIMS: + raise ValueError(f"`RmsNorm` currently supports rank 2 and 3: {value!r}") + + if ndim not in ndims: + ndims.append(ndim) + + return tuple(ndims) + + +def _configs(ninetoothed, dtypes, ndims): + configs = [] + + for ndim in _normalize_ndims(ndims): + for dtype_name in dtypes: + dtype = getattr(ninetoothed, dtype_name) + configs.append( + ( + (), + { + "ndim": ndim, + "num_normalized_dims": 1, + "input_dtype": dtype, + "weight_dtype": dtype, + "output_dtype": dtype, + }, + {}, + ) + ) + + return tuple(configs) + + +def generate(ninetoothed, output_dir, *, dtypes, rms_norm_ndims): + variant_dir = output_dir / "rms_norm" + variant_dir.mkdir(parents=True, exist_ok=True) + ninetoothed.build( + _premake, + _configs(ninetoothed, dtypes, rms_norm_ndims), + meta_parameters=None, + caller="cuda", + kernel_name="infiniops_ninetoothed_rms_norm", + output_dir=variant_dir, + lazy=False, + ) diff --git a/src/ninetoothed/ops/rms_norm/ninetoothed.h b/src/ninetoothed/ops/rms_norm/ninetoothed.h new file mode 100644 index 00000000..a8934002 --- /dev/null +++ b/src/ninetoothed/ops/rms_norm/ninetoothed.h @@ -0,0 +1,70 @@ +#ifndef INFINI_OPS_NINETOOTHED_RMS_NORM_H_ +#define INFINI_OPS_NINETOOTHED_RMS_NORM_H_ + +#include +#include +#include + +#include "base/rms_norm.h" +#include "data_type.h" +#include "ninetoothed/tensor.h" +#include "rms_norm/infiniops_ninetoothed_rms_norm.h" + +namespace infini::ops { + +template <> +class Operator : public RmsNorm { + public: + using RmsNorm::RmsNorm; + using RmsNorm::operator(); + + void operator()(const Tensor input, const Tensor weight, float eps, + Tensor out) const override { + assert(input.dtype() == out.dtype() && out.dtype() == weight.dtype() && + "operator `RmsNorm` requires all input and output tensors to have " + "the same dtype"); + assert(input.shape() == out.shape() && + "ninetoothed `RmsNorm` requires input and output tensors with the " + "same shape"); + assert(weight.ndim() == 1 && weight.size(-1) == out.size(-1) && + "ninetoothed `RmsNorm` requires a 1D weight matching the last " + "dimension"); + assert((out.ndim() == 2 || out.ndim() == 3) && + "ninetoothed `RmsNorm` currently supports rank-2 and rank-3 " + "tensors"); + + std::vector weight_sizes; + std::vector weight_strides; + double eps_value = static_cast(eps); + std::int64_t num_normalized_elements = + static_cast(out.size(-1)); + std::uint64_t empty_shape[1] = {}; + std::int64_t empty_strides[1] = {}; + + weight_sizes.assign(out.shape().begin(), out.shape().end()); + weight_strides.assign(out.ndim(), 0); + weight_strides.back() = + weight.strides().empty() ? 1 : weight.strides().back(); + + const int dtype_index = ninetoothed::DataTypeIndex(out.dtype()); + assert( + dtype_index >= 0 && + "ninetoothed `RmsNorm` supports only float16, bfloat16, and float32"); + + auto result = launch_infiniops_ninetoothed_rms_norm( + static_cast(stream_), ninetoothed::Tensor(input), + ninetoothed::Tensor(const_cast(weight.data()), + weight_sizes.data(), weight_strides.data()), + ninetoothed::Tensor(eps_value, empty_shape, empty_strides), + ninetoothed::Tensor(out), + ninetoothed::Tensor(num_normalized_elements, empty_shape, + empty_strides), + static_cast(out.ndim()), 1, dtype_index, dtype_index, dtype_index); + + assert(result == 0 && "ninetoothed `RmsNorm` launch failed"); + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/ninetoothed/tensor.h b/src/ninetoothed/tensor.h new file mode 100644 index 00000000..cf9a79a0 --- /dev/null +++ b/src/ninetoothed/tensor.h @@ -0,0 +1,62 @@ +#ifndef INFINI_OPS_NINETOOTHED_TENSOR_H_ +#define INFINI_OPS_NINETOOTHED_TENSOR_H_ + +#include +#include + +#include "data_type.h" +#include "tensor.h" + +namespace infini::ops::ninetoothed { + +inline int DataTypeIndex(DataType dtype) { + switch (dtype) { + case DataType::kFloat16: + return 8; + case DataType::kBFloat16: + return 9; + case DataType::kFloat32: + return 10; + default: + return -1; + } +} + +class Tensor { + public: + explicit Tensor(const ::infini::ops::Tensor& tensor) + : Tensor(const_cast(tensor.data()), + reinterpret_cast( + const_cast<::infini::ops::Tensor::Size*>( + tensor.shape().data())), + reinterpret_cast( + const_cast<::infini::ops::Tensor::Stride*>( + tensor.strides().data()))) { + static_assert(sizeof(::infini::ops::Tensor::Size) == sizeof(std::uint64_t)); + static_assert(sizeof(::infini::ops::Tensor::Stride) == + sizeof(std::int64_t)); + static_assert(std::is_unsigned_v<::infini::ops::Tensor::Size>); + static_assert(std::is_signed_v<::infini::ops::Tensor::Stride>); + } + + Tensor(void* data, std::uint64_t* shape, std::int64_t* strides) + : data_(data), shape_(shape), strides_(strides) {} + + template + Tensor(T& value, std::uint64_t* shape, std::int64_t* strides) + : Tensor(static_cast(&value), shape, strides) {} + + template + operator NineToothedTensor() const { + return NineToothedTensor{data_, shape_, strides_}; + } + + private: + void* data_; + std::uint64_t* shape_; + std::int64_t* strides_; +}; + +} // namespace infini::ops::ninetoothed + +#endif diff --git a/tests/test_generate_ninetoothed_ops.py b/tests/test_generate_ninetoothed_ops.py new file mode 100644 index 00000000..c19dfdff --- /dev/null +++ b/tests/test_generate_ninetoothed_ops.py @@ -0,0 +1,107 @@ +import importlib.util +import pathlib +import sys +import tempfile +import types + + +def _load_generator_module(fake_ninetoothed, monkeypatch): + path = ( + pathlib.Path(__file__).resolve().parents[1] + / "scripts" + / "generate_ninetoothed_ops.py" + ) + spec = importlib.util.spec_from_file_location( + "ninetoothed_codegen_under_test", path + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + monkeypatch.setitem(sys.modules, "ninetoothed", fake_ninetoothed) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + + return module + + +def test_generate_rms_norm_uses_ntops_premake_with_rank_configs(monkeypatch): + calls = [] + + fake_ninetoothed = types.SimpleNamespace( + float32="nt.float32", + ) + fake_ninetoothed.build = lambda *args, **kwargs: calls.append((args, kwargs)) + module = _load_generator_module(fake_ninetoothed, monkeypatch) + + fake_arrangement = object() + fake_application = object() + fake_tensors = object() + premake_calls = [] + + def fake_ntops_premake(*args, **kwargs): + premake_calls.append((args, kwargs)) + return fake_arrangement, fake_application, fake_tensors + + fake_ntops = types.SimpleNamespace( + kernels=types.SimpleNamespace( + rms_norm=types.SimpleNamespace(premake=fake_ntops_premake) + ) + ) + + monkeypatch.setattr(module, "_build_manifest", lambda output_dir: ["kernel.cpp"]) + + with tempfile.TemporaryDirectory() as tmpdir: + tmp_path = pathlib.Path(tmpdir) + manifest = module.generate( + ["rms_norm"], + output_dir=tmp_path, + dtypes=("float32",), + rms_norm_ndims=(2,), + ) + + assert manifest == ["kernel.cpp"] + assert len(calls) == 1 + + args, kwargs = calls[0] + premake, configs = args + assert configs == ( + ( + (), + { + "ndim": 2, + "num_normalized_dims": 1, + "input_dtype": "nt.float32", + "weight_dtype": "nt.float32", + "output_dtype": "nt.float32", + }, + {}, + ), + ) + assert kwargs["caller"] == "cuda" + assert kwargs["kernel_name"] == "infiniops_ninetoothed_rms_norm" + assert kwargs["output_dir"] == tmp_path / "rms_norm" + assert kwargs["lazy"] is False + assert kwargs["meta_parameters"] is None + + monkeypatch.setitem(sys.modules, "ntops", fake_ntops) + arrangement, application, tensors = premake( + ndim=2, + num_normalized_dims=1, + input_dtype="nt.float32", + weight_dtype="nt.float32", + output_dtype="nt.float32", + ) + + assert arrangement is fake_arrangement + assert application is fake_application + assert tensors is fake_tensors + assert premake_calls == [ + ( + (2, 1), + { + "input_dtype": "nt.float32", + "weight_dtype": "nt.float32", + "output_dtype": "nt.float32", + "block_size": 256, + }, + ) + ]