diff --git a/client/descriptions.go b/client/descriptions.go index 0b44f36b83..4f388fa7d3 100644 --- a/client/descriptions.go +++ b/client/descriptions.go @@ -52,13 +52,15 @@ func (col CollectionDescription) GetFieldByID(id FieldID) (FieldDescription, boo return FieldDescription{}, false } -// GetRelation returns the field that supports the relation of the given name. -func (col CollectionDescription) GetRelation(name string) (FieldDescription, bool) { - if !col.Schema.IsEmpty() { - for _, field := range col.Schema.Fields { - if field.RelationName == name { - return field, true - } +// GetFieldByRelation returns the field that supports the relation of the given name. +func (col CollectionDescription) GetFieldByRelation( + relationName string, + otherCollectionName string, + otherFieldName string, +) (FieldDescription, bool) { + for _, field := range col.Schema.Fields { + if field.RelationName == relationName && !(col.Name == otherCollectionName && otherFieldName == field.Name) { + return field, true } } return FieldDescription{}, false diff --git a/db/collection.go b/db/collection.go index cda5cbf584..f5dacccfb1 100644 --- a/db/collection.go +++ b/db/collection.go @@ -1062,7 +1062,7 @@ func (c *collection) save( if isSecondaryRelationID { primaryId := val.Value().(string) - err = c.patchPrimaryDoc(ctx, txn, relationFieldDescription, primaryKey.DocKey, primaryId) + err = c.patchPrimaryDoc(ctx, txn, c.Name(), relationFieldDescription, primaryKey.DocKey, primaryId) if err != nil { return cid.Undef, err } diff --git a/db/collection_update.go b/db/collection_update.go index 1a15482935..2e353dd0d3 100644 --- a/db/collection_update.go +++ b/db/collection_update.go @@ -350,6 +350,7 @@ func (c *collection) isSecondaryIDField(fieldDesc client.FieldDescription) (clie func (c *collection) patchPrimaryDoc( ctx context.Context, txn datastore.Txn, + secondaryCollectionName string, relationFieldDescription client.FieldDescription, docKey string, fieldValue string, @@ -365,7 +366,11 @@ func (c *collection) patchPrimaryDoc( } primaryCol = primaryCol.WithTxn(txn) - primaryField, ok := primaryCol.Description().GetRelation(relationFieldDescription.RelationName) + primaryField, ok := primaryCol.Description().GetFieldByRelation( + relationFieldDescription.RelationName, + secondaryCollectionName, + relationFieldDescription.Name, + ) if !ok { return client.NewErrFieldNotExist(relationFieldDescription.RelationName) } diff --git a/planner/type_join.go b/planner/type_join.go index f37437089e..ee771b01fc 100644 --- a/planner/type_join.go +++ b/planner/type_join.go @@ -259,7 +259,11 @@ func (p *Planner) makeTypeJoinOne( return nil, err } - subTypeField, subTypeFieldNameFound := subTypeCollectionDesc.GetRelation(subTypeFieldDesc.RelationName) + subTypeField, subTypeFieldNameFound := subTypeCollectionDesc.GetFieldByRelation( + subTypeFieldDesc.RelationName, + parent.sourceInfo.collectionDescription.Name, + subTypeFieldDesc.Name, + ) if !subTypeFieldNameFound { return nil, client.NewErrFieldNotExist(subTypeFieldDesc.RelationName) } @@ -481,7 +485,12 @@ func (p *Planner) makeTypeJoinMany( return nil, err } - rootField, rootNameFound := subTypeCollectionDesc.GetRelation(subTypeFieldDesc.RelationName) + rootField, rootNameFound := subTypeCollectionDesc.GetFieldByRelation( + subTypeFieldDesc.RelationName, + parent.sourceInfo.collectionDescription.Name, + subTypeFieldDesc.Name, + ) + if !rootNameFound { return nil, client.NewErrFieldNotExist(subTypeFieldDesc.RelationName) } diff --git a/tests/integration/mutation/update/field_kinds/one_to_one/with_self_ref_test.go b/tests/integration/mutation/update/field_kinds/one_to_one/with_self_ref_test.go new file mode 100644 index 0000000000..16225f4ab3 --- /dev/null +++ b/tests/integration/mutation/update/field_kinds/one_to_one/with_self_ref_test.go @@ -0,0 +1,191 @@ +// Copyright 2023 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package one_to_one + +import ( + "fmt" + "testing" + + testUtils "github.com/sourcenetwork/defradb/tests/integration" +) + +func TestMutationUpdateOneToOne_SelfReferencingFromPrimary(t *testing.T) { + user1ID := "bae-decf6467-4c7c-50d7-b09d-0a7097ef6bad" + + test := testUtils.TestCase{ + Description: "One to one update mutation, self referencing from primary", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type User { + name: String + boss: User @primary + underling: User + } + `, + }, + testUtils.CreateDoc{ + Doc: `{ + "name": "John" + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "name": "Fred" + }`, + }, + testUtils.UpdateDoc{ + DocID: 1, + Doc: fmt.Sprintf( + `{ + "boss_id": "%s" + }`, + user1ID, + ), + }, + testUtils.Request{ + Request: ` + query { + User { + name + boss { + name + } + } + }`, + Results: []map[string]any{ + { + "name": "Fred", + "boss": map[string]any{ + "name": "John", + }, + }, + { + "name": "John", + "boss": nil, + }, + }, + }, + testUtils.Request{ + Request: ` + query { + User { + name + underling { + name + } + } + }`, + Results: []map[string]any{ + { + "name": "Fred", + "underling": nil, + }, + { + "name": "John", + "underling": map[string]any{ + "name": "Fred", + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestMutationUpdateOneToOne_SelfReferencingFromSecondary(t *testing.T) { + user1ID := "bae-decf6467-4c7c-50d7-b09d-0a7097ef6bad" + + test := testUtils.TestCase{ + Description: "One to one update mutation, self referencing from secondary", + + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type User { + name: String + boss: User + underling: User @primary + } + `, + }, + testUtils.CreateDoc{ + Doc: `{ + "name": "John" + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "name": "Fred" + }`, + }, + testUtils.UpdateDoc{ + DocID: 1, + Doc: fmt.Sprintf( + `{ + "boss_id": "%s" + }`, + user1ID, + ), + }, + testUtils.Request{ + Request: ` + query { + User { + name + boss { + name + } + } + }`, + Results: []map[string]any{ + { + "name": "Fred", + "boss": map[string]any{ + "name": "John", + }, + }, + { + "name": "John", + "boss": nil, + }, + }, + }, + testUtils.Request{ + Request: ` + query { + User { + name + underling { + name + } + } + }`, + Results: []map[string]any{ + { + "name": "Fred", + "underling": nil, + }, + { + "name": "John", + "underling": map[string]any{ + "name": "Fred", + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +}