From 58d813f803c27b421c1e6ed47f4987ff3c0daefd Mon Sep 17 00:00:00 2001 From: Sasha Melentyev Date: Fri, 11 Aug 2023 20:14:44 +0300 Subject: [PATCH] feat: add join --- go.mod | 2 +- join.go | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++ join_go120.go | 8 +++++++ join_test.go | 34 +++++++++++++++++++++++++++++ 4 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 join.go create mode 100644 join_go120.go create mode 100644 join_test.go diff --git a/go.mod b/go.mod index e421c82..5d65caa 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/go-faster/errors -go 1.18 +go 1.20 diff --git a/join.go b/join.go new file mode 100644 index 0000000..d1c2a53 --- /dev/null +++ b/join.go @@ -0,0 +1,59 @@ +//go:build !go1.20 +// +build !go1.20 + +package errors + +import "unsafe" + +// Join returns an error that wraps the given errors. +// Any nil error values are discarded. +// Join returns nil if every value in errs is nil. +// The error formats as the concatenation of the strings obtained +// by calling the Error method of each element of errs, with a newline +// between each string. +// +// A non-nil error returned by Join implements the Unwrap() []error method. +func Join(errs ...error) error { + n := 0 + for _, err := range errs { + if err != nil { + n++ + } + } + if n == 0 { + return nil + } + e := &joinError{ + errs: make([]error, 0, n), + } + for _, err := range errs { + if err != nil { + e.errs = append(e.errs, err) + } + } + return e +} + +type joinError struct { + errs []error +} + +func (e *joinError) Error() string { + // Since Join returns nil if every value in errs is nil, + // e.errs cannot be empty. + if len(e.errs) == 1 { + return e.errs[0].Error() + } + + b := []byte(e.errs[0].Error()) + for _, err := range e.errs[1:] { + b = append(b, '\n') + b = append(b, err.Error()...) + } + // At this point, b has at least one byte '\n'. + return unsafe.String(&b[0], len(b)) +} + +func (e *joinError) Unwrap() []error { + return e.errs +} diff --git a/join_go120.go b/join_go120.go new file mode 100644 index 0000000..ecf930c --- /dev/null +++ b/join_go120.go @@ -0,0 +1,8 @@ +//go:build go1.20 +// +build go1.20 + +package errors + +import "errors" + +var Join = errors.Join diff --git a/join_test.go b/join_test.go new file mode 100644 index 0000000..bfc5667 --- /dev/null +++ b/join_test.go @@ -0,0 +1,34 @@ +package errors_test + +import ( + "reflect" + "testing" + + "github.com/go-faster/errors" +) + +func TestJoin(t *testing.T) { + err1 := errors.New("err1") + err2 := errors.New("err2") + for _, test := range []struct { + errs []error + want []error + }{{ + errs: []error{err1}, + want: []error{err1}, + }, { + errs: []error{err1, err2}, + want: []error{err1, err2}, + }, { + errs: []error{err1, nil, err2}, + want: []error{err1, err2}, + }} { + got := errors.Join(test.errs...).(interface{ Unwrap() []error }).Unwrap() + if !reflect.DeepEqual(got, test.want) { + t.Errorf("Join(%v) = %v; want %v", test.errs, got, test.want) + } + if len(got) != cap(got) { + t.Errorf("Join(%v) returns errors with len=%v, cap=%v; want len==cap", test.errs, len(got), cap(got)) + } + } +}