Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Entanglement simulation on a grid with custom predicates #90

Merged
merged 11 commits into from
Feb 12, 2024
2 changes: 1 addition & 1 deletion examples/firstgenrepeater_v2/2_swapper_example.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ for (;src, dst) in edges(network)
@process eprot()
end
for node in vertices(network)
sprot = SwapperProt(sim, network, node; nodeL = <(node), nodeR = >(node))
sprot = SwapperProt(sim, network, node; nodeL = <(node), nodeH = >(node), chooseL = argmin, chooseH = argmax)
@process sprot()
end

Expand Down
33 changes: 20 additions & 13 deletions src/ProtocolZoo/ProtocolZoo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ end
end
end

function random_index(arr)
return rand(keys(arr))
end

"""
$TYPEDEF
Expand All @@ -146,17 +149,21 @@ A protocol, running at a given node, that finds swappable entangled pairs and pe

$FIELDS
"""
@kwdef struct SwapperProt{L,R,LT} <: AbstractProtocol where {L<:Union{Int,<:Function,Wildcard}, R<:Union{Int,<:Function,Wildcard}, LT<:Union{Float64,Nothing}}
@kwdef struct SwapperProt{NL,NH,CL,CH,LT} <: AbstractProtocol where {NL<:Union{Int,<:Function,Wildcard}, NH<:Union{Int,<:Function,Wildcard}, CL<:Function, CH<:Function, LT<:Union{Float64,Nothing}}
"""time-and-schedule-tracking instance from `ConcurrentSim`"""
sim::Simulation
"""a network graph of registers"""
net::RegisterNet
"""the vertex of the node where swapping is happening"""
node::Int
"""the vertex of one of the remote nodes (or a predicate function or a wildcard)"""
nodeL::L = ❓
"""the vertex of the other remote node (or a predicate function or a wildcard)"""
nodeR::R = ❓
"""the vertex of one of the remote nodes for the swap, arbitrarily referred to as the "low" node (or a predicate function or a wildcard); if you are working on a repeater chain, a good choice is `<(current_node)`, i.e. any node to the "left" of the current node"""
nodeL::NL = ❓
"""the vertex of the other remote node for the swap, the "high" counterpart of `nodeL`; if you are working on a repeater chain, a good choice is `>(current_node)`, i.e. any node to the "right" of the current node"""
nodeH::NH = ❓
"""the `nodeL` predicate can return many positive candidates; `chooseL` picks one of them (by index into the array of filtered `nodeL` results), defaults to a random pick `arr->rand(keys(arr))`; if you are working on a repeater chain a good choice is `argmin`, i.e. the node furthest to the "left" """
chooseL::CL = random_index
"""the `nodeH` counterpart for `chooseH`; if you are working on a repeater chain a good choice is `argmax`, i.e. the node furthest to the "right" """
chooseH::CH = random_index
"""fixed "busy time" duration immediately before starting entanglement generation attempts"""
local_busy_time::Float64 = 0.0 # TODO the gates should have that busy time built in
"""how long to wait before retrying to lock qubits if no qubits are available (`nothing` for queuing up and waiting)"""
Expand All @@ -175,7 +182,7 @@ end
rounds = prot.rounds
while rounds != 0
reg = prot.net[prot.node]
qubit_pair = findswapablequbits(prot.net,prot.node)
qubit_pair = findswapablequbits(prot.net, prot.node, prot.nodeL, prot.nodeH, prot.chooseL, prot.chooseH)
if isnothing(qubit_pair)
isnothing(prot.retry_lock_time) && error("We do not yet support waiting on register to make qubits available") # TODO
@yield timeout(prot.sim, prot.retry_lock_time)
Expand Down Expand Up @@ -212,16 +219,16 @@ end
end
end

function findswapablequbits(net,node) # TODO parameterize the query predicates and the findmin/findmax
function findswapablequbits(net, node, pred_low, pred_high, choose_low, choose_high)
reg = net[node]

leftnodes = queryall(reg, EntanglementCounterpart, <(node), ❓; locked=false, assigned=true)
rightnodes = queryall(reg, EntanglementCounterpart, >(node), ❓; locked=false, assigned=true)
low_nodes = queryall(reg, EntanglementCounterpart, pred_low, ❓; locked=false, assigned=true)
high_nodes = queryall(reg, EntanglementCounterpart, pred_high, ❓; locked=false, assigned=true)

(isempty(leftnodes) || isempty(rightnodes)) && return nothing
_, il = findmin(n->n.tag[2], leftnodes) # TODO make [2] into a nice named property
_, ir = findmax(n->n.tag[2], rightnodes)
return leftnodes[il], rightnodes[ir]
(isempty(low_nodes) || isempty(high_nodes)) && return nothing
il = choose_low((n.tag[2] for n in low_nodes)) # TODO make [2] into a nice named property
ih = choose_high((n.tag[2] for n in high_nodes))
return low_nodes[il], high_nodes[ih]
end
Krastanov marked this conversation as resolved.
Show resolved Hide resolved


Expand Down
4 changes: 2 additions & 2 deletions src/queries.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ for (tagsymbol, tagvariant) in pairs(tag_types)
sig_wild = collect(sig)
sig_wild[idx] .= Union{Wildcard,Function}
argssig_wild = [:($a::$t) for (a,t) in zip(args, sig_wild)]
wild_checks = [:(isa($(args[i]),Wildcard) || $(args[i])(tag.data[$i])) for i in idx]
nonwild_checks = [:(tag.data[$i]==$(args[i])) for i in complement_idx]
wild_checks = [:(isa($(args[i]),Wildcard) || $(args[i])(tag[$i])) for i in idx]
nonwild_checks = [:(tag[$i]==$(args[i])) for i in complement_idx]
newmethod_reg = quote function query(reg::Register, $(argssig_wild...), ::Val{allB}=Val{false}(); locked::Union{Nothing,Bool}=nothing, assigned::Union{Nothing,Bool}=nothing) where {allB}
res = NamedTuple{(:slot, :tag), Tuple{RegRef, Tag}}[]
for (reg_idx, tags) in enumerate(reg.tags)
Expand Down
6 changes: 3 additions & 3 deletions src/tags.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ end
See also: [`query`](@ref), [`tag!`](@ref), [`Wildcard`](@ref)"""
const tag_types = Tag'

Base.getindex(tag::Tag, i::Int) = tag.data[i]
Base.length(tag::Tag) = length(tag.data.data)
Base.iterate(tag::Tag, state=1) = state > length(tag) ? nothing : (tag[state],state+1)
Base.getindex(tag::Tag, i::Int) = SumTypes.unwrap(tag)[i]
Base.length(tag::Tag) = length(SumTypes.unwrap(tag).data)
Base.iterate(tag::Tag, state=1) = state > length(tag) ? nothing : (SumTypes.unwrap(tag)[state],state+1)

function SumTypes.show_sumtype(io::IO, x::Tag)
data = SumTypes.unwrap(x)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ println("Starting tests with $(Threads.nthreads()) threads out of `Sys.CPU_THREA
@doset "messagebuffer"
@doset "tags_and_queries"
@doset "entanglement_tracker"
@doset "entanglement_tracker_grid"

@doset "circuitzoo_api"
@doset "circuitzoo_ent_swap"
Expand Down
6 changes: 3 additions & 3 deletions test/test_entanglement_tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ for i in 1:10
@test [islocked(ref) for i in vertices(net) for ref in net[i]] |> any == false


swapper2 = SwapperProt(sim, net, 2; rounds=1)
swapper3 = SwapperProt(sim, net, 3; rounds=1)
swapper2 = SwapperProt(sim, net, 2; nodeL = <(2), nodeH = >(2), chooseL = argmin, chooseH = argmax, rounds = 1)
swapper3 = SwapperProt(sim, net, 3; nodeL = <(3), nodeH = >(3), chooseL = argmin, chooseH = argmax, rounds = 1)
@process swapper2()
@process swapper3()
run(sim, 80)
Expand Down Expand Up @@ -102,7 +102,7 @@ for i in 1:30, n in 2:30
@process eprot()
end
for j in 2:n-1
swapper = SwapperProt(sim, net, j; rounds=1)
swapper = SwapperProt(sim, net, j; nodeL = <(j), nodeH = >(j), chooseL = argmin, chooseH = argmax, rounds = 1)
@process swapper()
end
run(sim, 200)
Expand Down
151 changes: 151 additions & 0 deletions test/test_entanglement_tracker_grid.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
using Revise
using QuantumSavory
using ResumableFunctions
using ConcurrentSim
using QuantumSavory.ProtocolZoo
using QuantumSavory.ProtocolZoo: EntanglementCounterpart, EntanglementHistory, EntanglementUpdateX, EntanglementUpdateZ
using Graphs
using Test

if isinteractive()
using Logging
logger = ConsoleLogger(Logging.Debug; meta_formatter=(args...)->(:black,"",""))
global_logger(logger)
println("Logger set to debug")
end

##
# Here we test entanglement tracker and swapper protocols on an arbitrary hardcoded path of long-range connection inside of what is otherwise a 2D grid
# We do NOT test anything related to automatic routing on such a grid -- only the hardcoded path is tested
##

## Custom Predicates

function top_left(net, node, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am waiting on CI and will then merge. However, this is an O(n) check for something that should be O(1).

From a node index get its x y coordinates (currently you get only one of them).
Then do the same for the other node.

Then just compare them to verify that you are to the left (one comparison) and to the top (second comparison).

No need for a for loop.

Submit a PR with the fix and feel free to merge it. It should also make this function much simpler.

n = sqrt(size(net.graph)[1]) # grid size
a = (node ÷ n) + 1 # row number
for i in 1:a-1
if x == (i-1)*n + i
return true
end
end
return false
end

function bottom_right(net, node, x)
n = sqrt(size(net.graph)[1]) # grid size
a = (node ÷ n) + 1 # row number
for i in a+1:n
if x == (i-1)*n + i
return true
end
end
return false
end

## Simulation

## without entanglement tracker - this is almost the same test as the one in test_entanglement_tracker.jl which tests a simple chain -- the only difference is that we have picked a few hardcoded arbitrary nodes through a grid (creating an ad-hoc chain)
for i in 1:10
graph = grid([4, 4])
add_edge!(graph, 1, 6)
add_edge!(graph, 6, 11)
add_edge!(graph, 11, 16)

net = RegisterNet(graph, [Register(3) for i in 1:16])
sim = get_time_tracker(net)


entangler1 = EntanglerProt(sim, net, 1, 6; rounds=1)
@process entangler1()
run(sim, 20)

@test net[1].tags == [[Tag(EntanglementCounterpart, 6, 1)],[],[]]


entangler2 = EntanglerProt(sim, net, 6, 11; rounds=1)
@process entangler2()
run(sim, 40)
entangler3 = EntanglerProt(sim, net, 11, 16; rounds=1)
@process entangler3()
run(sim, 60)

@test net[1].tags == [[Tag(EntanglementCounterpart, 6, 1)],[],[]]
@test net[6].tags == [[Tag(EntanglementCounterpart, 1, 1)],[Tag(EntanglementCounterpart, 11, 1)],[]]
@test net[11].tags == [[Tag(EntanglementCounterpart, 6, 2)],[Tag(EntanglementCounterpart, 16, 1)], []]
@test net[16].tags == [[Tag(EntanglementCounterpart, 11, 2)],[],[]]

@test [islocked(ref) for i in vertices(net) for ref in net[i]] |> any == false

l1(x) = top_left(net, 6, x)
h1(x) = bottom_right(net, 6, x)
swapper2 = SwapperProt(sim, net, 6; nodeL=l1, nodeH=h1, rounds=1)
l2(x) = top_left(net, 11, x)
h2(x) = bottom_right(net, 11, x)
swapper3 = SwapperProt(sim, net, 11; nodeL=l2, nodeH=h2, rounds=1)
@process swapper2()
@process swapper3()
run(sim, 80)

# In the absence of an entanglement tracker the tags will not all be updated
@test net[1].tags == [[Tag(EntanglementCounterpart, 6, 1)],[],[]]
@test net[6].tags == [[Tag(EntanglementHistory, 1, 1, 11, 1, 2)],[Tag(EntanglementHistory, 11, 1, 1, 1, 1)],[]]
@test net[11].tags == [[Tag(EntanglementHistory, 6, 2, 16, 1, 2)],[Tag(EntanglementHistory, 16, 1, 6, 2, 1)], []]
@test net[16].tags == [[Tag(EntanglementCounterpart, 11, 2)],[],[]]

@test isassigned(net[1][1]) && isassigned(net[16][1])
@test !isassigned(net[6][1]) && !isassigned(net[11][1])
@test !isassigned(net[6][2]) && !isassigned(net[11][2])

@test [islocked(ref) for i in vertices(net) for ref in net[i]] |> any == false

end

## with entanglement tracker -- here we hardcode the diagonal of the grid as the path on which we are making connections
for n in 4:10
graph = grid([n,n])

diag_pairs = []
diag_nodes = []
reg_num = 1 # starting register
for i in 1:n-1 # a grid with n nodes has n-1 pairs of diagonal nodes
push!(diag_pairs, (reg_num, reg_num+n+1))
push!(diag_nodes, reg_num)
reg_num += n + 1
end
push!(diag_nodes, n^2)

for (src, dst) in diag_pairs # need edges down the diagonal to establish cchannels and qchannels between the diagonal nodes
add_edge!(graph, src, dst)
end

net = RegisterNet(graph, [Register(8) for i in 1:n^2])

sim = get_time_tracker(net)

for (src, dst) in diag_pairs
eprot = EntanglerProt(sim, net, src, dst; rounds=1, randomize=true)
@process eprot()
end

for i in 2:n-1
l(x) = top_left(net, diag_nodes[i], x)
h(x) = bottom_right(net, diag_nodes[i], x)
swapper = SwapperProt(sim, net, diag_nodes[i]; nodeL = l, nodeH = h, rounds = 1)
@process swapper()
end

for v in diag_nodes
tracker = EntanglementTracker(sim, net, v)
@process tracker()
end

run(sim, 200)

q1 = query(net[1], EntanglementCounterpart, diag_nodes[n], ❓)
q2 = query(net[diag_nodes[n]], EntanglementCounterpart, 1, ❓)
@test q1.tag[2] == diag_nodes[n]
@test q2.tag[2] == 1
@test observable((q1.slot, q2.slot), Z⊗Z) ≈ 1
@test observable((q1.slot, q2.slot), X⊗X) ≈ 1
end
Loading