Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions dataclass_array/array_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import dataclasses
import functools
import sys
import typing
from typing import Any, Callable, ClassVar, Generic, Iterator, Optional, Set, Tuple, Type, TypeVar, Union

Expand Down Expand Up @@ -1120,14 +1121,7 @@ def full_shape(self) -> DcOrArrayT:
@functools.cached_property
def inner_shape(self) -> Shape:
"""Returns the the static shape resolved for the current value."""
# torch.func.vmap calls `tree_unflatten([0] * num_leaves)` internally,
# messing up shape inference.
if (
enp.lazy.has_torch
and isinstance(self.value, int)
and self.value == 0
):
return ()

if not self.inner_shape_non_static:
return ()
static_shape = self.full_shape[-len(self.inner_shape_non_static) :]
Expand Down Expand Up @@ -1167,6 +1161,14 @@ def is_value_missing(self) -> bool:
elif enp.array_spec.is_fake_array(self.value):
# `etree.spec_like`, Flax summary, ShapeDtypeStruct, ... compatibility
return True
# torch.func.vmap calls `tree_unflatten([0] * num_leaves)` internally.
elif (
enp.lazy.has_torch
and isinstance(self.value, int)
and self.value == 0
and _is_called_inside_torch_vmap()
):
return True
return False

@property
Expand Down Expand Up @@ -1194,6 +1196,22 @@ def broadcast_to(self, shape: Shape) -> DcOrArrayT:
return self.xnp.broadcast_to(self.value, final_shape)


def _is_called_inside_torch_vmap() -> bool:
"""Returns `True` if the current call is inside a `torch.func.vmap`."""
frame = sys._getframe(1) # pylint: disable=protected-access

while frame:
code = frame.f_code
if (
code.co_name == '_broadcast_to_and_flatten'
and code.co_filename.endswith('torch/utils/_pytree.py')
):
return True
frame = frame.f_back

return False


def _make_field_metadata(
field: dataclasses.Field[Any],
hints: dict[str, TypeAlias],
Expand Down
1 change: 1 addition & 0 deletions dataclass_array/array_dataclass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def _assert_common(p: dca.DataclassArray, shape: Shape, xnp: enp.NpModule):
@pytest.mark.parametrize(
'x, y, shape',
[
(0, 0, ()),
(1, 2, ()),
([1, 2], [3, 4], (2,)),
([[1], [2]], [[3], [4]], (2, 1)),
Expand Down
Loading