Commit d469b55
The dataclass_array Authors
Fix _ArrayField.inner_shape for scalar values in PyTorch contexts.
Corrected an edge case in `_ArrayField.inner_shape` where, when called on a scalar value (empty `full_shape`) within a PyTorch environment, it now returns an empty shape `()` to prevent shape validation errors, particularly when used inside `torch.vmap`.
PiperOrigin-RevId: 8558868311 parent a21b89b commit d469b55
1 file changed
Lines changed: 3 additions & 5 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1122 | 1122 | | |
1123 | 1123 | | |
1124 | 1124 | | |
1125 | | - | |
1126 | | - | |
1127 | | - | |
1128 | | - | |
1129 | | - | |
| 1125 | + | |
| 1126 | + | |
| 1127 | + | |
1130 | 1128 | | |
1131 | 1129 | | |
1132 | 1130 | | |
| |||
0 commit comments