diff --git a/cli/request.go b/cli/request.go index 796a5091c0..b6ec8e05ce 100644 --- a/cli/request.go +++ b/cli/request.go @@ -77,12 +77,12 @@ To learn more about the DefraDB GraphQL Query Language, refer to https://docs.so for _, err := range result.GQL.Errors { errors = append(errors, err.Error()) } - if result.Pub == nil { + if result.Subscription == nil { cmd.Print(REQ_RESULTS_HEADER) return writeJSON(cmd, map[string]any{"data": result.GQL.Data, "errors": errors}) } cmd.Print(SUB_RESULTS_HEADER) - for item := range result.Pub.Stream() { + for item := range result.Subscription { writeJSON(cmd, item) //nolint:errcheck } return nil diff --git a/client/db.go b/client/db.go index 6c530dd419..6ab945a815 100644 --- a/client/db.go +++ b/client/db.go @@ -265,9 +265,9 @@ type RequestResult struct { // GQL contains the immediate results of the GQL request. GQL GQLResult - // Pub contains a pointer to an event stream which channels any subscription results - // if the request was a GQL subscription. - Pub *events.Publisher[events.Update] + // Subscription is an optional channel which returns results + // from a subscription request. + Subscription <-chan GQLResult } // CollectionFetchOptions represents a set of options used for fetching collections. diff --git a/http/client.go b/http/client.go index 49982bad2a..9792208214 100644 --- a/http/client.go +++ b/http/client.go @@ -366,7 +366,7 @@ func (c *Client) ExecRequest( return result } if res.Header.Get("Content-Type") == "text/event-stream" { - result.Pub = c.execRequestSubscription(res.Body) + result.Subscription = c.execRequestSubscription(res.Body) return result } // ignore close errors because they have @@ -389,19 +389,17 @@ func (c *Client) ExecRequest( return result } -func (c *Client) execRequestSubscription(r io.ReadCloser) *events.Publisher[events.Update] { - pubCh := events.New[events.Update](0, 0) - pub, err := events.NewPublisher[events.Update](pubCh, 0) - if err != nil { - return nil - } - +func (c *Client) execRequestSubscription(r io.ReadCloser) chan client.GQLResult { + resCh := make(chan client.GQLResult) go func() { eventReader := sse.NewReadCloser(r) - // ignore close errors because the status - // and body of the request are already - // checked and it cannot be handled properly - defer eventReader.Close() //nolint:errcheck + defer func() { + // ignore close errors because the status + // and body of the request are already + // checked and it cannot be handled properly + eventReader.Close() //nolint:errcheck + close(resCh) + }() for { evt, err := eventReader.Next() @@ -412,14 +410,14 @@ func (c *Client) execRequestSubscription(r io.ReadCloser) *events.Publisher[even if err := json.Unmarshal(evt.Data, &response); err != nil { return } - pub.Publish(client.GQLResult{ + resCh <- client.GQLResult{ Errors: response.Errors, Data: response.Data, - }) + } } }() - return pub + return resCh } func (c *Client) PrintDump(ctx context.Context) error { diff --git a/http/handler_ccip.go b/http/handler_ccip.go index 01597377e2..5b9aeb5402 100644 --- a/http/handler_ccip.go +++ b/http/handler_ccip.go @@ -61,7 +61,7 @@ func (c *ccipHandler) ExecCCIP(rw http.ResponseWriter, req *http.Request) { } result := store.ExecRequest(req.Context(), request.Query) - if result.Pub != nil { + if result.Subscription != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{ErrStreamingNotSupported}) return } diff --git a/http/handler_store.go b/http/handler_store.go index 521aa13775..de534a8c1d 100644 --- a/http/handler_store.go +++ b/http/handler_store.go @@ -314,7 +314,7 @@ func (s *storeHandler) ExecRequest(rw http.ResponseWriter, req *http.Request) { result := store.ExecRequest(req.Context(), request.Query) - if result.Pub == nil { + if result.Subscription == nil { responseJSON(rw, http.StatusOK, GraphQLResponse{result.GQL.Data, result.GQL.Errors}) return } @@ -335,7 +335,7 @@ func (s *storeHandler) ExecRequest(rw http.ResponseWriter, req *http.Request) { select { case <-req.Context().Done(): return - case item, open := <-result.Pub.Stream(): + case item, open := <-result.Subscription: if !open { return } diff --git a/internal/db/request.go b/internal/db/request.go index ff60c0835f..e5ba3d5cf5 100644 --- a/internal/db/request.go +++ b/internal/db/request.go @@ -35,27 +35,20 @@ func (db *db) execRequest(ctx context.Context, request string) *client.RequestRe return res } - pub, subRequest, err := db.checkForClientSubscriptions(parsedRequest) + pub, err := db.handleSubscription(ctx, parsedRequest) if err != nil { res.GQL.Errors = []error{err} return res } if pub != nil { - res.Pub = pub - go db.handleSubscription(ctx, pub, subRequest) + res.Subscription = pub return res } txn := mustGetContextTxn(ctx) identity := GetContextIdentity(ctx) - planner := planner.New( - ctx, - identity, - db.acp, - db, - txn, - ) + planner := planner.New(ctx, identity, db.acp, db, txn) results, err := planner.RunRequest(ctx, parsedRequest) if err != nil { diff --git a/internal/db/subscriptions.go b/internal/db/subscriptions.go index b52504467e..a1b0147df4 100644 --- a/internal/db/subscriptions.go +++ b/internal/db/subscriptions.go @@ -19,83 +19,78 @@ import ( "github.com/sourcenetwork/defradb/internal/planner" ) -func (db *db) checkForClientSubscriptions(r *request.Request) ( - *events.Publisher[events.Update], - *request.ObjectSubscription, - error, -) { +// handleSubscription checks for a subscription within the given request and +// starts a new go routine that will return all subscription results on the returned +// channel. If a subscription does not exist on the given request nil will be returned. +func (db *db) handleSubscription(ctx context.Context, r *request.Request) (<-chan client.GQLResult, error) { if len(r.Subscription) == 0 || len(r.Subscription[0].Selections) == 0 { - // This is not a subscription request and we have nothing to do here - return nil, nil, nil + return nil, nil // This is not a subscription request and we have nothing to do here } - if !db.events.Updates.HasValue() { - return nil, nil, ErrSubscriptionsNotAllowed + return nil, ErrSubscriptionsNotAllowed } - - s := r.Subscription[0].Selections[0] - if subRequest, ok := s.(*request.ObjectSubscription); ok { - pub, err := events.NewPublisher(db.events.Updates.Value(), 5) - if err != nil { - return nil, nil, err - } - - return pub, subRequest, nil + selections := r.Subscription[0].Selections[0] + subRequest, ok := selections.(*request.ObjectSubscription) + if !ok { + return nil, client.NewErrUnexpectedType[request.ObjectSubscription]("SubscriptionSelection", selections) + } + // unsubscribing from this publisher will cause a race condition + // https://github.com/sourcenetwork/defradb/issues/2687 + pub, err := events.NewPublisher(db.events.Updates.Value(), 5) + if err != nil { + return nil, err } - return nil, nil, client.NewErrUnexpectedType[request.ObjectSubscription]("SubscriptionSelection", s) -} + resCh := make(chan client.GQLResult) + go func() { + defer close(resCh) -func (db *db) handleSubscription( - ctx context.Context, - pub *events.Publisher[events.Update], - r *request.ObjectSubscription, -) { - for evt := range pub.Event() { - txn, err := db.NewTxn(ctx, false) - if err != nil { - log.ErrorContext(ctx, err.Error()) - continue - } + // listen for events and send to the result channel + for { + var evt events.Update + select { + case <-ctx.Done(): + return // context cancelled + case val, ok := <-pub.Event(): + if !ok { + return // channel closed + } + evt = val + } - ctx := SetContextTxn(ctx, txn) - db.handleEvent(ctx, pub, evt, r) - txn.Discard(ctx) - } -} + txn, err := db.NewTxn(ctx, false) + if err != nil { + log.ErrorContext(ctx, err.Error()) + continue + } -func (db *db) handleEvent( - ctx context.Context, - pub *events.Publisher[events.Update], - evt events.Update, - r *request.ObjectSubscription, -) { - txn := mustGetContextTxn(ctx) - identity := GetContextIdentity(ctx) - p := planner.New( - ctx, - identity, - db.acp, - db, - txn, - ) + ctx := SetContextTxn(ctx, txn) + identity := GetContextIdentity(ctx) - s := r.ToSelect(evt.DocID, evt.Cid.String()) + p := planner.New(ctx, identity, db.acp, db, txn) + s := subRequest.ToSelect(evt.DocID, evt.Cid.String()) - result, err := p.RunSubscriptionRequest(ctx, s) - if err != nil { - pub.Publish(client.GQLResult{ - Errors: []error{err}, - }) - return - } + result, err := p.RunSubscriptionRequest(ctx, s) + if err == nil && len(result) == 0 { + txn.Discard(ctx) + continue // Don't send anything back to the client if the request yields an empty dataset. + } + res := client.GQLResult{ + Data: result, + } + if err != nil { + res.Errors = []error{err} + } - // Don't send anything back to the client if the request yields an empty dataset. - if len(result) == 0 { - return - } + select { + case <-ctx.Done(): + txn.Discard(ctx) + return // context cancelled + case resCh <- res: + txn.Discard(ctx) + } + } + }() - pub.Publish(client.GQLResult{ - Data: result, - }) + return resCh, nil } diff --git a/tests/clients/cli/wrapper.go b/tests/clients/cli/wrapper.go index 18e560b16c..25e4c177bf 100644 --- a/tests/clients/cli/wrapper.go +++ b/tests/clients/cli/wrapper.go @@ -411,7 +411,7 @@ func (w *Wrapper) ExecRequest( return result } if header == cli.SUB_RESULTS_HEADER { - result.Pub = w.execRequestSubscription(buffer) + result.Subscription = w.execRequestSubscription(buffer) return result } data, err := io.ReadAll(buffer) @@ -439,29 +439,24 @@ func (w *Wrapper) ExecRequest( return result } -func (w *Wrapper) execRequestSubscription(r io.Reader) *events.Publisher[events.Update] { - pubCh := events.New[events.Update](0, 0) - pub, err := events.NewPublisher[events.Update](pubCh, 0) - if err != nil { - return nil - } - +func (w *Wrapper) execRequestSubscription(r io.Reader) chan client.GQLResult { + resCh := make(chan client.GQLResult) go func() { dec := json.NewDecoder(r) + defer close(resCh) for { var response http.GraphQLResponse if err := dec.Decode(&response); err != nil { return } - pub.Publish(client.GQLResult{ + resCh <- client.GQLResult{ Errors: response.Errors, Data: response.Data, - }) + } } }() - - return pub + return resCh } func (w *Wrapper) NewTxn(ctx context.Context, readOnly bool) (datastore.Txn, error) { diff --git a/tests/integration/utils2.go b/tests/integration/utils2.go index 00c47fcfc2..041b553548 100644 --- a/tests/integration/utils2.go +++ b/tests/integration/utils2.go @@ -1718,13 +1718,11 @@ func executeSubscriptionRequest( allActionsAreDone := false expectedDataRecieved := len(action.Results) == 0 - stream := result.Pub.Stream() for { select { - case s := <-stream: - sResult, _ := s.(client.GQLResult) - sData, _ := sResult.Data.([]map[string]any) - errs = append(errs, sResult.Errors...) + case s := <-result.Subscription: + sData, _ := s.Data.([]map[string]any) + errs = append(errs, s.Errors...) data = append(data, sData...) if len(data) >= len(action.Results) {