diff --git a/.github/workflows/pytest_and_autopublish.yml b/.github/workflows/pytest_and_autopublish.yml index df1d0c5..f68faa6 100644 --- a/.github/workflows/pytest_and_autopublish.yml +++ b/.github/workflows/pytest_and_autopublish.yml @@ -22,8 +22,8 @@ jobs: - run: pip --version # As changes can be pushed to both etils and visu3d, we install etils from `main` branch # If modifying this, also modify `pyproject.toml` - - run: pip install "etils[array_types,edc,enp,epath,epy,etree] @ git+https://github.com/google/etils" - - run: pip install -e .[dev] + - run: pip install --no-cache-dir "etils[array_types,edc,enp,epath,epy,etree] @ git+https://github.com/google/etils" + - run: pip install --no-cache-dir -e .[dev] - run: pip freeze # Run tests (in parallel) diff --git a/dataclass_array/__init__.py b/dataclass_array/__init__.py index bf3a277..a0547be 100644 --- a/dataclass_array/__init__.py +++ b/dataclass_array/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dataclass_array/array_dataclass.py b/dataclass_array/array_dataclass.py index 6c65ab8..de884b3 100644 --- a/dataclass_array/array_dataclass.py +++ b/dataclass_array/array_dataclass.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -1120,6 +1120,14 @@ def full_shape(self) -> DcOrArrayT: @functools.cached_property def inner_shape(self) -> Shape: """Returns the the static shape resolved for the current value.""" + # torch.func.vmap calls `tree_unflatten([0] * num_leaves)` internally, + # messing up shape inference. + if ( + enp.lazy.has_torch + and isinstance(self.value, int) + and self.value == 0 + ): + return () if not self.inner_shape_non_static: return () static_shape = self.full_shape[-len(self.inner_shape_non_static) :] diff --git a/dataclass_array/array_dataclass_test.py b/dataclass_array/array_dataclass_test.py index dfdbcd7..5556238 100644 --- a/dataclass_array/array_dataclass_test.py +++ b/dataclass_array/array_dataclass_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dataclass_array/conftest.py b/dataclass_array/conftest.py index d0e4bd9..1eabf57 100644 --- a/dataclass_array/conftest.py +++ b/dataclass_array/conftest.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dataclass_array/field_utils.py b/dataclass_array/field_utils.py index d8630e2..8a053d8 100644 --- a/dataclass_array/field_utils.py +++ b/dataclass_array/field_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dataclass_array/import_test.py b/dataclass_array/import_test.py index 2afbeea..fb39a9d 100644 --- a/dataclass_array/import_test.py +++ b/dataclass_array/import_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dataclass_array/ops.py b/dataclass_array/ops.py index da52fa7..73f405d 100644 --- a/dataclass_array/ops.py +++ b/dataclass_array/ops.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dataclass_array/shape_parsing.py b/dataclass_array/shape_parsing.py index e09b213..f5212de 100644 --- a/dataclass_array/shape_parsing.py +++ b/dataclass_array/shape_parsing.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dataclass_array/shape_parsing_test.py b/dataclass_array/shape_parsing_test.py index a4ac186..b24da34 100644 --- a/dataclass_array/shape_parsing_test.py +++ b/dataclass_array/shape_parsing_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dataclass_array/testing.py b/dataclass_array/testing.py index 3f6b118..7b04109 100644 --- a/dataclass_array/testing.py +++ b/dataclass_array/testing.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dataclass_array/type_parsing.py b/dataclass_array/type_parsing.py index b1d92fc..92156b9 100644 --- a/dataclass_array/type_parsing.py +++ b/dataclass_array/type_parsing.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dataclass_array/type_parsing_test.py b/dataclass_array/type_parsing_test.py index 9b512ff..6c7ea8b 100644 --- a/dataclass_array/type_parsing_test.py +++ b/dataclass_array/type_parsing_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dataclass_array/typing.py b/dataclass_array/typing.py index 1b20495..a435690 100644 --- a/dataclass_array/typing.py +++ b/dataclass_array/typing.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dataclass_array/utils/file_utils.py b/dataclass_array/utils/file_utils.py index 39020c5..4825591 100644 --- a/dataclass_array/utils/file_utils.py +++ b/dataclass_array/utils/file_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dataclass_array/utils/inspect_utils.py b/dataclass_array/utils/inspect_utils.py index a056a60..1c4ba85 100644 --- a/dataclass_array/utils/inspect_utils.py +++ b/dataclass_array/utils/inspect_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dataclass_array/utils/inspect_utils_test.py b/dataclass_array/utils/inspect_utils_test.py index a9a8691..7757432 100644 --- a/dataclass_array/utils/inspect_utils_test.py +++ b/dataclass_array/utils/inspect_utils_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dataclass_array/utils/np_utils.py b/dataclass_array/utils/np_utils.py index 4924744..d7cbb57 100644 --- a/dataclass_array/utils/np_utils.py +++ b/dataclass_array/utils/np_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dataclass_array/utils/np_utils_test.py b/dataclass_array/utils/np_utils_test.py index 4a8a8e0..a6aef5c 100644 --- a/dataclass_array/utils/np_utils_test.py +++ b/dataclass_array/utils/np_utils_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dataclass_array/utils/py_utils.py b/dataclass_array/utils/py_utils.py index 8fbfe23..9796ab8 100644 --- a/dataclass_array/utils/py_utils.py +++ b/dataclass_array/utils/py_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dataclass_array/utils/tree_utils.py b/dataclass_array/utils/tree_utils.py index b05a17c..13989c0 100644 --- a/dataclass_array/utils/tree_utils.py +++ b/dataclass_array/utils/tree_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dataclass_array/utils/tree_utils_test.py b/dataclass_array/utils/tree_utils_test.py index c7fe3e3..c20fae1 100644 --- a/dataclass_array/utils/tree_utils_test.py +++ b/dataclass_array/utils/tree_utils_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dataclass_array/vectorization.py b/dataclass_array/vectorization.py index 7244989..ca4ea0c 100644 --- a/dataclass_array/vectorization.py +++ b/dataclass_array/vectorization.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/dataclass_array/vectorization_test.py b/dataclass_array/vectorization_test.py index 8053138..bef7c5c 100644 --- a/dataclass_array/vectorization_test.py +++ b/dataclass_array/vectorization_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/docs/conf.py b/docs/conf.py index 0dc3ddc..1cf22b2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,4 @@ -# Copyright 2024 The dataclass_array Authors. +# Copyright 2025 The dataclass_array Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.