Skip to content

Commit 37d7b79

Browse files
authored
Fix rewrite of sum(x; init) when x is not a generator (#344)
1 parent 0b37017 commit 37d7b79

2 files changed

Lines changed: 50 additions & 2 deletions

File tree

src/rewrite_generic.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ function _is_flatten(expr)
4747
return Meta.isexpr(expr, :call, 2) && Meta.isexpr(expr.args[2], :flatten)
4848
end
4949

50+
_is_generator_or_flatten(expr) = _is_generator(expr) || _is_flatten(expr)
51+
5052
function _is_parameters(expr)
5153
return Meta.isexpr(expr, :call, 3) && Meta.isexpr(expr.args[2], :parameters)
5254
end
@@ -111,7 +113,7 @@ function _rewrite_generic(stack::Expr, expr::Expr)
111113
elseif Meta.isexpr(expr.args[2], :(...))
112114
# If the first argument is a splat.
113115
return esc(expr), false
114-
elseif _is_generator(expr) || _is_flatten(expr) || _is_parameters(expr)
116+
elseif _is_generator_or_flatten(expr) || _is_parameters(expr)
115117
if !(expr.args[1] in (:sum, , :∑))
116118
# We don't know what this is. Return the expression and don't let
117119
# future callers mutate.
@@ -126,7 +128,7 @@ function _rewrite_generic(stack::Expr, expr::Expr)
126128
# not any of the others.
127129
p = expr.args[2]
128130
is_init = length(p.args) == 1 && _is_kwarg(p.args[1], :init)
129-
if is_init && expr.args[3] isa Expr
131+
if is_init && _is_generator_or_flatten(expr.args[3])
130132
# sum(iter ; init) form!
131133
# We rewrite only if `iter` is an Expr; if it's just a Symbol,
132134
# we don't enter this branch.

test/rewrite_generic.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,52 @@ function test_rewrite_init_symbol()
518518
return
519519
end
520520

521+
function test_issue_343()
522+
x = [1, 3, 5, 7, 8, 2]
523+
y = sum(x)
524+
@test MA.@rewrite(sum(x[:]), move_factors_into_sums = false) == y
525+
@test MA.@rewrite(sum(x[i] for i in 1:6), move_factors_into_sums = false) ==
526+
y
527+
@test MA.@rewrite(
528+
sum(x[i+j] for i in 1:2:6 for j in 0:1),
529+
move_factors_into_sums = false
530+
) == y
531+
@test MA.@rewrite(
532+
sum(x[i+j] for i in 1:2:6, j in 0:1),
533+
move_factors_into_sums = false
534+
) == y
535+
# Turn formatting off here so we preserve `init = 0`
536+
#!format:off
537+
@test MA.@rewrite(sum(x[:], init = 0), move_factors_into_sums = false) == y
538+
@test MA.@rewrite(
539+
sum(x[i] for i in 1:6, init = 0),
540+
move_factors_into_sums = false
541+
) == y
542+
@test MA.@rewrite(
543+
sum(x[i+j] for i in 1:2:6 for j in 0:1, init = 0),
544+
move_factors_into_sums = false
545+
) == y
546+
@test MA.@rewrite(
547+
sum(x[i+j] for i in 1:2:6, j in 0:1, init = 0),
548+
move_factors_into_sums = false
549+
) == y
550+
#!format:on
551+
@test MA.@rewrite(sum(x[:]; init = 0), move_factors_into_sums = false) == y
552+
@test MA.@rewrite(
553+
sum(x[i] for i in 1:6; init = 0),
554+
move_factors_into_sums = false
555+
) == y
556+
@test MA.@rewrite(
557+
sum(x[i+j] for i in 1:2:6 for j in 0:1; init = 0),
558+
move_factors_into_sums = false
559+
) == y
560+
@test MA.@rewrite(
561+
sum(x[i+j] for i in 1:2:6, j in 0:1; init = 0),
562+
move_factors_into_sums = false
563+
) == y
564+
return
565+
end
566+
521567
end # module
522568

523569
TestRewriteGeneric.runtests()

0 commit comments

Comments
 (0)