diff --git a/cmd/agent/app/options/options.go b/cmd/agent/app/options/options.go index a0e58f5dc..37d14ced7 100644 --- a/cmd/agent/app/options/options.go +++ b/cmd/agent/app/options/options.go @@ -92,7 +92,7 @@ func (o *GrpcProxyAgentOptions) Flags() *pflag.FlagSet { flags.DurationVar(&o.SyncIntervalCap, "sync-interval-cap", o.SyncIntervalCap, "The maximum interval for the SyncInterval to back off to when unable to connect to the proxy server") flags.DurationVar(&o.KeepaliveTime, "keepalive-time", o.KeepaliveTime, "Time for gRPC agent server keepalive.") flags.StringVar(&o.ServiceAccountTokenPath, "service-account-token-path", o.ServiceAccountTokenPath, "If non-empty proxy agent uses this token to prove its identity to the proxy server.") - flags.StringVar(&o.AgentIdentifiers, "agent-identifiers", o.AgentIdentifiers, "Identifiers of the agent that will be used by the server when choosing agent. N.B. the list of identifiers must be in URL encoded format. e.g.,host=localhost&host=node1.mydomain.com&cidr=127.0.0.1/16&ipv4=1.2.3.4&ipv4=5.6.7.8&ipv6=:::::&default-route=true") + flags.StringVar(&o.AgentIdentifiers, "agent-identifiers", o.AgentIdentifiers, "Identifiers of the agent that will be used by the server when choosing agent. N.B. the list of identifiers must be in URL encoded format. e.g.,host=localhost&host=node1.mydomain.com&cidr=127.0.0.1/16&ipv4=1.2.3.4&ipv4=5.6.7.8&ipv6=:::::&default-route=true. If a host starts with . it will be treated as wildcard") flags.BoolVar(&o.WarnOnChannelLimit, "warn-on-channel-limit", o.WarnOnChannelLimit, "Turns on a warning if the system is going to push to a full channel. The check involves an unsafe read.") flags.BoolVar(&o.SyncForever, "sync-forever", o.SyncForever, "If true, the agent continues syncing, in order to support server count changes.") return flags diff --git a/pkg/server/desthost_backend_manager.go b/pkg/server/desthost_backend_manager.go index 1d7dd89cd..73b08bb75 100644 --- a/pkg/server/desthost_backend_manager.go +++ b/pkg/server/desthost_backend_manager.go @@ -18,6 +18,7 @@ package server import ( "context" + "strings" "k8s.io/klog/v2" "sigs.k8s.io/apiserver-network-proxy/pkg/agent" @@ -49,6 +50,18 @@ func (dibm *DestHostBackendManager) Backend(ctx context.Context) (Backend, error klog.V(5).InfoS("Get the backend through the DestHostBackendManager", "destHost", destHost) return dibm.backends[destHost][0], nil } + + for backend := range dibm.backends { + // Ignore backends that do not have a leading dot after stripping a leading *, we don't want foo.com to match for barfoo.com + if !strings.HasPrefix(backend, ".") { + continue + } + klog.V(5).Infof("Checking for wildcard match", "backend", backend, "destHost", destHost) + if strings.HasSuffix(destHost, backend) && len(dibm.backends[backend]) > 0 { + klog.V(5).InfoS("Get the backend through wildcardmatching in the DestHostBackendManager", "destHost", destHost, "backend", backend) + return dibm.backends[backend][0], nil + } + } } return nil, &ErrNotFound{} } diff --git a/pkg/server/desthost_backend_manager_test.go b/pkg/server/desthost_backend_manager_test.go new file mode 100644 index 000000000..7ac4e3bac --- /dev/null +++ b/pkg/server/desthost_backend_manager_test.go @@ -0,0 +1,94 @@ +package server + +import ( + "context" + "testing" + + "sigs.k8s.io/apiserver-network-proxy/proto/agent" +) + +type fakeAgent struct { + ctx context.Context + agent.AgentService_ConnectServer +} + +func (fa *fakeAgent) Context() context.Context { + return fa.ctx +} + +const contextNameKey key = iota + +func newNamedBackend(name string) *backend { + return &backend{ + conn: &fakeAgent{ + ctx: context.WithValue(context.Background(), contextNameKey, name), + }, + } +} + +func TestBackend(t *testing.T) { + + testCases := []struct { + name string + destHost string + backends map[string][]*backend + expectedBackendName string + }{ + { + name: "Literal match", + destHost: "sts.amazonaws.com", + backends: map[string][]*backend{ + "sts.amazonaws.com": {newNamedBackend("sts.amazonaws.com")}, + }, + expectedBackendName: "sts.amazonaws.com", + }, + { + name: "Wildcard match", + destHost: "sts.amazonaws.com", + backends: map[string][]*backend{ + ".amazonaws.com": {newNamedBackend(".amazonaws.com")}, + }, + expectedBackendName: ".amazonaws.com", + }, + { + name: "Both literal and wildcard match, literal match takes precedence", + destHost: "sts.amazonaws.com", + backends: map[string][]*backend{ + "sts.amazonaws.com": {newNamedBackend("sts.amazonaws.com")}, + ".amazonaws.com": {newNamedBackend(".amazonaws.com")}, + }, + expectedBackendName: "sts.amazonaws.com", + }, + { + name: "No wildcard match when entry doesn't start with dot", + destHost: "foo-bar.com", + backends: map[string][]*backend{ + "bar.com": {newNamedBackend("bar.com")}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + request := context.WithValue(context.Background(), destHost, tc.destHost) + mgr := NewDestHostBackendManager() + mgr.backends = tc.backends + + matchExpected := tc.expectedBackendName != "" + match, err := mgr.Backend(request) + hadMatch := err == nil + if matchExpected != hadMatch { + t.Fatalf("expected a match: %t, got a match: %t", matchExpected, hadMatch) + } + if !matchExpected { + return + } + + matchName := match.Context().Value(contextNameKey).(string) + if matchName != tc.expectedBackendName { + t.Errorf("expected to get backend %s, got %s", tc.expectedBackendName, matchName) + } + }) + } +}