Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions coremltools/converters/mil/backend/mil/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,77 @@ def export(
return model


def _validate_no_state_write_aliased_with_output(prog: Program) -> None:
"""Reject programs whose model output Var also feeds a state write.

The pattern that triggers this comes from a forward like::

merged = self.cache + x
self.cache[:] = merged
return merged

After conversion, ``merged`` is a model output AND is consumed by the
``slice_update`` op whose result feeds ``coreml_update_state``. Loading
such a model in the Core ML runtime currently crashes with a
segmentation fault (no Python traceback), so the converter rejects it
here instead with a clear error and workaround.
"""
for func_name, func in prog.functions.items():
output_var_ids = {id(v) for v in func.outputs}
for op in func.operations:
if op.op_type != "coreml_update_state":
continue
value_var = op.inputs.get("value")
if value_var is None:
continue
# Walk back from the value of the state write to see whether
# any of its source Vars is also exposed as a model output.
offending = _find_aliased_output(value_var, output_var_ids)
if offending is None:
continue
state_var = op.inputs.get("state")
state_name = state_var.name if state_var is not None else "<state>"
raise ValueError(
"Function {!r} has a model output {!r} that is also a "
"source of the value written into state {!r}. Loading "
"this model in the Core ML runtime currently crashes "
"with a segmentation fault, so the converter rejects it "
"here instead. Workaround: return a tensor that does not "
"feed the state-write chain, e.g. "
"`return value.sum(dim=-1, keepdim=True)` or "
"`return value * other_tensor`.".format(
func_name, offending.name, state_name
)
)


def _find_aliased_output(start_var, output_var_ids, max_depth: int = 32):
"""Return the first ancestor of ``start_var`` whose id is in
``output_var_ids``, or ``None`` if no such ancestor exists within
``max_depth`` hops backwards through the op graph."""
seen: set = set()
stack = [(start_var, 0)]
while stack:
var, depth = stack.pop()
if id(var) in output_var_ids:
return var
if id(var) in seen or depth >= max_depth:
continue
seen.add(id(var))
producing_op = getattr(var, "op", None)
if producing_op is None:
continue
for input_value in producing_op.inputs.values():
# Inputs may be a Var or a list of Vars; handle both.
if hasattr(input_value, "op"):
stack.append((input_value, depth + 1))
elif isinstance(input_value, (list, tuple)):
for entry in input_value:
if hasattr(entry, "op"):
stack.append((entry, depth + 1))
return None


def load(
prog: Program,
weights_dir: str,
Expand All @@ -1060,6 +1131,8 @@ def load(
if prog.default_function_name not in prog.functions:
raise ValueError(f"Default function {prog.default_function_name} not found in program")

_validate_no_state_write_aliased_with_output(prog)

# if user has specified "mil_input_types.ClassifierConfig", then add the "classify" op to the prog
classifier_config = kwargs.get("classifier_config", None)
predicted_feature_name, predicted_probabilities_name = None, None
Expand Down
90 changes: 90 additions & 0 deletions coremltools/test/ml_program/test_stateful_output_alias_guard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) 2026, Apple Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

"""Regression test for the stateful-output aliasing guard.

Background: the Core ML runtime proxy crashes with a segmentation fault
(no Python traceback) when loading an mlprogram whose function output Var
is the same Var that feeds a ``coreml_update_state`` op. The pattern is
trivial to write when porting a torch decoder transformer: assign to a KV
cache via ``self.cache[:] = merged`` then ``return merged``.

The converter now rejects this case with a clear ``ValueError`` that names
both the offending output and the affected state and points at the
workaround. This test pins both behaviours.
"""

import numpy as np
import pytest
import torch
import torch.nn as nn

import coremltools as ct


EMBED = 8
MAX_SEQ = 16


class _AliasingNet(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("cache", torch.zeros(1, MAX_SEQ, EMBED))

def forward(self, x):
# The merged tensor is BOTH written to the cache AND returned —
# the runtime-crashing pattern.
merged = self.cache + x
self.cache[:] = merged
return merged


class _NonAliasingNet(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("cache", torch.zeros(1, MAX_SEQ, EMBED))

def forward(self, x):
# The returned tensor is reduced — it does not feed the state-write
# chain, so the converter accepts the program.
merged = self.cache + x
self.cache[:] = merged
return merged.sum(dim=-1, keepdim=True)


def _convert(model):
model.eval()
model.cache.zero_()
traced = torch.jit.trace(model, (torch.randn(1, MAX_SEQ, EMBED),))
return ct.convert(
traced,
inputs=[ct.TensorType(name="x", shape=(1, MAX_SEQ, EMBED), dtype=np.float16)],
states=[
ct.StateType(
wrapped_type=ct.TensorType(
shape=(1, MAX_SEQ, EMBED), dtype=np.float16
),
name="cache",
)
],
convert_to="mlprogram",
minimum_deployment_target=ct.target.iOS18,
)


class TestStatefulOutputAliasGuard:
def test_aliasing_pattern_raises_clear_error(self):
with pytest.raises(ValueError) as excinfo:
_convert(_AliasingNet())
message = str(excinfo.value)
# Error must name the offending output, the affected state, and
# point at the workaround.
assert "merged" in message
assert "cache" in message
assert "Workaround" in message or "workaround" in message

def test_non_aliasing_pattern_converts(self):
mlmodel = _convert(_NonAliasingNet())
assert isinstance(mlmodel, ct.models.MLModel)