Skip to content

Commit

Permalink
Merge pull request #140 from GianlucaFuwa/master
Browse files Browse the repository at this point in the history
Fix undefined nested iterating variable when using `Base.Generator`s and iterating over vector of functions
  • Loading branch information
chriselrod authored Apr 1, 2024
2 parents 1296bb3 + e870b29 commit a096315
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/closure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ function extractargs!(

startind = 1
if head === :call
startind = 2
if args[1] isa Symbol
startind = isdefined(mod, args[1]) ? 2 : 1
else
startind = 2
end
elseif head === :(=)
extractargs_equal!(arguments, defined, args, mod)
elseif head (:inbounds, :loopinfo)#, :(->))
Expand All @@ -121,6 +125,8 @@ function extractargs!(
extractargs!(arguments, td, args[1], mod)
extractargs!(arguments, td, args[2], mod)
return
elseif head === :generator
extractargs_equal!(arguments, defined, args[2].args, mod)
elseif (head === :local) || (head === :global)
for (i, arg) in enumerate(args)
if Meta.isexpr(arg, :(=))
Expand Down
32 changes: 32 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,20 @@ function issue25_but_with_strides!(dest, x, y)
dest
end

function issue108!(y::Vector{T1}, x::Vector{T2}) where {T1,T2}
@batch for i in eachindex(y)
y[i] = sum(x[j] for j in 2i-oneunit(i):2i)
end
end

function issue108_comment!(data::Vector{T}, functions) where {T}
@batch for i in eachindex(data)
for f in functions
data[i] += f(data[i])
end
end
end

function issue116!(y::Vector{T}, x::Vector{T}) where {T}
@batch for i in 1:length(x)
y[i] = exp(x[i] + one(T))
Expand Down Expand Up @@ -262,6 +276,24 @@ end
@test reduce(hcat, arrayofarrays) == (xref .*= 2)
end

@testset "Generators and looping over array of functions" begin
x = collect(1:12)
y = zeros(6)
issue108!(y, x)
@test y == [sum(x[j] for j in 2i-oneunit(i):2i) for i in 1:6]

functions = [x -> n*x for n in 1:3]
data = rand(100)
data1 = deepcopy(data)
issue108_comment!(data, functions)
for i in eachindex(data1)
for f in functions
data1[i] += f(data1[i])
end
end
@test data == data1
end

println("Issue 245...")

import Polyester: splitloop, combine, NoLoop, @batch
Expand Down

0 comments on commit a096315

Please sign in to comment.