From 8600ee7957c649ba994c8c1a67a83f6478214477 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 5 Jun 2024 18:41:58 -0400 Subject: [PATCH] Fix prewalk_if dropping metadata Update diff.jl Update rewrite-helpers.jl Update utils.jl Revert "Update rewrite-helpers.jl" This reverts commit 9b98a03537c6e6f320baf9dcb58c92fb8c80a2a7. Revert "Update utils.jl" This reverts commit a7f3d281a0db1ad641fb638bbdacd25faba2626b. Revert "Update diff.jl" This reverts commit 9e7b05b65b5a232066d1a64323d90f80081760da. --- src/arrays.jl | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/src/arrays.jl b/src/arrays.jl index dcba2dd86..084bf8620 100644 --- a/src/arrays.jl +++ b/src/arrays.jl @@ -588,14 +588,6 @@ function replace_by_scalarizing(ex, dict) rule = @rule(getindex(~x, ~~i) => scalarize(~x, (map(j->substitute(j, dict), ~~i)...,))) - simterm = (x, f, args; kws...) -> begin - if metadata(x) !== nothing - maketerm(typeof(x), f, args, symtype(x), metadata(x)) - else - f(args...) - end - end - function rewrite_operation(x) if iscall(x) && iscall(operation(x)) f = operation(x) @@ -612,20 +604,23 @@ function replace_by_scalarizing(ex, dict) prewalk_if(x->!(x isa ArrayOp || x isa ArrayMaker), Rewriters.PassThrough(Chain([rewrite_operation, rule])), - ex, simterm) + ex) end -function prewalk_if(cond, f, t, maketerm) +function prewalk_if(cond, f, t) t′ = cond(t) ? f(t) : return t if iscall(t′) - return maketerm(typeof(t′), TermInterface.head(t′), - map(x->prewalk_if(cond, f, x, maketerm), children(t′))) + if metadata(t′) !== nothing + return maketerm(typeof(t′), TermInterface.head(t′), + map(x->prewalk_if(cond, f, x), children(t′)), symtype(t′), metadata(t′)) + else + TermInterface.head(t′)(map(x->prewalk_if(cond, f, x), children(t′))...) + end else return t′ end end - function scalarize(arr::AbstractArray, idx) arr[idx...] end