Skip to content

Commit

Permalink
[LabelledNumbers] Drop labels more aggressively in binary operations (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Mar 27, 2024
1 parent b48b881 commit 88fa9e9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 31 deletions.
35 changes: 12 additions & 23 deletions NDTensors/src/lib/LabelledNumbers/src/labelled_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,41 +33,30 @@ labelled_oneunit(x) = set_value(x, one(x))
# encoded in the type.
labelled_oneunit(type::Type) = error("Not implemented.")

labelled_mul(x, y) = labelled_binary_op(*, x, y)
labelled_add(x, y) = labelled_binary_op(+, x, y) #labelled_add(LabelledStyle(x), x, LabelledStyle(y), y)
labelled_minus(x, y) = labelled_binary_op(-, x, y) #labelled_add(LabelledStyle(x), x, LabelledStyle(y), y)

function labelled_binary_op(f, x, y)
return labelled_binary_op(f, LabelledStyle(x), x, LabelledStyle(y), y)
end
labelled_binary_op(f, ::IsLabelled, x, ::IsLabelled, y) = f(unlabel(x), unlabel(y))
labelled_binary_op(f, ::IsLabelled, x, ::NotLabelled, y) = set_value(x, f(unlabel(x), y))
labelled_binary_op(f, ::NotLabelled, x, ::IsLabelled, y) = set_value(y, f(x, unlabel(y)))
labelled_binary_op(f, ::LabelledStyle, x, ::LabelledStyle, y) = f(unlabel(x), unlabel(y))

# TODO: This is only needed for older Julia versions, like Julia 1.6.
# Delete once we drop support for older Julia versions.
# TODO: Define in terms of `set_value`?
labelled_minus(x) = set_value(x, -unlabel(x))

# TODO: This is only needed for older Julia versions, like Julia 1.6.
# Delete once we drop support for older Julia versions.
labelled_hash(x, h::UInt64) = hash(unlabel(x), h)

for (f, labelled_f) in [(:div, :labelled_div), (:/, :labelled_division)]
@eval begin
$labelled_f(x, y) = $labelled_f(LabelledStyle(x), x, LabelledStyle(y), y)
$labelled_f(::IsLabelled, x, ::IsLabelled, y) = $f(unlabel(x), unlabel(y))
$labelled_f(::IsLabelled, x, ::NotLabelled, y) = labelled($f(unlabel(x), y), label(x))
$labelled_f(::NotLabelled, x, ::IsLabelled, y) = $f(x, unlabel(y))
end
end

for f in [:isequal, :isless]
labelled_f = Symbol(:labelled_, f)
for (fname, f) in [
(:mul, :*),
(:add, :+),
(:minus, :-),
(:div, :/),
(:division, :÷),
(:isequal, :isequal),
(:isless, :isless),
]
labelled_fname = Symbol(:(labelled_), fname)
@eval begin
$labelled_f(x, y) = $labelled_f(LabelledStyle(x), x, LabelledStyle(y), y)
$labelled_f(::IsLabelled, x, ::IsLabelled, y) = $f(unlabel(x), unlabel(y))
$labelled_f(::IsLabelled, x, ::NotLabelled, y) = $f(unlabel(x), y)
$labelled_f(::NotLabelled, x, ::IsLabelled, y) = $f(x, unlabel(y))
$labelled_fname(x, y) = labelled_binary_op($f, x, y)
end
end
16 changes: 8 additions & 8 deletions NDTensors/src/lib/LabelledNumbers/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,30 @@ using Test: @test, @testset
@test !islabelled(unlabel(x))

@test x * 2 == 4
@test label(x * 2) == "x"
@test !islabelled(x * 2)
@test 2 * x == 4
@test label(2 * x) == "x"
@test !islabelled(2 * x)
@test x * x == 4
@test !islabelled(x * x)

@test x + 3 == 5
@test label(x + 3) == "x"
@test !islabelled(x + 3)
@test 3 + x == 5
@test label(3 + x) == "x"
@test !islabelled(3 + x)
@test x + x == 4
@test !islabelled(x + x)

@test x - 3 == -1
@test label(x - 3) == "x"
@test !islabelled(x - 3)
@test 3 - x == 1
@test label(3 - x) == "x"
@test !islabelled(3 - x)
@test x - x == 0
@test !islabelled(x - x)

@test x / 2 == 1
@test label(x / 2) == "x"
@test !islabelled(x / 2)
@test x ÷ 2 == 1
@test label(x ÷ 2) == "x"
@test !islabelled(x ÷ 2)
@test -x == -2
@test hash(x) == hash(2)
@test zero(x) == false
Expand Down

0 comments on commit 88fa9e9

Please sign in to comment.