Skip to content

Commit

Permalink
fix(metadata): mutex for concurrency
Browse files Browse the repository at this point in the history
  • Loading branch information
just1not2 committed Dec 11, 2024
1 parent b71d9de commit fa7f73d
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 27 deletions.
4 changes: 2 additions & 2 deletions example/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func dumpMetadata(ctx context.Context) {
if !ok {
panic("no metadata")
}
if err := json.NewEncoder(os.Stdout).Encode(md); err != nil {
if err := json.NewEncoder(os.Stdout).Encode(md.GetCopy()); err != nil {
panic(err)
}
}
Expand Down Expand Up @@ -109,7 +109,7 @@ func client() error {
}

ctx := context.Background()
md := ttrpc.MD{}
md := ttrpc.NewMD(make(map[string][]string))
md.Set("name", "koye")
ctx = ttrpc.WithMetadata(ctx, md)

Expand Down
74 changes: 57 additions & 17 deletions metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,30 @@ package ttrpc
import (
"context"
"strings"
"sync"
)

// MD is the user type for ttrpc metadata
type MD map[string][]string
type MD struct {
data map[string][]string
mu sync.RWMutex
}

// NewMD creates a metadata object.
func NewMD(data map[string][]string) *MD {
return &MD{
data: data,
}
}

// Get returns the metadata for a given key when they exist.
// If there is no metadata, a nil slice and false are returned.
func (m MD) Get(key string) ([]string, bool) {
func (m *MD) Get(key string) ([]string, bool) {
m.mu.RLock()
defer m.mu.RUnlock()

key = strings.ToLower(key)
list, ok := m[key]
list, ok := m.data[key]
if !ok || len(list) == 0 {
return nil, false
}
Expand All @@ -39,31 +53,54 @@ func (m MD) Get(key string) ([]string, bool) {
// Set sets the provided values for a given key.
// The values will overwrite any existing values.
// If no values provided, a key will be deleted.
func (m MD) Set(key string, values ...string) {
func (m *MD) Set(key string, values ...string) {
m.mu.Lock()
defer m.mu.Unlock()

key = strings.ToLower(key)
if len(values) == 0 {
delete(m, key)
delete(m.data, key)
return
}
m[key] = values
m.data[key] = values
}

// Append appends additional values to the given key.
func (m MD) Append(key string, values ...string) {
func (m *MD) Append(key string, values ...string) {
key = strings.ToLower(key)
if len(values) == 0 {
return
}
current, ok := m[key]

m.mu.Lock()
defer m.mu.Unlock()

current, ok := m.data[key]
if ok {
m.Set(key, append(current, values...)...)
m.data[key] = append(current, values...)
} else {
m.Set(key, values...)
m.data[key] = values
}
}

func (m MD) setRequest(r *Request) {
for k, values := range m {
// GetCopy returns the metadata for a given key when they exist.
func (m *MD) GetCopy() map[string][]string {
m.mu.RLock()
defer m.mu.RUnlock()

mCopy := make(map[string][]string, len(m.data))
for key, value := range m.data {
mCopy[key] = value
}

return mCopy
}

func (m *MD) setRequest(r *Request) {
m.mu.RLock()
defer m.mu.RUnlock()

for k, values := range m.data {
for _, v := range values {
r.Metadata = append(r.Metadata, &KeyValue{
Key: k,
Expand All @@ -73,17 +110,20 @@ func (m MD) setRequest(r *Request) {
}
}

func (m MD) fromRequest(r *Request) {
func (m *MD) fromRequest(r *Request) {
m.mu.Lock()
defer m.mu.Unlock()

for _, kv := range r.Metadata {
m[kv.Key] = append(m[kv.Key], kv.Value)
m.data[kv.Key] = append(m.data[kv.Key], kv.Value)
}
}

type metadataKey struct{}

// GetMetadata retrieves metadata from context.Context (previously attached with WithMetadata)
func GetMetadata(ctx context.Context) (MD, bool) {
metadata, ok := ctx.Value(metadataKey{}).(MD)
func GetMetadata(ctx context.Context) (*MD, bool) {
metadata, ok := ctx.Value(metadataKey{}).(*MD)
return metadata, ok
}

Expand All @@ -102,6 +142,6 @@ func GetMetadataValue(ctx context.Context, name string) (string, bool) {
}

// WithMetadata attaches metadata map to a context.Context
func WithMetadata(ctx context.Context, md MD) context.Context {
func WithMetadata(ctx context.Context, md *MD) context.Context {
return context.WithValue(ctx, metadataKey{}, md)
}
12 changes: 6 additions & 6 deletions metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
)

func TestMetadataGet(t *testing.T) {
metadata := make(MD)
metadata := NewMD(make(map[string][]string))
metadata.Set("foo", "1", "2")

if list, ok := metadata.Get("foo"); !ok {
Expand All @@ -37,7 +37,7 @@ func TestMetadataGet(t *testing.T) {
}

func TestMetadataGetInvalidKey(t *testing.T) {
metadata := make(MD)
metadata := NewMD(make(map[string][]string))
metadata.Set("foo", "1", "2")

if _, ok := metadata.Get("invalid"); ok {
Expand All @@ -46,7 +46,7 @@ func TestMetadataGetInvalidKey(t *testing.T) {
}

func TestMetadataUnset(t *testing.T) {
metadata := make(MD)
metadata := NewMD(make(map[string][]string))
metadata.Set("foo", "1", "2")
metadata.Set("foo")

Expand All @@ -56,7 +56,7 @@ func TestMetadataUnset(t *testing.T) {
}

func TestMetadataReplace(t *testing.T) {
metadata := make(MD)
metadata := NewMD(make(map[string][]string))
metadata.Set("foo", "1", "2")
metadata.Set("foo", "3", "4")

Expand All @@ -72,7 +72,7 @@ func TestMetadataReplace(t *testing.T) {
}

func TestMetadataAppend(t *testing.T) {
metadata := make(MD)
metadata := NewMD(make(map[string][]string))
metadata.Set("foo", "1")
metadata.Append("foo", "2")
metadata.Append("bar", "3")
Expand All @@ -95,7 +95,7 @@ func TestMetadataAppend(t *testing.T) {
}

func TestMetadataContext(t *testing.T) {
metadata := make(MD)
metadata := NewMD(make(map[string][]string))
metadata.Set("foo", "bar")

ctx := WithMetadata(context.Background(), metadata)
Expand Down
2 changes: 1 addition & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ var noopFunc = func() {}

func getRequestContext(ctx context.Context, req *Request) (retCtx context.Context, cancel func()) {
if len(req.Metadata) > 0 {
md := MD{}
md := NewMD(make(map[string][]string))
md.fromRequest(req)
ctx = WithMetadata(ctx, md)
}
Expand Down
2 changes: 1 addition & 1 deletion server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ func roundTrip(ctx context.Context, client *testingClient, name string) callResu
}
)

ctx = WithMetadata(ctx, MD{"foo": []string{name}})
ctx = WithMetadata(ctx, NewMD(map[string][]string{"foo": {name}}))

resp, err := client.Test(ctx, tp)
if err != nil {
Expand Down

0 comments on commit fa7f73d

Please sign in to comment.