Skip to content

Commit 83f662f

Browse files
committed
fix: Specialized ReshapedArray dispatch to resolve setindex! ambiguities
1 parent 88dcf9c commit 83f662f

2 files changed

Lines changed: 78 additions & 2 deletions

File tree

src/host/indexing.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,14 @@ end
167167
function Base._unsafe_setindex!(::IndexStyle, A::WrappedGPUArray, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N
168168
return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...)
169169
end
170-
# And allow one more `ReshapedArray` wrapper to handle the `_maybe_reshape` optimization.
171-
function Base._unsafe_setindex!(::IndexStyle, A::Base.ReshapedArray{<:Any, <:Any, <:WrappedGPUArray}, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N
170+
171+
#Implementation for ReshapedArrays using Cartesian indexing to resolve dispatch ties.
172+
function Base._unsafe_setindex!(::Base.IndexCartesian, A::Base.ReshapedArray{T, N, <:WrappedGPUArray}, x, Is::Vararg{Union{Real, AbstractArray}, M}) where {T, N, M}
173+
return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...)
174+
end
175+
176+
#Implementation for ReshapedArrays using Linear indexing to resolve dispatch ties.
177+
function Base._unsafe_setindex!(::Base.IndexLinear, A::Base.ReshapedArray{T, N, <:WrappedGPUArray}, x, Is::Vararg{Union{Real, AbstractArray}, M}) where {T, N, M}
172178
return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...)
173179
end
174180

test/testsuite/indexing.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,3 +284,73 @@ end
284284
@test compare(argmin, AT, -rand(Int, 10))
285285
end
286286
end
287+
288+
@testsuite "indexing combinatorial" (AT, eltypes) -> begin
289+
@testset "Reshaped SubArray dispatch" for T in eltypes
290+
@testset "3D slice assignment" begin
291+
A = AT(ones(T, 4, 4, 4))
292+
@views V = A[:, :, 1:2]
293+
@allowscalar begin
294+
@test_nowarn V .= zero(T)
295+
@test all(Array(V) .== zero(T))
296+
end
297+
end
298+
299+
@testset "Logical mask view (dim = 3) — GPU safe" begin
300+
A = AT(ones(T, 4, 4, 4))
301+
idx = findall(Bool[true, false, true, false])
302+
@views V = A[:, :, idx]
303+
@allowscalar begin
304+
@test_nowarn V .+= T(2)
305+
@test all(Array(V) .== T(3))
306+
end
307+
end
308+
309+
@testset "Nested Reshape" begin
310+
A = AT(ones(T, 4, 4, 4))
311+
V = view(A, 1:2, 1:2, 1:2)
312+
R1 = reshape(V, 4, 2)
313+
R2 = reshape(R1, :)
314+
@allowscalar begin
315+
@test_nowarn R2 .+= one(T)
316+
@test all(Array(R2) .== T(2))
317+
end
318+
end
319+
end
320+
321+
@testset "Permuted and Reinterpreted Views" for T in eltypes
322+
@testset "Reshaped PermutedDims" begin
323+
A = AT(ones(T, 4, 4))
324+
P = PermutedDimsArray(A, (2, 1))
325+
R = reshape(P, :)
326+
@allowscalar begin
327+
@test_nowarn R[1:2] .= zero(T)
328+
# Check the full assigned range.
329+
@test all(Array(R)[1:2] .== zero(T))
330+
end
331+
end
332+
333+
@testset "Reshaped Reinterpreted" begin
334+
T_base = real(T)
335+
if T <: Complex
336+
A = AT(ones(T, 4, 4))
337+
IT = Complex{Int16}
338+
R = reshape(reinterpret(IT, A), :)
339+
@allowscalar begin
340+
@test_nowarn R[1:2] .= zero(IT)
341+
@test all(Array(R)[1:2] .== zero(IT))
342+
end
343+
end
344+
end
345+
end
346+
347+
@testset "Data parity with compare() — GPU safe" for T in eltypes
348+
idx = 2:4
349+
@test compare(AT, rand(T, 8, 8, 8)) do A
350+
# compare() handles CPU/GPU execution no @allowscalar needed here
351+
V = view(A, :, idx, :)
352+
V .+= one(T)
353+
A
354+
end
355+
end
356+
end

0 commit comments

Comments
 (0)