TVM shows inconsistent ONNX Cast(to=BOOL) behavior for NaN depending on how NaN is produced:
- Direct NaN constant:
Constant(NaN) -> Cast returns True (matches ONNX Runtime / PyTorch).
- NaN produced by computation:
x -> (NaN-producing op) -> Cast returns False in TVM, while ONNX Runtime / PyTorch return True.
Expected behavior
Per ONNX Cast operator spec for casting from floating point to bool:
+/-0.0 → False
- all else →
True
Therefore:
Cast(NaN -> bool) should be True (NaN is not +0.0/-0.0, so it falls under “all else”).
- In this repro,
Asin(5.0) is NaN because arcsine’s real domain is [-1, 1], so the final output should be True.
Actual behavior
Taking this model as an example:
Repro model (computed NaN → Cast): Constant(5.0) -> Asin -> Cast(to=BOOL) (opset 18, input-free)
- ONNX Runtime:
True
- PyTorch:
True
- TVM (Relax, LLVM target):
False
And we have also tried other possible ways to generate NAN:
Asin(x) with x=5.0
Acos(x) with x=2.0
Sqrt(x) with x=-1.0
Log(x) with x=-1.0
Div(x, x) with x=0.0 (0/0)
The results are consistent with the above.
Environment
Operating System:Ubuntu 22.04.4 LTS
TVM version:0.23.0dev
pytorch version:2.9.1
ort version:1.23.2
onnx version: 1.20.0
python:3.11.14
Steps to reproduce
model.zip
Download the model and run the following code to obtain the results.
python cast_compare.py --model model.onnx
from __future__ import annotations
import argparse
import os
import sys
from pathlib import Path
import numpy as np
import onnx
def _ensure_repo_tvm() -> None:
"""
Avoid mixing:
- repo TVM python (newer)
- site-packages TVM runtime (older)
Force-import TVM from this repo's `tvm/python`, and point TVM to `tvm/build`.
"""
repo_root = Path(__file__).resolve().parents[3]
tvm_python = repo_root / "tvm" / "python"
tvm_build = repo_root / "tvm" / "build"
if tvm_python.exists():
sys.path.insert(0, tvm_python.as_posix())
if "TVM_LIBRARY_PATH" not in os.environ and tvm_build.exists():
os.environ["TVM_LIBRARY_PATH"] = tvm_build.as_posix()
for k in list(sys.modules.keys()):
if k == "tvm" or k.startswith("tvm."):
del sys.modules[k]
def _run_torch() -> bool | None:
try:
import torch
except Exception:
return None
# Directly test the Cast semantics on NaN.
a = torch.tensor(float("nan"), dtype=torch.float32)
y = a.to(torch.bool)
return bool(y.item())
def _run_ort(model_bytes: bytes) -> bool:
import onnxruntime as ort # type: ignore
sess = ort.InferenceSession(model_bytes, providers=["CPUExecutionProvider"])
outs = sess.run(None, {})
if len(outs) != 1:
raise RuntimeError(f"ORT returned {len(outs)} outputs, expected 1")
y = np.array(outs[0]).item()
return bool(y)
def _run_tvm(model_path: Path) -> bool:
_ensure_repo_tvm()
import tvm # type: ignore
from tvm import relax # type: ignore
from tvm.relax.frontend import onnx as rx_onnx # type: ignore
onnx_model = onnx.load(model_path.as_posix())
converted = rx_onnx.from_onnx(onnx_model, shape_dict={})
mod = converted[0] if isinstance(converted, (list, tuple)) else converted
tgt = tvm.target.Target("llvm")
pipeline = relax.pipeline.get_default_pipeline(tgt)
with tvm.transform.PassContext(opt_level=3, config={"tir.enable_debug": False}):
ex = relax.build(mod, target=tgt, relax_pipeline=pipeline)
vm = relax.VirtualMachine(ex, tvm.cpu())
vm.set_input("main")
vm.invoke_stateful("main")
out = vm.get_outputs("main")
if isinstance(out, tuple):
out = out[0]
if hasattr(out, "numpy"):
arr = out.numpy()
else:
arr = np.array(out)
return bool(np.array(arr).item())
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--model", type=Path, default=Path("cast_nan_to_bool.onnx"))
args = ap.parse_args()
model_path = args.model.resolve()
if not model_path.exists():
print("error: model not found:", model_path)
return 1
model_bytes = model_path.read_bytes()
y_ort = _run_ort(model_bytes)
y_torch = _run_torch()
y_tvm = _run_tvm(model_path)
# Minimal output: just the three backend results.
print("ort :", y_ort)
print("torch:", "skip" if y_torch is None else y_torch)
print("tvm :", y_tvm)
return 0
if __name__ == "__main__":
raise SystemExit(main())
Triage
TVM shows inconsistent ONNX Cast(to=BOOL) behavior for NaN depending on how NaN is produced:
Constant(NaN) -> CastreturnsTrue(matches ONNX Runtime / PyTorch).x -> (NaN-producing op) -> CastreturnsFalsein TVM, while ONNX Runtime / PyTorch returnTrue.Expected behavior
Per ONNX
Castoperator spec for casting from floating point to bool:+/-0.0→FalseTrueTherefore:
Cast(NaN -> bool)should beTrue(NaN is not+0.0/-0.0, so it falls under “all else”).Asin(5.0)is NaN because arcsine’s real domain is[-1, 1], so the final output should beTrue.Actual behavior
Taking this model as an example:
Repro model (computed NaN → Cast):
Constant(5.0) -> Asin -> Cast(to=BOOL)(opset 18, input-free)
TrueTrueFalseAnd we have also tried other possible ways to generate NAN:
Asin(x)withx=5.0Acos(x)withx=2.0Sqrt(x)withx=-1.0Log(x)withx=-1.0Div(x, x)withx=0.0(0/0)The results are consistent with the above.
Environment
Operating System:Ubuntu 22.04.4 LTS
TVM version:0.23.0dev
pytorch version:2.9.1
ort version:1.23.2
onnx version: 1.20.0
python:3.11.14
Steps to reproduce
model.zip
Download the model and run the following code to obtain the results.
python cast_compare.py --model model.onnxTriage