From 4e348319bca52ce23de178ca616d98517e035a3c Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 8 Nov 2023 16:46:09 -0500 Subject: [PATCH 1/2] [NDTensors] Fix a bug contracting a tensor wrapping Adjoint --- NDTensors/src/dense/dense.jl | 2 +- NDTensors/test/diag.jl | 31 +++++++++++++++++++++---------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/NDTensors/src/dense/dense.jl b/NDTensors/src/dense/dense.jl index 1c41553eb2..39b6dadecf 100644 --- a/NDTensors/src/dense/dense.jl +++ b/NDTensors/src/dense/dense.jl @@ -123,7 +123,7 @@ function promote_rule( ::Type{<:Dense{ElT1,DataT1}}, ::Type{<:Dense{ElT2,DataT2}} ) where {ElT1,DataT1,ElT2,DataT2} ElR = promote_type(ElT1, ElT2) - VecR = promote_type(DataT1, DataT2) + VecR = promote_type(unwrap_type(DataT1), unwrap_type(DataT2)) VecR = similartype(VecR, ElR) return Dense{ElR,VecR} end diff --git a/NDTensors/test/diag.jl b/NDTensors/test/diag.jl index 8422082e92..3b909a5411 100644 --- a/NDTensors/test/diag.jl +++ b/NDTensors/test/diag.jl @@ -4,31 +4,42 @@ using Test @testset "DiagTensor basic functionality" begin include("device_list.jl") devs = devices_list(copy(ARGS)) - @testset "test device: $dev" for dev in devs - t = dev(tensor(Diag(rand(ComplexF64, 100)), (100, 100))) + @testset "test device: $dev" for dev in devs, + elt in (Float32, ComplexF32, Float64, ComplexF64) + if dev == NDTensors.mtl && real(elt) ≠ Float32 + # Metal doesn't support double precision + continue + end + t = dev(tensor(Diag(rand(elt, 100)), (100, 100))) @test conj(data(store(t))) == data(store(conj(t))) @test typeof(conj(t)) <: DiagTensor - d = rand(Float32, 10) - D = dev(Diag{ComplexF64}(d)) - @test eltype(D) == ComplexF64 - @test dev(Array(dense(D))) == convert.(ComplexF64, d) + d = rand(real(elt), 10) + D = dev(Diag{elt}(d)) + @test eltype(D) == elt + @test dev(Array(dense(D))) == convert.(elt, d) simD = similar(D) @test length(simD) == length(D) @test eltype(simD) == eltype(D) - D = dev(Diag(1.0)) - @test eltype(D) == Float64 - @test complex(D) == Diag(one(ComplexF64)) + D = dev(Diag(one(elt))) + @test eltype(D) == elt + @test complex(D) == Diag(one(complex(elt))) @test similar(D) == Diag(0.0) D = Tensor(Diag(1), (2, 2)) @test norm(D) == √2 d = 3 - vr = rand(d) + vr = rand(elt, d) D = dev(tensor(Diag(vr), (d, d))) @test Array(D) == NDTensors.LinearAlgebra.diagm(0 => vr) @test matrix(D) == NDTensors.LinearAlgebra.diagm(0 => vr) @test permutedims(D, (2, 1)) == D + + # Regression test for https://github.com/ITensor/ITensors.jl/issues/1199 + S = dev(tensor(Diag(randn(elt, 2)), (2, 2))) + V = dev(tensor(Dense(randn(elt, 12, 2)'), (3, 4, 2))) + @test contract(S, (2, -1), V, (3, 4, -1)) ≈ + contract(dense(S), (2, -1), copy(V), (3, 4, -1)) end end @testset "DiagTensor contractions" begin From 5d467a32d8f51d177163fb2211dca25350dc91ad Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 8 Nov 2023 16:52:45 -0500 Subject: [PATCH 2/2] Format --- NDTensors/test/diag.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/NDTensors/test/diag.jl b/NDTensors/test/diag.jl index 3b909a5411..a712afbc70 100644 --- a/NDTensors/test/diag.jl +++ b/NDTensors/test/diag.jl @@ -6,6 +6,7 @@ using Test devs = devices_list(copy(ARGS)) @testset "test device: $dev" for dev in devs, elt in (Float32, ComplexF32, Float64, ComplexF64) + if dev == NDTensors.mtl && real(elt) ≠ Float32 # Metal doesn't support double precision continue