diff --git a/internal/planner/multi.go b/internal/planner/multi.go index ac564c4ed1..4b82826118 100644 --- a/internal/planner/multi.go +++ b/internal/planner/multi.go @@ -136,8 +136,15 @@ func (p *parallelNode) nextMerge(_ int, plan planNode) (bool, error) { return false, err } - doc := plan.Value() - copy(p.currentValue.Fields, doc.Fields) + // Field-by-fields check is necessary because parallelNode can have multiple children, and + // each child can return the same doc, but with different related fields available + // depending on what is requested. + newFields := plan.Value().Fields + for i := range newFields { + if p.currentValue.Fields[i] == nil { + p.currentValue.Fields[i] = newFields[i] + } + } return true, nil } diff --git a/internal/planner/scan.go b/internal/planner/scan.go index 151705a698..019cd1dee2 100644 --- a/internal/planner/scan.go +++ b/internal/planner/scan.go @@ -92,10 +92,10 @@ func (n *scanNode) initFields(fields []mapper.Requestable) error { switch requestable := r.(type) { // field is simple as its just a base level field case *mapper.Field: - n.tryAddField(requestable.GetName()) + n.tryAddFieldWithName(requestable.GetName()) // select might have its own select fields and filters fields case *mapper.Select: - n.tryAddField(requestable.Field.Name + request.RelatedObjectID) // foreign key for type joins + n.tryAddFieldWithName(requestable.Field.Name + request.RelatedObjectID) // foreign key for type joins err := n.initFields(requestable.Fields) if err != nil { return err @@ -112,13 +112,13 @@ func (n *scanNode) initFields(fields []mapper.Requestable) error { return err } for _, fd := range fieldDescs { - n.tryAddField(fd.Name) + n.tryAddFieldWithName(fd.Name) } } if target.ChildTarget.HasValue { - n.tryAddField(target.ChildTarget.Name) + n.tryAddFieldWithName(target.ChildTarget.Name) } else { - n.tryAddField(target.Field.Name) + n.tryAddFieldWithName(target.Field.Name) } } } @@ -126,7 +126,7 @@ func (n *scanNode) initFields(fields []mapper.Requestable) error { return nil } -func (n *scanNode) tryAddField(fieldName string) bool { +func (n *scanNode) tryAddFieldWithName(fieldName string) bool { fd, ok := n.col.Definition().GetFieldByName(fieldName) if !ok { // skip fields that are not part of the @@ -134,10 +134,25 @@ func (n *scanNode) tryAddField(fieldName string) bool { // is only responsible for basic fields return false } - n.fields = append(n.fields, fd) + n.addField(fd) return true } +// addField adds a field to the list of fields to be fetched. +// It will not add the field if it is already in the list. +func (n *scanNode) addField(field client.FieldDefinition) { + found := false + for i := range n.fields { + if n.fields[i].Name == field.Name { + found = true + break + } + } + if !found { + n.fields = append(n.fields, field) + } +} + func (scan *scanNode) initFetcher( cid immutable.Option[string], index immutable.Option[client.IndexDescription], diff --git a/internal/planner/type_join.go b/internal/planner/type_join.go index 2102c74479..fc5eb9bbaf 100644 --- a/internal/planner/type_join.go +++ b/internal/planner/type_join.go @@ -531,43 +531,35 @@ func newPrimaryObjectsRetriever( return j } -func (j *primaryObjectsRetriever) retrievePrimaryDocsReferencingSecondaryDoc() error { - relIDFieldDef, ok := j.primarySide.col.Definition().GetFieldByName( - j.primarySide.relFieldDef.Value().Name + request.RelatedObjectID) +func (r *primaryObjectsRetriever) retrievePrimaryDocsReferencingSecondaryDoc() error { + relIDFieldDef, ok := r.primarySide.col.Definition().GetFieldByName( + r.primarySide.relFieldDef.Value().Name + request.RelatedObjectID) if !ok { - return client.NewErrFieldNotExist(j.primarySide.relFieldDef.Value().Name + request.RelatedObjectID) + return client.NewErrFieldNotExist(r.primarySide.relFieldDef.Value().Name + request.RelatedObjectID) } - j.primaryScan = getScanNode(j.primarySide.plan) + r.primaryScan = getScanNode(r.primarySide.plan) - j.relIDFieldDef = relIDFieldDef + r.relIDFieldDef = relIDFieldDef - primaryDocs, err := j.retrievePrimaryDocs() + primaryDocs, err := r.retrievePrimaryDocs() if err != nil { return err } - j.resultPrimaryDocs, j.resultSecondaryDoc = joinPrimaryDocs(primaryDocs, j.secondarySide, j.primarySide) + r.resultPrimaryDocs, r.resultSecondaryDoc = joinPrimaryDocs(primaryDocs, r.secondarySide, r.primarySide) return nil } -func (j *primaryObjectsRetriever) addIDFieldToScanner() { - found := false - for i := range j.primaryScan.fields { - if j.primaryScan.fields[i].Name == j.relIDFieldDef.Name { - found = true - break - } - } - if !found { - j.primaryScan.fields = append(j.primaryScan.fields, j.relIDFieldDef) +func (r *primaryObjectsRetriever) collectDocs(numDocs int) ([]core.Doc, error) { + p := r.primarySide.plan + // If the primary side is a multiScanNode, we need to get the source node, as we are the only + // consumer (one, not multiple) of it. + if multiScan, ok := p.(*multiScanNode); ok { + p = multiScan.Source() } -} - -func (j *primaryObjectsRetriever) collectDocs(numDocs int) ([]core.Doc, error) { - p := j.primarySide.plan if err := p.Init(); err != nil { return nil, NewErrSubTypeInit(err) } @@ -591,28 +583,28 @@ func (j *primaryObjectsRetriever) collectDocs(numDocs int) ([]core.Doc, error) { return docs, nil } -func (j *primaryObjectsRetriever) retrievePrimaryDocs() ([]core.Doc, error) { - j.addIDFieldToScanner() +func (r *primaryObjectsRetriever) retrievePrimaryDocs() ([]core.Doc, error) { + r.primaryScan.addField(r.relIDFieldDef) - secondaryDoc := j.secondarySide.plan.Value() - addFilterOnIDField(j.primaryScan, j.primarySide.relIDFieldMapIndex.Value(), secondaryDoc.GetID()) + secondaryDoc := r.secondarySide.plan.Value() + addFilterOnIDField(r.primaryScan, r.primarySide.relIDFieldMapIndex.Value(), secondaryDoc.GetID()) - oldFetcher := j.primaryScan.fetcher + oldFetcher := r.primaryScan.fetcher - indexOnRelation := findIndexByFieldName(j.primaryScan.col, j.relIDFieldDef.Name) - j.primaryScan.initFetcher(immutable.None[string](), indexOnRelation) + indexOnRelation := findIndexByFieldName(r.primaryScan.col, r.relIDFieldDef.Name) + r.primaryScan.initFetcher(immutable.None[string](), indexOnRelation) - docs, err := j.collectDocs(0) + docs, err := r.collectDocs(0) if err != nil { return nil, err } - err = j.primaryScan.fetcher.Close() + err = r.primaryScan.fetcher.Close() if err != nil { return nil, err } - j.primaryScan.fetcher = oldFetcher + r.primaryScan.fetcher = oldFetcher return docs, nil } @@ -780,7 +772,7 @@ func (join *invertibleTypeJoin) invertJoinDirectionWithIndex( ) error { p := join.childSide.plan s := getScanNode(p) - s.tryAddField(join.childSide.relFieldDef.Value().Name + request.RelatedObjectID) + s.tryAddFieldWithName(join.childSide.relFieldDef.Value().Name + request.RelatedObjectID) s.filter = fieldFilter s.initFetcher(immutable.Option[string]{}, immutable.Some(index)) diff --git a/tests/integration/index/query_with_index_only_filter_test.go b/tests/integration/index/query_with_index_only_filter_test.go index 5de362ec81..f0aab40546 100644 --- a/tests/integration/index/query_with_index_only_filter_test.go +++ b/tests/integration/index/query_with_index_only_filter_test.go @@ -718,3 +718,89 @@ func TestQueryWithIndex_EmptyFilterOnIndexedField_ShouldSucceed(t *testing.T) { testUtils.ExecuteTestCase(t, test) } + +// This test checks if a query with a filter on 2 relations (one of which is indexed) works. +// Because of 2 relations in the query a parallelNode will be used with each child focusing +// on fetching one of the relations. This test makes sure the result of the second child +// (say Device with manufacturer) doesn't overwrite the result of the first child (say Device with owner). +// Also as the fetching is inverted (because of the index) we fetch first the secondary doc which +// is User and fetch all primary docs (Device) that reference that User. For fetching the primary +// docs we use the same planNode which in this case happens to be multiscanNode (source of parallelNode). +// For every second call multiscanNode will return the result of the first call, but in this case +// we have only one consumer, so take the source of the multiscanNode and use it to fetch the primary docs +// to avoid having all docs doubled. +func TestQueryWithIndex_WithFilterOn2Relations_ShouldFilter(t *testing.T) { + test := testUtils.TestCase{ + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type User { + name: String @index + devices: [Device] + } + + type Manufacturer { + name: String + devices: [Device] + } + + type Device { + owner: User + manufacturer: Manufacturer + model: String + } + `, + }, + testUtils.CreateDoc{ + CollectionID: 0, + DocMap: map[string]any{ + "name": "John", + }, + }, + testUtils.CreateDoc{ + CollectionID: 1, + DocMap: map[string]any{ + "name": "Apple", + }, + }, + testUtils.CreateDoc{ + CollectionID: 2, + DocMap: map[string]any{ + "model": "iPhone", + "owner_id": testUtils.NewDocIndex(0, 0), + "manufacturer_id": testUtils.NewDocIndex(1, 0), + }, + }, + testUtils.CreateDoc{ + CollectionID: 2, + DocMap: map[string]any{ + "model": "MacBook", + "owner_id": testUtils.NewDocIndex(0, 0), + "manufacturer_id": testUtils.NewDocIndex(1, 0), + }, + }, + testUtils.Request{ + Request: `query { + Device (filter: { + manufacturer: {name: {_eq: "Apple"}}, + owner: {name: {_eq: "John"}} + }) { + model + } + }`, + Results: map[string]any{ + "Device": []map[string]any{ + { + "model": "iPhone", + }, + { + "model": "MacBook", + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +}