Skip to content
18 changes: 9 additions & 9 deletions src/bioimageio/core/_resource_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_model(
*,
determinism: Literal["seed_only", "full"] = "seed_only",
sha256: Optional[Sha256] = None,
stop_early: bool = True,
stop_early: bool = False,
working_dir: Optional[Union[os.PathLike[str], str]] = None,
**deprecated: Unpack[DeprecatedKwargs],
) -> ValidationSummary:
Expand Down Expand Up @@ -212,7 +212,7 @@ def test_description(
determinism: Literal["seed_only", "full"] = "seed_only",
expected_type: Optional[str] = None,
sha256: Optional[Sha256] = None,
stop_early: bool = True,
stop_early: bool = False,
runtime_env: Union[
Literal["currently-active", "as-described"], Path, BioimageioCondaEnv
] = ("currently-active"),
Expand Down Expand Up @@ -560,7 +560,7 @@ def load_description_and_test(
determinism: Literal["seed_only", "full"] = "seed_only",
expected_type: Literal["model"],
sha256: Optional[Sha256] = None,
stop_early: bool = True,
stop_early: bool = False,
working_dir: Optional[Union[os.PathLike[str], str]] = None,
**deprecated: Unpack[DeprecatedKwargs],
) -> Union[ModelDescr, InvalidDescr]: ...
Expand All @@ -576,7 +576,7 @@ def load_description_and_test(
determinism: Literal["seed_only", "full"] = "seed_only",
expected_type: Literal["dataset"],
sha256: Optional[Sha256] = None,
stop_early: bool = True,
stop_early: bool = False,
working_dir: Optional[Union[os.PathLike[str], str]] = None,
**deprecated: Unpack[DeprecatedKwargs],
) -> Union[DatasetDescr, InvalidDescr]: ...
Expand All @@ -592,7 +592,7 @@ def load_description_and_test(
determinism: Literal["seed_only", "full"] = "seed_only",
expected_type: Optional[str] = None,
sha256: Optional[Sha256] = None,
stop_early: bool = True,
stop_early: bool = False,
working_dir: Optional[Union[os.PathLike[str], str]] = None,
**deprecated: Unpack[DeprecatedKwargs],
) -> Union[LatestResourceDescr, InvalidDescr]: ...
Expand All @@ -608,7 +608,7 @@ def load_description_and_test(
determinism: Literal["seed_only", "full"] = "seed_only",
expected_type: Literal["model"],
sha256: Optional[Sha256] = None,
stop_early: bool = True,
stop_early: bool = False,
working_dir: Optional[Union[os.PathLike[str], str]] = None,
**deprecated: Unpack[DeprecatedKwargs],
) -> Union[AnyModelDescr, InvalidDescr]: ...
Expand All @@ -624,7 +624,7 @@ def load_description_and_test(
determinism: Literal["seed_only", "full"] = "seed_only",
expected_type: Literal["dataset"],
sha256: Optional[Sha256] = None,
stop_early: bool = True,
stop_early: bool = False,
working_dir: Optional[Union[os.PathLike[str], str]] = None,
**deprecated: Unpack[DeprecatedKwargs],
) -> Union[AnyDatasetDescr, InvalidDescr]: ...
Expand All @@ -640,7 +640,7 @@ def load_description_and_test(
determinism: Literal["seed_only", "full"] = "seed_only",
expected_type: Optional[str] = None,
sha256: Optional[Sha256] = None,
stop_early: bool = True,
stop_early: bool = False,
working_dir: Optional[Union[os.PathLike[str], str]] = None,
**deprecated: Unpack[DeprecatedKwargs],
) -> Union[ResourceDescr, InvalidDescr]: ...
Expand All @@ -655,7 +655,7 @@ def load_description_and_test(
determinism: Literal["seed_only", "full"] = "seed_only",
expected_type: Optional[str] = None,
sha256: Optional[Sha256] = None,
stop_early: bool = True,
stop_early: bool = False,
working_dir: Optional[Union[os.PathLike[str], str]] = None,
**deprecated: Unpack[DeprecatedKwargs],
) -> Union[ResourceDescr, InvalidDescr]:
Expand Down
28 changes: 26 additions & 2 deletions src/bioimageio/core/_settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Literal
import os
import platform
from typing import Literal, Optional

from pydantic import Field
from loguru import logger
from pydantic import Field, field_validator
from typing_extensions import Annotated

from bioimageio.spec._internal._settings import Settings as SpecSettings
Expand All @@ -13,6 +16,27 @@ class Settings(SpecSettings):
Literal["torch", "tensorflow", "jax"], Field(alias="KERAS_BACKEND")
] = "torch"

pytorch_enable_mps_fallback: Annotated[
Optional[bool], Field(alias="PYTORCH_ENABLE_MPS_FALLBACK")
] = None

@field_validator("pytorch_enable_mps_fallback", mode="after")
@classmethod
def _set_default_mps_fallback(cls, value: Optional[bool]):
# pytorch versions up to the 2.6 don't support all operations (esp 3d) on MPS
# this env variable allows falling back to CPU for those networks instead of failing
# see for current status https://github.com/pytorch/pytorch/issues/141287
if (
value is None
and platform.system().lower() == "darwin"
and platform.machine().lower() == "arm64"
):
logger.info("Set environment variable 'PYTORCH_ENABLE_MPS_FALLBACK=1'.")
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
Comment thread
k-dominik marked this conversation as resolved.
return True

return value


settings = Settings()
"""parsed environment variables for bioimageio.spec and bioimageio.core"""
3 changes: 3 additions & 0 deletions src/bioimageio/core/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def __init__(
)
self._data = xr.DataArray(array, dims=axes)

def __repr__(self) -> str:
return f"<Tensor {repr(self._data)}>"

def __array__(self, dtype: DTypeLike = None):
return np.asarray(self._data, dtype=dtype)

Expand Down
Loading