Skip to content

Commit

Permalink
Add SVD support for BlockDiagonal
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Oct 31, 2024
1 parent f382da6 commit 30a0295
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/BlockArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ include("blocks.jl")

include("blockbroadcast.jl")
include("blockcholesky.jl")
include("blocksvd.jl")
include("blocklinalg.jl")
include("blockproduct.jl")
include("show.jl")
include("blockreduce.jl")
include("blockdeque.jl")
include("blockarrayinterface.jl")
include("blockbanded.jl")
include("blocksvd.jl")

@deprecate getblock(A::AbstractBlockArray{T,N}, I::Vararg{Integer, N}) where {T,N} view(A, Block(I))
@deprecate getblock!(X, A::AbstractBlockArray{T,N}, I::Vararg{Integer, N}) where {T,N} copyto!(X, view(A, Block(I)))
Expand Down
14 changes: 14 additions & 0 deletions src/blocksvd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@ SVD on blockmatrices:
Interpret the matrix as a linear map acting on vector spaces with a direct sum structure, which is reflected in the structure of U and V.
In the generic case, the SVD does not preserve this structure, and can mix up the blocks, so S becomes a single block.
(Implementation-wise, also most efficiently done by first mapping to a `BlockedArray`)
In the case of `BlockDiagonal` however, the structure is preserved and carried over to the structure of `S`.
=#

LinearAlgebra.eigencopy_oftype(A::AbstractBlockMatrix, S) = BlockedArray(Array{S}(A), blocksizes(A, 1), blocksizes(A, 2))

Check warning on line 9 in src/blocksvd.jl

View check run for this annotation

Codecov / codecov/patch

src/blocksvd.jl#L9

Added line #L9 was not covered by tests

function LinearAlgebra.eigencopy_oftype(A::BlockDiagonal, S)
diag = map(Base.Fix2(LinearAlgebra.eigencopy_oftype, S), A.blocks.diag)
return BlockDiagonal(diag)

Check warning on line 13 in src/blocksvd.jl

View check run for this annotation

Codecov / codecov/patch

src/blocksvd.jl#L11-L13

Added lines #L11 - L13 were not covered by tests
end

function LinearAlgebra.svd!(A::BlockedMatrix; full::Bool=false, alg::LinearAlgebra.Algorithm=default_svd_alg(A))
USV = svd!(parent(A); full, alg)

Check warning on line 17 in src/blocksvd.jl

View check run for this annotation

Codecov / codecov/patch

src/blocksvd.jl#L16-L17

Added lines #L16 - L17 were not covered by tests

Expand All @@ -19,3 +25,11 @@ function LinearAlgebra.svd!(A::BlockedMatrix; full::Bool=false, alg::LinearAlgeb
vt = BlockedArray(USV.Vt, bsz2, bsz3)
return SVD(u, s, vt)

Check warning on line 26 in src/blocksvd.jl

View check run for this annotation

Codecov / codecov/patch

src/blocksvd.jl#L23-L26

Added lines #L23 - L26 were not covered by tests
end

function LinearAlgebra.svd!(A::BlockDiagonal; full::Bool=false, alg::LinearAlgebra.Algorithm=default_svd_alg(A))
USVs = map(a -> svd!(a; full, alg), A.blocks.diag)
Us = map(Base.Fix2(getproperty, :U), USVs)
Ss = map(Base.Fix2(getproperty, :S), USVs)
Vts = map(Base.Fix2(getproperty, :Vt), USVs)
return SVD(BlockDiagonal(Us), mortar(Ss), BlockDiagonal(Vts))

Check warning on line 34 in src/blocksvd.jl

View check run for this annotation

Codecov / codecov/patch

src/blocksvd.jl#L29-L34

Added lines #L29 - L34 were not covered by tests
end
34 changes: 34 additions & 0 deletions test/test_blocksvd.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module TestBlockSVD

using BlockArrays, Test, LinearAlgebra, Random
using BlockArrays: BlockDiagonal

Random.seed!(0)

Expand Down Expand Up @@ -70,4 +71,37 @@ end
@test U_blocked * Diagonal(S_blocked) * Vt_blocked y
end

@testset "BlockDiagonal SVD ($T)" for T in eltypes
blocksz = (2, 3, 1)
y = BlockDiagonal([rand(T, d, d) for d in blocksz])
x = Array(y)

USV = svd(x)
U, S, Vt = USV.U, USV.S, USV.Vt

# https://github.com/JuliaArrays/BlockArrays.jl/issues/425
# USV_blocked = @inferred svd(y)
USV_block = svd(y)
U_block, S_block, Vt_block = USV_block.U, USV_block.S, USV_block.Vt

# test types
@test U_block isa BlockDiagonal
@test eltype(U_block) == float(T)
@test S_block isa BlockVector
@test eltype(S_block) == real(float(T))
@test Vt_block isa BlockDiagonal
@test eltype(Vt_block) == float(T)

# test structure
@test blocksizes(U_block, 1) == blocksizes(y, 1)
@test length(blocksizes(U_block, 2)) == length(blocksz)
@test blocksizes(Vt_block, 2) == blocksizes(y, 2)
@test length(blocksizes(Vt_block, 1)) == length(blocksz)

# test correctness: SVD is not unique, so cannot compare to dense
@test U_block * BlockDiagonal(Diagonal.(S_block.blocks)) * Vt_block y
@test U_block' * U_block LinearAlgebra.I
@test Vt_block * Vt_block' LinearAlgebra.I
end

end # module

0 comments on commit 30a0295

Please sign in to comment.