Skip to content

Commit

Permalink
Try fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Dec 12, 2024
1 parent ba9c22f commit 1f0dea3
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 2 deletions.
127 changes: 125 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)
[![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)

A block sparse array type in Julia based on the [`BlockArrays.jl`](https://github.com/JuliaArrays/BlockArrays.jl) interface.

## Installation instructions

This package resides in the `ITensor/ITensorRegistry` local registry.
Expand All @@ -32,10 +34,131 @@ julia> Pkg.add("BlockSparseArrays")
## Examples

````julia
using BlockSparseArrays: BlockSparseArrays
using BlockArrays: BlockArrays, BlockedVector, Block, blockedrange
using BlockSparseArrays: BlockSparseArray, block_stored_length
using Test: @test, @test_broken

function main()
# Block dimensions
i1 = [2, 3]
i2 = [2, 3]

i_axes = (blockedrange(i1), blockedrange(i2))

function block_size(axes, block)
return length.(getindex.(axes, Block.(block.n)))
end

# Data
nz_blocks = Block.([(1, 1), (2, 2)])
nz_block_sizes = [block_size(i_axes, nz_block) for nz_block in nz_blocks]
nz_block_lengths = prod.(nz_block_sizes)

# Blocks with contiguous underlying data
d_data = BlockedVector(randn(sum(nz_block_lengths)), nz_block_lengths)
d_blocks = [
reshape(@view(d_data[Block(i)]), block_size(i_axes, nz_blocks[i])) for
i in 1:length(nz_blocks)
]
b = BlockSparseArray(nz_blocks, d_blocks, i_axes)

@test block_stored_length(b) == 2

# Blocks with discontiguous underlying data
d_blocks = randn.(nz_block_sizes)
b = BlockSparseArray(nz_blocks, d_blocks, i_axes)

@test block_stored_length(b) == 2

# Access a block
@test b[Block(1, 1)] == d_blocks[1]

# Access a zero block, returns a zero matrix
@test b[Block(1, 2)] == zeros(2, 3)

# Set a zero block
a₁₂ = randn(2, 3)
b[Block(1, 2)] = a₁₂
@test b[Block(1, 2)] == a₁₂

# Matrix multiplication
# TODO: Fix this, broken.
@test_broken b * b Array(b) * Array(b)

permuted_b = permutedims(b, (2, 1))
@test permuted_b isa BlockSparseArray
@test permuted_b == permutedims(Array(b), (2, 1))

@test b + b Array(b) + Array(b)
@test b + b isa BlockSparseArray
# TODO: Fix this, broken.
@test_broken block_stored_length(b + b) == 2

scaled_b = 2b
@test scaled_b 2Array(b)
@test scaled_b isa BlockSparseArray

# TODO: Fix this, broken.
@test_broken reshape(b, ([4, 6, 6, 9],)) isa BlockSparseArray{<:Any,1}

return nothing
end

main()
````

Examples go here.
# BlockSparseArrays.jl and BlockArrays.jl interface

````julia
using BlockArrays: BlockArrays, Block
using BlockSparseArrays: BlockSparseArray

i1 = [2, 3]
i2 = [2, 3]
B = BlockSparseArray{Float64}(i1, i2)
B[Block(1, 1)] = randn(2, 2)
B[Block(2, 2)] = randn(3, 3)

# Minimal interface

# Specifies the block structure
@show collect.(BlockArrays.blockaxes(axes(B, 1)))

# Index range of a block
@show axes(B, 1)[Block(1)]

# Last index of each block
@show BlockArrays.blocklasts(axes(B, 1))

# Find the block containing the index
@show BlockArrays.findblock(axes(B, 1), 3)

# Retrieve a block
@show B[Block(1, 1)]
@show BlockArrays.viewblock(B, Block(1, 1))

# Check block bounds
@show BlockArrays.blockcheckbounds(B, 2, 2)
@show BlockArrays.blockcheckbounds(B, Block(2, 2))

# Derived interface

# Specifies the block structure
@show collect(Iterators.product(BlockArrays.blockaxes(B)...))

# Iterate over block views
@show sum.(BlockArrays.eachblock(B))

# Reshape into 1-d
# TODO: Fix this, broken.
# @show BlockArrays.blockvec(B)[Block(1)]

# Array-of-array view
@show BlockArrays.blocks(B)[1, 1] == B[Block(1, 1)]

# Access an index within a block
@show B[Block(1, 1)[1, 1]] == B[1, 1]
````

---

Expand Down
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2"
Expand All @@ -11,7 +12,9 @@ NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
NestedPermutedDimsArrays = "2c2a8ec4-3cfc-4276-aa3e-1307b4294e58"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
SymmetrySectors = "f8a8ad64-adbc-4fce-92f7-ffe2bb36a86e"
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down

0 comments on commit 1f0dea3

Please sign in to comment.