Skip to content

Commit

Permalink
Replace Backend interface with a struct.
Browse files Browse the repository at this point in the history
Comparisons (like those in DefaultBackendStorage) should be by pointer value.
  • Loading branch information
jkh52 committed Apr 12, 2024
1 parent 457af6a commit 6ebfd2b
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 67 deletions.
54 changes: 22 additions & 32 deletions pkg/server/backend_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,7 @@ func GenProxyStrategiesFromStr(proxyStrategies string) ([]ProxyStrategy, error)
// In the only currently supported case (gRPC), it wraps an
// agent.AgentService_ConnectServer, provides synchronization and
// emits common stream metrics.
type Backend interface {
Send(p *client.Packet) error
Recv() (*client.Packet, error)
Context() context.Context
GetAgentID() string
GetAgentIdentifiers() header.Identifiers
}

var _ Backend = &backend{}

type backend struct {
type Backend struct {
sendLock sync.Mutex
recvLock sync.Mutex
conn agent.AgentService_ConnectServer
Expand All @@ -97,7 +87,7 @@ type backend struct {
idents header.Identifiers
}

func (b *backend) Send(p *client.Packet) error {
func (b *Backend) Send(p *client.Packet) error {
b.sendLock.Lock()
defer b.sendLock.Unlock()

Expand All @@ -110,7 +100,7 @@ func (b *backend) Send(p *client.Packet) error {
return err
}

func (b *backend) Recv() (*client.Packet, error) {
func (b *Backend) Recv() (*client.Packet, error) {
b.recvLock.Lock()
defer b.recvLock.Unlock()

Expand All @@ -126,16 +116,16 @@ func (b *backend) Recv() (*client.Packet, error) {
return pkt, nil
}

func (b *backend) Context() context.Context {
func (b *Backend) Context() context.Context {
// TODO: does Context require lock protection?
return b.conn.Context()
}

func (b *backend) GetAgentID() string {
func (b *Backend) GetAgentID() string {
return b.id
}

func (b *backend) GetAgentIdentifiers() header.Identifiers {
func (b *Backend) GetAgentIdentifiers() header.Identifiers {
return b.idents
}

Expand Down Expand Up @@ -168,7 +158,7 @@ func getAgentIdentifiers(conn agent.AgentService_ConnectServer) (header.Identifi
return header.GenAgentIdentifiers(agentIdent[0])
}

func NewBackend(conn agent.AgentService_ConnectServer) (Backend, error) {
func NewBackend(conn agent.AgentService_ConnectServer) (*Backend, error) {
agentID, err := getAgentID(conn)
if err != nil {
return nil, err
Expand All @@ -177,16 +167,16 @@ func NewBackend(conn agent.AgentService_ConnectServer) (Backend, error) {
if err != nil {
return nil, err
}
return &backend{conn: conn, id: agentID, idents: agentIdentifiers}, nil
return &Backend{conn: conn, id: agentID, idents: agentIdentifiers}, nil
}

// BackendStorage is an interface to manage the storage of the backend
// connections, i.e., get, add and remove
type BackendStorage interface {
// addBackend adds a backend.
addBackend(identifier string, idType header.IdentifierType, backend Backend)
addBackend(identifier string, idType header.IdentifierType, backend *Backend)
// removeBackend removes a backend.
removeBackend(identifier string, idType header.IdentifierType, backend Backend)
removeBackend(identifier string, idType header.IdentifierType, backend *Backend)
// NumBackends returns the number of backends.
NumBackends() int
}
Expand All @@ -199,11 +189,11 @@ type BackendManager interface {
// context instead of a request-scoped context, as the backend manager will
// pick a backend for every tunnel session and each tunnel session may
// contains multiple requests.
Backend(ctx context.Context) (Backend, error)
Backend(ctx context.Context) (*Backend, error)
// AddBackend adds a backend.
AddBackend(backend Backend)
AddBackend(backend *Backend)
// RemoveBackend adds a backend.
RemoveBackend(backend Backend)
RemoveBackend(backend *Backend)
BackendStorage
ReadinessManager
}
Expand All @@ -215,18 +205,18 @@ type DefaultBackendManager struct {
*DefaultBackendStorage
}

func (dbm *DefaultBackendManager) Backend(_ context.Context) (Backend, error) {
func (dbm *DefaultBackendManager) Backend(_ context.Context) (*Backend, error) {
klog.V(5).InfoS("Get a random backend through the DefaultBackendManager")
return dbm.DefaultBackendStorage.GetRandomBackend()
}

func (dbm *DefaultBackendManager) AddBackend(backend Backend) {
func (dbm *DefaultBackendManager) AddBackend(backend *Backend) {
agentID := backend.GetAgentID()
klog.V(5).InfoS("Add the agent to DefaultBackendManager", "agentID", agentID)
dbm.addBackend(agentID, header.UID, backend)
}

func (dbm *DefaultBackendManager) RemoveBackend(backend Backend) {
func (dbm *DefaultBackendManager) RemoveBackend(backend *Backend) {
agentID := backend.GetAgentID()
klog.V(5).InfoS("Remove the agent from the DefaultBackendManager", "agentID", agentID)
dbm.removeBackend(agentID, header.UID, backend)
Expand All @@ -242,7 +232,7 @@ type DefaultBackendStorage struct {
//
// TODO: fix documentation. This is not always agentID, e.g. in
// the case of DestHostBackendManager.
backends map[string][]Backend
backends map[string][]*Backend
// agentID is tracked in this slice to enable randomly picking an
// agentID in the Backend() method. There is no reliable way to
// randomly pick a key from a map (in this case, the backends) in
Expand Down Expand Up @@ -272,7 +262,7 @@ func NewDefaultBackendStorage(idTypes []header.IdentifierType) *DefaultBackendSt
// no agent ever successfully connects.
metrics.Metrics.SetBackendCount(0)
return &DefaultBackendStorage{
backends: make(map[string][]Backend),
backends: make(map[string][]*Backend),
random: rand.New(rand.NewSource(time.Now().UnixNano())),
idTypes: idTypes,
} /* #nosec G404 */
Expand All @@ -283,7 +273,7 @@ func containIDType(idTypes []header.IdentifierType, idType header.IdentifierType
}

// addBackend adds a backend.
func (s *DefaultBackendStorage) addBackend(identifier string, idType header.IdentifierType, backend Backend) {
func (s *DefaultBackendStorage) addBackend(identifier string, idType header.IdentifierType, backend *Backend) {
if !containIDType(s.idTypes, idType) {
klog.V(4).InfoS("fail to add backend", "backend", identifier, "error", &ErrWrongIDType{idType, s.idTypes})
return
Expand All @@ -302,7 +292,7 @@ func (s *DefaultBackendStorage) addBackend(identifier string, idType header.Iden
s.backends[identifier] = append(s.backends[identifier], backend)
return
}
s.backends[identifier] = []Backend{backend}
s.backends[identifier] = []*Backend{backend}
metrics.Metrics.SetBackendCount(len(s.backends))
s.agentIDs = append(s.agentIDs, identifier)
if idType == header.DefaultRoute {
Expand All @@ -311,7 +301,7 @@ func (s *DefaultBackendStorage) addBackend(identifier string, idType header.Iden
}

// removeBackend removes a backend.
func (s *DefaultBackendStorage) removeBackend(identifier string, idType header.IdentifierType, backend Backend) {
func (s *DefaultBackendStorage) removeBackend(identifier string, idType header.IdentifierType, backend *Backend) {
if !containIDType(s.idTypes, idType) {
klog.ErrorS(&ErrWrongIDType{idType, s.idTypes}, "fail to remove backend")
return
Expand Down Expand Up @@ -390,7 +380,7 @@ func ignoreNotFound(err error) error {
}

// GetRandomBackend returns a random backend connection from all connected agents.
func (s *DefaultBackendStorage) GetRandomBackend() (Backend, error) {
func (s *DefaultBackendStorage) GetRandomBackend() (*Backend, error) {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.backends) == 0 {
Expand Down
24 changes: 12 additions & 12 deletions pkg/server/backend_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func TestDefaultBackendManager_AddRemoveBackends(t *testing.T) {

p.AddBackend(backend1)
p.RemoveBackend(backend1)
expectedBackends := make(map[string][]Backend)
expectedBackends := make(map[string][]*Backend)
expectedAgentIDs := []string{}
expectedDefaultRouteAgentIDs := []string(nil)
if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) {
Expand All @@ -143,7 +143,7 @@ func TestDefaultBackendManager_AddRemoveBackends(t *testing.T) {
p.RemoveBackend(backend22)
p.RemoveBackend(backend2)
p.RemoveBackend(backend1)
expectedBackends = map[string][]Backend{
expectedBackends = map[string][]*Backend{
"agent1": {backend12},
"agent3": {backend3},
}
Expand Down Expand Up @@ -174,7 +174,7 @@ func TestDefaultRouteBackendManager_AddRemoveBackends(t *testing.T) {

p.AddBackend(backend1)
p.RemoveBackend(backend1)
expectedBackends := make(map[string][]Backend)
expectedBackends := make(map[string][]*Backend)
expectedAgentIDs := []string{}
expectedDefaultRouteAgentIDs := []string{}
if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) {
Expand All @@ -199,7 +199,7 @@ func TestDefaultRouteBackendManager_AddRemoveBackends(t *testing.T) {
p.RemoveBackend(backend2)
p.RemoveBackend(backend1)

expectedBackends = map[string][]Backend{
expectedBackends = map[string][]*Backend{
"agent1": {backend12},
"agent3": {backend3},
}
Expand Down Expand Up @@ -231,7 +231,7 @@ func TestDestHostBackendManager_AddRemoveBackends(t *testing.T) {

p.AddBackend(backend1)
p.RemoveBackend(backend1)
expectedBackends := make(map[string][]Backend)
expectedBackends := make(map[string][]*Backend)
expectedAgentIDs := []string{}
expectedDefaultRouteAgentIDs := []string(nil)
if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) {
Expand All @@ -247,7 +247,7 @@ func TestDestHostBackendManager_AddRemoveBackends(t *testing.T) {
p = NewDestHostBackendManager()
p.AddBackend(backend1)

expectedBackends = map[string][]Backend{
expectedBackends = map[string][]*Backend{
"localhost": {backend1},
"1.2.3.4": {backend1},
"9878::7675:1292:9183:7562": {backend1},
Expand All @@ -273,7 +273,7 @@ func TestDestHostBackendManager_AddRemoveBackends(t *testing.T) {
p.AddBackend(backend2)
p.AddBackend(backend3)

expectedBackends = map[string][]Backend{
expectedBackends = map[string][]*Backend{
"localhost": {backend1},
"node1.mydomain.com": {backend1},
"node2.mydomain.com": {backend3},
Expand Down Expand Up @@ -306,7 +306,7 @@ func TestDestHostBackendManager_AddRemoveBackends(t *testing.T) {
p.RemoveBackend(backend2)
p.RemoveBackend(backend1)

expectedBackends = map[string][]Backend{
expectedBackends = map[string][]*Backend{
"node2.mydomain.com": {backend3},
"5.6.7.8": {backend3},
"::": {backend3},
Expand All @@ -328,7 +328,7 @@ func TestDestHostBackendManager_AddRemoveBackends(t *testing.T) {
}

p.RemoveBackend(backend3)
expectedBackends = map[string][]Backend{}
expectedBackends = map[string][]*Backend{}
expectedAgentIDs = []string{}

if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) {
Expand Down Expand Up @@ -356,7 +356,7 @@ func TestDestHostBackendManager_WithDuplicateIdents(t *testing.T) {
p.AddBackend(backend2)
p.AddBackend(backend3)

expectedBackends := map[string][]Backend{
expectedBackends := map[string][]*Backend{
"localhost": {backend1, backend2, backend3},
"1.2.3.4": {backend1, backend2},
"5.6.7.8": {backend3},
Expand Down Expand Up @@ -389,7 +389,7 @@ func TestDestHostBackendManager_WithDuplicateIdents(t *testing.T) {
p.RemoveBackend(backend1)
p.RemoveBackend(backend3)

expectedBackends = map[string][]Backend{
expectedBackends = map[string][]*Backend{
"localhost": {backend2},
"1.2.3.4": {backend2},
"9878::7675:1292:9183:7562": {backend2},
Expand All @@ -413,7 +413,7 @@ func TestDestHostBackendManager_WithDuplicateIdents(t *testing.T) {
}

p.RemoveBackend(backend2)
expectedBackends = map[string][]Backend{}
expectedBackends = map[string][]*Backend{}
expectedAgentIDs = []string{}

if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) {
Expand Down
6 changes: 3 additions & 3 deletions pkg/server/default_route_backend_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ func NewDefaultRouteBackendManager() *DefaultRouteBackendManager {
}

// Backend tries to get a backend that advertises default route, with random selection.
func (dibm *DefaultRouteBackendManager) Backend(_ context.Context) (Backend, error) {
func (dibm *DefaultRouteBackendManager) Backend(_ context.Context) (*Backend, error) {
return dibm.GetRandomBackend()
}

func (dibm *DefaultRouteBackendManager) AddBackend(backend Backend) {
func (dibm *DefaultRouteBackendManager) AddBackend(backend *Backend) {
agentID := backend.GetAgentID()
agentIdentifiers := backend.GetAgentIdentifiers()
if agentIdentifiers.DefaultRoute {
Expand All @@ -49,7 +49,7 @@ func (dibm *DefaultRouteBackendManager) AddBackend(backend Backend) {
}
}

func (dibm *DefaultRouteBackendManager) RemoveBackend(backend Backend) {
func (dibm *DefaultRouteBackendManager) RemoveBackend(backend *Backend) {
agentID := backend.GetAgentID()
agentIdentifiers := backend.GetAgentIdentifiers()
if agentIdentifiers.DefaultRoute {
Expand Down
6 changes: 3 additions & 3 deletions pkg/server/desthost_backend_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func NewDestHostBackendManager() *DestHostBackendManager {
[]header.IdentifierType{header.IPv4, header.IPv6, header.Host})}
}

func (dibm *DestHostBackendManager) AddBackend(backend Backend) {
func (dibm *DestHostBackendManager) AddBackend(backend *Backend) {
agentIdentifiers := backend.GetAgentIdentifiers()
for _, ipv4 := range agentIdentifiers.IPv4 {
klog.V(5).InfoS("Add the agent to DestHostBackendManager", "agent address", ipv4)
Expand All @@ -51,7 +51,7 @@ func (dibm *DestHostBackendManager) AddBackend(backend Backend) {
}
}

func (dibm *DestHostBackendManager) RemoveBackend(backend Backend) {
func (dibm *DestHostBackendManager) RemoveBackend(backend *Backend) {
agentIdentifiers := backend.GetAgentIdentifiers()
for _, ipv4 := range agentIdentifiers.IPv4 {
klog.V(5).InfoS("Remove the agent from the DestHostBackendManager", "agentHost", ipv4)
Expand All @@ -68,7 +68,7 @@ func (dibm *DestHostBackendManager) RemoveBackend(backend Backend) {
}

// Backend tries to get a backend associating to the request destination host.
func (dibm *DestHostBackendManager) Backend(ctx context.Context) (Backend, error) {
func (dibm *DestHostBackendManager) Backend(ctx context.Context) (*Backend, error) {
dibm.mu.RLock()
defer dibm.mu.RUnlock()
if len(dibm.backends) == 0 {
Expand Down
Loading

0 comments on commit 6ebfd2b

Please sign in to comment.