diff --git a/compute.go b/compute.go index 283ab65..353fa0f 100644 --- a/compute.go +++ b/compute.go @@ -4,11 +4,26 @@ package stl4go // returns the result as type R, this is useful when T is too small to hold the result. // Complexity: O(len(a)). func SumAs[R, T Numeric](a []T) R { - var total R - for _, v := range a { - total += R(v) + switch zero := T(0); any(zero).(type) { + case int8, int16, int32, int, int64: + var total int64 + for _, v := range a { + total += int64(v) + } + return R(total) + case uint8, uint16, uint32, uint, uint64, uintptr: + var total uint64 + for _, v := range a { + total += uint64(v) + } + return R(total) + default: + var total float64 + for _, v := range a { + total += float64(v) + } + return R(total) } - return total } // Sum summarize all elements in a. diff --git a/compute_test.go b/compute_test.go index b09e85f..7fca033 100644 --- a/compute_test.go +++ b/compute_test.go @@ -1,10 +1,29 @@ package stl4go -import "testing" +import ( + "testing" +) func Test_SumAs(t *testing.T) { - a := Range[uint8](1, 101) - expectEq(t, SumAs[int](a), 5050) + t.Run("sum uint8 to int", func(t *testing.T) { + a := Range[uint8](1, 101) + expectEq(t, SumAs[int](a), 5050) + }) + + t.Run("sum int to uint8", func(t *testing.T) { + a := Range[int](1, 101) + expectEq(t, SumAs[uint8](a), uint8(5050%256)) + }) + + t.Run("sum int64 to float64", func(t *testing.T) { + a := Range[int64](1, 101) + expectEq(t, SumAs[float64](a), 5050.) + }) + + t.Run("sum float64 to int64", func(t *testing.T) { + a := Range[float64](1.1, 101.1) + expectEq(t, SumAs[int](a), 101.2*50) // 5060 + }) } func Test_Sum(t *testing.T) {