diff --git a/err_impl.go b/err_impl.go index a68e51c..f8ef34b 100644 --- a/err_impl.go +++ b/err_impl.go @@ -1,7 +1,7 @@ package errors import ( - "errors" + "cmp" ) func newE(opts ...any) error { @@ -9,6 +9,9 @@ func newE(opts ...any) error { skipFrames := FrameSkips(3) for _, o := range opts { + if o == nil { + continue + } switch arg := o.(type) { case string: err.msg = arg @@ -83,19 +86,22 @@ func (err *e) Fields() []any { out []any kind Kind ) - for err := error(err); err != nil; err = errors.Unwrap(err) { - ee, ok := err.(*e) - if !ok { - continue - } - - for _, kv := range ee.kvs { + for err := error(err); err != nil; err = Unwrap(err) { + em := getErrMeta(err) + for _, kv := range em.kvs { out = append(out, kv.K, kv.V) } - if kind == "" { - kind = ee.kind + kind = cmp.Or(kind, em.kind) + if ej, ok := err.(*joinE); ok { + innerKind, multiErrFields := ej.subErrFields() + kind = cmp.Or(kind, innerKind) + if len(multiErrFields) > 0 { + out = append(out, "multi_err", multiErrFields) + } + break } } + if kind != "" { out = append(out, "err_kind", string(kind)) } @@ -110,18 +116,18 @@ func (err *e) Fields() []any { return out } +func (err *e) Is(target error) bool { + kind, ok := target.(Kind) + return ok && err.kind == kind +} + func (err *e) Unwrap() error { return err.wrappedErr } func (err *e) V(key string) (any, bool) { - for err := error(err); err != nil; err = errors.Unwrap(err) { - ee, ok := err.(*e) - if !ok { - continue - } - - for _, kv := range ee.kvs { + for err := error(err); err != nil; err = Unwrap(err) { + for _, kv := range getErrMeta(err).kvs { if kv.K == key { return kv.V, true } @@ -132,10 +138,47 @@ func (err *e) V(key string) (any, bool) { func (err *e) stackTrace() StackFrames { var out StackFrames - for err := error(err); err != nil; err = errors.Unwrap(err) { - if ee, ok := err.(*e); ok && ee.frame.FilePath != "" { - out = append(out, ee.frame) + for err := error(err); err != nil; err = Unwrap(err) { + em := getErrMeta(err) + if em.frame.FilePath == "" { + continue + } + out = append(out, em.frame) + if em.errType == errTypeJoin { + break } } return out } + +type errMeta struct { + kind Kind + frame Frame + kvs []KV + errType string +} + +const ( + errTypeE = "e" + errTypeJoin = "j" +) + +func getErrMeta(err error) errMeta { + var em errMeta + switch err := err.(type) { + case *e: + em.kind, em.frame, em.kvs, em.errType = err.kind, err.frame, err.kvs, errTypeE + case *joinE: + em.kind, em.frame, em.kvs, em.errType = err.kind, err.frame, err.kvs, errTypeJoin + } + return em +} + +func getKind(err error) Kind { + for ; err != nil; err = Unwrap(err) { + if em := getErrMeta(err); em.kind != "" { + return em.kind + } + } + return "" +} diff --git a/errors.go b/errors.go index 8b21784..7eef3d0 100644 --- a/errors.go +++ b/errors.go @@ -23,6 +23,16 @@ func Wrap(err error, opts ...any) error { return newE(passedOpts...) } +// Join returns a new multi error. +// +// TODO: +// - play with Join(opts ...any) and Join(errs []error, opts ...any) sigs +// and ask for feedback regarding tradeoffs with type safety of first arg. +// As of writing some tests, I kind of dig the loose Join(opts ...any). +func Join(opts ...any) error { + return newJoinE(opts...) +} + // Fields returns logging fields for a given error. func Fields(err error) []any { if err == nil { @@ -41,8 +51,10 @@ func Fields(err error) []any { // TODO: // 1. make this more robust with Is // 2. determine if its even worth exposing an accessor for this private method +// 3. allow for StackTraces() to accommodate joined errors, perhaps returning a map[string]StackFrames +// or some graph representation would be awesome. func StackTrace(err error) StackFrames { - ee, ok := err.(*e) + ee, ok := err.(interface{ stackTrace() StackFrames }) if !ok { return nil } diff --git a/errors_stack_traces_test.go b/errors_stack_traces_test.go index 0fd5812..7b44dc5 100644 --- a/errors_stack_traces_test.go +++ b/errors_stack_traces_test.go @@ -181,6 +181,43 @@ func Test_Errors(t *testing.T) { }, }, }, + { + name: "with wrapped joined errors error with inner kind", + input: errors.Wrap( + errors.Wrap( + errors.Join( + errors.Wrap( + errors.New("first error", errors.Kind("inner")), + ), + errors.New("second error"), + ), + ), + ), + want: wants{ + msg: `2 errors occurred: + * first error + * second error +`, + fields: []any{ + "multi_err", []any{ + "err_0", []any{ + "err_kind", "inner", + "stack_trace", []string{ + "github.com/jsteenb2/errors/errors_stack_traces_test.go:189[Test_Errors]", + "github.com/jsteenb2/errors/errors_stack_traces_test.go:190[Test_Errors]", + }, + }, + "err_1", []any{"stack_trace", []string{"github.com/jsteenb2/errors/errors_stack_traces_test.go:192[Test_Errors]"}}, + }, + "err_kind", "inner", + "stack_trace", []string{ + "github.com/jsteenb2/errors/errors_stack_traces_test.go:186[Test_Errors]", + "github.com/jsteenb2/errors/errors_stack_traces_test.go:187[Test_Errors]", + "github.com/jsteenb2/errors/errors_stack_traces_test.go:188[Test_Errors]", + }, + }, + }, + }, } for _, tt := range tests { diff --git a/errors_test.go b/errors_test.go index e8352d6..4c0efbd 100644 --- a/errors_test.go +++ b/errors_test.go @@ -1,6 +1,7 @@ package errors_test import ( + "encoding/json" "reflect" "testing" @@ -77,6 +78,13 @@ func eqV[T comparable](t *testing.T, err error, key string, want T) bool { func eqFields(t *testing.T, want, got []any) bool { t.Helper() + defer func() { + if t.Failed() { + b, _ := json.MarshalIndent(got, "", " ") + t.Logf("got: %s", string(b)) + } + }() + if matches := eqLen(t, len(want), got); !matches { return matches } diff --git a/join.go b/join.go new file mode 100644 index 0000000..40fa792 --- /dev/null +++ b/join.go @@ -0,0 +1,234 @@ +package errors + +import ( + "cmp" + "errors" + "fmt" + "strings" +) + +func newJoinE(opts ...any) error { + var ( + baseOpts = make([]any, 1, len(opts)+1) + errs []error + formatFn = listFormatFn + ) + // since we're calling newE from 3 frames away instead of 2 + baseOpts[0] = SkipCaller + + // here we'll make use of a split loop, so that we aren't + // polluting the newE with multi-err concerns it does not + // need to be bothered with. + for _, o := range opts { + if o == nil { + continue + } + switch v := o.(type) { + case error: + errs = append(errs, v) + case []error: + errs = append(errs, v...) + case JoinFormatFn: + if v != nil { + formatFn = v + } + default: + baseOpts = append(baseOpts, o) + } + } + if len(errs) == 0 { + return nil + } + + ee := newE(baseOpts...).(*e) + return &joinE{ + msg: ee.msg, + formatFn: formatFn, + frame: ee.frame, + kind: ee.kind, + errs: errs, + kvs: ee.kvs, + } +} + +type joinE struct { + msg string + + formatFn JoinFormatFn + frame Frame + kind Kind + errs []error + + // TODO: + // 1. should kvs be a map instead? aka unique by key name? + // * if unique by name... what to do with collisions, last write wins? combine values into slice? + // or have some other way to signal what to do with collisions via an additional option? + // 2. if slice of KVs, do we separate the stack frames from the output when + // calling something like Meta/Fields on the error? Then have a specific + // function for getting the logging fields (i.e. everything to []any) + kvs []KV +} + +func (err *joinE) Error() string { + return err.formatFn(err.msg, err.errs) +} + +func (err *joinE) Fields() []any { + var ( + out []any + kind = err.kind + ) + for _, kv := range err.kvs { + out = append(out, kv.K, kv.V) + } + + innerKind, subErrFields := err.subErrFields() + kind = cmp.Or(kind, innerKind) + if kind != "" { + out = append(out, "err_kind", string(kind)) + } + if stackFrames := err.stackTrace(); len(stackFrames) > 0 { + var simplified []string + for _, frame := range stackFrames { + simplified = append(simplified, frame.String()) + } + out = append(out, "stack_trace", simplified) + } + for _, v := range subErrFields { + out = append(out, v) + } + + return out +} + +func (err *joinE) subErrFields() (Kind, []any) { + var ( + kind Kind + subErrFields []any + ) + for i, err := range err.errs { + var errFields []any + switch err := err.(type) { + case *e: + errFields = err.Fields() + case *joinE: + errFields = err.Fields() + } + if len(errFields) > 0 { + subErrFields = append(subErrFields, fmt.Sprintf("err_%d", i), errFields) + } + if innerKind := getKind(err); kind == "" && innerKind != "" { + kind = innerKind + } + } + return kind, subErrFields +} + +func (err *joinE) stackTrace() StackFrames { + if err.frame.FilePath == "" { + return nil + } + return StackFrames{err.frame} +} + +// Unwrap returns an error from Error (or nil if there are no errors). +// This error returned will further support Unwrap to get the next error, +// etc. The order will match the order of errors provided when calling Join. +// +// The resulting error supports errors.As/Is/Unwrap so you can continue +// to use the stdlib errors package to introspect further. +// +// The is borrowed from hashi/go-multierror module. +func (err *joinE) Unwrap() error { + if err == nil || len(err.errs) == 0 { + return nil + } + + if len(err.errs) == 1 { + return err.errs[0] + } + + // Shallow copy the slice + errs := make([]error, len(err.errs)) + copy(errs, err.errs) + return chain(errs) +} + +// chain implements the interfaces necessary for errors.Is/As/Unwrap to +// work in a deterministic way with multierror. A chain tracks a list of +// errors while accounting for the current represented error. This lets +// Is/As be meaningful. +// +// Unwrap returns the next error. In the cleanest form, Unwrap would return +// the wrapped error here but we can't do that if we want to properly +// get access to all the errors. Instead, users are recommended to use +// Is/As to get the correct error type out. +// +// Precondition: []error is non-empty (len > 0) +// +// TODO: +// - add support for Fields +// - add support stack trace +// - question is, do we make these show fields/stack trace for +// each individual error similar to how the Unwrapping is forcing +// users to interact with the unwrapped Join error, or make it list +// all fields/stack traces (not sure what stack trace would look like here)? +type chain []error + +// Error implements the error interface +func (e chain) Error() string { + return e[0].Error() +} + +func (e chain) Fields() []any { + fielder, ok := e[0].(interface{ Fields() []any }) + if !ok { + return nil + } + return fielder.Fields() +} + +func (e chain) stackTrace() StackFrames { + st, ok := e[0].(interface{ stackTrace() StackFrames }) + if !ok { + return nil + } + return st.stackTrace() +} + +// Unwrap implements errors.Unwrap by returning the next error in the +// chain or nil if there are no more errors. +func (e chain) Unwrap() error { + if len(e) == 1 { + return nil + } + + return e[1:] +} + +// As implements errors.As by attempting to map to the current value. +func (e chain) As(target interface{}) bool { + return errors.As(e[0], target) +} + +// Is implements errors.Is by comparing the current value directly. +func (e chain) Is(target error) bool { + return errors.Is(e[0], target) +} + +// listFormatFn borrowed from hashi go-multierror module. +func listFormatFn(msg string, errs []error) string { + if msg == "" && len(errs) == 1 { + return fmt.Sprintf("1 error occurred:\n\t* %s\n", errs[0]) + } + + points := make([]string, len(errs)) + for i, err := range errs { + points[i] = fmt.Sprintf("* %s", err) + } + + if msg == "" { + msg = fmt.Sprintf("%d errors occurred:\n\t", len(errs)) + } + return fmt.Sprintf("%s%s\n", msg, strings.Join(points, "\n\t")) +} diff --git a/join_test.go b/join_test.go new file mode 100644 index 0000000..7dc9426 --- /dev/null +++ b/join_test.go @@ -0,0 +1,105 @@ +package errors_test + +import ( + "fmt" + "testing" + + "github.com/jsteenb2/errors" +) + +var sentinelErr = fmt.Errorf("sentinel err") + +func TestJoin(t *testing.T) { + t.Run("single error joined error can be unwrapped", func(t *testing.T) { + err := errors.Join(errors.New("first multi error")) + + gotMsg := err.Error() + eq(t, "1 error occurred:\n\t* first multi error\n", gotMsg) + + unwrappedErr := errors.Unwrap(err) + if unwrappedErr == nil { + t.Fatal("unexpected nil unwrapped error") + } + + gotMsg = unwrappedErr.Error() + eq(t, "first multi error", gotMsg) + }) + + t.Run("multiple joined errors can be unwrapped", func(t *testing.T) { + err := errors.Join( + errors.New("err 1"), + errors.New("err 2"), + ) + + wantMsg := `2 errors occurred: + * err 1 + * err 2 +` + eq(t, wantMsg, err.Error()) + + unwrappedErr := errors.Unwrap(err) + if unwrappedErr == nil { + t.Fatal("unexpected nil unwrapped error") + } + eq(t, "err 1", unwrappedErr.Error()) + + unwrappedErr = errors.Unwrap(unwrappedErr) + if unwrappedErr == nil { + t.Fatal("unexpected nil unwrapped error") + } + eq(t, "err 2", unwrappedErr.Error()) + }) + + t.Run("multiple joined errors can be used with Is and As", func(t *testing.T) { + err := errors.Join( + errors.New("err 1", errors.Kind("foo")), + sentinelErr, + ) + + wantMsg := `2 errors occurred: + * err 1 + * sentinel err +` + eq(t, wantMsg, err.Error()) + + if !errors.Is(err, sentinelErr) { + t.Errorf("failed to identify sentinel error") + } + if !errors.Is(err, errors.Kind("foo")) { + t.Error("failed to find matching kind error") + } + }) + + t.Run("multiple joined errors can be used with Fields", func(t *testing.T) { + err := errors.Join( + errors.New("err 1", errors.Kind("foo"), errors.KVs("ki1", "vi1")), + sentinelErr, + errors.New("err 3", errors.KVs("ki3", "vi3")), + errors.Join( + errors.New("err 4"), + ), + errors.KVs("kj1", "vj1"), + ) + wantFields := []any{ + // parent Join error + "kj1", "vj1", "err_kind", "foo", "stack_trace", []string{"github.com/jsteenb2/errors/join_test.go:74[TestJoin.func4]"}, + // first err + "err_0", []any{"ki1", "vi1", "err_kind", "foo", "stack_trace", []string{"github.com/jsteenb2/errors/join_test.go:75[TestJoin.func4]"}}, + // third err + "err_2", []any{"ki3", "vi3", "stack_trace", []string{"github.com/jsteenb2/errors/join_test.go:77[TestJoin.func4]"}}, + // fourth err + "err_3", []any{ + "stack_trace", []string{"github.com/jsteenb2/errors/join_test.go:78[TestJoin.func4]"}, + "err_0", []any{"stack_trace", []string{"github.com/jsteenb2/errors/join_test.go:79[TestJoin.func4]"}}, + }, + } + eqFields(t, wantFields, errors.Fields(err)) + + unwrapped := errors.Unwrap(err) + wantFields = []any{"ki1", "vi1", "err_kind", "foo", "stack_trace", []string{"github.com/jsteenb2/errors/join_test.go:75[TestJoin.func4]"}} + eqFields(t, wantFields, errors.Fields(unwrapped)) + + sentinelUnwrapped := errors.Unwrap(unwrapped) + eqFields(t, nil, errors.Fields(sentinelUnwrapped)) + }) +} diff --git a/options.go b/options.go index 6d7204d..e529292 100644 --- a/options.go +++ b/options.go @@ -19,6 +19,10 @@ const ( SkipCaller FrameSkips = 1 ) +// JoinFormatFn is the join errors formatter. This allows the user to customize +// the text output when calling Error() on the join error. +type JoinFormatFn func(msg string, errs []error) string + // Kind represents the category of the error type. A few examples of // error kinds are as follows: // @@ -40,7 +44,7 @@ const ( // target error is of kind "first": // // err := errors.New("some error", errors.Kind("first")) -// stderrors.Is(errors.Kind("first"), err) // output is true +// errors.Is(err, errors.Kind("first")) // output is true type Kind string // Error returns the error string indicating the kind's error. This is