From 41d38e77c34332557da30793cdb6b79d31f28448 Mon Sep 17 00:00:00 2001 From: apkille Date: Thu, 27 Jun 2024 23:01:13 -0400 Subject: [PATCH] qsimplify updates --- ext/QuantumCliffordExt/QuantumCliffordExt.jl | 1 - src/QSymbolicsBase/QSymbolicsBase.jl | 4 +-- src/QSymbolicsBase/basic_ops_homogeneous.jl | 31 +++++++++++++++---- src/QSymbolicsBase/basic_ops_inhomogeneous.jl | 5 +++ src/QSymbolicsBase/predefined.jl | 3 ++ src/QSymbolicsBase/rules.jl | 10 +++--- test/test_express_cliff.jl | 16 +++++----- 7 files changed, 50 insertions(+), 20 deletions(-) diff --git a/ext/QuantumCliffordExt/QuantumCliffordExt.jl b/ext/QuantumCliffordExt/QuantumCliffordExt.jl index 70db332..dd1ee9b 100644 --- a/ext/QuantumCliffordExt/QuantumCliffordExt.jl +++ b/ext/QuantumCliffordExt/QuantumCliffordExt.jl @@ -56,7 +56,6 @@ express_nolookup(::YGate, ::CliffordRepr, ::UseAsObservable) = QuantumClifford.P express_nolookup(::ZGate, ::CliffordRepr, ::UseAsObservable) = QuantumClifford.P"Z" express_nolookup(op::SScaledOperator, r::CliffordRepr, u::UseAsObservable) = arguments(op)[1] * express(arguments(op)[2],r,u) express_nolookup(x::SMulOperator, r::CliffordRepr, u::UseAsObservable) = (*)((express(t,r,u) for t in arguments(x))...) -express_nolookup(x::STensorOperator, r::CliffordRepr, u::UseAsObservable) = QuantumClifford.tensor((express(t,r,u) for t in arguments(x))...) express_nolookup(op, ::CliffordRepr, ::UseAsObservable) = error("Can not convert $(op) into a `PauliOperator`, which is the only observable that can be computed for QuantumClifford objects. Consider defining `express_nolookup(op, ::CliffordRepr, ::UseAsObservable)::PauliOperator` for this object.") struct QCRandomSampler # TODO specify types diff --git a/src/QSymbolicsBase/QSymbolicsBase.jl b/src/QSymbolicsBase/QSymbolicsBase.jl index ac50922..3a78be5 100644 --- a/src/QSymbolicsBase/QSymbolicsBase.jl +++ b/src/QSymbolicsBase/QSymbolicsBase.jl @@ -1,9 +1,9 @@ using Symbolics import Symbolics: simplify using SymbolicUtils -import SymbolicUtils: Symbolic, _isone, flatten_term, isnotflat, Chain, Fixpoint +import SymbolicUtils: Symbolic, _isone, flatten_term, isnotflat, Chain, Fixpoint, Prewalk using TermInterface -import TermInterface: isexpr, head, iscall, children, operation, arguments, metadata +import TermInterface: isexpr, head, iscall, children, operation, arguments, metadata, maketerm using LinearAlgebra import LinearAlgebra: eigvecs, ishermitian, inv diff --git a/src/QSymbolicsBase/basic_ops_homogeneous.jl b/src/QSymbolicsBase/basic_ops_homogeneous.jl index 176feeb..af9c6a8 100644 --- a/src/QSymbolicsBase/basic_ops_homogeneous.jl +++ b/src/QSymbolicsBase/basic_ops_homogeneous.jl @@ -30,30 +30,39 @@ arguments(x::SScaled) = [x.coeff,x.obj] operation(x::SScaled) = * head(x::SScaled) = :* children(x::SScaled) = [:*,x.coeff,x.obj] -Base.:(*)(c, x::Symbolic{T}) where {T<:QObj} = iszero(c) || iszero(x) ? SZero{T}() : SScaled{T}(c, x) +function Base.:(*)(c, x::Symbolic{T}) where {T<:QObj} + if iszero(c) || iszero(x) + SZero{T}() + else + x isa SScaled ? SScaled{T}(c*x.coeff, x.obj) : SScaled{T}(c, x) + end +end Base.:(*)(x::Symbolic{T}, c) where {T<:QObj} = c*x Base.:(/)(x::Symbolic{T}, c) where {T<:QObj} = iszero(c) ? throw(DomainError(c,"cannot divide QSymbolics expressions by zero")) : (1/c)*x basis(x::SScaled) = basis(x.obj) const SScaledKet = SScaled{AbstractKet} +maketerm(::Type{SScaledKet}, f, a, t, m) = f(a...) function Base.show(io::IO, x::SScaledKet) - if x.coeff isa Number + if x.coeff isa Real print(io, "$(x.coeff)$(x.obj)") else print(io, "($(x.coeff))$(x.obj)") end end const SScaledOperator = SScaled{AbstractOperator} +maketerm(::Type{SScaledOperator}, f, a, t, m) = f(a...) function Base.show(io::IO, x::SScaledOperator) - if x.coeff isa Number + if x.coeff isa Real print(io, "$(x.coeff)$(x.obj)") else print(io, "($(x.coeff))$(x.obj)") end end const SScaledBra = SScaled{AbstractBra} +maketerm(::Type{SScaledBra}, f, a, t, m) = f(a...) function Base.show(io::IO, x::SScaledBra) - if x.coeff isa Number + if x.coeff isa Real print(io, "$(x.coeff)$(x.obj)") else print(io, "($(x.coeff))$(x.obj)") @@ -94,16 +103,19 @@ Base.:(+)(xs::Vararg{Symbolic{<:QObj},0}) = 0 # to avoid undefined type paramete basis(x::SAdd) = basis(first(x.dict).first) const SAddKet = SAdd{AbstractKet} +maketerm(::Type{SAddKet}, f, a, t, m) = f(a...) function Base.show(io::IO, x::SAddKet) ordered_terms = sort([repr(i) for i in arguments(x)]) print(io, "("*join(ordered_terms,"+")::String*")") # type assert to help inference end const SAddOperator = SAdd{AbstractOperator} +maketerm(::Type{SAddOperator}, f, a, t, m) = f(a...) function Base.show(io::IO, x::SAddOperator) ordered_terms = sort([repr(i) for i in arguments(x)]) print(io, "("*join(ordered_terms,"+")::String*")") # type assert to help inference end const SAddBra = SAdd{AbstractBra} +maketerm(::Type{SAddBra}, f, a, t, m) = f(a...) function Base.show(io::IO, x::SAddBra) ordered_terms = sort([repr(i) for i in arguments(x)]) print(io, "("*join(ordered_terms,"+")::String*")") # type assert to help inference @@ -131,6 +143,7 @@ arguments(x::SMulOperator) = x.terms operation(x::SMulOperator) = * head(x::SMulOperator) = :* children(x::SMulOperator) = [:*;x.terms] +maketerm(::Type{SMulOperator}, f, a, t, m) = f(a...) function Base.:(*)(xs::Symbolic{AbstractOperator}...) zero_ind = findfirst(x->iszero(x), xs) isnothing(zero_ind) ? SMulOperator(collect(xs)) : SZeroOperator() @@ -171,14 +184,18 @@ function ⊗(xs::Symbolic{T}...) where {T<:QObj} end basis(x::STensor) = tensor(basis.(x.terms)...) +const STensorBra = STensor{AbstractBra} +maketerm(::Type{STensorBra}, f, a, t, m) = f(a...) +Base.show(io::IO, x::STensorBra) = print(io, join(map(string, arguments(x)),"")) const STensorKet = STensor{AbstractKet} +maketerm(::Type{STensorKet}, f, a, t, m) = f(a...) Base.show(io::IO, x::STensorKet) = print(io, join(map(string, arguments(x)),"")) const STensorOperator = STensor{AbstractOperator} +maketerm(::Type{STensorOperator}, f, a, t, m) = f(a...) Base.show(io::IO, x::STensorOperator) = print(io, join(map(string, arguments(x)),"⊗")) const STensorSuperOperator = STensor{AbstractSuperOperator} +maketerm(::Type{STensorSuperOperator}, f, a, t, m) = f(a...) Base.show(io::IO, x::STensorSuperOperator) = print(io, join(map(string, arguments(x)),"⊗")) -const STensorBra = STensor{AbstractBra} -Base.show(io::IO, x::STensorBra) = print(io, join(map(string, arguments(x)),"")) """Symbolic commutator of two operators @@ -206,6 +223,7 @@ arguments(x::SCommutator) = [x.op1, x.op2] operation(x::SCommutator) = commutator head(x::SCommutator) = :commutator children(x::SCommutator) = [:commutator, x.op1, x.op2] +maketerm(::Type{SCommutator}, f, a, t, m) = f(a...) commutator(o1::Symbolic{AbstractOperator}, o2::Symbolic{AbstractOperator}) = SCommutator(o1, o2) commutator(o1::SZeroOperator, o2::Symbolic{AbstractOperator}) = SZeroOperator() commutator(o1::Symbolic{AbstractOperator}, o2::SZeroOperator) = SZeroOperator() @@ -237,6 +255,7 @@ arguments(x::SAnticommutator) = [x.op1, x.op2] operation(x::SAnticommutator) = anticommutator head(x::SAnticommutator) = :anticommutator children(x::SAnticommutator) = [:anticommutator, x.op1, x.op2] +maketerm(::Type{SAnticommutator}, f, a, t, m) = f(a...) anticommutator(o1::Symbolic{AbstractOperator}, o2::Symbolic{AbstractOperator}) = SAnticommutator(o1, o2) anticommutator(o1::SZeroOperator, o2::Symbolic{AbstractOperator}) = SZeroOperator() anticommutator(o1::Symbolic{AbstractOperator}, o2::SZeroOperator) = SZeroOperator() diff --git a/src/QSymbolicsBase/basic_ops_inhomogeneous.jl b/src/QSymbolicsBase/basic_ops_inhomogeneous.jl index 49c48ae..6679794 100644 --- a/src/QSymbolicsBase/basic_ops_inhomogeneous.jl +++ b/src/QSymbolicsBase/basic_ops_inhomogeneous.jl @@ -25,6 +25,7 @@ arguments(x::SApplyKet) = [x.op,x.ket] operation(x::SApplyKet) = * head(x::SApplyKet) = :* children(x::SApplyKet) = [:*,x.op,x.ket] +maketerm(::Type{SApplyKet}, f, a, t, m) = f(a...) Base.:(*)(op::Symbolic{AbstractOperator}, k::Symbolic{AbstractKet}) = SApplyKet(op,k) Base.:(*)(op::SZeroOperator, k::Symbolic{AbstractKet}) = SZeroKet() Base.:(*)(op::Symbolic{AbstractOperator}, k::SZeroKet) = SZeroKet() @@ -55,6 +56,7 @@ arguments(x::SApplyBra) = [x.bra,x.op] operation(x::SApplyBra) = * head(x::SApplyBra) = :* children(x::SApplyBra) = [:*,x.bra,x.op] +maketerm(::Type{SApplyBra}, f, a, t, m) = f(a...) Base.:(*)(b::Symbolic{AbstractBra}, op::Symbolic{AbstractOperator}) = SApplyBra(b,op) Base.:(*)(b::SZeroBra, op::Symbolic{AbstractOperator}) = SZeroBra() Base.:(*)(b::Symbolic{AbstractBra}, op::SZeroOperator) = SZeroBra() @@ -81,6 +83,7 @@ arguments(x::SBraKet) = [x.bra,x.ket] operation(x::SBraKet) = * head(x::SBraKet) = :* children(x::SBraKet) = [:*,x.bra,x.ket] +maketerm(::Type{SBraKet}, f, a, t, m) = f(a...) Base.:(*)(b::Symbolic{AbstractBra}, k::Symbolic{AbstractKet}) = SBraKet(b,k) Base.:(*)(b::SZeroBra, k::Symbolic{AbstractKet}) = 0 Base.:(*)(b::Symbolic{AbstractBra}, k::SZeroKet) = 0 @@ -99,6 +102,7 @@ arguments(x::SSuperOpApply) = [x.sop,x.op] operation(x::SSuperOpApply) = * head(x::SSuperOpApply) = :* children(x::SSuperOpApply) = [:*,x.sop,x.op] +maketerm(::Type{SSuperOpApply}, f, a, t, m) = f(a...) Base.:(*)(sop::Symbolic{AbstractSuperOperator}, op::Symbolic{AbstractOperator}) = SSuperOpApply(sop,op) Base.:(*)(sop::Symbolic{AbstractSuperOperator}, op::SZeroOperator) = SZeroOperator() Base.:(*)(sop::Symbolic{AbstractSuperOperator}, k::Symbolic{AbstractKet}) = SSuperOpApply(sop,SProjector(k)) @@ -128,6 +132,7 @@ arguments(x::SOuterKetBra) = [x.ket,x.bra] operation(x::SOuterKetBra) = * head(x::SOuterKetBra) = :* children(x::SOuterKetBra) = [:*,x.ket,x.bra] +maketerm(::Type{SOuterKetBra}, f, a, t, m) = f(a...) Base.:(*)(k::Symbolic{AbstractKet}, b::Symbolic{AbstractBra}) = SOuterKetBra(k,b) Base.:(*)(k::SZeroKet, b::Symbolic{AbstractBra}) = SZeroOperator() Base.:(*)(k::Symbolic{AbstractKet}, b::SZeroBra) = SZeroOperator() diff --git a/src/QSymbolicsBase/predefined.jl b/src/QSymbolicsBase/predefined.jl index 182e8bc..21a63ea 100644 --- a/src/QSymbolicsBase/predefined.jl +++ b/src/QSymbolicsBase/predefined.jl @@ -221,6 +221,7 @@ arguments(x::SProjector) = [x.ket] operation(x::SProjector) = projector head(x::SProjector) = :projector children(x::SProjector) = [:projector,x.ket] +maketerm(::Type{SProjector}, f, a, t, m) = f(a...) projector(x::Symbolic{AbstractKet}) = SProjector(x) projector(x::SZeroKet) = SZeroOperator() basis(x::SProjector) = basis(x.ket) @@ -261,6 +262,7 @@ arguments(x::SDagger) = [x.obj] operation(x::SDagger) = dagger head(x::SDagger) = :dagger children(x::SDagger) = [:dagger, x.obj] +maketerm(::Type{SDagger}, f, a, t, m) = f(a...) dagger(x::Symbolic{AbstractBra}) = SDagger{AbstractKet}(x) dagger(x::Symbolic{AbstractKet}) = SDagger{AbstractBra}(x) dagger(x::Symbolic{AbstractOperator}) = SDagger{AbstractOperator}(x) @@ -309,6 +311,7 @@ arguments(x::SInvOperator) = [x.op] operation(x::SInvOperator) = inv head(x::SInvOperator) = :inv children(x::SInvOperator) = [:inv, x.op] +maketerm(::Type{SInvOperator}, f, a, t, m) = f(a...) basis(x::SInvOperator) = basis(x.op) Base.show(io::IO, x::SInvOperator) = print(io, "$(x.op)⁻¹") Base.:(*)(invop::SInvOperator, op::SOperator) = isequal(invop.op, op) ? IdentityOp(basis(op)) : SMulOperator(invop, op) diff --git a/src/QSymbolicsBase/rules.jl b/src/QSymbolicsBase/rules.jl index 6a8d6b7..49c5df2 100644 --- a/src/QSymbolicsBase/rules.jl +++ b/src/QSymbolicsBase/rules.jl @@ -121,6 +121,9 @@ If the keyword `rewriter` is not specified, then `qsimplify` will apply every de For performance or single-purpose motivations, the user has the option to define a specific rewriter for `qsimplify` to apply to the expression. ```jldoctest +julia> qsimplify(Y*commutator(X*Z, Z)) +(0 - 2im)Z + julia> qsimplify(anticommutator(σˣ, σˣ), rewriter=qsimplify_anticommutator) 2𝕀 ``` @@ -128,12 +131,11 @@ julia> qsimplify(anticommutator(σˣ, σˣ), rewriter=qsimplify_anticommutator) function qsimplify(s; rewriter=nothing) if QuantumSymbolics.isexpr(s) if isnothing(rewriter) - Fixpoint(Chain(RULES_ALL))(s) + Fixpoint(Prewalk(Chain(RULES_ALL)))(s) else - Fixpoint(rewriter)(s) + Fixpoint(Prewalk(rewriter))(s) end else error("Object $(s) of type $(typeof(s)) is not an expression.") end -end - +end \ No newline at end of file diff --git a/test/test_express_cliff.jl b/test/test_express_cliff.jl index 4733874..3375283 100644 --- a/test/test_express_cliff.jl +++ b/test/test_express_cliff.jl @@ -30,14 +30,16 @@ UseObs = UseAsObservable() end @testset "Clifford representations as observables" begin - isequal(express(QuantumSymbolics.X, CR, UseObs), P"X") - isequal(express(QuantumSymbolics.Y, CR, UseObs), P"Y") - isequal(express(QuantumSymbolics.Z, CR, UseObs), P"Z") - isequal(express(im*QuantumSymbolics.X, CR, UseObs), im*P"X") + isequal(express(σˣ, CR, UseObs), P"X") + isequal(express(σʸ, CR, UseObs), P"Y") + isequal(express(σᶻ, CR, UseObs), P"Z") + isequal(express(im*σˣ, CR, UseObs), im*P"X") + isequal(express(σˣ⊗σʸ⊗σᶻ), P"X"⊗P"Y"⊗P"Z") + isequal(express(σˣ*σʸ*σᶻ), P"X"*P"Y"*P"Z") end @testset "Clifford representations as operations" begin - isequal(express(QuantumSymbolics.X, CR, UseOp), sX) - isequal(express(QuantumSymbolics.Y, CR, UseOp), sY) - isequal(express(QuantumSymbolics.Z, CR, UseOp), sZ) + isequal(express(σˣ, CR, UseOp), sX) + isequal(express(σʸ, CR, UseOp), sY) + isequal(express(σᶻ, CR, UseOp), sZ) end \ No newline at end of file