From a12089817c9de5d0d4bd895c21b9c23726b2be7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 20 Jul 2023 17:13:44 +0200 Subject: [PATCH] Fix permutation search on `Tensor` with repeated indices --- src/Tensor.jl | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/Tensor.jl b/src/Tensor.jl index c683f20..3f86b6e 100644 --- a/src/Tensor.jl +++ b/src/Tensor.jl @@ -60,6 +60,19 @@ function Base.similar(t::Tensor, ::Type{T}, dims::Int64...; labels = labels(t), Tensor(data, labels; meta...) end +function __find_index_permutation(a, b) + labels_b = collect(Union{Missing,Symbol}, b) + + Iterators.map(a) do label + i = findfirst(isequal(label), labels_b) + + # mark element as used + labels_b[i] = missing + + i + end |> collect +end + Base.:(==)(a::AbstractArray, b::Tensor) = isequal(b, a) Base.:(==)(a::Tensor, b::AbstractArray) = isequal(a, b) Base.:(==)(a::Tensor, b::Tensor) = isequal(a, b) @@ -67,7 +80,7 @@ Base.isequal(a::AbstractArray, b::Tensor) = false Base.isequal(a::Tensor, b::AbstractArray) = false function Base.isequal(a::Tensor, b::Tensor) issetequal(labels(a), labels(b)) || return false - perm = [findfirst(==(label), labels(b)) for label in labels(a)] + perm = __find_index_permutation(labels(a), labels(b)) return all(eachindex(IndexCartesian(), a)) do i j = CartesianIndex(Tuple(permute!(collect(Tuple(i)), invperm(perm)))) isequal(a[i], b[j]) @@ -78,7 +91,7 @@ Base.isapprox(a::AbstractArray, b::Tensor) = false Base.isapprox(a::Tensor, b::AbstractArray) = false function Base.isapprox(a::Tensor, b::Tensor) issetequal(labels(a), labels(b)) || return false - perm = [findfirst(==(label), labels(b)) for label in labels(a)] + perm = __find_index_permutation(labels(a), labels(b)) return all(eachindex(IndexCartesian(), a)) do i j = CartesianIndex(Tuple(permute!(collect(Tuple(i)), invperm(perm)))) isapprox(a[i], b[j])