From 1f0dea3bf54dde875f9bfe551f67b6c277c0263b Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 12 Dec 2024 09:18:34 -0500 Subject: [PATCH] Try fixing tests --- README.md | 127 +++++++++++++++++++++++++++++++++++++++++++++- test/Project.toml | 3 ++ 2 files changed, 128 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7d6a93e..19c582b 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,8 @@ [![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) [![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) +A block sparse array type in Julia based on the [`BlockArrays.jl`](https://github.com/JuliaArrays/BlockArrays.jl) interface. + ## Installation instructions This package resides in the `ITensor/ITensorRegistry` local registry. @@ -32,10 +34,131 @@ julia> Pkg.add("BlockSparseArrays") ## Examples ````julia -using BlockSparseArrays: BlockSparseArrays +using BlockArrays: BlockArrays, BlockedVector, Block, blockedrange +using BlockSparseArrays: BlockSparseArray, block_stored_length +using Test: @test, @test_broken + +function main() + # Block dimensions + i1 = [2, 3] + i2 = [2, 3] + + i_axes = (blockedrange(i1), blockedrange(i2)) + + function block_size(axes, block) + return length.(getindex.(axes, Block.(block.n))) + end + + # Data + nz_blocks = Block.([(1, 1), (2, 2)]) + nz_block_sizes = [block_size(i_axes, nz_block) for nz_block in nz_blocks] + nz_block_lengths = prod.(nz_block_sizes) + + # Blocks with contiguous underlying data + d_data = BlockedVector(randn(sum(nz_block_lengths)), nz_block_lengths) + d_blocks = [ + reshape(@view(d_data[Block(i)]), block_size(i_axes, nz_blocks[i])) for + i in 1:length(nz_blocks) + ] + b = BlockSparseArray(nz_blocks, d_blocks, i_axes) + + @test block_stored_length(b) == 2 + + # Blocks with discontiguous underlying data + d_blocks = randn.(nz_block_sizes) + b = BlockSparseArray(nz_blocks, d_blocks, i_axes) + + @test block_stored_length(b) == 2 + + # Access a block + @test b[Block(1, 1)] == d_blocks[1] + + # Access a zero block, returns a zero matrix + @test b[Block(1, 2)] == zeros(2, 3) + + # Set a zero block + a₁₂ = randn(2, 3) + b[Block(1, 2)] = a₁₂ + @test b[Block(1, 2)] == a₁₂ + + # Matrix multiplication + # TODO: Fix this, broken. + @test_broken b * b ≈ Array(b) * Array(b) + + permuted_b = permutedims(b, (2, 1)) + @test permuted_b isa BlockSparseArray + @test permuted_b == permutedims(Array(b), (2, 1)) + + @test b + b ≈ Array(b) + Array(b) + @test b + b isa BlockSparseArray + # TODO: Fix this, broken. + @test_broken block_stored_length(b + b) == 2 + + scaled_b = 2b + @test scaled_b ≈ 2Array(b) + @test scaled_b isa BlockSparseArray + + # TODO: Fix this, broken. + @test_broken reshape(b, ([4, 6, 6, 9],)) isa BlockSparseArray{<:Any,1} + + return nothing +end + +main() ```` -Examples go here. +# BlockSparseArrays.jl and BlockArrays.jl interface + +````julia +using BlockArrays: BlockArrays, Block +using BlockSparseArrays: BlockSparseArray + +i1 = [2, 3] +i2 = [2, 3] +B = BlockSparseArray{Float64}(i1, i2) +B[Block(1, 1)] = randn(2, 2) +B[Block(2, 2)] = randn(3, 3) + +# Minimal interface + +# Specifies the block structure +@show collect.(BlockArrays.blockaxes(axes(B, 1))) + +# Index range of a block +@show axes(B, 1)[Block(1)] + +# Last index of each block +@show BlockArrays.blocklasts(axes(B, 1)) + +# Find the block containing the index +@show BlockArrays.findblock(axes(B, 1), 3) + +# Retrieve a block +@show B[Block(1, 1)] +@show BlockArrays.viewblock(B, Block(1, 1)) + +# Check block bounds +@show BlockArrays.blockcheckbounds(B, 2, 2) +@show BlockArrays.blockcheckbounds(B, Block(2, 2)) + +# Derived interface + +# Specifies the block structure +@show collect(Iterators.product(BlockArrays.blockaxes(B)...)) + +# Iterate over block views +@show sum.(BlockArrays.eachblock(B)) + +# Reshape into 1-d +# TODO: Fix this, broken. +# @show BlockArrays.blockvec(B)[Block(1)] + +# Array-of-array view +@show BlockArrays.blocks(B)[1, 1] == B[Block(1, 1)] + +# Access an index within a block +@show B[Block(1, 1)[1, 1]] == B[1, 1] +```` --- diff --git a/test/Project.toml b/test/Project.toml index 36367a6..dec0051 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2" @@ -11,7 +12,9 @@ NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" NestedPermutedDimsArrays = "2c2a8ec4-3cfc-4276-aa3e-1307b4294e58" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" +Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" SymmetrySectors = "f8a8ad64-adbc-4fce-92f7-ffe2bb36a86e" TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"