diff --git a/trace/exporter_type.go b/trace/exporter_type.go index 206731acc3bd..13337b4ebfd4 100644 --- a/trace/exporter_type.go +++ b/trace/exporter_type.go @@ -10,14 +10,22 @@ import ( ) const ( - GRPC ExporterType = iota + 1 + NoOp ExporterType = iota + GRPC HTTP ) -var errUnknownExporterType = errors.New("unknown exporter type") +var ( + errUnknownExporterType = errors.New("unknown exporter type") + errInvalidFormat = errors.New("invalid format") +) func ExporterTypeFromString(exporterTypeStr string) (ExporterType, error) { switch strings.ToLower(exporterTypeStr) { + case NoOp.String(): + return 0, nil + case "null": + return 0, nil case GRPC.String(): return GRPC, nil case HTTP.String(): @@ -33,8 +41,23 @@ func (t ExporterType) MarshalJSON() ([]byte, error) { return []byte(`"` + t.String() + `"`), nil } +func (t *ExporterType) UnmarshalJSON(b []byte) error { + exporterTypeStr, err := stripQuotes(string(b)) + if err != nil { + return err + } + exporterType, err := ExporterTypeFromString(exporterTypeStr) + if err != nil { + return err + } + *t = exporterType + return nil +} + func (t ExporterType) String() string { switch t { + case NoOp: + return "" case GRPC: return "grpc" case HTTP: @@ -43,3 +66,10 @@ func (t ExporterType) String() string { return "unknown" } } + +func stripQuotes(s string) (string, error) { + if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' { + return "", errInvalidFormat + } + return s[1 : len(s)-1], nil +} diff --git a/trace/exporter_type_test.go b/trace/exporter_type_test.go new file mode 100644 index 000000000000..f2f5155abbec --- /dev/null +++ b/trace/exporter_type_test.go @@ -0,0 +1,96 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package trace + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMarshalUnmarshal(t *testing.T) { + tests := []struct { + Name string + ExporterType ExporterType + ExpectedErr error + }{ + { + Name: "GRPC", + ExporterType: GRPC, + }, + { + Name: "HTTP", + ExporterType: HTTP, + }, + { + Name: "NoOp", + ExporterType: NoOp, + }, + } + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + require := require.New(t) + + b, err := json.Marshal(tt.ExporterType) + require.NoError(err) + + var et ExporterType + require.NoError(json.Unmarshal(b, &et)) + + require.Equal(tt.ExporterType, et) + }) + } +} + +func TestUnmarshal(t *testing.T) { + tests := []struct { + Name string + Str string + ExpectedError error + }{ + { + Name: "NoQuotes", + Str: "grpc", + ExpectedError: errInvalidFormat, + }, + { + Name: "SingleLeftQuote", + Str: "\"grpc", + ExpectedError: errInvalidFormat, + }, + { + Name: "SingleRightQuote", + Str: "grpc\"", + ExpectedError: errInvalidFormat, + }, + { + Name: "MultipleQuotes", + Str: "\"\"grpc\"\"\"", + ExpectedError: errUnknownExporterType, + }, + { + Name: "NullString", + Str: "\"null\"", + ExpectedError: nil, + }, + { + Name: "EmptyString", + Str: "\"\"", + ExpectedError: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + require := require.New(t) + + var et ExporterType + + err := et.UnmarshalJSON([]byte(tt.Str)) + require.ErrorIs(err, tt.ExpectedError) + }) + } +}