Skip to content

Commit d469b55

Browse files
author
The dataclass_array Authors
committed
Fix _ArrayField.inner_shape for scalar values in PyTorch contexts.
Corrected an edge case in `_ArrayField.inner_shape` where, when called on a scalar value (empty `full_shape`) within a PyTorch environment, it now returns an empty shape `()` to prevent shape validation errors, particularly when used inside `torch.vmap`. PiperOrigin-RevId: 855886831
1 parent a21b89b commit d469b55

1 file changed

Lines changed: 3 additions & 5 deletions

File tree

dataclass_array/array_dataclass.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,11 +1122,9 @@ def inner_shape(self) -> Shape:
11221122
"""Returns the the static shape resolved for the current value."""
11231123
# torch.func.vmap calls `tree_unflatten([0] * num_leaves)` internally,
11241124
# messing up shape inference.
1125-
if (
1126-
enp.lazy.has_torch
1127-
and isinstance(self.value, int)
1128-
and self.value == 0
1129-
):
1125+
if not self.full_shape and self.inner_shape_non_static:
1126+
if enp.lazy.has_torch:
1127+
return ()
11301128
return ()
11311129
if not self.inner_shape_non_static:
11321130
return ()

0 commit comments

Comments
 (0)