Skip to content

Commit

Permalink
[NDTensors] Fix a bug contracting a tensor wrapping Adjoint (#1242)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Nov 8, 2023
1 parent 6d3e62c commit 57c33b9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
2 changes: 1 addition & 1 deletion NDTensors/src/dense/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ function promote_rule(
::Type{<:Dense{ElT1,DataT1}}, ::Type{<:Dense{ElT2,DataT2}}
) where {ElT1,DataT1,ElT2,DataT2}
ElR = promote_type(ElT1, ElT2)
VecR = promote_type(DataT1, DataT2)
VecR = promote_type(unwrap_type(DataT1), unwrap_type(DataT2))
VecR = similartype(VecR, ElR)
return Dense{ElR,VecR}
end
Expand Down
32 changes: 22 additions & 10 deletions NDTensors/test/diag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,43 @@ using Test
@testset "DiagTensor basic functionality" begin
include("device_list.jl")
devs = devices_list(copy(ARGS))
@testset "test device: $dev" for dev in devs
t = dev(tensor(Diag(rand(ComplexF64, 100)), (100, 100)))
@testset "test device: $dev" for dev in devs,
elt in (Float32, ComplexF32, Float64, ComplexF64)

if dev == NDTensors.mtl && real(elt) Float32
# Metal doesn't support double precision
continue
end
t = dev(tensor(Diag(rand(elt, 100)), (100, 100)))
@test conj(data(store(t))) == data(store(conj(t)))
@test typeof(conj(t)) <: DiagTensor

d = rand(Float32, 10)
D = dev(Diag{ComplexF64}(d))
@test eltype(D) == ComplexF64
@test dev(Array(dense(D))) == convert.(ComplexF64, d)
d = rand(real(elt), 10)
D = dev(Diag{elt}(d))
@test eltype(D) == elt
@test dev(Array(dense(D))) == convert.(elt, d)
simD = similar(D)
@test length(simD) == length(D)
@test eltype(simD) == eltype(D)
D = dev(Diag(1.0))
@test eltype(D) == Float64
@test complex(D) == Diag(one(ComplexF64))
D = dev(Diag(one(elt)))
@test eltype(D) == elt
@test complex(D) == Diag(one(complex(elt)))
@test similar(D) == Diag(0.0)

D = Tensor(Diag(1), (2, 2))
@test norm(D) == 2
d = 3
vr = rand(d)
vr = rand(elt, d)
D = dev(tensor(Diag(vr), (d, d)))
@test Array(D) == NDTensors.LinearAlgebra.diagm(0 => vr)
@test matrix(D) == NDTensors.LinearAlgebra.diagm(0 => vr)
@test permutedims(D, (2, 1)) == D

# Regression test for https://github.com/ITensor/ITensors.jl/issues/1199
S = dev(tensor(Diag(randn(elt, 2)), (2, 2)))
V = dev(tensor(Dense(randn(elt, 12, 2)'), (3, 4, 2)))
@test contract(S, (2, -1), V, (3, 4, -1))
contract(dense(S), (2, -1), copy(V), (3, 4, -1))
end
end
@testset "DiagTensor contractions" begin
Expand Down

0 comments on commit 57c33b9

Please sign in to comment.