Skip to content

Commit

Permalink
Add SVD support for BlockDiagonal
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Oct 31, 2024
1 parent f382da6 commit f28ef12
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/blocksvd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@ SVD on blockmatrices:
Interpret the matrix as a linear map acting on vector spaces with a direct sum structure, which is reflected in the structure of U and V.
In the generic case, the SVD does not preserve this structure, and can mix up the blocks, so S becomes a single block.
(Implementation-wise, also most efficiently done by first mapping to a `BlockedArray`)
In the case of `BlockDiagonal` however, the structure is preserved and carried over to the structure of `S`.
=#

LinearAlgebra.eigencopy_oftype(A::AbstractBlockMatrix, S) = BlockedArray(Array{S}(A), blocksizes(A, 1), blocksizes(A, 2))

function LinearAlgebra.eigencopy_oftype(A::BlockDiagonal, S)
diag = map(Base.Fix2(LinearAlgebra.eigencopy_oftype, S), A.blocks.diag)
return BlockDiagonal(diag)
end

function LinearAlgebra.svd!(A::BlockedMatrix; full::Bool=false, alg::LinearAlgebra.Algorithm=default_svd_alg(A))
USV = svd!(parent(A); full, alg)

Expand All @@ -19,3 +25,11 @@ function LinearAlgebra.svd!(A::BlockedMatrix; full::Bool=false, alg::LinearAlgeb
vt = BlockedArray(USV.Vt, bsz2, bsz3)
return SVD(u, s, vt)
end

function LinearAlgebra.svd!(A::BlockDiagonal; full::Bool=false, alg::LinearAlgebra.Algorithm=default_svd_alg(A))
USVs = map(a -> svd!(a; full, alg), A.blocks.diag)
Us = map(Base.Fix2(getproperty, :U), USVs)
Ss = map(Base.Fix2(getproperty, :S), USVs)
Vts = map(Base.Fix2(getproperty, :Vt), USVs)
return SVD(BlockDiagonal(Us), mortar(Ss), BlockDiagonal(Vts))
end
34 changes: 34 additions & 0 deletions test/test_blocksvd.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module TestBlockSVD

using BlockArrays, Test, LinearAlgebra, Random
using BlockArrays: BlockDiagonal

Random.seed!(0)

Expand Down Expand Up @@ -70,4 +71,37 @@ end
@test U_blocked * Diagonal(S_blocked) * Vt_blocked y
end

@testset "BlockDiagonal SVD ($T)" for T in eltypes
blocksz = (2, 3, 1)
y = BlockDiagonal([rand(T, d, d) for d in blocksz])
x = Array(y)

USV = svd(x)
U, S, Vt = USV.U, USV.S, USV.Vt

# https://github.com/JuliaArrays/BlockArrays.jl/issues/425
# USV_blocked = @inferred svd(y)
USV_block = svd(y)
U_block, S_block, Vt_block = USV_block.U, USV_block.S, USV_block.Vt

# test types
@test U_block isa BlockDiagonal
@test eltype(U_block) == float(T)
@test S_block isa BlockVector
@test eltype(S_block) == real(float(T))
@test Vt_block isa BlockDiagonal
@test eltype(Vt_block) == float(T)

# test structure
@test blocksizes(U_block, 1) == blocksizes(y, 1)
@test length(blocksizes(U_block, 2)) == length(blocksz)
@test blocksizes(Vt_block, 2) == blocksizes(y, 2)
@test length(blocksizes(Vt_block, 1)) == length(blocksz)

# test correctness: SVD is not unique, so cannot compare to dense
@test U_block * BlockDiagonal(Diagonal.(S_block.blocks)) * Vt_block y
@test U_block' * U_block LinearAlgebra.I
@test Vt_block * Vt_block' LinearAlgebra.I
end

end # module

0 comments on commit f28ef12

Please sign in to comment.