diff --git a/example/cmd/main.go b/example/cmd/main.go index 8c83c2c99..61c0503c4 100644 --- a/example/cmd/main.go +++ b/example/cmd/main.go @@ -66,7 +66,7 @@ func dumpMetadata(ctx context.Context) { if !ok { panic("no metadata") } - if err := json.NewEncoder(os.Stdout).Encode(md); err != nil { + if err := json.NewEncoder(os.Stdout).Encode(md.GetCopy()); err != nil { panic(err) } } @@ -109,7 +109,7 @@ func client() error { } ctx := context.Background() - md := ttrpc.MD{} + md := ttrpc.NewMD(make(map[string][]string)) md.Set("name", "koye") ctx = ttrpc.WithMetadata(ctx, md) diff --git a/metadata.go b/metadata.go index ce8c0d13c..aa58e3bfa 100644 --- a/metadata.go +++ b/metadata.go @@ -19,16 +19,30 @@ package ttrpc import ( "context" "strings" + "sync" ) // MD is the user type for ttrpc metadata -type MD map[string][]string +type MD struct { + data map[string][]string + mu sync.RWMutex +} + +// NewMD creates a metadata object. +func NewMD(data map[string][]string) *MD { + return &MD{ + data: data, + } +} // Get returns the metadata for a given key when they exist. // If there is no metadata, a nil slice and false are returned. -func (m MD) Get(key string) ([]string, bool) { +func (m *MD) Get(key string) ([]string, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + key = strings.ToLower(key) - list, ok := m[key] + list, ok := m.data[key] if !ok || len(list) == 0 { return nil, false } @@ -39,31 +53,54 @@ func (m MD) Get(key string) ([]string, bool) { // Set sets the provided values for a given key. // The values will overwrite any existing values. // If no values provided, a key will be deleted. -func (m MD) Set(key string, values ...string) { +func (m *MD) Set(key string, values ...string) { + m.mu.Lock() + defer m.mu.Unlock() + key = strings.ToLower(key) if len(values) == 0 { - delete(m, key) + delete(m.data, key) return } - m[key] = values + m.data[key] = values } // Append appends additional values to the given key. -func (m MD) Append(key string, values ...string) { +func (m *MD) Append(key string, values ...string) { key = strings.ToLower(key) if len(values) == 0 { return } - current, ok := m[key] + + m.mu.Lock() + defer m.mu.Unlock() + + current, ok := m.data[key] if ok { - m.Set(key, append(current, values...)...) + m.data[key] = append(current, values...) } else { - m.Set(key, values...) + m.data[key] = values } } -func (m MD) setRequest(r *Request) { - for k, values := range m { +// GetCopy returns the metadata for a given key when they exist. +func (m *MD) GetCopy() map[string][]string { + m.mu.RLock() + defer m.mu.RUnlock() + + mCopy := make(map[string][]string, len(m.data)) + for key, value := range m.data { + mCopy[key] = value + } + + return mCopy +} + +func (m *MD) setRequest(r *Request) { + m.mu.RLock() + defer m.mu.RUnlock() + + for k, values := range m.data { for _, v := range values { r.Metadata = append(r.Metadata, &KeyValue{ Key: k, @@ -73,17 +110,20 @@ func (m MD) setRequest(r *Request) { } } -func (m MD) fromRequest(r *Request) { +func (m *MD) fromRequest(r *Request) { + m.mu.Lock() + defer m.mu.Unlock() + for _, kv := range r.Metadata { - m[kv.Key] = append(m[kv.Key], kv.Value) + m.data[kv.Key] = append(m.data[kv.Key], kv.Value) } } type metadataKey struct{} // GetMetadata retrieves metadata from context.Context (previously attached with WithMetadata) -func GetMetadata(ctx context.Context) (MD, bool) { - metadata, ok := ctx.Value(metadataKey{}).(MD) +func GetMetadata(ctx context.Context) (*MD, bool) { + metadata, ok := ctx.Value(metadataKey{}).(*MD) return metadata, ok } @@ -102,6 +142,6 @@ func GetMetadataValue(ctx context.Context, name string) (string, bool) { } // WithMetadata attaches metadata map to a context.Context -func WithMetadata(ctx context.Context, md MD) context.Context { +func WithMetadata(ctx context.Context, md *MD) context.Context { return context.WithValue(ctx, metadataKey{}, md) } diff --git a/metadata_test.go b/metadata_test.go index d7fc09559..d1e932cdf 100644 --- a/metadata_test.go +++ b/metadata_test.go @@ -22,7 +22,7 @@ import ( ) func TestMetadataGet(t *testing.T) { - metadata := make(MD) + metadata := NewMD(make(map[string][]string)) metadata.Set("foo", "1", "2") if list, ok := metadata.Get("foo"); !ok { @@ -37,7 +37,7 @@ func TestMetadataGet(t *testing.T) { } func TestMetadataGetInvalidKey(t *testing.T) { - metadata := make(MD) + metadata := NewMD(make(map[string][]string)) metadata.Set("foo", "1", "2") if _, ok := metadata.Get("invalid"); ok { @@ -46,7 +46,7 @@ func TestMetadataGetInvalidKey(t *testing.T) { } func TestMetadataUnset(t *testing.T) { - metadata := make(MD) + metadata := NewMD(make(map[string][]string)) metadata.Set("foo", "1", "2") metadata.Set("foo") @@ -56,7 +56,7 @@ func TestMetadataUnset(t *testing.T) { } func TestMetadataReplace(t *testing.T) { - metadata := make(MD) + metadata := NewMD(make(map[string][]string)) metadata.Set("foo", "1", "2") metadata.Set("foo", "3", "4") @@ -72,7 +72,7 @@ func TestMetadataReplace(t *testing.T) { } func TestMetadataAppend(t *testing.T) { - metadata := make(MD) + metadata := NewMD(make(map[string][]string)) metadata.Set("foo", "1") metadata.Append("foo", "2") metadata.Append("bar", "3") @@ -95,7 +95,7 @@ func TestMetadataAppend(t *testing.T) { } func TestMetadataContext(t *testing.T) { - metadata := make(MD) + metadata := NewMD(make(map[string][]string)) metadata.Set("foo", "bar") ctx := WithMetadata(context.Background(), metadata) diff --git a/server.go b/server.go index bb71de677..50b0f34c4 100644 --- a/server.go +++ b/server.go @@ -571,7 +571,7 @@ var noopFunc = func() {} func getRequestContext(ctx context.Context, req *Request) (retCtx context.Context, cancel func()) { if len(req.Metadata) > 0 { - md := MD{} + md := NewMD(make(map[string][]string)) md.fromRequest(req) ctx = WithMetadata(ctx, md) } diff --git a/server_test.go b/server_test.go index 4a1561df8..537418125 100644 --- a/server_test.go +++ b/server_test.go @@ -550,7 +550,7 @@ func roundTrip(ctx context.Context, client *testingClient, name string) callResu } ) - ctx = WithMetadata(ctx, MD{"foo": []string{name}}) + ctx = WithMetadata(ctx, NewMD(map[string][]string{"foo": {name}})) resp, err := client.Test(ctx, tp) if err != nil {