diff --git a/coremltools/converters/mil/backend/mil/load.py b/coremltools/converters/mil/backend/mil/load.py index 424b50452..c946dca73 100644 --- a/coremltools/converters/mil/backend/mil/load.py +++ b/coremltools/converters/mil/backend/mil/load.py @@ -1050,6 +1050,77 @@ def export( return model +def _validate_no_state_write_aliased_with_output(prog: Program) -> None: + """Reject programs whose model output Var also feeds a state write. + + The pattern that triggers this comes from a forward like:: + + merged = self.cache + x + self.cache[:] = merged + return merged + + After conversion, ``merged`` is a model output AND is consumed by the + ``slice_update`` op whose result feeds ``coreml_update_state``. Loading + such a model in the Core ML runtime currently crashes with a + segmentation fault (no Python traceback), so the converter rejects it + here instead with a clear error and workaround. + """ + for func_name, func in prog.functions.items(): + output_var_ids = {id(v) for v in func.outputs} + for op in func.operations: + if op.op_type != "coreml_update_state": + continue + value_var = op.inputs.get("value") + if value_var is None: + continue + # Walk back from the value of the state write to see whether + # any of its source Vars is also exposed as a model output. + offending = _find_aliased_output(value_var, output_var_ids) + if offending is None: + continue + state_var = op.inputs.get("state") + state_name = state_var.name if state_var is not None else "" + raise ValueError( + "Function {!r} has a model output {!r} that is also a " + "source of the value written into state {!r}. Loading " + "this model in the Core ML runtime currently crashes " + "with a segmentation fault, so the converter rejects it " + "here instead. Workaround: return a tensor that does not " + "feed the state-write chain, e.g. " + "`return value.sum(dim=-1, keepdim=True)` or " + "`return value * other_tensor`.".format( + func_name, offending.name, state_name + ) + ) + + +def _find_aliased_output(start_var, output_var_ids, max_depth: int = 32): + """Return the first ancestor of ``start_var`` whose id is in + ``output_var_ids``, or ``None`` if no such ancestor exists within + ``max_depth`` hops backwards through the op graph.""" + seen: set = set() + stack = [(start_var, 0)] + while stack: + var, depth = stack.pop() + if id(var) in output_var_ids: + return var + if id(var) in seen or depth >= max_depth: + continue + seen.add(id(var)) + producing_op = getattr(var, "op", None) + if producing_op is None: + continue + for input_value in producing_op.inputs.values(): + # Inputs may be a Var or a list of Vars; handle both. + if hasattr(input_value, "op"): + stack.append((input_value, depth + 1)) + elif isinstance(input_value, (list, tuple)): + for entry in input_value: + if hasattr(entry, "op"): + stack.append((entry, depth + 1)) + return None + + def load( prog: Program, weights_dir: str, @@ -1060,6 +1131,8 @@ def load( if prog.default_function_name not in prog.functions: raise ValueError(f"Default function {prog.default_function_name} not found in program") + _validate_no_state_write_aliased_with_output(prog) + # if user has specified "mil_input_types.ClassifierConfig", then add the "classify" op to the prog classifier_config = kwargs.get("classifier_config", None) predicted_feature_name, predicted_probabilities_name = None, None diff --git a/coremltools/test/ml_program/test_stateful_output_alias_guard.py b/coremltools/test/ml_program/test_stateful_output_alias_guard.py new file mode 100644 index 000000000..9bc980dea --- /dev/null +++ b/coremltools/test/ml_program/test_stateful_output_alias_guard.py @@ -0,0 +1,90 @@ +# Copyright (c) 2026, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +"""Regression test for the stateful-output aliasing guard. + +Background: the Core ML runtime proxy crashes with a segmentation fault +(no Python traceback) when loading an mlprogram whose function output Var +is the same Var that feeds a ``coreml_update_state`` op. The pattern is +trivial to write when porting a torch decoder transformer: assign to a KV +cache via ``self.cache[:] = merged`` then ``return merged``. + +The converter now rejects this case with a clear ``ValueError`` that names +both the offending output and the affected state and points at the +workaround. This test pins both behaviours. +""" + +import numpy as np +import pytest +import torch +import torch.nn as nn + +import coremltools as ct + + +EMBED = 8 +MAX_SEQ = 16 + + +class _AliasingNet(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("cache", torch.zeros(1, MAX_SEQ, EMBED)) + + def forward(self, x): + # The merged tensor is BOTH written to the cache AND returned — + # the runtime-crashing pattern. + merged = self.cache + x + self.cache[:] = merged + return merged + + +class _NonAliasingNet(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("cache", torch.zeros(1, MAX_SEQ, EMBED)) + + def forward(self, x): + # The returned tensor is reduced — it does not feed the state-write + # chain, so the converter accepts the program. + merged = self.cache + x + self.cache[:] = merged + return merged.sum(dim=-1, keepdim=True) + + +def _convert(model): + model.eval() + model.cache.zero_() + traced = torch.jit.trace(model, (torch.randn(1, MAX_SEQ, EMBED),)) + return ct.convert( + traced, + inputs=[ct.TensorType(name="x", shape=(1, MAX_SEQ, EMBED), dtype=np.float16)], + states=[ + ct.StateType( + wrapped_type=ct.TensorType( + shape=(1, MAX_SEQ, EMBED), dtype=np.float16 + ), + name="cache", + ) + ], + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS18, + ) + + +class TestStatefulOutputAliasGuard: + def test_aliasing_pattern_raises_clear_error(self): + with pytest.raises(ValueError) as excinfo: + _convert(_AliasingNet()) + message = str(excinfo.value) + # Error must name the offending output, the affected state, and + # point at the workaround. + assert "merged" in message + assert "cache" in message + assert "Workaround" in message or "workaround" in message + + def test_non_aliasing_pattern_converts(self): + mlmodel = _convert(_NonAliasingNet()) + assert isinstance(mlmodel, ct.models.MLModel)