diff --git a/sluice/request.go b/sluice/request.go index c4ffd24..849799a 100644 --- a/sluice/request.go +++ b/sluice/request.go @@ -11,7 +11,7 @@ import ( ) const ( - defaultSluiceTTL = 5 * time.Minute + defaultSluiceTTL = 30 * time.Second ) var ( @@ -44,12 +44,11 @@ func AwaitRequest(connInfo *network.Connection, callbackFn RequestCallbackFunc) return fmt.Errorf("sluice for network %s %w", network, ErrSluiceOffline) } - sluice.AwaitRequest(&Request{ + return sluice.AwaitRequest(&Request{ ConnInfo: connInfo, CallbackFn: callbackFn, Expires: time.Now().Add(defaultSluiceTTL), }) - return nil } func getNetworkFromConnInfo(connInfo *network.Connection) string { diff --git a/sluice/sluice.go b/sluice/sluice.go index 4f6d47e..f6c6f00 100644 --- a/sluice/sluice.go +++ b/sluice/sluice.go @@ -54,7 +54,7 @@ func StartSluice(network, address string) { } // AwaitRequest pre-registers a connection. -func (s *Sluice) AwaitRequest(r *Request) { +func (s *Sluice) AwaitRequest(r *Request) error { // Set default expiry. if r.Expires.IsZero() { r.Expires = time.Now().Add(defaultSluiceTTL) @@ -63,8 +63,16 @@ func (s *Sluice) AwaitRequest(r *Request) { s.lock.Lock() defer s.lock.Unlock() + // Check if a pending request already exists for this local address. key := net.JoinHostPort(r.ConnInfo.LocalIP.String(), strconv.Itoa(int(r.ConnInfo.LocalPort))) + _, exists := s.pendingRequests[key] + if exists { + return fmt.Errorf("a pending request for %s already exists", key) + } + + // Add to pending requests. s.pendingRequests[key] = r + return nil } func (s *Sluice) getRequest(address string) (r *Request, ok bool) {