Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum
* Updated `dpnp.fix` to reuse `dpnp.trunc` internally [#2722](https://github.com/IntelPython/dpnp/pull/2722)
* Changed the build scripts and documentation due to `python setup.py develop` deprecation notice [#2716](https://github.com/IntelPython/dpnp/pull/2716)
* Clarified behavior on repeated `axes` in `dpnp.tensordot` and `dpnp.linalg.tensordot` functions [#2733](https://github.com/IntelPython/dpnp/pull/2733)
* Aligned `dpnp.trim_zeros` with NumPy 2.4 to support a tuple of integers passed with `axis` keyword [#2746](https://github.com/IntelPython/dpnp/pull/2746)

### Deprecated

Expand Down
24 changes: 12 additions & 12 deletions dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3983,7 +3983,7 @@ def trim_zeros(filt, trim="fb", axis=None):
(or index -1).

Default: ``"fb"``.
axis : {None, int}, optional
axis : {None, int, tuple of ints}, optional
If ``None``, `filt` is cropped such that the smallest bounding box is
returned that still contains all values which are not zero.
If an `axis` is specified, `filt` will be sliced in that dimension only
Expand Down Expand Up @@ -4038,11 +4038,14 @@ def trim_zeros(filt, trim="fb", axis=None):
raise ValueError(f"unexpected character(s) in `trim`: {trim!r}")

nd = filt.ndim
if axis is not None:
axis = normalize_axis_index(axis, nd)
if axis is None:
Copy link
Contributor

@vlad-perevezentsev vlad-perevezentsev Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect shape for empty input with Numpy 2.4

a = numpy.empty((0,3))
a_dp = dpnp.empty((0,3))

numpy.trim_zeros(a,axis=None).shape
# (0, 0)

dpnp.trim_zeros(a_dp, axis=None).shape
# (0, 3)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems more as a bug in NumPy implementation which exists before 2.4 release as well.
An empty array does not contain any zero elements, so there is nothing to trim. IMHO it is correct to return the input array unchanged then.

axis = tuple(range(nd))
else:
axis = normalize_axis_tuple(axis, nd, argname="axis")

if filt.size == 0:
return filt # no trailing zeros in empty array
# check if an empty array or no trimming requested
if filt.size == 0 or not axis:
return filt

non_zero = dpnp.argwhere(filt)
if non_zero.size == 0:
Expand All @@ -4061,13 +4064,10 @@ def trim_zeros(filt, trim="fb", axis=None):
else:
stop = (None,) * nd

if axis is None:
# trim all axes
sl = tuple(slice(*x) for x in zip(start, stop))
else:
# only trim single axis
sl = (slice(None),) * axis + (slice(start[axis], stop[axis]),) + (...,)

sl = tuple(
slice(start[ax], stop[ax]) if ax in axis else slice(None)
for ax in range(nd)
)
return filt[sl]


Expand Down
51 changes: 49 additions & 2 deletions dpnp/tests/test_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1432,6 +1432,8 @@ def test_usm_array(self):


class TestTrimZeros:
ALL_TRIMS = ["F", "B", "fb"]

@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
def test_basic(self, dtype):
a = numpy.array([0, 0, 1, 0, 2, 3, 4, 0], dtype=dtype)
Expand All @@ -1443,7 +1445,7 @@ def test_basic(self, dtype):

@testing.with_requires("numpy>=2.2")
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
@pytest.mark.parametrize("trim", ["F", "B", "fb"])
@pytest.mark.parametrize("trim", ALL_TRIMS)
@pytest.mark.parametrize("ndim", [0, 1, 2, 3])
def test_basic_nd(self, dtype, trim, ndim):
a = numpy.ones((2,) * ndim, dtype=dtype)
Expand Down Expand Up @@ -1477,7 +1479,7 @@ def test_all_zero(self, dtype, trim):

@testing.with_requires("numpy>=2.2")
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
@pytest.mark.parametrize("trim", ["F", "B", "fb"])
@pytest.mark.parametrize("trim", ALL_TRIMS)
@pytest.mark.parametrize("ndim", [0, 1, 2, 3])
def test_all_zero_nd(self, dtype, trim, ndim):
a = numpy.zeros((3,) * ndim, dtype=dtype)
Expand All @@ -1496,6 +1498,51 @@ def test_size_zero(self):
expected = numpy.trim_zeros(a)
assert_array_equal(result, expected)

@testing.with_requires("numpy>=2.4")
@pytest.mark.parametrize(
"shape, axis",
[
[(5,), None],
[(5,), ()],
[(5,), 0],
[(5, 6), None],
[(5, 6), ()],
[(5, 6), 0],
[(5, 6), (-1,)],
[(5, 6, 7), None],
[(5, 6, 7), ()],
[(5, 6, 7), 1],
[(5, 6, 7), (0, 2)],
[(5, 6, 7, 8), None],
[(5, 6, 7, 8), ()],
[(5, 6, 7, 8), -2],
[(5, 6, 7, 8), (0, 1, 3)],
],
)
@pytest.mark.parametrize("trim", ALL_TRIMS)
def test_multiple_axes(self, shape, axis, trim):
# standardize axis to a tuple
if axis is None:
axis = tuple(range(len(shape)))
elif isinstance(axis, int):
axis = (len(shape) + axis if axis < 0 else axis,)
else:
axis = tuple(len(shape) + ax if ax < 0 else ax for ax in axis)

# populate a random interior slice with nonzero entries
rng = numpy.random.default_rng(4321)
a = numpy.zeros(shape)
start = rng.integers(low=0, high=numpy.array(shape) - 1)
end = rng.integers(low=start + 1, high=shape)
shape = tuple(end - start)
data = 1 + rng.random(shape)
a[tuple(slice(i, j) for i, j in zip(start, end))] = data
ia = dpnp.array(a)

result = dpnp.trim_zeros(ia, axis=axis, trim=trim)
expected = numpy.trim_zeros(a, axis=axis, trim=trim)
assert_array_equal(result, expected)

@pytest.mark.parametrize(
"a", [numpy.array([0, 2**62, 0]), numpy.array([0, 2**63, 0])]
)
Expand Down
Loading