diff --git a/NDTensors/src/lib/LabelledNumbers/src/labelled_interface.jl b/NDTensors/src/lib/LabelledNumbers/src/labelled_interface.jl index 6cc65c46d6..3f267ce966 100644 --- a/NDTensors/src/lib/LabelledNumbers/src/labelled_interface.jl +++ b/NDTensors/src/lib/LabelledNumbers/src/labelled_interface.jl @@ -33,41 +33,30 @@ labelled_oneunit(x) = set_value(x, one(x)) # encoded in the type. labelled_oneunit(type::Type) = error("Not implemented.") -labelled_mul(x, y) = labelled_binary_op(*, x, y) -labelled_add(x, y) = labelled_binary_op(+, x, y) #labelled_add(LabelledStyle(x), x, LabelledStyle(y), y) -labelled_minus(x, y) = labelled_binary_op(-, x, y) #labelled_add(LabelledStyle(x), x, LabelledStyle(y), y) - function labelled_binary_op(f, x, y) return labelled_binary_op(f, LabelledStyle(x), x, LabelledStyle(y), y) end -labelled_binary_op(f, ::IsLabelled, x, ::IsLabelled, y) = f(unlabel(x), unlabel(y)) -labelled_binary_op(f, ::IsLabelled, x, ::NotLabelled, y) = set_value(x, f(unlabel(x), y)) -labelled_binary_op(f, ::NotLabelled, x, ::IsLabelled, y) = set_value(y, f(x, unlabel(y))) +labelled_binary_op(f, ::LabelledStyle, x, ::LabelledStyle, y) = f(unlabel(x), unlabel(y)) # TODO: This is only needed for older Julia versions, like Julia 1.6. # Delete once we drop support for older Julia versions. -# TODO: Define in terms of `set_value`? labelled_minus(x) = set_value(x, -unlabel(x)) # TODO: This is only needed for older Julia versions, like Julia 1.6. # Delete once we drop support for older Julia versions. labelled_hash(x, h::UInt64) = hash(unlabel(x), h) -for (f, labelled_f) in [(:div, :labelled_div), (:/, :labelled_division)] - @eval begin - $labelled_f(x, y) = $labelled_f(LabelledStyle(x), x, LabelledStyle(y), y) - $labelled_f(::IsLabelled, x, ::IsLabelled, y) = $f(unlabel(x), unlabel(y)) - $labelled_f(::IsLabelled, x, ::NotLabelled, y) = labelled($f(unlabel(x), y), label(x)) - $labelled_f(::NotLabelled, x, ::IsLabelled, y) = $f(x, unlabel(y)) - end -end - -for f in [:isequal, :isless] - labelled_f = Symbol(:labelled_, f) +for (fname, f) in [ + (:mul, :*), + (:add, :+), + (:minus, :-), + (:div, :/), + (:division, :÷), + (:isequal, :isequal), + (:isless, :isless), +] + labelled_fname = Symbol(:(labelled_), fname) @eval begin - $labelled_f(x, y) = $labelled_f(LabelledStyle(x), x, LabelledStyle(y), y) - $labelled_f(::IsLabelled, x, ::IsLabelled, y) = $f(unlabel(x), unlabel(y)) - $labelled_f(::IsLabelled, x, ::NotLabelled, y) = $f(unlabel(x), y) - $labelled_f(::NotLabelled, x, ::IsLabelled, y) = $f(x, unlabel(y)) + $labelled_fname(x, y) = labelled_binary_op($f, x, y) end end diff --git a/NDTensors/src/lib/LabelledNumbers/test/runtests.jl b/NDTensors/src/lib/LabelledNumbers/test/runtests.jl index bfb6983e79..cf3f87e86d 100644 --- a/NDTensors/src/lib/LabelledNumbers/test/runtests.jl +++ b/NDTensors/src/lib/LabelledNumbers/test/runtests.jl @@ -12,30 +12,30 @@ using Test: @test, @testset @test !islabelled(unlabel(x)) @test x * 2 == 4 - @test label(x * 2) == "x" + @test !islabelled(x * 2) @test 2 * x == 4 - @test label(2 * x) == "x" + @test !islabelled(2 * x) @test x * x == 4 @test !islabelled(x * x) @test x + 3 == 5 - @test label(x + 3) == "x" + @test !islabelled(x + 3) @test 3 + x == 5 - @test label(3 + x) == "x" + @test !islabelled(3 + x) @test x + x == 4 @test !islabelled(x + x) @test x - 3 == -1 - @test label(x - 3) == "x" + @test !islabelled(x - 3) @test 3 - x == 1 - @test label(3 - x) == "x" + @test !islabelled(3 - x) @test x - x == 0 @test !islabelled(x - x) @test x / 2 == 1 - @test label(x / 2) == "x" + @test !islabelled(x / 2) @test x ÷ 2 == 1 - @test label(x ÷ 2) == "x" + @test !islabelled(x ÷ 2) @test -x == -2 @test hash(x) == hash(2) @test zero(x) == false