From 023e090a9ce65223ff7df75e558edd74aced13bb Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Mon, 26 Jan 2026 12:17:26 -0800 Subject: [PATCH 1/2] Add support for tuple of ints for exis keyword --- dpnp/dpnp_iface_manipulation.py | 24 ++++++++-------- dpnp/tests/test_manipulation.py | 51 +++++++++++++++++++++++++++++++-- 2 files changed, 61 insertions(+), 14 deletions(-) diff --git a/dpnp/dpnp_iface_manipulation.py b/dpnp/dpnp_iface_manipulation.py index 9df5278bd16b..dd872485a602 100644 --- a/dpnp/dpnp_iface_manipulation.py +++ b/dpnp/dpnp_iface_manipulation.py @@ -3983,7 +3983,7 @@ def trim_zeros(filt, trim="fb", axis=None): (or index -1). Default: ``"fb"``. - axis : {None, int}, optional + axis : {None, int, tuple of ints}, optional If ``None``, `filt` is cropped such that the smallest bounding box is returned that still contains all values which are not zero. If an `axis` is specified, `filt` will be sliced in that dimension only @@ -4038,11 +4038,14 @@ def trim_zeros(filt, trim="fb", axis=None): raise ValueError(f"unexpected character(s) in `trim`: {trim!r}") nd = filt.ndim - if axis is not None: - axis = normalize_axis_index(axis, nd) + if axis is None: + axis = tuple(range(nd)) + else: + axis = normalize_axis_tuple(axis, nd, argname="axis") - if filt.size == 0: - return filt # no trailing zeros in empty array + # check if an empty array or no trimming requested + if filt.size == 0 or not axis: + return filt non_zero = dpnp.argwhere(filt) if non_zero.size == 0: @@ -4061,13 +4064,10 @@ def trim_zeros(filt, trim="fb", axis=None): else: stop = (None,) * nd - if axis is None: - # trim all axes - sl = tuple(slice(*x) for x in zip(start, stop)) - else: - # only trim single axis - sl = (slice(None),) * axis + (slice(start[axis], stop[axis]),) + (...,) - + sl = tuple( + slice(start[ax], stop[ax]) if ax in axis else slice(None) + for ax in range(nd) + ) return filt[sl] diff --git a/dpnp/tests/test_manipulation.py b/dpnp/tests/test_manipulation.py index 373817466f5b..0ae5bf88d818 100644 --- a/dpnp/tests/test_manipulation.py +++ b/dpnp/tests/test_manipulation.py @@ -1432,6 +1432,8 @@ def test_usm_array(self): class TestTrimZeros: + ALL_TRIMS = ["F", "B", "fb"] + @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) def test_basic(self, dtype): a = numpy.array([0, 0, 1, 0, 2, 3, 4, 0], dtype=dtype) @@ -1443,7 +1445,7 @@ def test_basic(self, dtype): @testing.with_requires("numpy>=2.2") @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) - @pytest.mark.parametrize("trim", ["F", "B", "fb"]) + @pytest.mark.parametrize("trim", ALL_TRIMS) @pytest.mark.parametrize("ndim", [0, 1, 2, 3]) def test_basic_nd(self, dtype, trim, ndim): a = numpy.ones((2,) * ndim, dtype=dtype) @@ -1477,7 +1479,7 @@ def test_all_zero(self, dtype, trim): @testing.with_requires("numpy>=2.2") @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) - @pytest.mark.parametrize("trim", ["F", "B", "fb"]) + @pytest.mark.parametrize("trim", ALL_TRIMS) @pytest.mark.parametrize("ndim", [0, 1, 2, 3]) def test_all_zero_nd(self, dtype, trim, ndim): a = numpy.zeros((3,) * ndim, dtype=dtype) @@ -1496,6 +1498,51 @@ def test_size_zero(self): expected = numpy.trim_zeros(a) assert_array_equal(result, expected) + @testing.with_requires("numpy>=2.4") + @pytest.mark.parametrize( + "shape, axis", + [ + [(5,), None], + [(5,), ()], + [(5,), 0], + [(5, 6), None], + [(5, 6), ()], + [(5, 6), 0], + [(5, 6), (-1,)], + [(5, 6, 7), None], + [(5, 6, 7), ()], + [(5, 6, 7), 1], + [(5, 6, 7), (0, 2)], + [(5, 6, 7, 8), None], + [(5, 6, 7, 8), ()], + [(5, 6, 7, 8), -2], + [(5, 6, 7, 8), (0, 1, 3)], + ], + ) + @pytest.mark.parametrize("trim", ALL_TRIMS) + def test_multiple_axes(self, shape, axis, trim): + # standardize axis to a tuple + if axis is None: + axis = tuple(range(len(shape))) + elif isinstance(axis, int): + axis = (len(shape) + axis if axis < 0 else axis,) + else: + axis = tuple(len(shape) + ax if ax < 0 else ax for ax in axis) + + # populate a random interior slice with nonzero entries + rng = numpy.random.default_rng(4321) + a = numpy.zeros(shape) + start = rng.integers(low=0, high=numpy.array(shape) - 1) + end = rng.integers(low=start + 1, high=shape) + shape = tuple(end - start) + data = 1 + rng.random(shape) + a[tuple(slice(i, j) for i, j in zip(start, end))] = data + ia = dpnp.array(a) + + result = dpnp.trim_zeros(ia, axis=axis, trim=trim) + expected = numpy.trim_zeros(a, axis=axis, trim=trim) + assert_array_equal(result, expected) + @pytest.mark.parametrize( "a", [numpy.array([0, 2**62, 0]), numpy.array([0, 2**63, 0])] ) From f1e6ec58e07199ca3a03809dc9d6449d8230c383 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Mon, 26 Jan 2026 12:22:13 -0800 Subject: [PATCH 2/2] Add PR to the changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 69b06cb64bf8..d52206d1e5a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum * Updated `dpnp.fix` to reuse `dpnp.trunc` internally [#2722](https://github.com/IntelPython/dpnp/pull/2722) * Changed the build scripts and documentation due to `python setup.py develop` deprecation notice [#2716](https://github.com/IntelPython/dpnp/pull/2716) * Clarified behavior on repeated `axes` in `dpnp.tensordot` and `dpnp.linalg.tensordot` functions [#2733](https://github.com/IntelPython/dpnp/pull/2733) +* Aligned `dpnp.trim_zeros` with NumPy 2.4 to support a tuple of integers passed with `axis` keyword [#2746](https://github.com/IntelPython/dpnp/pull/2746) ### Deprecated