diff --git a/dataclass_array/array_dataclass.py b/dataclass_array/array_dataclass.py index de884b3..4f9b01e 100644 --- a/dataclass_array/array_dataclass.py +++ b/dataclass_array/array_dataclass.py @@ -18,6 +18,7 @@ import dataclasses import functools +import sys import typing from typing import Any, Callable, ClassVar, Generic, Iterator, Optional, Set, Tuple, Type, TypeVar, Union @@ -1120,14 +1121,7 @@ def full_shape(self) -> DcOrArrayT: @functools.cached_property def inner_shape(self) -> Shape: """Returns the the static shape resolved for the current value.""" - # torch.func.vmap calls `tree_unflatten([0] * num_leaves)` internally, - # messing up shape inference. - if ( - enp.lazy.has_torch - and isinstance(self.value, int) - and self.value == 0 - ): - return () + if not self.inner_shape_non_static: return () static_shape = self.full_shape[-len(self.inner_shape_non_static) :] @@ -1167,6 +1161,14 @@ def is_value_missing(self) -> bool: elif enp.array_spec.is_fake_array(self.value): # `etree.spec_like`, Flax summary, ShapeDtypeStruct, ... compatibility return True + # torch.func.vmap calls `tree_unflatten([0] * num_leaves)` internally. + elif ( + enp.lazy.has_torch + and isinstance(self.value, int) + and self.value == 0 + and _is_called_inside_torch_vmap() + ): + return True return False @property @@ -1194,6 +1196,22 @@ def broadcast_to(self, shape: Shape) -> DcOrArrayT: return self.xnp.broadcast_to(self.value, final_shape) +def _is_called_inside_torch_vmap() -> bool: + """Returns `True` if the current call is inside a `torch.func.vmap`.""" + frame = sys._getframe(1) # pylint: disable=protected-access + + while frame: + code = frame.f_code + if ( + code.co_name == '_broadcast_to_and_flatten' + and code.co_filename.endswith('torch/utils/_pytree.py') + ): + return True + frame = frame.f_back + + return False + + def _make_field_metadata( field: dataclasses.Field[Any], hints: dict[str, TypeAlias], diff --git a/dataclass_array/array_dataclass_test.py b/dataclass_array/array_dataclass_test.py index 5556238..76a26b4 100644 --- a/dataclass_array/array_dataclass_test.py +++ b/dataclass_array/array_dataclass_test.py @@ -262,6 +262,7 @@ def _assert_common(p: dca.DataclassArray, shape: Shape, xnp: enp.NpModule): @pytest.mark.parametrize( 'x, y, shape', [ + (0, 0, ()), (1, 2, ()), ([1, 2], [3, 4], (2,)), ([[1], [2]], [[3], [4]], (2, 1)),