diff --git a/pkg/utils/grpcutil/grpcutil.go b/pkg/utils/grpcutil/grpcutil.go index 9b8cc2feb49..5633533ae4a 100644 --- a/pkg/utils/grpcutil/grpcutil.go +++ b/pkg/utils/grpcutil/grpcutil.go @@ -186,13 +186,9 @@ func ResetForwardContext(ctx context.Context) context.Context { // GetForwardedHost returns the forwarded host in metadata. func GetForwardedHost(ctx context.Context) string { - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - log.Debug("failed to get gRPC incoming metadata when getting forwarded host") - return "" - } - if t, ok := md[ForwardMetadataKey]; ok { - return t[0] + s := metadata.ValueFromIncomingContext(ctx, ForwardMetadataKey) + if len(s) > 0 { + return s[0] } return "" } diff --git a/pkg/utils/grpcutil/grpcutil_test.go b/pkg/utils/grpcutil/grpcutil_test.go index 2cbff4f3ebc..99cbeae6cde 100644 --- a/pkg/utils/grpcutil/grpcutil_test.go +++ b/pkg/utils/grpcutil/grpcutil_test.go @@ -1,6 +1,7 @@ package grpcutil import ( + "context" "os" "os/exec" "path" @@ -9,6 +10,7 @@ import ( "github.com/pingcap/errors" "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/errs" + "google.golang.org/grpc/metadata" ) var ( @@ -66,3 +68,14 @@ func TestToTLSConfig(t *testing.T) { _, err = tlsConfig.ToTLSConfig() re.True(errors.ErrorEqual(err, errs.ErrCryptoAppendCertsFromPEM)) } + +func BenchmarkGetForwardedHost(b *testing.B) { + // Without forwarded host key + md := metadata.Pairs("test", "example.com") + ctx := metadata.NewIncomingContext(context.Background(), md) + + // Run the GetForwardedHost function b.N times + for i := 0; i < b.N; i++ { + GetForwardedHost(ctx) + } +}