Skip to content

Commit

Permalink
fix dagger
Browse files Browse the repository at this point in the history
  • Loading branch information
apkille committed Jun 19, 2024
1 parent 9de9bb3 commit af3c073
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 17 deletions.
12 changes: 6 additions & 6 deletions src/QSymbolicsBase/basic_ops_homogeneous.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ arguments(x::SScaled) = [x.coeff,x.obj]
operation(x::SScaled) = *
head(x::SScaled) = :*
children(x::SScaled) = [:*,x.coeff,x.obj]
Base.:(*)(c, x::Symbolic{T}) where {T<:QObj} = c == 0 ? 0 : SScaled{T}(c,x)
Base.:(*)(x::Symbolic{T}, c) where {T<:QObj} = c == 0 ? 0 : SScaled{T}(c,x)
Base.:(*)(c, x::Symbolic{T}) where {T<:QObj} = SScaled{T}(c,x)
Base.:(*)(x::Symbolic{T}, c) where {T<:QObj} = SScaled{T}(c,x)
Base.:(/)(x::Symbolic{T}, c) where {T<:QObj} = SScaled{T}(1/c,x)
basis(x::SScaled) = basis(x.obj)

Expand Down Expand Up @@ -85,17 +85,17 @@ basis(x::SAdd) = basis(first(x.dict).first)
const SAddKet = SAdd{AbstractKet}
function Base.show(io::IO, x::SAddKet)
ordered_terms = sort([repr(i) for i in arguments(x)])
print(io, "("*join(map(string, ordered_terms),"+")::String*")") # type assert to help inference
print(io, "("*join(ordered_terms,"+")::String*")") # type assert to help inference
end
const SAddOperator = SAdd{AbstractOperator}
function Base.show(io::IO, x::SAddOperator)
ordered_terms = sort([repr(i) for i in arguments(x)])
print(io, "("*join(map(string, ordered_terms),"+")::String*")") # type assert to help inference
print(io, "("*join(ordered_terms,"+")::String*")") # type assert to help inference

Check warning on line 93 in src/QSymbolicsBase/basic_ops_homogeneous.jl

View check run for this annotation

Codecov / codecov/patch

src/QSymbolicsBase/basic_ops_homogeneous.jl#L91-L93

Added lines #L91 - L93 were not covered by tests
end
const SAddBra = SAdd{AbstractBra}
function Base.show(io::IO, x::SAddBra)
ordered_terms = sort([repr(i) for i in arguments(x)])
print(io, "("*join(map(string, ordered_terms),"+")::String*")") # type assert to help inference
print(io, "("*join(ordered_terms,"+")::String*")") # type assert to help inference

Check warning on line 98 in src/QSymbolicsBase/basic_ops_homogeneous.jl

View check run for this annotation

Codecov / codecov/patch

src/QSymbolicsBase/basic_ops_homogeneous.jl#L96-L98

Added lines #L96 - L98 were not covered by tests
end

"""Symbolic application of operator on operator
Expand All @@ -119,7 +119,7 @@ iscall(::SMulOperator) = true
arguments(x::SMulOperator) = x.terms
operation(x::SMulOperator) = *
head(x::SMulOperator) = :*
children(x::SMulOperator) = pushfirst!(x.terms,:*)
children(x::SMulOperator) = [:*;x.terms]

Check warning on line 122 in src/QSymbolicsBase/basic_ops_homogeneous.jl

View check run for this annotation

Codecov / codecov/patch

src/QSymbolicsBase/basic_ops_homogeneous.jl#L121-L122

Added lines #L121 - L122 were not covered by tests
Base.:(*)(xs::Symbolic{AbstractOperator}...) = SMulOperator(collect(xs))
Base.show(io::IO, x::SMulOperator) = print(io, join(map(string, arguments(x)),""))
basis(x::SMulOperator) = basis(x.terms)

Check warning on line 125 in src/QSymbolicsBase/basic_ops_homogeneous.jl

View check run for this annotation

Codecov / codecov/patch

src/QSymbolicsBase/basic_ops_homogeneous.jl#L125

Added line #L125 was not covered by tests
Expand Down
5 changes: 1 addition & 4 deletions src/QSymbolicsBase/predefined.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ end
julia> a = SKet(:a, SpinBasis(1//2)); A = SOperator(:A, SpinBasis(1//2));
julia> dagger(2*im*A*a)
0 - 2im⟨a|A†
0 - 2im|a⟩†A†
julia> B = SOperator(:B, SpinBasis(1//2));
Expand Down Expand Up @@ -263,13 +263,10 @@ children(x::SDagger) = [:dagger, x.obj]
dagger(x::Symbolic{AbstractBra}) = SDagger{AbstractKet}(x)
dagger(x::Symbolic{AbstractKet}) = SDagger{AbstractBra}(x)
dagger(x::Symbolic{AbstractOperator}) = SDagger{AbstractOperator}(x)
dagger(x::SKet) = SBra(x.name, x.basis)
dagger(x::SScaledKet) = SScaledBra(conj(x.coeff), dagger(x.obj))
dagger(x::SAddKet) = SAddBra(Dict(dagger(k)=>v for (k,v) in pairs(x.dict)))
dagger(x::SBra) = SKet(x.name, x.basis)
dagger(x::SScaledBra) = SScaledKet(conj(x.coeff), dagger(x.obj))
dagger(x::SAddBra) = SAddKet(Dict(dagger(b)=>v for (b,v) in pairs(x.dict)))
dagger(x::SOperator) = SDagger{AbstractOperator}(x)
dagger(x::SAddOperator) = SAddOperator(Dict(dagger(o)=>v for (o,v) in pairs(x.dict)))
dagger(x::SHermitianOperator) = x
dagger(x::SHermitianUnitaryOperator) = x

Check warning on line 272 in src/QSymbolicsBase/predefined.jl

View check run for this annotation

Codecov / codecov/patch

src/QSymbolicsBase/predefined.jl#L272

Added line #L272 was not covered by tests
Expand Down
6 changes: 3 additions & 3 deletions src/QSymbolicsBase/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ function prefactorscalings(xs; scalar=false) # If the scalar keyword is true, th
for x in xs
if isexpr(x) && operation(x) == *
c,t = arguments(x)
if scalar == false
if !scalar
coeff *= c
push!(terms,t)
elseif scalar == true && c isa Number
elseif scalar && c isa Number
coeff *= c
push!(terms, t)
else
Expand Down Expand Up @@ -127,7 +127,7 @@ julia> qsimplify(anticommutator(σˣ, σˣ), rewriter=qsimplify_anticommutator)
"""
function qsimplify(s; rewriter=nothing)
if QuantumSymbolics.isexpr(s)
if rewriter == nothing
if isnothing(rewriter)
Fixpoint(Chain(RULES_ALL))(s)

Check warning on line 131 in src/QSymbolicsBase/rules.jl

View check run for this annotation

Codecov / codecov/patch

src/QSymbolicsBase/rules.jl#L131

Added line #L131 was not covered by tests
else
Fixpoint(rewriter)(s)
Expand Down
6 changes: 2 additions & 4 deletions test/test_dagger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@ U = SUnitaryOperator(:U, SpinBasis(1//2))
= SHermitianOperator(:ℋ, SpinBasis(1//2))

@testset "symbolic dagger tests" begin
@test isequal(dagger(k₁), SBra(:k₁, SpinBasis(1//2)))
@test isequal(dagger(im*k₁), -im*SBra(:k₁, SpinBasis(1//2)))
@test isequal(dagger(im*k₁), -im*dagger(k₁))
@test isequal(dagger(k₁+k₂), dagger(k₁)+dagger(k₂))
@test isequal(dagger(b₁), SKet(:b₁, SpinBasis(1//2)))
@test isequal(dagger(im*b₁), -im*SKet(:b₁, SpinBasis(1//2)))
@test isequal(dagger(im*b₁), -im*dagger(b₁))
@test isequal(dagger(b₁+b₂), dagger(b₁)+dagger(b₂))
@test isequal(dagger(A+B), dagger(A) + dagger(B))
@test isequal(dagger(ℋ), ℋ)
Expand Down

0 comments on commit af3c073

Please sign in to comment.