Skip to content

Commit

Permalink
Added ability to do n ary decomposition. Binary, trinary etc
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Apr 1, 2024
1 parent 7ecd5a5 commit a6f7217
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 25 deletions.
8 changes: 7 additions & 1 deletion src/TensorNetworkFunctionals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@ include("itensornetworks_elementary_operators.jl")

export ITensorNetworkFunction
export BitMap,
default_dimension_map, vertex, calculate_xyz, calculate_x, calculate_bit_values, dimension
default_dimension_map,
vertex,
calculate_xyz,
calculate_x,
calculate_bit_values,
dimension,
base
export const_itensornetwork,
exp_itensornetwork,
cosh_itensornetwork,
Expand Down
53 changes: 35 additions & 18 deletions src/bitmaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,51 @@ using Dictionaries: Dictionary, set!
using Graphs: Graphs

struct BitMap{VB,VD}
vertex_bit::VB
vertex_digit::VB
vertex_dimension::VD
base::Int64
end

vertex_bit(bm::BitMap) = bm.vertex_bit
default_base() = 2

vertex_digit(bm::BitMap) = bm.vertex_digit
vertex_dimension(bm::BitMap) = bm.vertex_dimension
base(bm::BitMap) = bm.base

default_bit_map(vertices::Vector) = Dictionary(vertices, [i for i in 1:length(vertices)])
function default_dimension_map(vertices::Vector)
return Dictionary(vertices, [1 for i in 1:length(vertices)])
end

BitMap(g) = BitMap(default_bit_map(vertices(g)), default_dimension_map(vertices(g)))
function BitMap(dimension_vertices::Vector{Vector{V}}) where {V}
vertex_bit = Dictionary()
function BitMap(g; base::Int64=default_base())
return BitMap(default_bit_map(vertices(g)), default_dimension_map(vertices(g)), base)
end
function BitMap(vertex_digit, vertex_dimension; base::Int64=default_base())
return BitMap(vertex_digit, vertex_dimension, base)
end
function BitMap(dimension_vertices::Vector{Vector{V}}; base::Int64=default_base()) where {V}
vertex_digit = Dictionary()
vertex_dimension = Dictionary()
for (dimension, vertices) in enumerate(dimension_vertices)
for (bit, v) in enumerate(vertices)
set!(vertex_bit, v, bit)
set!(vertex_digit, v, bit)
set!(vertex_dimension, v, dimension)
end
end
return BitMap(vertex_bit, vertex_dimension)
return BitMap(vertex_digit, vertex_dimension, base)
end

Base.copy(bm::BitMap) = BitMap(copy(vertex_bit(bm)), copy(vertex_dimension(bm)))
function Base.copy(bm::BitMap)
return BitMap(copy(vertex_digit(bm)), copy(vertex_dimension(bm)), copy(base(bm)))
end

dimension(bm::BitMap) = maximum(collect(values(vertex_dimension(bm))))
dimension(bm::BitMap, vertex) = vertex_dimension(bm)[vertex]
bit(bm::BitMap, vertex) = vertex_bit(bm)[vertex]
digit(bm::BitMap, vertex) = vertex_digit(bm)[vertex]
bit_value_to_scalar(bm::BitMap, vertex, value::Int64) = value / (base(bm)^digit(bm, vertex))

function Graphs.vertices(bm::BitMap)
@assert keys(vertex_dimension(bm)) == keys(vertex_bit(bm))
@assert keys(vertex_dimension(bm)) == keys(vertex_digit(bm))
return collect(keys(vertex_dimension(bm)))
end
function Graphs.vertices(bm::BitMap, dimension::Int64)
Expand All @@ -45,7 +57,7 @@ end
function vertex(bm::BitMap, dimension::Int64, bit::Int64)
return only(
filter(
v -> vertex_dimension(bm)[v] == dimension && vertex_bit(bm)[v] == bit,
v -> vertex_dimension(bm)[v] == dimension && vertex_digit(bm)[v] == bit,
keys(vertex_dimension(bm)),
),
)
Expand All @@ -55,7 +67,7 @@ function calculate_xyz(bm::BitMap, vertex_to_bit_value_map, dimensions::Vector{I
out = Float64[]
for dimension in dimensions
vs = vertices(bm, dimension)
push!(out, sum([vertex_to_bit_value_map[v] / (2^bit(bm, v)) for v in vs]))
push!(out, sum([bit_value_to_scalar(bm, v, vertex_to_bit_value_map[v]) for v in vs]))
end
return out
end
Expand All @@ -79,13 +91,18 @@ function calculate_bit_values(
dimension = dimensions[i]
x_rn = copy(x)
vs = vertices(bm, dimension)
sorted_vertices = sort(vs; by=vs -> bit(bm, vs))
sorted_vertices = sort(vs; by=vs -> digit(bm, vs))
for v in sorted_vertices
if (x_rn >= 1.0 / (2^bit(bm, v)))
set!(vertex_to_bit_value_map, v, 1)
x_rn -= 1.0 / (2^bit(bm, v))
else
set!(vertex_to_bit_value_map, v, 0)
i = base(bm) - 1
vertex_set = false
while (!vertex_set)
if x_rn >= bit_value_to_scalar(bm, v, i)
set!(vertex_to_bit_value_map, v, i)
x_rn -= bit_value_to_scalar(bm, v, i)
vertex_set = true
else
i = i - 1
end
end
end

Expand Down
1 change: 1 addition & 0 deletions src/itensornetworkfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ for f in [
:calculate_bit_values,
:calculate_x,
:calculate_xyz,
:base,
]
@eval begin
function $f(fitn::ITensorNetworkFunction, args...; kwargs...)
Expand Down
26 changes: 22 additions & 4 deletions src/itensornetworks_elementary_functions.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
using Graphs: nv, vertices, edges, neighbors
using NamedGraphs:
random_bfs_tree, rem_edges, add_edges, undirected_graph, NamedEdge, AbstractGraph, leaf_vertices, a_star
random_bfs_tree,
rem_edges,
add_edges,
undirected_graph,
NamedEdge,
AbstractGraph,
leaf_vertices,
a_star
using ITensors: dim, commoninds
using ITensorNetworks: IndsNetwork, underlying_graph

Expand Down Expand Up @@ -35,7 +42,12 @@ function exp_itensornetwork(
ψ = const_itensornetwork(s, bit_map)
Lx = length(vertices(bit_map, dimension))
for v in vertices(bit_map, dimension)
ψ[v] = ITensor([exp(a / Lx), exp(a / Lx) * exp(k / (2^bit(bit_map, v)))], inds(ψ[v]))
ψ[v] = ITensor(
[
exp(a / Lx) * exp(k * bit_value_to_scalar(bit_map, v, i)) for i in 0:(dim(s[v]) - 1)
],
inds(ψ[v]),
)
end

return ψ
Expand Down Expand Up @@ -217,12 +229,18 @@ function polynomial_itensornetwork(
siteindex,
alphas,
betaindex,
[0.0, (1.0 / (2^bit(bit_map, v)))],
[bit_value_to_scalar(bit_map, v, i) for i in 0:(dim(siteindex) - 1)],
)
elseif v == root_vertex
betaindex = Index(n + 1, "DummyInd")
alphas = setdiff(inds(ψ[v]), Index[siteindex])
ψv = Q_N_tensor(2, siteindex, alphas, betaindex, [0.0, (1.0 / (2^bit(bit_map, v)))])
ψv = Q_N_tensor(
2,
siteindex,
alphas,
betaindex,
[bit_value_to_scalar(bit_map, v, i) for i in 0:(dim(siteindex) - 1)],
)
ψ[v] = ψv * ITensor(reverse(coeffs), betaindex)
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/itensornetworks_elementary_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ function stencil(
stencil_op = truncate(stencil_op; truncate_kwargs...)

for v in vertices(bit_map, dimension)
stencil_op[v] = (2^delta_power) * stencil_op[v]
stencil_op[v] = (base(bit_map)^delta_power) * stencil_op[v]
end

return truncate(stencil_op; truncate_kwargs...)
Expand Down
24 changes: 24 additions & 0 deletions test/test_bitmaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ using Dictionaries: Dictionary

@test dimension(bit_map) == 1
@test Set(vertices(bit_map)) == Set(vertices(g))
@test base(bit_map) == 2

x = 0.625
vertex_to_bit_value_map = calculate_bit_values(bit_map, x)
Expand All @@ -40,3 +41,26 @@ end
xyzvals_approx = calculate_xyz(bit_map, vertex_to_bit_value_map)
xyzvals xyzvals_approx
end

@testset "test multi dimensional trinary bit map" begin
L = 50
b = 3
g = named_grid((L, L))
vertex_to_dimension_map = Dictionary(vertices(g), [v[1] for v in vertices(g)])
vertex_to_bit_map = Dictionary(vertices(g), [v[2] for v in vertices(g)])
bit_map = BitMap(vertex_to_bit_map, vertex_to_dimension_map, b)

@test base(bit_map) == b

x, y = (1.0 / 3.0), (5.0 / 9.0)
vertex_to_bit_value_map = calculate_bit_values(bit_map, [x, y], [1, 2])
xyz = calculate_xyz(bit_map, vertex_to_bit_value_map, [1, 2])
@test first(xyz) == x
@test last(xyz) == y

xyzvals = [rand() for i in 1:L]
vertex_to_bit_value_map = calculate_bit_values(bit_map, xyzvals)

xyzvals_approx = calculate_xyz(bit_map, vertex_to_bit_value_map)
xyzvals xyzvals_approx
end
29 changes: 28 additions & 1 deletion test/test_itensorfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ using Distributions: Uniform

@test vertices(fψ, 1) == vertices(fψ)
@test dimension(fψ) == 1
@test base(fψ) == 2

dimension_vertices = collect(values(group(v -> first(v) < Int64(0.5 * L), vertices(ψ))))
= ITensorNetworkFunction(ψ, dimension_vertices)
Expand Down Expand Up @@ -54,7 +55,7 @@ end
("sin", sin_itn, sin),
]
for (name, net_func, func) in funcs
@testset "test $name" begin
@testset "test $name in binary" begin
Lx, Ly = 2, 3
g = named_comb_tree((2, 3))
a = 1.2
Expand All @@ -70,6 +71,32 @@ end
end
end

funcs = [
("cosh", cosh_itn, cosh),
("sinh", sinh_itn, sinh),
("exp", exp_itn, exp),
("cos", cos_itn, cos),
("sin", sin_itn, sin),
]
for (name, net_func, func) in funcs
@testset "test $name in trinary" begin
Lx, Ly = 2, 3
g = named_comb_tree((2, 3))
a = 1.2
k = 0.125
b = 3
s = siteinds("S=1", g)

bit_map = BitMap(g; base=b)

x = (5.0 / 9.0)
ψ_fx = net_func(s, bit_map; k, a)
@test base(ψ_fx) == 3
fx_x = calculate_fx(ψ_fx, x)
@test func(k * x + a) fx_x
end
end

@testset "test tanh" begin
L = 10
g = named_grid((L, 1))
Expand Down

0 comments on commit a6f7217

Please sign in to comment.