From b730d3fea2b6dff19504ef4c7f7a53e846cde1e8 Mon Sep 17 00:00:00 2001 From: Islam Aliev Date: Fri, 1 Mar 2024 22:20:34 +0100 Subject: [PATCH] feat: Reverted order for indexed fields (#2335) ## Relevant issue(s) Resolves #2229 ## Description Enable reverted ordering for indexed fields. Which is mostly relevant to composite indexes at this point. This PR introduces a whole new package `encoding`. A significant part of it is taken from CocroachDB which fits well to our needs. Other encoding approaches (mostly for integers) were also consider: like fixed-length encoding, avro's zizzag encoding, base128 varints encoding and few others. The encoding package can later be used for encoding of other value and might potentially replace CBOR. --- client/document.go | 5 +- client/field.go | 16 +- client/index.go | 14 +- client/index_test.go | 18 +- core/encoding.go | 139 +++- core/errors.go | 7 + core/key.go | 105 +-- core/key_test.go | 387 +++++----- db/collection.go | 12 +- db/collection_delete.go | 7 +- db/collection_index.go | 9 +- db/errors.go | 36 +- db/fetcher/encoded_doc.go | 2 +- db/fetcher/errors.go | 10 + db/fetcher/indexer.go | 27 +- db/fetcher/indexer_iterators.go | 670 ++++++++++-------- db/index.go | 144 ++-- db/index_test.go | 51 +- db/indexed_docs_test.go | 92 +-- ...2229-order-direction-for-indexed-fields.md | 3 + encoding/bytes.go | 154 ++++ encoding/bytes_test.go | 205 ++++++ encoding/encoding.go | 55 ++ encoding/errors.go | 91 +++ encoding/field_value.go | 114 +++ encoding/field_value_test.go | 142 ++++ encoding/float.go | 97 +++ encoding/float_test.go | 127 ++++ encoding/int.go | 246 +++++++ encoding/int_test.go | 223 ++++++ encoding/null.go | 41 ++ encoding/null_test.go | 40 ++ encoding/string.go | 52 ++ encoding/string_test.go | 122 ++++ encoding/type.go | 46 ++ encoding/type_test.go | 41 ++ go.mod | 2 +- lens/fetcher.go | 2 +- request/graphql/schema/collection.go | 33 +- request/graphql/schema/generate.go | 10 +- request/graphql/schema/index_parse_test.go | 63 +- request/graphql/schema/types/types.go | 10 + .../index/create_composite_test.go | 8 +- tests/integration/index/create_get_test.go | 6 +- .../index/create_unique_composite_test.go | 10 +- tests/integration/index/create_unique_test.go | 9 +- ...y_with_composite_index_field_order_test.go | 435 ++++++++++++ ...y_with_composite_index_only_filter_test.go | 10 +- .../query_with_index_only_field_order_test.go | 180 +++++ .../query_with_index_only_filter_test.go | 10 +- ...with_unique_composite_index_filter_test.go | 14 +- ...uery_with_unique_index_only_filter_test.go | 10 +- tests/integration/test_case.go | 15 +- tests/integration/utils2.go | 12 +- tools/configs/golangci.yaml | 2 +- 55 files changed, 3466 insertions(+), 925 deletions(-) create mode 100644 docs/data_format_changes/i2229-order-direction-for-indexed-fields.md create mode 100644 encoding/bytes.go create mode 100644 encoding/bytes_test.go create mode 100644 encoding/encoding.go create mode 100644 encoding/errors.go create mode 100644 encoding/field_value.go create mode 100644 encoding/field_value_test.go create mode 100644 encoding/float.go create mode 100644 encoding/float_test.go create mode 100644 encoding/int.go create mode 100644 encoding/int_test.go create mode 100644 encoding/null.go create mode 100644 encoding/null_test.go create mode 100644 encoding/string.go create mode 100644 encoding/string_test.go create mode 100644 encoding/type.go create mode 100644 encoding/type_test.go create mode 100644 tests/integration/index/query_with_composite_index_field_order_test.go create mode 100644 tests/integration/index/query_with_index_only_field_order_test.go diff --git a/client/document.go b/client/document.go index 866910e89c..d455949bd5 100644 --- a/client/document.go +++ b/client/document.go @@ -172,7 +172,8 @@ func NewDocsFromJSON(obj []byte, sd SchemaDescription) ([]*Document, error) { return docs, nil } -func isNillableKind(kind FieldKind) bool { +// IsNillableKind returns true if the given FieldKind is nillable. +func IsNillableKind(kind FieldKind) bool { switch kind { case FieldKind_NILLABLE_STRING, FieldKind_NILLABLE_BLOB, FieldKind_NILLABLE_JSON, FieldKind_NILLABLE_BOOL, FieldKind_NILLABLE_FLOAT, FieldKind_NILLABLE_DATETIME, @@ -188,7 +189,7 @@ func isNillableKind(kind FieldKind) bool { // It will do any minor parsing, like dates, and return // the typed value again as an interface. func validateFieldSchema(val any, field SchemaFieldDescription) (any, error) { - if isNillableKind(field.Kind) { + if IsNillableKind(field.Kind) { if val == nil { return nil, nil } diff --git a/client/field.go b/client/field.go index a2c80c7ff8..40f130e428 100644 --- a/client/field.go +++ b/client/field.go @@ -14,23 +14,18 @@ package client type Field interface { Name() string Type() CType //TODO Abstract into a Field Type interface - SchemaType() string } type simpleField struct { - name string - crdtType CType - schemaType string + name string + crdtType CType } -func (doc *Document) newField(t CType, name string, schemaType ...string) Field { +func (doc *Document) newField(t CType, name string) Field { f := simpleField{ name: name, crdtType: t, } - if len(schemaType) > 0 { - f.schemaType = schemaType[0] - } return f } @@ -43,8 +38,3 @@ func (field simpleField) Name() string { func (field simpleField) Type() CType { return field.crdtType } - -// SchemaType returns the schema type of the field. -func (field simpleField) SchemaType() string { - return field.schemaType -} diff --git a/client/index.go b/client/index.go index cfc0b2ef01..9175cf7c0d 100644 --- a/client/index.go +++ b/client/index.go @@ -10,22 +10,12 @@ package client -// IndexDirection is the direction of an index. -type IndexDirection string - -const ( - // Ascending is the value to use for an ascending fields - Ascending IndexDirection = "ASC" - // Descending is the value to use for an descending fields - Descending IndexDirection = "DESC" -) - // IndexFieldDescription describes how a field is being indexed. type IndexedFieldDescription struct { // Name contains the name of the field. Name string - // Direction contains the direction of the index. - Direction IndexDirection + // Descending indicates whether the field is indexed in descending order. + Descending bool } // IndexDescription describes an index. diff --git a/client/index_test.go b/client/index_test.go index feb8ccdd69..d10450ece1 100644 --- a/client/index_test.go +++ b/client/index_test.go @@ -38,7 +38,7 @@ func TestCollectIndexesOnField(t *testing.T) { { Name: "index1", Fields: []IndexedFieldDescription{ - {Name: "test", Direction: Ascending}, + {Name: "test"}, }, }, }, @@ -48,7 +48,7 @@ func TestCollectIndexesOnField(t *testing.T) { { Name: "index1", Fields: []IndexedFieldDescription{ - {Name: "test", Direction: Ascending}, + {Name: "test"}, }, }, }, @@ -60,13 +60,13 @@ func TestCollectIndexesOnField(t *testing.T) { { Name: "index1", Fields: []IndexedFieldDescription{ - {Name: "test", Direction: Ascending}, + {Name: "test"}, }, }, { Name: "index2", Fields: []IndexedFieldDescription{ - {Name: "test", Direction: Descending}, + {Name: "test", Descending: true}, }, }, }, @@ -76,13 +76,13 @@ func TestCollectIndexesOnField(t *testing.T) { { Name: "index1", Fields: []IndexedFieldDescription{ - {Name: "test", Direction: Ascending}, + {Name: "test"}, }, }, { Name: "index2", Fields: []IndexedFieldDescription{ - {Name: "test", Direction: Descending}, + {Name: "test", Descending: true}, }, }, }, @@ -94,7 +94,7 @@ func TestCollectIndexesOnField(t *testing.T) { { Name: "index1", Fields: []IndexedFieldDescription{ - {Name: "other", Direction: Ascending}, + {Name: "other"}, }, }, }, @@ -109,8 +109,8 @@ func TestCollectIndexesOnField(t *testing.T) { { Name: "index1", Fields: []IndexedFieldDescription{ - {Name: "other", Direction: Ascending}, - {Name: "test", Direction: Ascending}, + {Name: "other"}, + {Name: "test"}, }, }, }, diff --git a/core/encoding.go b/core/encoding.go index d4ac66bc04..40e74915b8 100644 --- a/core/encoding.go +++ b/core/encoding.go @@ -17,11 +17,13 @@ import ( "github.com/sourcenetwork/immutable" "github.com/sourcenetwork/defradb/client" + "github.com/sourcenetwork/defradb/client/request" + "github.com/sourcenetwork/defradb/encoding" ) -// DecodeFieldValue takes a field value and description and converts it to the +// NormalizeFieldValue takes a field value and description and converts it to the // standardized Defra Go type. -func DecodeFieldValue(fieldDesc client.FieldDefinition, val any) (any, error) { +func NormalizeFieldValue(fieldDesc client.FieldDefinition, val any) (any, error) { if val == nil { return nil, nil } @@ -125,6 +127,16 @@ func DecodeFieldValue(fieldDesc client.FieldDefinition, val any) (any, error) { case string: return time.Parse(time.RFC3339, v) } + case client.FieldKind_NILLABLE_BOOL: + switch v := val.(type) { + case int64: + return v != 0, nil + } + case client.FieldKind_NILLABLE_STRING: + switch v := val.(type) { + case []byte: + return string(v), nil + } } } @@ -179,3 +191,126 @@ func convertToInt(propertyName string, untypedValue any) (int64, error) { return 0, client.NewErrUnexpectedType[string](propertyName, untypedValue) } } + +// DecodeIndexDataStoreKey decodes a IndexDataStoreKey from bytes. +// It expects the input bytes is in the following format: +// +// /[CollectionID]/[IndexID]/[FieldValue](/[FieldValue]...) +// +// Where [CollectionID] and [IndexID] are integers +// +// All values of the fields are converted to standardized Defra Go type +// according to fields description. +func DecodeIndexDataStoreKey( + data []byte, + indexDesc *client.IndexDescription, + fields []client.FieldDefinition, +) (IndexDataStoreKey, error) { + if len(data) == 0 { + return IndexDataStoreKey{}, ErrEmptyKey + } + + if data[0] != '/' { + return IndexDataStoreKey{}, ErrInvalidKey + } + data = data[1:] + + data, colID, err := encoding.DecodeUvarintAscending(data) + if err != nil { + return IndexDataStoreKey{}, err + } + + key := IndexDataStoreKey{CollectionID: uint32(colID)} + + if data[0] != '/' { + return IndexDataStoreKey{}, ErrInvalidKey + } + data = data[1:] + + data, indID, err := encoding.DecodeUvarintAscending(data) + if err != nil { + return IndexDataStoreKey{}, err + } + key.IndexID = uint32(indID) + + if len(data) == 0 { + return key, nil + } + + for len(data) > 0 { + if data[0] != '/' { + return IndexDataStoreKey{}, ErrInvalidKey + } + data = data[1:] + + i := len(key.Fields) + descending := false + // If the key has more values encoded then fields on the index description, the last + // value must be the docID and we treat it as a string. + if i < len(indexDesc.Fields) { + descending = indexDesc.Fields[i].Descending + } else if i > len(indexDesc.Fields) { + return IndexDataStoreKey{}, ErrInvalidKey + } + + var val any + data, val, err = encoding.DecodeFieldValue(data, descending) + if err != nil { + return IndexDataStoreKey{}, err + } + + key.Fields = append(key.Fields, IndexedField{Value: val, Descending: descending}) + } + + err = normalizeIndexDataStoreKeyValues(&key, fields) + return key, err +} + +// normalizeIndexDataStoreKeyValues converts all field values to standardized +// Defra Go type according to fields description. +func normalizeIndexDataStoreKeyValues(key *IndexDataStoreKey, fields []client.FieldDefinition) error { + for i := range key.Fields { + if key.Fields[i].Value == nil { + continue + } + var err error + var val any + if i == len(key.Fields)-1 && len(key.Fields)-len(fields) == 1 { + bytes, ok := key.Fields[i].Value.([]byte) + if !ok { + return client.NewErrUnexpectedType[[]byte](request.DocIDArgName, key.Fields[i].Value) + } + val = string(bytes) + } else { + val, err = NormalizeFieldValue(fields[i], key.Fields[i].Value) + } + if err != nil { + return err + } + key.Fields[i].Value = val + } + return nil +} + +// EncodeIndexDataStoreKey encodes a IndexDataStoreKey to bytes to be stored as a key +// for secondary indexes. +func EncodeIndexDataStoreKey(key *IndexDataStoreKey) []byte { + if key.CollectionID == 0 { + return []byte{} + } + + b := encoding.EncodeUvarintAscending([]byte{'/'}, uint64(key.CollectionID)) + + if key.IndexID == 0 { + return b + } + b = append(b, '/') + b = encoding.EncodeUvarintAscending(b, uint64(key.IndexID)) + + for _, field := range key.Fields { + b = append(b, '/') + b = encoding.EncodeFieldValue(b, field.Value, field.Descending) + } + + return b +} diff --git a/core/errors.go b/core/errors.go index b672c1ed00..440e5778ac 100644 --- a/core/errors.go +++ b/core/errors.go @@ -16,15 +16,22 @@ import ( const ( errFailedToGetFieldIdOfKey string = "failed to get FieldID of Key" + errInvalidFieldIndex string = "invalid field index" ) var ( ErrFailedToGetFieldIdOfKey = errors.New(errFailedToGetFieldIdOfKey) ErrEmptyKey = errors.New("received empty key string") ErrInvalidKey = errors.New("invalid key string") + ErrInvalidFieldIndex = errors.New(errInvalidFieldIndex) ) // NewErrFailedToGetFieldIdOfKey returns the error indicating failure to get FieldID of Key. func NewErrFailedToGetFieldIdOfKey(inner error) error { return errors.Wrap(errFailedToGetFieldIdOfKey, inner) } + +// NewErrInvalidFieldIndex returns the error indicating invalid field index. +func NewErrInvalidFieldIndex(i int) error { + return errors.New(errInvalidFieldIndex, errors.NewKV("index", i)) +} diff --git a/core/key.go b/core/key.go index b9bea29e41..4017d445b0 100644 --- a/core/key.go +++ b/core/key.go @@ -75,14 +75,23 @@ type DataStoreKey struct { var _ Key = (*DataStoreKey)(nil) +// IndexedField contains information necessary for storing a single +// value of a field in an index. +type IndexedField struct { + // Value is the value of the field in the index + Value any + // Descending is true if the field is sorted in descending order + Descending bool +} + // IndexDataStoreKey is key of an indexed document in the database. type IndexDataStoreKey struct { // CollectionID is the id of the collection CollectionID uint32 // IndexID is the id of the index IndexID uint32 - // FieldValues is the values of the fields in the index - FieldValues [][]byte + // Fields is the values of the fields in the index + Fields []IndexedField } var _ Key = (*IndexDataStoreKey)(nil) @@ -496,52 +505,19 @@ func (k DataStoreKey) ToPrimaryDataStoreKey() PrimaryDataStoreKey { } } -// NewIndexDataStoreKey creates a new IndexDataStoreKey from a string. -// It expects the input string is in the following format: -// -// /[CollectionID]/[IndexID]/[FieldValue](/[FieldValue]...) -// -// Where [CollectionID] and [IndexID] are integers -func NewIndexDataStoreKey(key string) (IndexDataStoreKey, error) { - if key == "" { - return IndexDataStoreKey{}, ErrEmptyKey - } - - if !strings.HasPrefix(key, "/") { - return IndexDataStoreKey{}, ErrInvalidKey - } - - elements := strings.Split(key[1:], "/") - - // With less than 3 elements, we know it's an invalid key - if len(elements) < 3 { - return IndexDataStoreKey{}, ErrInvalidKey - } - - colID, err := strconv.Atoi(elements[0]) - if err != nil { - return IndexDataStoreKey{}, ErrInvalidKey - } - - indexKey := IndexDataStoreKey{CollectionID: uint32(colID)} - - indID, err := strconv.Atoi(elements[1]) - if err != nil { - return IndexDataStoreKey{}, ErrInvalidKey - } - indexKey.IndexID = uint32(indID) - - // first 2 elements are the collection and index IDs, the rest are field values - for i := 2; i < len(elements); i++ { - indexKey.FieldValues = append(indexKey.FieldValues, []byte(elements[i])) +// NewIndexDataStoreKey creates a new IndexDataStoreKey from a collection ID, index ID and fields. +// It also validates values of the fields. +func NewIndexDataStoreKey(collectionID, indexID uint32, fields []IndexedField) IndexDataStoreKey { + return IndexDataStoreKey{ + CollectionID: collectionID, + IndexID: indexID, + Fields: fields, } - - return indexKey, nil } // Bytes returns the byte representation of the key func (k *IndexDataStoreKey) Bytes() []byte { - return []byte(k.ToString()) + return EncodeIndexDataStoreKey(k) } // ToDS returns the datastore key @@ -555,48 +531,7 @@ func (k *IndexDataStoreKey) ToDS() ds.Key { // If while composing the string from left to right, a component // is empty, the string is returned up to that point func (k *IndexDataStoreKey) ToString() string { - sb := strings.Builder{} - - if k.CollectionID == 0 { - return "" - } - sb.WriteByte('/') - sb.WriteString(strconv.Itoa(int(k.CollectionID))) - - if k.IndexID == 0 { - return sb.String() - } - sb.WriteByte('/') - sb.WriteString(strconv.Itoa(int(k.IndexID))) - - for _, v := range k.FieldValues { - if len(v) == 0 { - break - } - sb.WriteByte('/') - sb.WriteString(string(v)) - } - - return sb.String() -} - -// Equal returns true if the two keys are equal -func (k IndexDataStoreKey) Equal(other IndexDataStoreKey) bool { - if k.CollectionID != other.CollectionID { - return false - } - if k.IndexID != other.IndexID { - return false - } - if len(k.FieldValues) != len(other.FieldValues) { - return false - } - for i := range k.FieldValues { - if string(k.FieldValues[i]) != string(other.FieldValues[i]) { - return false - } - } - return true + return string(k.Bytes()) } func (k PrimaryDataStoreKey) ToDataStoreKey() DataStoreKey { diff --git a/core/key_test.go b/core/key_test.go index 50bf1198c4..3fa7f41a63 100644 --- a/core/key_test.go +++ b/core/key_test.go @@ -17,6 +17,9 @@ import ( ds "github.com/ipfs/go-datastore" "github.com/sourcenetwork/immutable" "github.com/stretchr/testify/assert" + + "github.com/sourcenetwork/defradb/client" + "github.com/sourcenetwork/defradb/encoding" ) func TestNewDataStoreKey_ReturnsEmptyStruct_GivenEmptyString(t *testing.T) { @@ -159,267 +162,243 @@ func TestNewIndexKeyFromString_IfFullKeyString_ReturnKey(t *testing.T) { assert.Equal(t, "idx", key.IndexName) } -func toFieldValues(values ...string) [][]byte { - var result [][]byte = make([][]byte, 0, len(values)) - for _, value := range values { - result = append(result, []byte(value)) +func encodePrefix(colID, indexID uint32) []byte { + return encoding.EncodeUvarintAscending(append(encoding.EncodeUvarintAscending( + []byte{'/'}, uint64(colID)), '/'), uint64(indexID)) +} + +func encodeKey(colID, indexID uint32, fieldParts ...any) []byte { + b := encodePrefix(colID, indexID) + const partSize = 2 + if len(fieldParts)%partSize != 0 { + panic(fmt.Sprintf("fieldParts must be a multiple of %d: value, descending", partSize)) } - return result + for i := 0; i < len(fieldParts)/partSize; i++ { + b = append(b, '/') + isDescending := fieldParts[i*partSize+1].(bool) + if fieldParts[i*partSize] == nil { + if isDescending { + b = encoding.EncodeNullDescending(b) + } else { + b = encoding.EncodeNullAscending(b) + } + } else { + if isDescending { + b = encoding.EncodeUvarintDescending(b, uint64(fieldParts[i*partSize].(int))) + } else { + b = encoding.EncodeUvarintAscending(b, uint64(fieldParts[i*partSize].(int))) + } + } + } + return b } -func TestIndexDatastoreKey_ToString(t *testing.T) { +func TestIndexDatastoreKey_Bytes(t *testing.T) { cases := []struct { - Key IndexDataStoreKey - Expected string + Name string + CollectionID uint32 + IndexID uint32 + Fields []IndexedField + Expected []byte }{ { - Key: IndexDataStoreKey{}, - Expected: "", - }, - { - Key: IndexDataStoreKey{ - CollectionID: 1, - }, - Expected: "/1", - }, - { - Key: IndexDataStoreKey{ - CollectionID: 1, - IndexID: 2, - }, - Expected: "/1/2", - }, - { - Key: IndexDataStoreKey{ - CollectionID: 1, - IndexID: 2, - FieldValues: toFieldValues("3"), - }, - Expected: "/1/2/3", - }, - { - Key: IndexDataStoreKey{ - CollectionID: 1, - IndexID: 2, - FieldValues: toFieldValues("3", "4"), - }, - Expected: "/1/2/3/4", + Name: "empty", + Expected: []byte{}, }, { - Key: IndexDataStoreKey{ - CollectionID: 1, - FieldValues: toFieldValues("3"), - }, - Expected: "/1", + Name: "only collection", + CollectionID: 1, + Expected: encoding.EncodeUvarintAscending([]byte{'/'}, 1), }, { - Key: IndexDataStoreKey{ - IndexID: 2, - FieldValues: toFieldValues("3"), - }, - Expected: "", + Name: "only collection and index", + CollectionID: 1, + IndexID: 2, + Expected: encodePrefix(1, 2), }, { - Key: IndexDataStoreKey{ - FieldValues: toFieldValues("3"), - }, - Expected: "", + Name: "collection, index and one field", + CollectionID: 1, + IndexID: 2, + Fields: []IndexedField{{Value: 5}}, + Expected: encodeKey(1, 2, 5, false), }, { - Key: IndexDataStoreKey{ - CollectionID: 1, - IndexID: 2, - FieldValues: toFieldValues("", ""), - }, - Expected: "/1/2", + Name: "collection, index and two fields", + CollectionID: 1, + IndexID: 2, + Fields: []IndexedField{{Value: 5}, {Value: 7}}, + Expected: encodeKey(1, 2, 5, false, 7, false), }, { - Key: IndexDataStoreKey{ - CollectionID: 1, - IndexID: 2, - FieldValues: toFieldValues("", "3"), - }, - Expected: "/1/2", + Name: "no index", + CollectionID: 1, + Fields: []IndexedField{{Value: 5}}, + Expected: encoding.EncodeUvarintAscending([]byte{'/'}, 1), }, { - Key: IndexDataStoreKey{ - CollectionID: 1, - IndexID: 2, - FieldValues: toFieldValues("3", "", "4"), - }, - Expected: "/1/2/3", + Name: "no collection", + IndexID: 2, + Fields: []IndexedField{{Value: 5}}, + Expected: []byte{}, }, } - for i, c := range cases { - assert.Equal(t, c.Key.ToString(), c.Expected, "case %d", i) + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + key := NewIndexDataStoreKey(c.CollectionID, c.IndexID, c.Fields) + actual := key.Bytes() + assert.Equal(t, c.Expected, actual, "upon calling key.Bytes()") + encKey := EncodeIndexDataStoreKey(&key) + assert.Equal(t, c.Expected, encKey, "upon calling EncodeIndexDataStoreKey") + }) } } -func TestIndexDatastoreKey_Bytes(t *testing.T) { - key := IndexDataStoreKey{ - CollectionID: 1, - IndexID: 2, - FieldValues: toFieldValues("3", "4"), - } - assert.Equal(t, key.Bytes(), []byte("/1/2/3/4")) +func TestIndexDatastoreKey_ToString(t *testing.T) { + key := NewIndexDataStoreKey(1, 2, []IndexedField{{Value: 5}}) + assert.Equal(t, key.ToString(), string(encodeKey(1, 2, 5, false))) } func TestIndexDatastoreKey_ToDS(t *testing.T) { - key := IndexDataStoreKey{ - CollectionID: 1, - IndexID: 2, - FieldValues: toFieldValues("3", "4"), + key := NewIndexDataStoreKey(1, 2, []IndexedField{{Value: 5}}) + assert.Equal(t, key.ToDS(), ds.NewKey(string(encodeKey(1, 2, 5, false)))) +} + +func TestCollectionIndexKey_Bytes(t *testing.T) { + key := CollectionIndexKey{ + CollectionID: immutable.Some[uint32](1), + IndexName: "idx", } - assert.Equal(t, key.ToDS(), ds.NewKey("/1/2/3/4")) + assert.Equal(t, []byte(COLLECTION_INDEX+"/1/idx"), key.Bytes()) } -func TestIndexDatastoreKey_EqualTrue(t *testing.T) { - cases := [][]IndexDataStoreKey{ +func TestDecodeIndexDataStoreKey(t *testing.T) { + const colID, indexID = 1, 2 + cases := []struct { + name string + desc client.IndexDescription + inputBytes []byte + expectedFields []IndexedField + fieldKinds []client.FieldKind + }{ { - { - CollectionID: 1, - IndexID: 2, - FieldValues: toFieldValues("3", "4"), - }, - { - CollectionID: 1, - IndexID: 2, - FieldValues: toFieldValues("3", "4"), + name: "one field", + desc: client.IndexDescription{ + ID: indexID, + Fields: []client.IndexedFieldDescription{{}}, }, + inputBytes: encodeKey(colID, indexID, 5, false), + expectedFields: []IndexedField{{Value: int64(5)}}, }, { - { - CollectionID: 1, - FieldValues: toFieldValues("3", "4"), - }, - { - CollectionID: 1, - FieldValues: toFieldValues("3", "4"), + name: "two fields (one descending)", + desc: client.IndexDescription{ + ID: indexID, + Fields: []client.IndexedFieldDescription{{}, {Descending: true}}, }, + inputBytes: encodeKey(colID, indexID, 5, false, 7, true), + expectedFields: []IndexedField{{Value: int64(5)}, {Value: int64(7), Descending: true}}, }, { - { - CollectionID: 1, - }, - { - CollectionID: 1, + name: "last encoded value without matching field description is docID", + desc: client.IndexDescription{ + ID: indexID, + Fields: []client.IndexedFieldDescription{{}}, }, + inputBytes: encoding.EncodeStringAscending(append(encodeKey(1, indexID, 5, false), '/'), "docID"), + expectedFields: []IndexedField{{Value: int64(5)}, {Value: "docID"}}, + fieldKinds: []client.FieldKind{client.FieldKind_NILLABLE_INT}, }, } - for i, c := range cases { - assert.True(t, c[0].Equal(c[1]), "case %d", i) + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + expectedKey := NewIndexDataStoreKey(colID, indexID, tc.expectedFields) + fieldDescs := make([]client.FieldDefinition, len(tc.desc.Fields)) + for i := range tc.fieldKinds { + fieldDescs[i] = client.FieldDefinition{Kind: tc.fieldKinds[i]} + } + key, err := DecodeIndexDataStoreKey(tc.inputBytes, &tc.desc, fieldDescs) + assert.NoError(t, err) + assert.Equal(t, expectedKey, key) + }) } } -func TestCollectionIndexKey_Bytes(t *testing.T) { - key := CollectionIndexKey{ - CollectionID: immutable.Some[uint32](1), - IndexName: "idx", +func TestDecodeIndexDataStoreKey_InvalidKey(t *testing.T) { + replace := func(b []byte, i int, v byte) []byte { + b = append([]byte{}, b...) + b[i] = v + return b } - assert.Equal(t, []byte(COLLECTION_INDEX+"/1/idx"), key.Bytes()) -} + cutEnd := func(b []byte, l int) []byte { + return b[:len(b)-l] + } + + const colID, indexID = 1, 2 -func TestIndexDatastoreKey_EqualFalse(t *testing.T) { - cases := [][]IndexDataStoreKey{ + cases := []struct { + name string + val []byte + numFields int + }{ { - { - CollectionID: 1, - }, - { - CollectionID: 2, - }, + name: "empty", + val: []byte{}, }, { - { - CollectionID: 1, - IndexID: 2, - }, - { - CollectionID: 1, - IndexID: 3, - }, + name: "only slash", + val: []byte{'/'}, }, { - { - CollectionID: 1, - }, - { - IndexID: 1, - }, + name: "slash after collection", + val: append(encoding.EncodeUvarintAscending([]byte{'/'}, colID), '/'), }, { - { - CollectionID: 1, - IndexID: 2, - FieldValues: toFieldValues("4", "3"), - }, - { - CollectionID: 1, - IndexID: 2, - FieldValues: toFieldValues("3", "4"), - }, + name: "wrong prefix", + val: replace(encodeKey(colID, indexID, 5, false), 0, ' '), + numFields: 1, }, { - { - CollectionID: 1, - IndexID: 2, - FieldValues: toFieldValues("3"), - }, - { - CollectionID: 1, - IndexID: 2, - FieldValues: toFieldValues("3", "4"), - }, + name: "no slash before collection", + val: encodeKey(colID, indexID, 5, false)[1:], + numFields: 1, }, { - { - CollectionID: 1, - FieldValues: toFieldValues("3", "", "4"), - }, - { - CollectionID: 1, - FieldValues: toFieldValues("3", "4"), - }, + name: "no slash before index", + val: replace(encodeKey(colID, indexID, 5, false), 2, ' '), + numFields: 1, + }, + { + name: "no slash before field value", + val: replace(encodeKey(colID, indexID, 5, false), 4, ' '), + numFields: 1, + }, + { + name: "no field value", + val: cutEnd(encodeKey(colID, indexID, 5, false), 1), + numFields: 1, + }, + { + name: "no field description", + val: encodeKey(colID, indexID, 5, false, 7, false, 9, false), + numFields: 2, + }, + { + name: "invalid docID value", + val: encoding.EncodeUvarintAscending(append(encodeKey(colID, indexID, 5, false), '/'), 5), + numFields: 1, }, } - - for i, c := range cases { - assert.False(t, c[0].Equal(c[1]), "case %d", i) - } -} - -func TestNewIndexDataStoreKey_ValidKey(t *testing.T) { - str, err := NewIndexDataStoreKey("/1/2/3") - assert.NoError(t, err) - assert.Equal(t, str, IndexDataStoreKey{ - CollectionID: 1, - IndexID: 2, - FieldValues: toFieldValues("3"), - }) - - str, err = NewIndexDataStoreKey("/1/2/3/4") - assert.NoError(t, err) - assert.Equal(t, str, IndexDataStoreKey{ - CollectionID: 1, - IndexID: 2, - FieldValues: toFieldValues("3", "4"), - }) -} - -func TestNewIndexDataStoreKey_InvalidKey(t *testing.T) { - keys := []string{ - "", - "/", - "/1", - "/1/2", - " /1/2/3", - "1/2/3", - "/a/2/3", - "/1/b/3", - } - for i, key := range keys { - _, err := NewIndexDataStoreKey(key) - assert.Error(t, err, "case %d: %s", i, key) + indexDesc := client.IndexDescription{ID: indexID, Fields: []client.IndexedFieldDescription{{}}} + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + fieldDescs := make([]client.FieldDefinition, c.numFields) + for i := 0; i < c.numFields; i++ { + fieldDescs[i] = client.FieldDefinition{Kind: client.FieldKind_NILLABLE_INT} + } + _, err := DecodeIndexDataStoreKey(c.val, &indexDesc, fieldDescs) + assert.Error(t, err, c.name) + }) } } diff --git a/db/collection.go b/db/collection.go index ea231c0448..99b1bdc595 100644 --- a/db/collection.go +++ b/db/collection.go @@ -64,7 +64,7 @@ type collection struct { // to be auto generated based on a more controllable and user friendly // CollectionOptions object. -// NewCollection returns a pointer to a newly instanciated DB Collection +// newCollection returns a pointer to a newly instantiated DB Collection func (db *db) newCollection(desc client.CollectionDescription, schema client.SchemaDescription) *collection { return &collection{ db: db, @@ -88,7 +88,7 @@ func (c *collection) newFetcher() fetcher.Fetcher { } // createCollection creates a collection and saves it to the database in its system store. -// Note: Collection.ID is an autoincrementing value that is generated by the database. +// Note: Collection.ID is an auto-incrementing value that is generated by the database. func (db *db) createCollection( ctx context.Context, txn datastore.Txn, @@ -185,8 +185,6 @@ func (db *db) updateSchema( setAsActiveVersion bool, ) error { hasChanged, err := db.validateUpdateSchema( - ctx, - txn, existingSchemaByName, proposedDescriptionsByName, schema, @@ -375,8 +373,6 @@ func (db *db) updateSchema( // Will return true if the given description differs from the current persisted state of the // schema. Will return an error if it fails validation. func (db *db) validateUpdateSchema( - ctx context.Context, - txn datastore.Txn, existingDescriptionsByName map[string]client.SchemaDescription, proposedDescriptionsByName map[string]client.SchemaDescription, proposedDesc client.SchemaDescription, @@ -405,7 +401,7 @@ func (db *db) validateUpdateSchema( } if proposedDesc.VersionID != "" && proposedDesc.VersionID != existingDesc.VersionID { - // If users specify this it will be overwritten, an error is prefered to quietly ignoring it. + // If users specify this it will be overwritten, an error is preferred to quietly ignoring it. return false, ErrCannotSetVersionID } @@ -592,7 +588,7 @@ func (db *db) setActiveSchemaVersion( activeCol, rootCol, isActiveFound = db.getActiveCollectionDown(ctx, txn, colsByID, sources[0].SourceCollectionID) } if !isActiveFound { - // We need to look both down and up for the active version - the most recent is not nessecarily the active one. + // We need to look both down and up for the active version - the most recent is not necessarily the active one. activeCol, isActiveFound = db.getActiveCollectionUp(ctx, txn, colsBySourceID, rootCol.ID) } diff --git a/db/collection_delete.go b/db/collection_delete.go index 6c360d09c0..785b2830d7 100644 --- a/db/collection_delete.go +++ b/db/collection_delete.go @@ -57,7 +57,7 @@ func (c *collection) DeleteWithDocID( defer c.discardImplicitTxn(ctx, txn) dsKey := c.getPrimaryKeyFromDocID(docID) - res, err := c.deleteWithKey(ctx, txn, dsKey, client.Deleted) + res, err := c.deleteWithKey(ctx, txn, dsKey) if err != nil { return nil, err } @@ -109,7 +109,6 @@ func (c *collection) deleteWithKey( ctx context.Context, txn datastore.Txn, key core.PrimaryDataStoreKey, - status client.DocumentStatus, ) (*client.DeleteResult, error) { // Check the key we have been given to delete with actually has a corresponding // document (i.e. document actually exists in the collection). @@ -131,7 +130,7 @@ func (c *collection) deleteWithIDs( ctx context.Context, txn datastore.Txn, docIDs []client.DocID, - status client.DocumentStatus, + _ client.DocumentStatus, ) (*client.DeleteResult, error) { results := &client.DeleteResult{ DocIDs: make([]string, 0), @@ -160,7 +159,7 @@ func (c *collection) deleteWithFilter( ctx context.Context, txn datastore.Txn, filter any, - status client.DocumentStatus, + _ client.DocumentStatus, ) (*client.DeleteResult, error) { // Make a selection plan that will scan through only the documents with matching filter. selectionPlan, err := c.makeSelectionPlan(ctx, txn, filter) diff --git a/db/collection_index.go b/db/collection_index.go index 09e0b1c3ec..b95154097c 100644 --- a/db/collection_index.go +++ b/db/collection_index.go @@ -193,7 +193,7 @@ func (c *collection) createIndex( return nil, err } - err = c.checkExistingFields(ctx, desc.Fields) + err = c.checkExistingFields(desc.Fields) if err != nil { return nil, err } @@ -408,7 +408,6 @@ func (c *collection) GetIndexes(ctx context.Context) ([]client.IndexDescription, } func (c *collection) checkExistingFields( - ctx context.Context, fields []client.IndexedFieldDescription, ) error { collectionFields := c.Schema().Fields @@ -467,16 +466,10 @@ func validateIndexDescription(desc client.IndexDescription) error { if len(desc.Fields) == 0 { return ErrIndexMissingFields } - if len(desc.Fields) == 1 && desc.Fields[0].Direction == client.Descending { - return ErrIndexSingleFieldWrongDirection - } for i := range desc.Fields { if desc.Fields[i].Name == "" { return ErrIndexFieldMissingName } - if desc.Fields[i].Direction == "" { - desc.Fields[i].Direction = client.Ascending - } } return nil } diff --git a/db/errors.go b/db/errors.go index b907331950..34dd0d53b5 100644 --- a/db/errors.go +++ b/db/errors.go @@ -51,7 +51,6 @@ const ( errNonZeroIndexIDProvided string = "non-zero index ID provided" errIndexFieldMissingName string = "index field missing name" errIndexFieldMissingDirection string = "index field missing direction" - errIndexSingleFieldWrongDirection string = "wrong direction for index with a single field" errIndexWithNameAlreadyExists string = "index with name already exists" errInvalidStoredIndex string = "invalid stored index" errInvalidStoredIndexKey string = "invalid stored index key" @@ -89,24 +88,23 @@ const ( ) var ( - ErrFailedToGetCollection = errors.New(errFailedToGetCollection) - ErrSubscriptionsNotAllowed = errors.New("server does not accept subscriptions") - ErrInvalidFilter = errors.New("invalid filter") - ErrCollectionAlreadyExists = errors.New("collection already exists") - ErrCollectionNameEmpty = errors.New("collection name can't be empty") - ErrSchemaNameEmpty = errors.New("schema name can't be empty") - ErrSchemaRootEmpty = errors.New("schema root can't be empty") - ErrSchemaVersionIDEmpty = errors.New("schema version ID can't be empty") - ErrKeyEmpty = errors.New("key cannot be empty") - ErrCannotSetVersionID = errors.New(errCannotSetVersionID) - ErrIndexMissingFields = errors.New(errIndexMissingFields) - ErrIndexFieldMissingName = errors.New(errIndexFieldMissingName) - ErrIndexSingleFieldWrongDirection = errors.New(errIndexSingleFieldWrongDirection) - ErrCorruptedIndex = errors.New(errCorruptedIndex) - ErrExpectedJSONObject = errors.New(errExpectedJSONObject) - ErrExpectedJSONArray = errors.New(errExpectedJSONArray) - ErrInvalidViewQuery = errors.New(errInvalidViewQuery) - ErrCanNotIndexNonUniqueFields = errors.New(errCanNotIndexNonUniqueFields) + ErrFailedToGetCollection = errors.New(errFailedToGetCollection) + ErrSubscriptionsNotAllowed = errors.New("server does not accept subscriptions") + ErrInvalidFilter = errors.New("invalid filter") + ErrCollectionAlreadyExists = errors.New("collection already exists") + ErrCollectionNameEmpty = errors.New("collection name can't be empty") + ErrSchemaNameEmpty = errors.New("schema name can't be empty") + ErrSchemaRootEmpty = errors.New("schema root can't be empty") + ErrSchemaVersionIDEmpty = errors.New("schema version ID can't be empty") + ErrKeyEmpty = errors.New("key cannot be empty") + ErrCannotSetVersionID = errors.New(errCannotSetVersionID) + ErrIndexMissingFields = errors.New(errIndexMissingFields) + ErrIndexFieldMissingName = errors.New(errIndexFieldMissingName) + ErrCorruptedIndex = errors.New(errCorruptedIndex) + ErrExpectedJSONObject = errors.New(errExpectedJSONObject) + ErrExpectedJSONArray = errors.New(errExpectedJSONArray) + ErrInvalidViewQuery = errors.New(errInvalidViewQuery) + ErrCanNotIndexNonUniqueFields = errors.New(errCanNotIndexNonUniqueFields) ) // NewErrFailedToGetHeads returns a new error indicating that the heads of a document diff --git a/db/fetcher/encoded_doc.go b/db/fetcher/encoded_doc.go index 031ebe091f..889aea848a 100644 --- a/db/fetcher/encoded_doc.go +++ b/db/fetcher/encoded_doc.go @@ -60,7 +60,7 @@ func (e encProperty) Decode() (any, error) { return nil, err } - return core.DecodeFieldValue(e.Desc, val) + return core.NormalizeFieldValue(e.Desc, val) } // @todo: Implement Encoded Document type diff --git a/db/fetcher/errors.go b/db/fetcher/errors.go index 2ff87225c7..2a2967bbdb 100644 --- a/db/fetcher/errors.go +++ b/db/fetcher/errors.go @@ -11,6 +11,8 @@ package fetcher import ( + "fmt" + "github.com/sourcenetwork/defradb/errors" ) @@ -28,6 +30,7 @@ const ( errMissingMapper string = "missing document mapper" errInvalidInOperatorValue string = "invalid _in/_nin value" errInvalidFilterOperator string = "invalid filter operator is provided" + errUnexpectedTypeValue string = "unexpected type value" ) var ( @@ -45,6 +48,7 @@ var ( ErrSingleSpanOnly = errors.New("spans must contain only a single entry") ErrInvalidInOperatorValue = errors.New(errInvalidInOperatorValue) ErrInvalidFilterOperator = errors.New(errInvalidFilterOperator) + ErrUnexpectedTypeValue = errors.New(errUnexpectedTypeValue) ) // NewErrFieldIdNotFound returns an error indicating that the given FieldId was not found. @@ -102,3 +106,9 @@ func NewErrFailedToGetDagNode(inner error) error { func NewErrInvalidFilterOperator(operator string) error { return errors.New(errInvalidFilterOperator, errors.NewKV("Operator", operator)) } + +// NewErrUnexpectedTypeValue returns an error indicating that the given value is of an unexpected type. +func NewErrUnexpectedTypeValue[T any](value any) error { + var t T + return errors.New(errUnexpectedTypeValue, errors.NewKV("Value", value), errors.NewKV("Type", fmt.Sprintf("%T", t))) +} diff --git a/db/fetcher/indexer.go b/db/fetcher/indexer.go index 3dc39d2f9e..158c7cb88d 100644 --- a/db/fetcher/indexer.go +++ b/db/fetcher/indexer.go @@ -123,26 +123,35 @@ func (f *IndexFetcher) FetchNext(ctx context.Context) (EncodedDocument, ExecInfo return nil, f.execInfo, nil } - // This CBOR-specific value will be gone soon once we implement - // our own encryption package hasNilField := false - const cborNil = 0xf6 for i, indexedField := range f.indexedFields { - property := &encProperty{ - Desc: indexedField, - Raw: res.key.FieldValues[i], - } - if len(res.key.FieldValues[i]) == 1 && res.key.FieldValues[i][0] == cborNil { + property := &encProperty{Desc: indexedField} + + field := res.key.Fields[i] + if field.Value == nil { hasNilField = true } + // We need to convert it to cbor bytes as this is what it will be encoded from on value retrieval. + // In the future we have to either get rid of CBOR or properly handle different encoding + // for properties in a single document. + fieldBytes, err := client.NewFieldValue(client.NONE_CRDT, field.Value).Bytes() + if err != nil { + return nil, ExecInfo{}, err + } + property.Raw = fieldBytes + f.doc.properties[indexedField] = property } if f.indexDesc.Unique && !hasNilField { f.doc.id = res.value } else { - f.doc.id = res.key.FieldValues[len(res.key.FieldValues)-1] + docID, ok := res.key.Fields[len(res.key.Fields)-1].Value.(string) + if !ok { + return nil, ExecInfo{}, err + } + f.doc.id = []byte(docID) } if f.docFetcher != nil && len(f.docFields) > 0 { diff --git a/db/fetcher/indexer_iterators.go b/db/fetcher/indexer_iterators.go index 66ad4c967c..b79e3bc9d7 100644 --- a/db/fetcher/indexer_iterators.go +++ b/db/fetcher/indexer_iterators.go @@ -11,12 +11,11 @@ package fetcher import ( - "bytes" + "cmp" "context" "errors" "strings" - "github.com/fxamacker/cbor/v2" ds "github.com/ipfs/go-datastore" "github.com/sourcenetwork/defradb/client" @@ -63,55 +62,57 @@ type indexIterResult struct { } type queryResultIterator struct { - resultIter query.Results + resultIter query.Results + indexDesc client.IndexDescription + indexedFields []client.FieldDefinition } -func (i *queryResultIterator) Next() (indexIterResult, error) { - res, hasVal := i.resultIter.NextSync() +func (iter *queryResultIterator) Next() (indexIterResult, error) { + res, hasVal := iter.resultIter.NextSync() if res.Error != nil { return indexIterResult{}, res.Error } if !hasVal { return indexIterResult{}, nil } - key, err := core.NewIndexDataStoreKey(res.Key) + key, err := core.DecodeIndexDataStoreKey([]byte(res.Key), &iter.indexDesc, iter.indexedFields) if err != nil { return indexIterResult{}, err } + return indexIterResult{key: key, value: res.Value, foundKey: true}, nil } -func (i *queryResultIterator) Close() error { - return i.resultIter.Close() +func (iter *queryResultIterator) Close() error { + return iter.resultIter.Close() } type eqPrefixIndexIterator struct { + queryResultIterator indexKey core.IndexDataStoreKey execInfo *ExecInfo matchers []valueMatcher - - queryResultIterator } -func (i *eqPrefixIndexIterator) Init(ctx context.Context, store datastore.DSReaderWriter) error { +func (iter *eqPrefixIndexIterator) Init(ctx context.Context, store datastore.DSReaderWriter) error { resultIter, err := store.Query(ctx, query.Query{ - Prefix: i.indexKey.ToString(), + Prefix: iter.indexKey.ToString(), }) if err != nil { return err } - i.resultIter = resultIter + iter.resultIter = resultIter return nil } -func (i *eqPrefixIndexIterator) Next() (indexIterResult, error) { +func (iter *eqPrefixIndexIterator) Next() (indexIterResult, error) { for { - res, err := i.queryResultIterator.Next() + res, err := iter.queryResultIterator.Next() if err != nil || !res.foundKey { return res, err } - i.execInfo.IndexesFetched++ - doesMatch, err := executeValueMatchers(i.matchers, res.key.FieldValues) + iter.execInfo.IndexesFetched++ + doesMatch, err := executeValueMatchers(iter.matchers, res.key.Fields) if err != nil { return indexIterResult{}, err } @@ -123,39 +124,33 @@ func (i *eqPrefixIndexIterator) Next() (indexIterResult, error) { } type eqSingleIndexIterator struct { - indexKey core.IndexDataStoreKey - keyFieldValues [][]byte - execInfo *ExecInfo + indexKey core.IndexDataStoreKey + execInfo *ExecInfo ctx context.Context store datastore.DSReaderWriter } -func (i *eqSingleIndexIterator) SetKeyFieldValue(value []byte) { - i.keyFieldValues = [][]byte{value} -} - -func (i *eqSingleIndexIterator) Init(ctx context.Context, store datastore.DSReaderWriter) error { - i.ctx = ctx - i.store = store +func (iter *eqSingleIndexIterator) Init(ctx context.Context, store datastore.DSReaderWriter) error { + iter.ctx = ctx + iter.store = store return nil } -func (i *eqSingleIndexIterator) Next() (indexIterResult, error) { - if i.store == nil { +func (iter *eqSingleIndexIterator) Next() (indexIterResult, error) { + if iter.store == nil { return indexIterResult{}, nil } - i.indexKey.FieldValues = i.keyFieldValues - val, err := i.store.Get(i.ctx, i.indexKey.ToDS()) + val, err := iter.store.Get(iter.ctx, iter.indexKey.ToDS()) if err != nil { if errors.Is(err, ds.ErrNotFound) { - return indexIterResult{key: i.indexKey}, nil + return indexIterResult{key: iter.indexKey}, nil } return indexIterResult{}, err } - i.store = nil - i.execInfo.IndexesFetched++ - return indexIterResult{key: i.indexKey, value: val, foundKey: true}, nil + iter.store = nil + iter.execInfo.IndexesFetched++ + return indexIterResult{key: iter.indexKey, value: val, foundKey: true}, nil } func (i *eqSingleIndexIterator) Close() error { @@ -164,55 +159,55 @@ func (i *eqSingleIndexIterator) Close() error { type inIndexIterator struct { indexIterator - inValues [][]byte + inValues []any nextValIndex int ctx context.Context store datastore.DSReaderWriter hasIterator bool } -func (i *inIndexIterator) nextIterator() (bool, error) { - if i.nextValIndex > 0 { - err := i.indexIterator.Close() +func (iter *inIndexIterator) nextIterator() (bool, error) { + if iter.nextValIndex > 0 { + err := iter.indexIterator.Close() if err != nil { return false, err } } - if i.nextValIndex >= len(i.inValues) { + if iter.nextValIndex >= len(iter.inValues) { return false, nil } - switch fieldIter := i.indexIterator.(type) { + switch fieldIter := iter.indexIterator.(type) { case *eqPrefixIndexIterator: - fieldIter.indexKey.FieldValues[0] = i.inValues[i.nextValIndex] + fieldIter.indexKey.Fields[0].Value = iter.inValues[iter.nextValIndex] case *eqSingleIndexIterator: - fieldIter.keyFieldValues[0] = i.inValues[i.nextValIndex] + fieldIter.indexKey.Fields[0].Value = iter.inValues[iter.nextValIndex] } - err := i.indexIterator.Init(i.ctx, i.store) + err := iter.indexIterator.Init(iter.ctx, iter.store) if err != nil { return false, err } - i.nextValIndex++ + iter.nextValIndex++ return true, nil } -func (i *inIndexIterator) Init(ctx context.Context, store datastore.DSReaderWriter) error { - i.ctx = ctx - i.store = store +func (iter *inIndexIterator) Init(ctx context.Context, store datastore.DSReaderWriter) error { + iter.ctx = ctx + iter.store = store var err error - i.hasIterator, err = i.nextIterator() + iter.hasIterator, err = iter.nextIterator() return err } -func (i *inIndexIterator) Next() (indexIterResult, error) { - for i.hasIterator { - res, err := i.indexIterator.Next() +func (iter *inIndexIterator) Next() (indexIterResult, error) { + for iter.hasIterator { + res, err := iter.indexIterator.Next() if err != nil { return indexIterResult{}, err } if !res.foundKey { - i.hasIterator, err = i.nextIterator() + iter.hasIterator, err = iter.nextIterator() if err != nil { return indexIterResult{}, err } @@ -223,19 +218,13 @@ func (i *inIndexIterator) Next() (indexIterResult, error) { return indexIterResult{}, nil } -func (i *inIndexIterator) Close() error { +func (iter *inIndexIterator) Close() error { return nil } -type errorCheckingFilter struct { - matchers []valueMatcher - err error - execInfo *ExecInfo -} - -func executeValueMatchers(matchers []valueMatcher, values [][]byte) (bool, error) { +func executeValueMatchers(matchers []valueMatcher, fields []core.IndexedField) (bool, error) { for i := range matchers { - res, err := matchers[i].Match(values[i]) + res, err := matchers[i].Match(fields[i].Value) if err != nil { return false, err } @@ -246,90 +235,149 @@ func executeValueMatchers(matchers []valueMatcher, values [][]byte) (bool, error return true, nil } -func (f *errorCheckingFilter) Filter(e query.Entry) bool { - if f.err != nil { - return false - } - f.execInfo.IndexesFetched++ - - indexKey, err := core.NewIndexDataStoreKey(e.Key) - if err != nil { - f.err = err - return false - } - - var res bool - res, f.err = executeValueMatchers(f.matchers, indexKey.FieldValues) - return res -} - type scanningIndexIterator struct { queryResultIterator indexKey core.IndexDataStoreKey matchers []valueMatcher - filter errorCheckingFilter execInfo *ExecInfo } -func (i *scanningIndexIterator) Init(ctx context.Context, store datastore.DSReaderWriter) error { - i.filter.matchers = i.matchers - i.filter.execInfo = i.execInfo - - iter, err := store.Query(ctx, query.Query{ - Prefix: i.indexKey.ToString(), - Filters: []query.Filter{&i.filter}, +func (iter *scanningIndexIterator) Init(ctx context.Context, store datastore.DSReaderWriter) error { + resultIter, err := store.Query(ctx, query.Query{ + Prefix: iter.indexKey.ToString(), }) if err != nil { return err } - i.resultIter = iter + iter.resultIter = resultIter return nil } -func (i *scanningIndexIterator) Next() (indexIterResult, error) { - res, err := i.queryResultIterator.Next() - if i.filter.err != nil { - return indexIterResult{}, i.filter.err +func (iter *scanningIndexIterator) Next() (indexIterResult, error) { + for { + res, err := iter.queryResultIterator.Next() + if err != nil || !res.foundKey { + return indexIterResult{}, err + } + iter.execInfo.IndexesFetched++ + + didMatch, err := executeValueMatchers(iter.matchers, res.key.Fields) + + if didMatch { + return res, err + } } - return res, err } // checks if the value satisfies the condition type valueMatcher interface { - Match([]byte) (bool, error) + Match(any) (bool, error) +} + +type intMatcher struct { + value int64 + evalFunc func(int64, int64) bool } -// indexByteValuesMatcher is a filter that compares the index value with a given value. -// It uses bytes.Compare to compare the values and evaluate the result with evalFunc. -type indexByteValuesMatcher struct { - value []byte - // evalFunc receives a result of bytes.Compare - evalFunc func(int) bool +func (m *intMatcher) Match(value any) (bool, error) { + intVal, ok := value.(int64) + if !ok { + return false, NewErrUnexpectedTypeValue[int64](value) + } + return m.evalFunc(intVal, m.value), nil } -func (m *indexByteValuesMatcher) Match(value []byte) (bool, error) { - res := bytes.Compare(value, m.value) - return m.evalFunc(res), nil +type floatMatcher struct { + value float64 + evalFunc func(float64, float64) bool +} + +func (m *floatMatcher) Match(value any) (bool, error) { + floatVal, ok := value.(float64) + if !ok { + return false, NewErrUnexpectedTypeValue[float64](value) + } + return m.evalFunc(m.value, floatVal), nil +} + +type stringMatcher struct { + value string + evalFunc func(string, string) bool +} + +func (m *stringMatcher) Match(value any) (bool, error) { + stringVal, ok := value.(string) + if !ok { + return false, NewErrUnexpectedTypeValue[string](value) + } + return m.evalFunc(m.value, stringVal), nil +} + +type nilMatcher struct{} + +func (m *nilMatcher) Match(value any) (bool, error) { + return value == nil, nil } // checks if the index value is or is not in the given array type indexInArrayMatcher struct { - values map[string]bool - isIn bool + inValues []any + isIn bool +} + +func newNinIndexCmp(values []any, kind client.FieldKind, isIn bool) (*indexInArrayMatcher, error) { + normalizeValueFunc := getNormalizeValueFunc(kind) + for i := range values { + normalized, err := normalizeValueFunc(values[i]) + if err != nil { + return nil, err + } + values[i] = normalized + } + return &indexInArrayMatcher{inValues: values, isIn: isIn}, nil } -func newNinIndexCmp(values [][]byte, isIn bool) *indexInArrayMatcher { - valuesMap := make(map[string]bool) - for _, v := range values { - valuesMap[string(v)] = true +func getNormalizeValueFunc(kind client.FieldKind) func(any) (any, error) { + switch kind { + case client.FieldKind_NILLABLE_INT: + return func(value any) (any, error) { + if v, ok := value.(int64); ok { + return v, nil + } + if v, ok := value.(int32); ok { + return int64(v), nil + } + return nil, ErrInvalidInOperatorValue + } + case client.FieldKind_NILLABLE_FLOAT: + return func(value any) (any, error) { + if v, ok := value.(float64); ok { + return v, nil + } + if v, ok := value.(float32); ok { + return float64(v), nil + } + return nil, ErrInvalidInOperatorValue + } + case client.FieldKind_NILLABLE_STRING: + return func(value any) (any, error) { + if v, ok := value.(string); ok { + return v, nil + } + return nil, ErrInvalidInOperatorValue + } } - return &indexInArrayMatcher{values: valuesMap, isIn: isIn} + return nil } -func (m *indexInArrayMatcher) Match(value []byte) (bool, error) { - _, found := m.values[string(value)] - return found == m.isIn, nil +func (m *indexInArrayMatcher) Match(value any) (bool, error) { + for _, inVal := range m.inValues { + if inVal == value { + return m.isIn, nil + } + } + return !m.isIn, nil } // checks if the index value satisfies the LIKE condition @@ -341,7 +389,7 @@ type indexLikeMatcher struct { value string } -func newLikeIndexCmp(filterValue string, isLike bool) *indexLikeMatcher { +func newLikeIndexCmp(filterValue string, isLike bool) (*indexLikeMatcher, error) { matcher := &indexLikeMatcher{ isLike: isLike, } @@ -360,14 +408,13 @@ func newLikeIndexCmp(filterValue string, isLike bool) *indexLikeMatcher { } matcher.value = filterValue - return matcher + return matcher, nil } -func (m *indexLikeMatcher) Match(value []byte) (bool, error) { - var currentVal string - err := cbor.Unmarshal(value, ¤tVal) - if err != nil { - return false, err +func (m *indexLikeMatcher) Match(value any) (bool, error) { + currentVal, ok := value.(string) + if !ok { + return false, NewErrUnexpectedTypeValue[string](currentVal) } return m.doesMatch(currentVal) == m.isLike, nil @@ -392,136 +439,7 @@ func (m *indexLikeMatcher) doesMatch(currentVal string) bool { type anyMatcher struct{} -func (m *anyMatcher) Match([]byte) (bool, error) { return true, nil } - -func createValueMatcher(op string, filterVal any) (valueMatcher, error) { - switch op { - case opEq, opGt, opGe, opLt, opLe, opNe: - fieldValue := client.NewFieldValue(client.LWW_REGISTER, filterVal) - - valueBytes, err := fieldValue.Bytes() - if err != nil { - return nil, err - } - - m := &indexByteValuesMatcher{value: valueBytes} - switch op { - case opEq: - m.evalFunc = func(res int) bool { return res == 0 } - case opGt: - m.evalFunc = func(res int) bool { return res > 0 } - case opGe: - m.evalFunc = func(res int) bool { return res > 0 || res == 0 } - case opLt: - m.evalFunc = func(res int) bool { return res < 0 } - case opLe: - m.evalFunc = func(res int) bool { return res < 0 || res == 0 } - case opNe: - m.evalFunc = func(res int) bool { return res != 0 } - } - return m, nil - case opIn, opNin: - inArr, ok := filterVal.([]any) - if !ok { - return nil, ErrInvalidInOperatorValue - } - valArr := make([][]byte, 0, len(inArr)) - for _, v := range inArr { - fieldValue := client.NewFieldValue(client.LWW_REGISTER, v) - valueBytes, err := fieldValue.Bytes() - if err != nil { - return nil, err - } - valArr = append(valArr, valueBytes) - } - return newNinIndexCmp(valArr, op == opIn), nil - case opLike, opNlike: - return newLikeIndexCmp(filterVal.(string), op == opLike), nil - case opAny: - return &anyMatcher{}, nil - } - - return nil, NewErrInvalidFilterOperator(op) -} - -func createValueMatchers(conditions []fieldFilterCond) ([]valueMatcher, error) { - matchers := make([]valueMatcher, 0, len(conditions)) - for i := range conditions { - m, err := createValueMatcher(conditions[i].op, conditions[i].val) - if err != nil { - return nil, err - } - matchers = append(matchers, m) - } - return matchers, nil -} - -type fieldFilterCond struct { - op string - val any -} - -func (f *IndexFetcher) determineFieldFilterConditions() []fieldFilterCond { - result := make([]fieldFilterCond, 0, len(f.indexedFields)) - for i := range f.indexedFields { - fieldInd := f.mapping.FirstIndexOfName(f.indexedFields[i].Name) - found := false - // iterate through conditions and find the one that matches the current field - for filterKey, indexFilterCond := range f.indexFilter.Conditions { - propKey, ok := filterKey.(*mapper.PropertyIndex) - if !ok || fieldInd != propKey.Index { - continue - } - - found = true - - condMap := indexFilterCond.(map[connor.FilterKey]any) - for key, filterVal := range condMap { - opKey := key.(*mapper.Operator) - result = append(result, fieldFilterCond{op: opKey.Operation, val: filterVal}) - break - } - break - } - if !found { - result = append(result, fieldFilterCond{op: opAny}) - } - } - return result -} - -// isUniqueFetchByFullKey checks if the only index key can be fetched by the full index key. -// -// This method ignores the first condition (unless it's nil) because it's expected to be called only -// when the first field is used as a prefix in the index key. So we only check if the -// rest of the conditions are _eq. -func isUniqueFetchByFullKey(indexDesc *client.IndexDescription, conditions []fieldFilterCond) bool { - // we need to check length of conditions because full key fetch is only possible - // if all fields of the index are specified in the filter - res := indexDesc.Unique && len(conditions) == len(indexDesc.Fields) - - // first condition is not required to be _eq, but if is, val must be not nil - res = res && (conditions[0].op != opEq || conditions[0].val != nil) - - // for the rest it must be _eq and val must be not nil - for i := 1; i < len(conditions); i++ { - res = res && (conditions[i].op == opEq && conditions[i].val != nil) - } - return res -} - -func getFieldsBytes(conditions []fieldFilterCond) ([][]byte, error) { - result := make([][]byte, 0, len(conditions)) - for i := range conditions { - fieldVal := client.NewFieldValue(client.LWW_REGISTER, conditions[i].val) - keyFieldBytes, err := fieldVal.Bytes() - if err != nil { - return nil, err - } - result = append(result, keyFieldBytes) - } - return result, nil -} +func (m *anyMatcher) Match(any) (bool, error) { return true, nil } // newPrefixIndexIterator creates a new eqPrefixIndexIterator for fetching indexed data. // It can modify the input matchers slice. @@ -529,7 +447,7 @@ func (f *IndexFetcher) newPrefixIndexIterator( fieldConditions []fieldFilterCond, matchers []valueMatcher, ) (*eqPrefixIndexIterator, error) { - keyFieldValues := make([][]byte, 0, len(fieldConditions)) + keyFieldValues := make([]any, 0, len(fieldConditions)) for i := range fieldConditions { if fieldConditions[i].op != opEq { // prefix can be created only for subsequent _eq conditions @@ -537,14 +455,7 @@ func (f *IndexFetcher) newPrefixIndexIterator( break } - fieldVal := client.NewFieldValue(client.LWW_REGISTER, fieldConditions[i].val) - - keyValueBytes, err := fieldVal.Bytes() - if err != nil { - return nil, err - } - - keyFieldValues = append(keyFieldValues, keyValueBytes) + keyFieldValues = append(keyFieldValues, fieldConditions[i].val) } // iterators for _eq filter already iterate over keys with first field value @@ -553,15 +464,20 @@ func (f *IndexFetcher) newPrefixIndexIterator( matchers[0] = &anyMatcher{} } - indexKey := f.newIndexDataStoreKey() - indexKey.FieldValues = keyFieldValues + key := f.newIndexDataStoreKeyWithValues(keyFieldValues) + return &eqPrefixIndexIterator{ - indexKey: indexKey, - execInfo: &f.execInfo, - matchers: matchers, + queryResultIterator: f.newQueryResultIterator(), + indexKey: key, + execInfo: &f.execInfo, + matchers: matchers, }, nil } +func (f *IndexFetcher) newQueryResultIterator() queryResultIterator { + return queryResultIterator{indexDesc: f.indexDesc, indexedFields: f.indexedFields} +} + // newInIndexIterator creates a new inIndexIterator for fetching indexed data. // It can modify the input matchers slice. func (f *IndexFetcher) newInIndexIterator( @@ -572,14 +488,10 @@ func (f *IndexFetcher) newInIndexIterator( if !ok { return nil, ErrInvalidInOperatorValue } - keyFieldArr := make([][]byte, 0, len(inArr)) + inValues := make([]any, 0, len(inArr)) for _, v := range inArr { - fieldVal := client.NewFieldValue(client.LWW_REGISTER, v) - keyFieldBytes, err := fieldVal.Bytes() - if err != nil { - return nil, err - } - keyFieldArr = append(keyFieldArr, keyFieldBytes) + fieldVal := client.NewFieldValue(client.NONE_CRDT, v) + inValues = append(inValues, fieldVal.Value()) } // iterators for _in filter already iterate over keys with first field value @@ -590,33 +502,46 @@ func (f *IndexFetcher) newInIndexIterator( var iter indexIterator if isUniqueFetchByFullKey(&f.indexDesc, fieldConditions) { - keyFieldValues, e := getFieldsBytes(fieldConditions[1:]) - if e != nil { - return nil, e + keyFieldValues := make([]any, len(fieldConditions)) + for i := range fieldConditions { + keyFieldValues[i] = fieldConditions[i].val } - keyFieldValues = append([][]byte{{}}, keyFieldValues...) + + key := f.newIndexDataStoreKeyWithValues(keyFieldValues) + iter = &eqSingleIndexIterator{ - indexKey: f.newIndexDataStoreKey(), - execInfo: &f.execInfo, - keyFieldValues: keyFieldValues, + indexKey: key, + execInfo: &f.execInfo, } } else { indexKey := f.newIndexDataStoreKey() - indexKey.FieldValues = [][]byte{{}} + indexKey.Fields = []core.IndexedField{{Descending: f.indexDesc.Fields[0].Descending}} + iter = &eqPrefixIndexIterator{ - indexKey: indexKey, - execInfo: &f.execInfo, - matchers: matchers, + queryResultIterator: f.newQueryResultIterator(), + indexKey: indexKey, + execInfo: &f.execInfo, + matchers: matchers, } } return &inIndexIterator{ indexIterator: iter, - inValues: keyFieldArr, + inValues: inValues, }, nil } func (f *IndexFetcher) newIndexDataStoreKey() core.IndexDataStoreKey { - return core.IndexDataStoreKey{CollectionID: f.col.ID(), IndexID: f.indexDesc.ID} + key := core.IndexDataStoreKey{CollectionID: f.col.ID(), IndexID: f.indexDesc.ID} + return key +} + +func (f *IndexFetcher) newIndexDataStoreKeyWithValues(values []any) core.IndexDataStoreKey { + fields := make([]core.IndexedField, len(values)) + for i := range values { + fields[i].Value = values[i] + fields[i].Descending = f.indexDesc.Fields[i].Descending + } + return core.NewIndexDataStoreKey(f.col.ID(), f.indexDesc.ID, fields) } func (f *IndexFetcher) createIndexIterator() (indexIterator, error) { @@ -630,14 +555,16 @@ func (f *IndexFetcher) createIndexIterator() (indexIterator, error) { switch fieldConditions[0].op { case opEq: if isUniqueFetchByFullKey(&f.indexDesc, fieldConditions) { - keyFieldsBytes, err := getFieldsBytes(fieldConditions) - if err != nil { - return nil, err + keyFieldValues := make([]any, len(fieldConditions)) + for i := range fieldConditions { + keyFieldValues[i] = fieldConditions[i].val } + + key := f.newIndexDataStoreKeyWithValues(keyFieldValues) + return &eqSingleIndexIterator{ - indexKey: f.newIndexDataStoreKey(), - keyFieldValues: keyFieldsBytes, - execInfo: &f.execInfo, + indexKey: key, + execInfo: &f.execInfo, }, nil } else { return f.newPrefixIndexIterator(fieldConditions, matchers) @@ -646,11 +573,168 @@ func (f *IndexFetcher) createIndexIterator() (indexIterator, error) { return f.newInIndexIterator(fieldConditions, matchers) case opGt, opGe, opLt, opLe, opNe, opNin, opLike, opNlike: return &scanningIndexIterator{ - indexKey: f.newIndexDataStoreKey(), - matchers: matchers, - execInfo: &f.execInfo, + queryResultIterator: f.newQueryResultIterator(), + indexKey: f.newIndexDataStoreKey(), + matchers: matchers, + execInfo: &f.execInfo, }, nil } return nil, NewErrInvalidFilterOperator(fieldConditions[0].op) } + +func createValueMatcher(condition *fieldFilterCond) (valueMatcher, error) { + if condition.op == "" { + return &anyMatcher{}, nil + } + + if client.IsNillableKind(condition.kind) && condition.val == nil { + return &nilMatcher{}, nil + } + + switch condition.op { + case opEq, opGt, opGe, opLt, opLe, opNe: + switch condition.kind { + case client.FieldKind_NILLABLE_INT: + var intVal int64 + switch v := condition.val.(type) { + case int64: + intVal = v + case int32: + intVal = int64(v) + case int: + intVal = int64(v) + default: + return nil, NewErrUnexpectedTypeValue[int64](condition.val) + } + return &intMatcher{value: intVal, evalFunc: getCompareValsFunc[int64](condition.op)}, nil + case client.FieldKind_NILLABLE_FLOAT: + floatVal, ok := condition.val.(float64) + if !ok { + return nil, NewErrUnexpectedTypeValue[float64](condition.val) + } + return &floatMatcher{value: floatVal, evalFunc: getCompareValsFunc[float64](condition.op)}, nil + case client.FieldKind_DocID, client.FieldKind_NILLABLE_STRING: + strVal, ok := condition.val.(string) + if !ok { + return nil, NewErrUnexpectedTypeValue[string](condition.val) + } + return &stringMatcher{value: strVal, evalFunc: getCompareValsFunc[string](condition.op)}, nil + } + case opIn, opNin: + inArr, ok := condition.val.([]any) + if !ok { + return nil, ErrInvalidInOperatorValue + } + return newNinIndexCmp(inArr, condition.kind, condition.op == opIn) + case opLike, opNlike: + strVal, ok := condition.val.(string) + if !ok { + return nil, NewErrUnexpectedTypeValue[string](condition.val) + } + return newLikeIndexCmp(strVal, condition.op == opLike) + case opAny: + return &anyMatcher{}, nil + } + + return nil, NewErrInvalidFilterOperator(condition.op) +} + +func createValueMatchers(conditions []fieldFilterCond) ([]valueMatcher, error) { + matchers := make([]valueMatcher, 0, len(conditions)) + for i := range conditions { + m, err := createValueMatcher(&conditions[i]) + if err != nil { + return nil, err + } + matchers = append(matchers, m) + } + return matchers, nil +} + +type fieldFilterCond struct { + op string + val any + kind client.FieldKind +} + +// determineFieldFilterConditions determines the conditions and their corresponding operation +// for each indexed field. +// It returns a slice of fieldFilterCond, where each element corresponds to a field in the index. +func (f *IndexFetcher) determineFieldFilterConditions() []fieldFilterCond { + result := make([]fieldFilterCond, 0, len(f.indexedFields)) + for i := range f.indexedFields { + fieldInd := f.mapping.FirstIndexOfName(f.indexedFields[i].Name) + found := false + // iterate through conditions and find the one that matches the current field + for filterKey, indexFilterCond := range f.indexFilter.Conditions { + propKey, ok := filterKey.(*mapper.PropertyIndex) + if !ok || fieldInd != propKey.Index { + continue + } + + found = true + + condMap := indexFilterCond.(map[connor.FilterKey]any) + for key, filterVal := range condMap { + opKey := key.(*mapper.Operator) + result = append(result, fieldFilterCond{ + op: opKey.Operation, + val: filterVal, + kind: f.indexedFields[i].Kind, + }) + break + } + break + } + if !found { + result = append(result, fieldFilterCond{op: opAny}) + } + } + return result +} + +// isUniqueFetchByFullKey checks if the only index key can be fetched by the full index key. +// +// This method ignores the first condition (unless it's nil) because it's expected to be called only +// when the first field is used as a prefix in the index key. So we only check if the +// rest of the conditions are _eq. +func isUniqueFetchByFullKey(indexDesc *client.IndexDescription, conditions []fieldFilterCond) bool { + // we need to check length of conditions because full key fetch is only possible + // if all fields of the index are specified in the filter + res := indexDesc.Unique && len(conditions) == len(indexDesc.Fields) + + // first condition is not required to be _eq, but if is, val must be not nil + res = res && (conditions[0].op != opEq || conditions[0].val != nil) + + // for the rest it must be _eq and val must be not nil + for i := 1; i < len(conditions); i++ { + res = res && (conditions[i].op == opEq && conditions[i].val != nil) + } + return res +} + +func getCompareValsFunc[T cmp.Ordered](op string) func(T, T) bool { + switch op { + case opGt: + return checkGT + case opGe: + return checkGE + case opLt: + return checkLT + case opLe: + return checkLE + case opEq: + return checkEQ + case opNe: + return checkNE + } + return nil +} + +func checkGE[T cmp.Ordered](a, b T) bool { return a >= b } +func checkGT[T cmp.Ordered](a, b T) bool { return a > b } +func checkLE[T cmp.Ordered](a, b T) bool { return a <= b } +func checkLT[T cmp.Ordered](a, b T) bool { return a < b } +func checkEQ[T cmp.Ordered](a, b T) bool { return a == b } +func checkNE[T cmp.Ordered](a, b T) bool { return a != b } diff --git a/db/index.go b/db/index.go index 03246fc3d1..ddec525598 100644 --- a/db/index.go +++ b/db/index.go @@ -91,19 +91,19 @@ func NewCollectionIndex( return nil, NewErrIndexDescHasNoFields(desc) } base := collectionBaseIndex{collection: collection, desc: desc} - base.validateFieldFuncs = make([]func(any) bool, 0, len(desc.Fields)) - base.fieldsDescs = make([]client.SchemaFieldDescription, 0, len(desc.Fields)) - for _, fieldDesc := range desc.Fields { - field, foundField := collection.Schema().GetFieldByName(fieldDesc.Name) + base.validateFieldFuncs = make([]func(any) bool, len(desc.Fields)) + base.fieldsDescs = make([]client.SchemaFieldDescription, len(desc.Fields)) + for i := range desc.Fields { + field, foundField := collection.Schema().GetFieldByName(desc.Fields[i].Name) if !foundField { - return nil, client.NewErrFieldNotExist(desc.Fields[0].Name) + return nil, client.NewErrFieldNotExist(desc.Fields[i].Name) } - base.fieldsDescs = append(base.fieldsDescs, field) + base.fieldsDescs[i] = field validateFunc, err := getFieldValidateFunc(field.Kind) if err != nil { return nil, err } - base.validateFieldFuncs = append(base.validateFieldFuncs, validateFunc) + base.validateFieldFuncs[i] = validateFunc } if desc.Unique { return &collectionUniqueIndex{collectionBaseIndex: base}, nil @@ -119,50 +119,39 @@ type collectionBaseIndex struct { fieldsDescs []client.SchemaFieldDescription } -func (i *collectionBaseIndex) getDocFieldValue(doc *client.Document) ([][]byte, error) { - result := make([][]byte, 0, len(i.fieldsDescs)) - for iter := range i.fieldsDescs { - fieldVal, err := doc.TryGetValue(i.fieldsDescs[iter].Name) +func (index *collectionBaseIndex) getDocFieldValues(doc *client.Document) ([]*client.FieldValue, error) { + result := make([]*client.FieldValue, 0, len(index.fieldsDescs)) + for iter := range index.fieldsDescs { + fieldVal, err := doc.TryGetValue(index.fieldsDescs[iter].Name) if err != nil { return nil, err } if fieldVal == nil || fieldVal.Value() == nil { - // this will be gone very soon with new encoding of secondary indexes - valBytes, err := client.NewFieldValue(client.LWW_REGISTER, nil).Bytes() - if err != nil { - return nil, err - } - result = append(result, valBytes) + result = append(result, client.NewFieldValue(client.NONE_CRDT, nil)) continue } - if !i.validateFieldFuncs[iter](fieldVal.Value()) { - return nil, NewErrInvalidFieldValue(i.fieldsDescs[iter].Kind, fieldVal) - } - valBytes, err := fieldVal.Bytes() - if err != nil { - return nil, err - } - result = append(result, valBytes) + result = append(result, fieldVal) } return result, nil } -func (i *collectionBaseIndex) getDocumentsIndexKey( +func (index *collectionBaseIndex) getDocumentsIndexKey( doc *client.Document, ) (core.IndexDataStoreKey, error) { - fieldValues, err := i.getDocFieldValue(doc) + fieldValues, err := index.getDocFieldValues(doc) if err != nil { return core.IndexDataStoreKey{}, err } - indexDataStoreKey := core.IndexDataStoreKey{} - indexDataStoreKey.CollectionID = i.collection.ID() - indexDataStoreKey.IndexID = i.desc.ID - indexDataStoreKey.FieldValues = fieldValues - return indexDataStoreKey, nil + fields := make([]core.IndexedField, len(index.fieldsDescs)) + for i := range index.fieldsDescs { + fields[i].Value = fieldValues[i].Value() + fields[i].Descending = index.desc.Fields[i].Descending + } + return core.NewIndexDataStoreKey(index.collection.ID(), index.desc.ID, fields), nil } -func (i *collectionBaseIndex) deleteIndexKey( +func (index *collectionBaseIndex) deleteIndexKey( ctx context.Context, txn datastore.Txn, key core.IndexDataStoreKey, @@ -172,17 +161,17 @@ func (i *collectionBaseIndex) deleteIndexKey( return err } if !exists { - return NewErrCorruptedIndex(i.desc.Name) + return NewErrCorruptedIndex(index.desc.Name) } return txn.Datastore().Delete(ctx, key.ToDS()) } // RemoveAll remove all artifacts of the index from the storage, i.e. all index // field values for all documents. -func (i *collectionBaseIndex) RemoveAll(ctx context.Context, txn datastore.Txn) error { +func (index *collectionBaseIndex) RemoveAll(ctx context.Context, txn datastore.Txn) error { prefixKey := core.IndexDataStoreKey{} - prefixKey.CollectionID = i.collection.ID() - prefixKey.IndexID = i.desc.ID + prefixKey.CollectionID = index.collection.ID() + prefixKey.IndexID = index.desc.ID keys, err := datastore.FetchKeysForPrefix(ctx, prefixKey.ToString(), txn.Datastore()) if err != nil { @@ -200,13 +189,13 @@ func (i *collectionBaseIndex) RemoveAll(ctx context.Context, txn datastore.Txn) } // Name returns the name of the index -func (i *collectionBaseIndex) Name() string { - return i.desc.Name +func (index *collectionBaseIndex) Name() string { + return index.desc.Name } // Description returns the description of the index -func (i *collectionBaseIndex) Description() client.IndexDescription { - return i.desc +func (index *collectionBaseIndex) Description() client.IndexDescription { + return index.desc } // collectionSimpleIndex is an non-unique index that indexes documents by a single field. @@ -217,65 +206,64 @@ type collectionSimpleIndex struct { var _ CollectionIndex = (*collectionSimpleIndex)(nil) -func (i *collectionSimpleIndex) getDocumentsIndexKey( +func (index *collectionSimpleIndex) getDocumentsIndexKey( doc *client.Document, ) (core.IndexDataStoreKey, error) { - key, err := i.collectionBaseIndex.getDocumentsIndexKey(doc) + key, err := index.collectionBaseIndex.getDocumentsIndexKey(doc) if err != nil { return core.IndexDataStoreKey{}, err } - key.FieldValues = append(key.FieldValues, []byte(doc.ID().String())) + key.Fields = append(key.Fields, core.IndexedField{Value: doc.ID().String()}) return key, nil } // Save indexes a document by storing the indexed field value. -func (i *collectionSimpleIndex) Save( +func (index *collectionSimpleIndex) Save( ctx context.Context, txn datastore.Txn, doc *client.Document, ) error { - key, err := i.getDocumentsIndexKey(doc) + key, err := index.getDocumentsIndexKey(doc) if err != nil { return err } err = txn.Datastore().Put(ctx, key.ToDS(), []byte{}) if err != nil { - return NewErrFailedToStoreIndexedField(key.ToDS().String(), err) + return NewErrFailedToStoreIndexedField(key.ToString(), err) } return nil } -func (i *collectionSimpleIndex) Update( +func (index *collectionSimpleIndex) Update( ctx context.Context, txn datastore.Txn, oldDoc *client.Document, newDoc *client.Document, ) error { - err := i.deleteDocIndex(ctx, txn, oldDoc) + err := index.deleteDocIndex(ctx, txn, oldDoc) if err != nil { return err } - return i.Save(ctx, txn, newDoc) + return index.Save(ctx, txn, newDoc) } -func (i *collectionSimpleIndex) deleteDocIndex( +func (index *collectionSimpleIndex) deleteDocIndex( ctx context.Context, txn datastore.Txn, doc *client.Document, ) error { - key, err := i.getDocumentsIndexKey(doc) + key, err := index.getDocumentsIndexKey(doc) if err != nil { return err } - return i.deleteIndexKey(ctx, txn, key) + return index.deleteIndexKey(ctx, txn, key) } // hasIndexKeyNilField returns true if the index key has a field with nil value func hasIndexKeyNilField(key *core.IndexDataStoreKey) bool { - const cborNil = 0xf6 - for i := range key.FieldValues { - if len(key.FieldValues[i]) == 1 && key.FieldValues[i][0] == cborNil { + for i := range key.Fields { + if key.Fields[i].Value == nil { return true } } @@ -288,7 +276,7 @@ type collectionUniqueIndex struct { var _ CollectionIndex = (*collectionUniqueIndex)(nil) -func (i *collectionUniqueIndex) save( +func (index *collectionUniqueIndex) save( ctx context.Context, txn datastore.Txn, key *core.IndexDataStoreKey, @@ -301,24 +289,24 @@ func (i *collectionUniqueIndex) save( return nil } -func (i *collectionUniqueIndex) Save( +func (index *collectionUniqueIndex) Save( ctx context.Context, txn datastore.Txn, doc *client.Document, ) error { - key, val, err := i.prepareIndexRecordToStore(ctx, txn, doc) + key, val, err := index.prepareIndexRecordToStore(ctx, txn, doc) if err != nil { return err } - return i.save(ctx, txn, &key, val) + return index.save(ctx, txn, &key, val) } -func (i *collectionUniqueIndex) newUniqueIndexError( +func (index *collectionUniqueIndex) newUniqueIndexError( doc *client.Document, ) error { - kvs := make([]errors.KV, 0, len(i.fieldsDescs)) - for iter := range i.fieldsDescs { - fieldVal, err := doc.TryGetValue(i.fieldsDescs[iter].Name) + kvs := make([]errors.KV, 0, len(index.fieldsDescs)) + for iter := range index.fieldsDescs { + fieldVal, err := doc.TryGetValue(index.fieldsDescs[iter].Name) var val any if err != nil { return err @@ -327,33 +315,33 @@ func (i *collectionUniqueIndex) newUniqueIndexError( if fieldVal != nil { val = fieldVal.Value() } - kvs = append(kvs, errors.NewKV(i.fieldsDescs[iter].Name, val)) + kvs = append(kvs, errors.NewKV(index.fieldsDescs[iter].Name, val)) } return NewErrCanNotIndexNonUniqueFields(doc.ID().String(), kvs...) } -func (i *collectionUniqueIndex) getDocumentsIndexRecord( +func (index *collectionUniqueIndex) getDocumentsIndexRecord( doc *client.Document, ) (core.IndexDataStoreKey, []byte, error) { - key, err := i.getDocumentsIndexKey(doc) + key, err := index.getDocumentsIndexKey(doc) if err != nil { return core.IndexDataStoreKey{}, nil, err } if hasIndexKeyNilField(&key) { - key.FieldValues = append(key.FieldValues, []byte(doc.ID().String())) + key.Fields = append(key.Fields, core.IndexedField{Value: doc.ID().String()}) return key, []byte{}, nil } else { return key, []byte(doc.ID().String()), nil } } -func (i *collectionUniqueIndex) prepareIndexRecordToStore( +func (index *collectionUniqueIndex) prepareIndexRecordToStore( ctx context.Context, txn datastore.Txn, doc *client.Document, ) (core.IndexDataStoreKey, []byte, error) { - key, val, err := i.getDocumentsIndexRecord(doc) + key, val, err := index.getDocumentsIndexRecord(doc) if err != nil { return core.IndexDataStoreKey{}, nil, err } @@ -364,37 +352,37 @@ func (i *collectionUniqueIndex) prepareIndexRecordToStore( return core.IndexDataStoreKey{}, nil, err } if exists { - return core.IndexDataStoreKey{}, nil, i.newUniqueIndexError(doc) + return core.IndexDataStoreKey{}, nil, index.newUniqueIndexError(doc) } } return key, val, nil } -func (i *collectionUniqueIndex) Update( +func (index *collectionUniqueIndex) Update( ctx context.Context, txn datastore.Txn, oldDoc *client.Document, newDoc *client.Document, ) error { - newKey, newVal, err := i.prepareIndexRecordToStore(ctx, txn, newDoc) + newKey, newVal, err := index.prepareIndexRecordToStore(ctx, txn, newDoc) if err != nil { return err } - err = i.deleteDocIndex(ctx, txn, oldDoc) + err = index.deleteDocIndex(ctx, txn, oldDoc) if err != nil { return err } - return i.save(ctx, txn, &newKey, newVal) + return index.save(ctx, txn, &newKey, newVal) } -func (i *collectionUniqueIndex) deleteDocIndex( +func (index *collectionUniqueIndex) deleteDocIndex( ctx context.Context, txn datastore.Txn, doc *client.Document, ) error { - key, _, err := i.getDocumentsIndexRecord(doc) + key, _, err := index.getDocumentsIndexRecord(doc) if err != nil { return err } - return i.deleteIndexKey(ctx, txn, key) + return index.deleteIndexKey(ctx, txn, key) } diff --git a/db/index_test.go b/db/index_test.go index 19787cfe93..2b0dbdc8b6 100644 --- a/db/index_test.go +++ b/db/index_test.go @@ -148,7 +148,7 @@ func getUsersIndexDescOnName() client.IndexDescription { return client.IndexDescription{ Name: testUsersColIndexName, Fields: []client.IndexedFieldDescription{ - {Name: usersNameFieldName, Direction: client.Ascending}, + {Name: usersNameFieldName}, }, } } @@ -157,7 +157,7 @@ func getUsersIndexDescOnAge() client.IndexDescription { return client.IndexDescription{ Name: testUsersColIndexAge, Fields: []client.IndexedFieldDescription{ - {Name: usersAgeFieldName, Direction: client.Ascending}, + {Name: usersAgeFieldName}, }, } } @@ -166,7 +166,7 @@ func getUsersIndexDescOnWeight() client.IndexDescription { return client.IndexDescription{ Name: testUsersColIndexWeight, Fields: []client.IndexedFieldDescription{ - {Name: usersWeightFieldName, Direction: client.Ascending}, + {Name: usersWeightFieldName}, }, } } @@ -175,7 +175,7 @@ func getProductsIndexDescOnCategory() client.IndexDescription { return client.IndexDescription{ Name: testUsersColIndexAge, Fields: []client.IndexedFieldDescription{ - {Name: productsCategoryFieldName, Direction: client.Ascending}, + {Name: productsCategoryFieldName}, }, } } @@ -200,7 +200,7 @@ func (f *indexTestFixture) createUserCollectionUniqueIndexOnName() client.IndexD func addFieldToIndex(indexDesc client.IndexDescription, fieldName string) client.IndexDescription { indexDesc.Fields = append(indexDesc.Fields, client.IndexedFieldDescription{ - Name: fieldName, Direction: client.Ascending, + Name: fieldName, }) return indexDesc } @@ -222,7 +222,7 @@ func (f *indexTestFixture) dropIndex(colName, indexName string) error { return f.db.dropCollectionIndex(f.ctx, f.txn, colName, indexName) } -func (f *indexTestFixture) countIndexPrefixes(colName, indexName string) int { +func (f *indexTestFixture) countIndexPrefixes(indexName string) int { prefix := core.NewCollectionIndexKey(immutable.Some(f.users.ID()), indexName) q, err := f.txn.Systemstore().Query(f.ctx, query.Query{ Prefix: prefix.ToString(), @@ -289,7 +289,7 @@ func TestCreateIndex_IfIndexDescriptionIDIsNotZero_ReturnError(t *testing.T) { Name: "some_index_name", ID: id, Fields: []client.IndexedFieldDescription{ - {Name: usersNameFieldName, Direction: client.Ascending}, + {Name: usersNameFieldName}, }, } _, err := f.createCollectionIndex(desc) @@ -304,7 +304,7 @@ func TestCreateIndex_IfValidInput_CreateIndex(t *testing.T) { desc := client.IndexDescription{ Name: "some_index_name", Fields: []client.IndexedFieldDescription{ - {Name: usersNameFieldName, Direction: client.Ascending}, + {Name: usersNameFieldName}, }, } resultDesc, err := f.createCollectionIndex(desc) @@ -321,7 +321,7 @@ func TestCreateIndex_IfFieldNameIsEmpty_ReturnError(t *testing.T) { desc := client.IndexDescription{ Name: "some_index_name", Fields: []client.IndexedFieldDescription{ - {Name: "", Direction: client.Ascending}, + {Name: ""}, }, } _, err := f.createCollectionIndex(desc) @@ -338,20 +338,7 @@ func TestCreateIndex_IfFieldHasNoDirection_DefaultToAsc(t *testing.T) { } newDesc, err := f.createCollectionIndex(desc) assert.NoError(t, err) - assert.Equal(t, client.Ascending, newDesc.Fields[0].Direction) -} - -func TestCreateIndex_IfSingleFieldInDescOrder_ReturnError(t *testing.T) { - f := newIndexTestFixture(t) - defer f.db.Close() - - desc := client.IndexDescription{ - Fields: []client.IndexedFieldDescription{ - {Name: usersNameFieldName, Direction: client.Descending}, - }, - } - _, err := f.createCollectionIndex(desc) - assert.EqualError(t, err, errIndexSingleFieldWrongDirection) + assert.False(t, newDesc.Fields[0].Descending) } func TestCreateIndex_IfIndexWithNameAlreadyExists_ReturnError(t *testing.T) { @@ -454,7 +441,7 @@ func TestCreateIndex_WithMultipleCollectionsAndIndexes_AssignIncrementedIDPerCol makeIndex := func(fieldName string) client.IndexDescription { return client.IndexDescription{ Fields: []client.IndexedFieldDescription{ - {Name: fieldName, Direction: client.Ascending}, + {Name: fieldName}, }, } } @@ -536,7 +523,7 @@ func TestCreateIndex_IfAttemptToIndexOnUnsupportedType_ReturnError(t *testing.T) indexDesc := client.IndexDescription{ Fields: []client.IndexedFieldDescription{ - {Name: "field", Direction: client.Ascending}, + {Name: "field"}, }, } @@ -598,7 +585,7 @@ func TestGetIndexes_IfInvalidIndexKeyIsStored_ReturnError(t *testing.T) { desc := client.IndexDescription{ Name: "some_index_name", Fields: []client.IndexedFieldDescription{ - {Name: usersNameFieldName, Direction: client.Ascending}, + {Name: usersNameFieldName}, }, } descData, _ := json.Marshal(desc) @@ -904,7 +891,7 @@ func TestCollectionGetIndexes_IfStoredIndexWithUnsupportedType_ReturnError(t *te indexDesc := client.IndexDescription{ Fields: []client.IndexedFieldDescription{ - {Name: "field", Direction: client.Ascending}, + {Name: "field"}, }, } indexDescData, err := json.Marshal(indexDesc) @@ -1025,7 +1012,7 @@ func TestCollectionGetIndexes_ShouldReturnIndexesInOrderedByName(t *testing.T) { indexDesc := client.IndexDescription{ Name: indexNamePrefix + iStr, Fields: []client.IndexedFieldDescription{ - {Name: fieldNamePrefix + iStr, Direction: client.Ascending}, + {Name: fieldNamePrefix + iStr}, }, } @@ -1166,24 +1153,24 @@ func TestDropAllIndexes_ShouldDeleteAllIndexes(t *testing.T) { defer f.db.Close() _, err := f.createCollectionIndexFor(usersColName, client.IndexDescription{ Fields: []client.IndexedFieldDescription{ - {Name: usersNameFieldName, Direction: client.Ascending}, + {Name: usersNameFieldName}, }, }) assert.NoError(f.t, err) _, err = f.createCollectionIndexFor(usersColName, client.IndexDescription{ Fields: []client.IndexedFieldDescription{ - {Name: usersAgeFieldName, Direction: client.Ascending}, + {Name: usersAgeFieldName}, }, }) assert.NoError(f.t, err) - assert.Equal(t, 2, f.countIndexPrefixes(usersColName, "")) + assert.Equal(t, 2, f.countIndexPrefixes("")) err = f.users.(*collection).dropAllIndexes(f.ctx, f.txn) assert.NoError(t, err) - assert.Equal(t, 0, f.countIndexPrefixes(usersColName, "")) + assert.Equal(t, 0, f.countIndexPrefixes("")) } func TestDropAllIndexes_IfStorageFails_ReturnError(t *testing.T) { diff --git a/db/indexed_docs_test.go b/db/indexed_docs_test.go index 1d8a1ce803..d10ad8eb5b 100644 --- a/db/indexed_docs_test.go +++ b/db/indexed_docs_test.go @@ -77,12 +77,12 @@ func (f *indexTestFixture) newProdDoc(id int, price float64, cat string, col cli // The format of the non-unique index key is: "////" // Example: "/5/1/12/bae-61cd6879-63ca-5ca9-8731-470a3c1dac69" type indexKeyBuilder struct { - f *indexTestFixture - colName string - fieldsNames []string - doc *client.Document - values [][]byte - isUnique bool + f *indexTestFixture + colName string + fieldsNames []string + descendingFields []bool + doc *client.Document + isUnique bool } func newIndexKeyBuilder(f *indexTestFixture) *indexKeyBuilder { @@ -102,6 +102,12 @@ func (b *indexKeyBuilder) Fields(fieldsNames ...string) *indexKeyBuilder { return b } +// Fields sets the fields names for the index key. +func (b *indexKeyBuilder) DescendingFields(descending ...bool) *indexKeyBuilder { + b.descendingFields = descending + return b +} + // Doc sets the document for the index key. // For non-unique index keys, it will try to find the field value in the document // corresponding to the field name set in the builder. @@ -111,13 +117,6 @@ func (b *indexKeyBuilder) Doc(doc *client.Document) *indexKeyBuilder { return b } -// Values sets the values for the index key. -// It will override the field values stored in the document. -func (b *indexKeyBuilder) Values(values ...[]byte) *indexKeyBuilder { - b.values = values - return b -} - func (b *indexKeyBuilder) Unique() *indexKeyBuilder { b.isUnique = true return b @@ -164,41 +163,30 @@ indexLoop: } if b.doc != nil { - // This CBOR-specific value will be gone soon once we implement - // our own encryption package - const cborNil = 0xf6 hasNilValue := false for i, fieldName := range b.fieldsNames { - var fieldBytesVal []byte - var fieldValue *client.FieldValue - var err error - if len(b.values) <= i { - fieldValue, err = b.doc.GetValue(fieldName) - if err != nil { - if errors.Is(err, client.ErrFieldNotExist) { - fieldValue = client.NewFieldValue(client.LWW_REGISTER, nil) - } else { - require.NoError(b.f.t, err) - } - } else if fieldValue != nil && fieldValue.Value() == nil { - fieldValue = client.NewFieldValue(client.LWW_REGISTER, nil) + fieldValue, err := b.doc.GetValue(fieldName) + var val any + if err != nil { + if !errors.Is(err, client.ErrFieldNotExist) { + require.NoError(b.f.t, err) } - } else { - fieldValue = client.NewFieldValue(client.LWW_REGISTER, b.values[i]) + } else if fieldValue != nil { + val = fieldValue.Value() } - fieldBytesVal, err = fieldValue.Bytes() - require.NoError(b.f.t, err) - if len(fieldBytesVal) == 1 && fieldBytesVal[0] == cborNil { + if val == nil { hasNilValue = true } - key.FieldValues = append(key.FieldValues, fieldBytesVal) + descending := false + if i < len(b.descendingFields) { + descending = b.descendingFields[i] + } + key.Fields = append(key.Fields, core.IndexedField{Value: val, Descending: descending}) } if !b.isUnique || hasNilValue { - key.FieldValues = append(key.FieldValues, []byte(b.doc.ID().String())) + key.Fields = append(key.Fields, core.IndexedField{Value: b.doc.ID().String()}) } - } else if len(b.values) > 0 { - key.FieldValues = b.values } return key @@ -288,6 +276,25 @@ func TestNonUnique_IfDocIsAdded_ShouldBeIndexed(t *testing.T) { assert.Len(t, data, 0) } +func TestNonUnique_IfDocWithDescendingOrderIsAdded_ShouldBeIndexed(t *testing.T) { + f := newIndexTestFixture(t) + defer f.db.Close() + + indexDesc := getUsersIndexDescOnName() + indexDesc.Fields[0].Descending = true + _, err := f.createCollectionIndexFor(f.users.Name().Value(), indexDesc) + require.NoError(f.t, err) + + doc := f.newUserDoc("John", 21, f.users) + f.saveDocToCollection(doc, f.users) + + key := newIndexKeyBuilder(f).Col(usersColName).Fields(usersNameFieldName).DescendingFields(true).Doc(doc).Build() + + data, err := f.txn.Datastore().Get(f.ctx, key.ToDS()) + require.NoError(t, err) + assert.Len(t, data, 0) +} + func TestNonUnique_IfFailsToStoredIndexedDoc_Error(t *testing.T) { f := newIndexTestFixture(t) defer f.db.Close() @@ -531,8 +538,7 @@ func TestNonUnique_IfIndexedFieldIsNil_StoreItAsNil(t *testing.T) { f.saveDocToCollection(doc, f.users) - key := newIndexKeyBuilder(f).Col(usersColName).Fields(usersNameFieldName).Doc(doc). - Values([]byte(nil)).Build() + key := newIndexKeyBuilder(f).Col(usersColName).Fields(usersNameFieldName).Doc(doc).Build() data, err := f.txn.Datastore().Get(f.ctx, key.ToDS()) require.NoError(t, err) @@ -982,8 +988,7 @@ func TestNonUpdate_IfIndexedFieldWasNil_ShouldDeleteIt(t *testing.T) { f.saveDocToCollection(doc, f.users) - oldKey := newIndexKeyBuilder(f).Col(usersColName).Fields(usersNameFieldName).Doc(doc). - Values([]byte(nil)).Build() + oldKey := newIndexKeyBuilder(f).Col(usersColName).Fields(usersNameFieldName).Doc(doc).Build() err = doc.Set(usersNameFieldName, "John") require.NoError(f.t, err) @@ -1069,8 +1074,7 @@ func TestUnique_IfIndexedFieldIsNil_StoreItAsNil(t *testing.T) { f.saveDocToCollection(doc, f.users) - key := newIndexKeyBuilder(f).Col(usersColName).Fields(usersNameFieldName).Unique().Doc(doc). - Values([]byte(nil)).Build() + key := newIndexKeyBuilder(f).Col(usersColName).Fields(usersNameFieldName).Unique().Doc(doc).Build() data, err := f.txn.Datastore().Get(f.ctx, key.ToDS()) require.NoError(t, err) diff --git a/docs/data_format_changes/i2229-order-direction-for-indexed-fields.md b/docs/data_format_changes/i2229-order-direction-for-indexed-fields.md new file mode 100644 index 0000000000..b1260e77e5 --- /dev/null +++ b/docs/data_format_changes/i2229-order-direction-for-indexed-fields.md @@ -0,0 +1,3 @@ +# Order directions for indexed fields + +Secondary indexes are now using entirely different way of encoding fields. \ No newline at end of file diff --git a/encoding/bytes.go b/encoding/bytes.go new file mode 100644 index 0000000000..ac390f1bd3 --- /dev/null +++ b/encoding/bytes.go @@ -0,0 +1,154 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +package encoding + +import ( + "bytes" +) + +const ( + // All terminators are encoded as \x00\x01 sequence. + // In order to distinguish \x00 byte it is escaped as \x00\xff + escape byte = 0x00 + escapedTerm byte = 0x01 + escaped00 byte = 0xff + escapedFF byte = 0x00 + escapeDesc byte = ^escape + escapedTermDesc byte = ^escapedTerm + escaped00Desc byte = ^escaped00 + escapedFFDesc byte = ^escapedFF +) + +type escapes struct { + escape byte + escapedTerm byte + escaped00 byte + escapedFF byte + marker byte +} + +var ( + ascendingBytesEscapes = escapes{escape, escapedTerm, escaped00, escapedFF, bytesMarker} + descendingBytesEscapes = escapes{escapeDesc, escapedTermDesc, escaped00Desc, escapedFFDesc, bytesDescMarker} +) + +// EncodeBytesAscending encodes the []byte value using an escape-based +// encoding. The encoded value is terminated with the sequence +// "\x00\x01" which is guaranteed to not occur elsewhere in the +// encoded value. The encoded bytes are append to the supplied buffer +// and the resulting buffer is returned. +func EncodeBytesAscending(b []byte, data []byte) []byte { + return encodeBytesAscendingWithTerminatorAndPrefix(b, data, ascendingBytesEscapes.escapedTerm, bytesMarker) +} + +// encodeBytesAscendingWithTerminatorAndPrefix encodes the []byte value using an escape-based +// encoding. The encoded value is terminated with the sequence +// "\x00\terminator". The encoded bytes are append to the supplied buffer +// and the resulting buffer is returned. The terminator allows us to pass +// different terminators for things such as JSON key encoding. +func encodeBytesAscendingWithTerminatorAndPrefix( + b []byte, data []byte, terminator byte, prefix byte, +) []byte { + b = append(b, prefix) + return encodeBytesAscendingWithTerminator(b, data, terminator) +} + +// encodeBytesAscendingWithTerminator encodes the []byte value using an escape-based +// encoding. The encoded value is terminated with the sequence +// "\x00\terminator". The encoded bytes are append to the supplied buffer +// and the resulting buffer is returned. The terminator allows us to pass +// different terminators for things such as JSON key encoding. +func encodeBytesAscendingWithTerminator(b []byte, data []byte, terminator byte) []byte { + bs := encodeBytesAscendingWithoutTerminatorOrPrefix(b, data) + return append(bs, escape, terminator) +} + +// encodeBytesAscendingWithoutTerminatorOrPrefix encodes the []byte value using an escape-based +// encoding. +func encodeBytesAscendingWithoutTerminatorOrPrefix(b []byte, data []byte) []byte { + for { + // IndexByte is implemented by the go runtime in assembly and is + // much faster than looping over the bytes in the slice. + i := bytes.IndexByte(data, escape) + if i == -1 { + break + } + b = append(b, data[:i]...) + b = append(b, escape, escaped00) + data = data[i+1:] + } + return append(b, data...) +} + +// EncodeBytesDescending encodes the []byte value using an +// escape-based encoding and then inverts (ones complement) the result +// so that it sorts in reverse order, from larger to smaller +// lexicographically. +func EncodeBytesDescending(b []byte, data []byte) []byte { + n := len(b) + b = EncodeBytesAscending(b, data) + b[n] = bytesDescMarker + onesComplement(b[n+1:]) + return b +} + +// DecodeBytesAscending decodes a []byte value from the input buffer +// which was encoded using EncodeBytesAscending. The decoded bytes +// are appended to r. The remainder of the input buffer and the +// decoded []byte are returned. +func DecodeBytesAscending(b []byte) ([]byte, []byte, error) { + return decodeBytesInternal(b, ascendingBytesEscapes, true /* expectMarker */) +} + +// DecodeBytesDescending decodes a []byte value from the input buffer +// which was encoded using EncodeBytesDescending. The decoded bytes +// are appended to r. The remainder of the input buffer and the +// decoded []byte are returned. +func DecodeBytesDescending(b []byte) ([]byte, []byte, error) { + b, r, err := decodeBytesInternal(b, descendingBytesEscapes, true /* expectMarker */) + onesComplement(r) + return b, r, err +} + +func decodeBytesInternal(b []byte, e escapes, expectMarker bool) ([]byte, []byte, error) { + if expectMarker { + if len(b) == 0 || b[0] != e.marker { + return nil, nil, NewErrMarkersNotFound(b, e.marker) + } + b = b[1:] + } + + var r []byte + for { + i := bytes.IndexByte(b, e.escape) + if i == -1 { + return nil, nil, NewErrTerminatorNotFound(b, e.escape) + } + if i+1 >= len(b) { + return nil, nil, NewErrMalformedEscape(b) + } + v := b[i+1] + if v == e.escapedTerm { + r = append(r, b[:i]...) + return b[i+2:], r, nil + } + + if v != e.escaped00 { + return nil, nil, NewErrUnknownEscapeSequence(b[i:i+2], e.escape) + } + + r = append(r, b[:i]...) + r = append(r, e.escapedFF) + b = b[i+2:] + } +} diff --git a/encoding/bytes_test.go b/encoding/bytes_test.go new file mode 100644 index 0000000000..ba29239530 --- /dev/null +++ b/encoding/bytes_test.go @@ -0,0 +1,205 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +package encoding + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEncodeDecodeBytes(t *testing.T) { + testCases := []struct { + value []byte + encoded []byte + }{ + {[]byte{0, 1, 'a'}, []byte{bytesMarker, 0x00, escaped00, 1, 'a', escape, escapedTerm}}, + {[]byte{0, 'a'}, []byte{bytesMarker, 0x00, escaped00, 'a', escape, escapedTerm}}, + {[]byte{0, 0xff, 'a'}, []byte{bytesMarker, 0x00, escaped00, 0xff, 'a', escape, escapedTerm}}, + {[]byte{'a'}, []byte{bytesMarker, 'a', escape, escapedTerm}}, + {[]byte{'b'}, []byte{bytesMarker, 'b', escape, escapedTerm}}, + {[]byte{'b', 0}, []byte{bytesMarker, 'b', 0x00, escaped00, escape, escapedTerm}}, + {[]byte{'b', 0, 0}, []byte{bytesMarker, 'b', 0x00, escaped00, 0x00, escaped00, escape, escapedTerm}}, + {[]byte{'b', 0, 0, 'a'}, []byte{bytesMarker, 'b', 0x00, escaped00, 0x00, escaped00, 'a', escape, escapedTerm}}, + {[]byte{'b', 0xff}, []byte{bytesMarker, 'b', 0xff, escape, escapedTerm}}, + {[]byte("hello"), []byte{bytesMarker, 'h', 'e', 'l', 'l', 'o', escape, escapedTerm}}, + } + for i, c := range testCases { + enc := EncodeBytesAscending(nil, c.value) + if !bytes.Equal(enc, c.encoded) { + t.Errorf("unexpected encoding mismatch for %v. expected [% x], got [% x]", + c.value, c.encoded, enc) + } + if i > 0 { + if bytes.Compare(testCases[i-1].encoded, enc) >= 0 { + t.Errorf("%v: expected [% x] to be less than [% x]", + c.value, testCases[i-1].encoded, enc) + } + } + remainder, dec, err := DecodeBytesAscending(enc) + if err != nil { + t.Error(err) + continue + } + if !bytes.Equal(c.value, dec) { + t.Errorf("unexpected decoding mismatch for %v. got %v", c.value, dec) + } + if len(remainder) != 0 { + t.Errorf("unexpected remaining bytes: %v", remainder) + } + + enc = append(enc, []byte("remainder")...) + remainder, _, err = DecodeBytesAscending(enc) + if err != nil { + t.Error(err) + continue + } + if string(remainder) != "remainder" { + t.Errorf("unexpected remaining bytes: %v", remainder) + } + } +} + +func TestEncodeDecodeBytesDescending(t *testing.T) { + testCases := []struct { + value []byte + encoded []byte + }{ + {[]byte("hello"), []byte{bytesDescMarker, ^byte('h'), ^byte('e'), ^byte('l'), ^byte('l'), ^byte('o'), escapeDesc, escapedTermDesc}}, + {[]byte{'b', 0xff}, []byte{bytesDescMarker, ^byte('b'), 0x00, escapeDesc, escapedTermDesc}}, + {[]byte{'b', 0, 0, 'a'}, []byte{bytesDescMarker, ^byte('b'), 0xff, escaped00Desc, 0xff, escaped00Desc, ^byte('a'), escapeDesc, escapedTermDesc}}, + {[]byte{'b', 0, 0}, []byte{bytesDescMarker, ^byte('b'), 0xff, escaped00Desc, 0xff, escaped00Desc, escapeDesc, escapedTermDesc}}, + {[]byte{'b', 0}, []byte{bytesDescMarker, ^byte('b'), 0xff, escaped00Desc, escapeDesc, escapedTermDesc}}, + {[]byte{'b'}, []byte{bytesDescMarker, ^byte('b'), escapeDesc, escapedTermDesc}}, + {[]byte{'a'}, []byte{bytesDescMarker, ^byte('a'), escapeDesc, escapedTermDesc}}, + {[]byte{0, 0xff, 'a'}, []byte{bytesDescMarker, 0xff, escaped00Desc, 0x00, ^byte('a'), escapeDesc, escapedTermDesc}}, + {[]byte{0, 'a'}, []byte{bytesDescMarker, 0xff, escaped00Desc, ^byte('a'), escapeDesc, escapedTermDesc}}, + {[]byte{0, 1, 'a'}, []byte{bytesDescMarker, 0xff, escaped00Desc, ^byte(1), ^byte('a'), escapeDesc, escapedTermDesc}}, + } + for i, c := range testCases { + enc := EncodeBytesDescending(nil, c.value) + if !bytes.Equal(enc, c.encoded) { + t.Errorf("%d: unexpected encoding mismatch for %v ([% x]). expected [% x], got [% x]", + i, c.value, c.value, c.encoded, enc) + } + if i > 0 { + if bytes.Compare(testCases[i-1].encoded, enc) >= 0 { + t.Errorf("%v: expected [% x] to be less than [% x]", + c.value, testCases[i-1].encoded, enc) + } + } + remainder, dec, err := DecodeBytesDescending(enc) + if err != nil { + t.Error(err) + continue + } + if !bytes.Equal(c.value, dec) { + t.Errorf("unexpected decoding mismatch for %v. got %v", c.value, dec) + } + if len(remainder) != 0 { + t.Errorf("unexpected remaining bytes: %v", remainder) + } + + enc = append(enc, []byte("remainder")...) + remainder, _, err = DecodeBytesDescending(enc) + if err != nil { + t.Error(err) + continue + } + if string(remainder) != "remainder" { + t.Errorf("unexpected remaining bytes: %v", remainder) + } + } +} + +// TestDecodeInvalid tests that decoding invalid bytes panics. +func TestDecodeInvalid(t *testing.T) { + tests := []struct { + name string // name printed with errors. + buf []byte // buf contains an invalid uvarint to decode. + expectedErr error // expectedErr is the expected error. + decode func([]byte) error // decode is called with buf. + }{ + { + name: "DecodeVarint, overflows int64", + buf: []byte{IntMax, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + expectedErr: ErrVarintOverflow, + decode: func(b []byte) error { _, _, err := DecodeVarintAscending(b); return err }, + }, + { + name: "Bytes, no marker", + buf: []byte{'a'}, + expectedErr: ErrMarkersNotFound, + decode: func(b []byte) error { _, _, err := DecodeBytesAscending(b); return err }, + }, + { + name: "Bytes, no terminator", + buf: []byte{bytesMarker, 'a'}, + expectedErr: ErrTerminatorNotFound, + decode: func(b []byte) error { _, _, err := DecodeBytesAscending(b); return err }, + }, + { + name: "Bytes, malformed escape", + buf: []byte{bytesMarker, 'a', 0x00}, + expectedErr: ErrMalformedEscape, + decode: func(b []byte) error { _, _, err := DecodeBytesAscending(b); return err }, + }, + { + name: "Bytes, invalid escape 1", + buf: []byte{bytesMarker, 'a', 0x00, 0x00}, + expectedErr: ErrUnknownEscapeSequence, + decode: func(b []byte) error { _, _, err := DecodeBytesAscending(b); return err }, + }, + { + name: "Bytes, invalid escape 2", + buf: []byte{bytesMarker, 'a', 0x00, 0x02}, + expectedErr: ErrUnknownEscapeSequence, + decode: func(b []byte) error { _, _, err := DecodeBytesAscending(b); return err }, + }, + { + name: "BytesDescending, no marker", + buf: []byte{'a'}, + expectedErr: ErrMarkersNotFound, + decode: func(b []byte) error { _, _, err := DecodeBytesAscending(b); return err }, + }, + { + name: "BytesDescending, no terminator", + buf: []byte{bytesDescMarker, ^byte('a')}, + expectedErr: ErrTerminatorNotFound, + decode: func(b []byte) error { _, _, err := DecodeBytesDescending(b); return err }, + }, + { + name: "BytesDescending, malformed escape", + buf: []byte{bytesDescMarker, ^byte('a'), 0xff}, + expectedErr: ErrMalformedEscape, + decode: func(b []byte) error { _, _, err := DecodeBytesDescending(b); return err }, + }, + { + name: "BytesDescending, invalid escape 1", + buf: []byte{bytesDescMarker, ^byte('a'), 0xff, 0xff}, + expectedErr: ErrUnknownEscapeSequence, + decode: func(b []byte) error { _, _, err := DecodeBytesDescending(b); return err }, + }, + { + name: "BytesDescending, invalid escape 2", + buf: []byte{bytesDescMarker, ^byte('a'), 0xff, 0xfd}, + expectedErr: ErrUnknownEscapeSequence, + decode: func(b []byte) error { _, _, err := DecodeBytesDescending(b); return err }, + }, + } + for _, test := range tests { + err := test.decode(test.buf) + assert.ErrorIs(t, err, test.expectedErr) + } +} diff --git a/encoding/encoding.go b/encoding/encoding.go new file mode 100644 index 0000000000..164706d922 --- /dev/null +++ b/encoding/encoding.go @@ -0,0 +1,55 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +// Portions of this code are adapted from basic functionality found in the CockroachDB project, +// specifically within the encoding package at: +// https://github.com/cockroachdb/cockroach/tree/v20.2.19/pkg/util/encoding +// +// Our use of this code is in compliance with the Apache License 2.0, under which it is shared. + +package encoding + +const ( + encodedNull = iota + floatNaN + floatNeg + floatZero + floatPos + floatNaNDesc + bytesMarker + bytesDescMarker + + // These constants define a range of values and are used to determine how many bytes are + // needed to represent the given uint64 value. The constants IntMin and IntMax define the + // lower and upper bounds of the range, while intMaxWidth is the maximum width (in bytes) + // for encoding an integer. intZero is the starting point for encoding small integers, + // and intSmall represents the threshold below which a value can be encoded in a single byte. + + // IntMin is set to 0x80 (128) to avoid overlap with the ASCII range, enhancing testing clarity. + IntMin = 0x80 // 128 + // Maximum number of bytes to represent an integer, affecting encoding size. + intMaxWidth = 8 + // intZero is the base value for encoding non-negative integers, calculated to avoid ASCII conflicts. + intZero = IntMin + intMaxWidth // 136 + // intSmall defines the upper limit for integers that can be encoded in a single byte, considering offset. + intSmall = IntMax - intZero - intMaxWidth // 109 + // IntMax marks the upper bound for integer tag values, reserved for encoding use. + IntMax = 0xfd // 253 + + encodedNullDesc = 0xff +) + +func onesComplement(b []byte) { + for i := range b { + b[i] = ^b[i] + } +} diff --git a/encoding/errors.go b/encoding/errors.go new file mode 100644 index 0000000000..38b4671633 --- /dev/null +++ b/encoding/errors.go @@ -0,0 +1,91 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package encoding + +import ( + "github.com/sourcenetwork/defradb/client" + "github.com/sourcenetwork/defradb/errors" +) + +const ( + errInsufficientBytesToDecode = "insufficient bytes to decode buffer into a target type" + errCanNotDecodeFieldValue = "can not decode field value" + errMarkersNotFound = "did not find any of required markers in buffer" + errTerminatorNotFound = "did not find required terminator in buffer" + errMalformedEscape = "malformed escape in buffer" + errUnknownEscapeSequence = "unknown escape sequence" + errInvalidUvarintLength = "invalid length for uvarint" + errVarintOverflow = "varint overflows a 64-bit integer" +) + +var ( + ErrInsufficientBytesToDecode = errors.New(errInsufficientBytesToDecode) + ErrCanNotDecodeFieldValue = errors.New(errCanNotDecodeFieldValue) + ErrMarkersNotFound = errors.New(errMarkersNotFound) + ErrTerminatorNotFound = errors.New(errTerminatorNotFound) + ErrMalformedEscape = errors.New(errMalformedEscape) + ErrUnknownEscapeSequence = errors.New(errUnknownEscapeSequence) + ErrInvalidUvarintLength = errors.New(errInvalidUvarintLength) + ErrVarintOverflow = errors.New(errVarintOverflow) +) + +// NewErrInsufficientBytesToDecode returns a new error indicating that the provided +// bytes are not sufficient to decode into a target type. +func NewErrInsufficientBytesToDecode(b []byte, decodeTarget string) error { + return errors.New(errInsufficientBytesToDecode, + errors.NewKV("Buffer", b), errors.NewKV("Decode Target", decodeTarget)) +} + +// NewErrCanNotDecodeFieldValue returns a new error indicating that the encoded +// bytes could not be decoded into a client.FieldValue of a certain kind. +func NewErrCanNotDecodeFieldValue(b []byte, kind client.FieldKind, innerErr ...error) error { + kvs := []errors.KV{errors.NewKV("Buffer", b), errors.NewKV("Kind", kind)} + if len(innerErr) > 0 { + kvs = append(kvs, errors.NewKV("InnerErr", innerErr[0])) + } + return errors.New(errCanNotDecodeFieldValue, kvs...) +} + +// NewErrMarkersNotFound returns a new error indicating that the required +// marker was not found in the buffer. +func NewErrMarkersNotFound(b []byte, markers ...byte) error { + return errors.New(errMarkersNotFound, errors.NewKV("Markers", markers), errors.NewKV("Buffer", b)) +} + +// NewErrTerminatorNotFound returns a new error indicating that the required +// terminator was not found in the buffer. +func NewErrTerminatorNotFound(b []byte, terminator byte) error { + return errors.New(errTerminatorNotFound, errors.NewKV("Terminator", terminator), errors.NewKV("Buffer", b)) +} + +// NewErrMalformedEscape returns a new error indicating that the buffer +// contains a malformed escape sequence. +func NewErrMalformedEscape(b []byte) error { + return errors.New(errMalformedEscape, errors.NewKV("Buffer", b)) +} + +// NewErrUnknownEscapeSequence returns a new error indicating that the buffer +// contains an unknown escape sequence. +func NewErrUnknownEscapeSequence(b []byte, escape byte) error { + return errors.New(errUnknownEscapeSequence, errors.NewKV("Escape", escape), errors.NewKV("Buffer", b)) +} + +// NewErrInvalidUvarintLength returns a new error indicating that the buffer +// contains an invalid length for a uvarint. +func NewErrInvalidUvarintLength(b []byte, length int) error { + return errors.New(errInvalidUvarintLength, errors.NewKV("Buffer", b), errors.NewKV("Length", length)) +} + +// NewErrVarintOverflow returns a new error indicating that the buffer +// contains a varint that overflows a 64-bit integer. +func NewErrVarintOverflow(b []byte, value uint64) error { + return errors.New(errVarintOverflow, errors.NewKV("Buffer", b), errors.NewKV("Value", value)) +} diff --git a/encoding/field_value.go b/encoding/field_value.go new file mode 100644 index 0000000000..9c8cd5589f --- /dev/null +++ b/encoding/field_value.go @@ -0,0 +1,114 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package encoding + +import ( + "golang.org/x/exp/constraints" + + "github.com/sourcenetwork/defradb/client" +) + +func encodeIntFieldValue[T constraints.Integer](b []byte, val T, descending bool) []byte { + if descending { + return EncodeVarintDescending(b, int64(val)) + } + return EncodeVarintAscending(b, int64(val)) +} + +// EncodeFieldValue encodes a FieldValue into a byte slice. +// The encoded value is appended to the supplied buffer and the resulting buffer is returned. +func EncodeFieldValue(b []byte, val any, descending bool) []byte { + if val == nil { + if descending { + return EncodeNullDescending(b) + } else { + return EncodeNullAscending(b) + } + } + switch v := val.(type) { + case bool: + var boolInt int64 = 0 + if v { + boolInt = 1 + } + if descending { + return EncodeVarintDescending(b, boolInt) + } + return EncodeVarintAscending(b, boolInt) + case int: + return encodeIntFieldValue(b, v, descending) + case int32: + return encodeIntFieldValue(b, v, descending) + case int64: + return encodeIntFieldValue(b, v, descending) + case float64: + if descending { + return EncodeFloatDescending(b, v) + } + return EncodeFloatAscending(b, v) + case string: + if descending { + return EncodeStringDescending(b, v) + } + return EncodeStringAscending(b, v) + } + + return b +} + +// DecodeFieldValue decodes a FieldValue from a byte slice. +// The decoded value is returned along with the remaining byte slice. +func DecodeFieldValue(b []byte, descending bool) ([]byte, any, error) { + typ := PeekType(b) + switch typ { + case Null: + b, _ = DecodeIfNull(b) + return b, nil, nil + case Int: + var v int64 + var err error + if descending { + b, v, err = DecodeVarintDescending(b) + } else { + b, v, err = DecodeVarintAscending(b) + } + if err != nil { + return nil, nil, NewErrCanNotDecodeFieldValue(b, client.FieldKind_NILLABLE_INT, err) + } + return b, v, nil + case Float: + var v float64 + var err error + if descending { + b, v, err = DecodeFloatDescending(b) + } else { + b, v, err = DecodeFloatAscending(b) + } + if err != nil { + return nil, nil, NewErrCanNotDecodeFieldValue(b, client.FieldKind_NILLABLE_FLOAT, err) + } + return b, v, nil + case Bytes, BytesDesc: + var v []byte + var err error + if descending { + b, v, err = DecodeBytesDescending(b) + } else { + b, v, err = DecodeBytesAscending(b) + } + if err != nil { + return nil, nil, NewErrCanNotDecodeFieldValue(b, client.FieldKind_NILLABLE_STRING, err) + } + return b, v, nil + } + + return nil, nil, NewErrCanNotDecodeFieldValue(b, client.FieldKind_NILLABLE_STRING) +} diff --git a/encoding/field_value_test.go b/encoding/field_value_test.go new file mode 100644 index 0000000000..a08446cb1f --- /dev/null +++ b/encoding/field_value_test.go @@ -0,0 +1,142 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package encoding + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEncodeDecodeFieldValue(t *testing.T) { + tests := []struct { + name string + inputVal any + expectedBytes []byte + expectedBytesDesc []byte + expectedDecodedVal any + }{ + { + name: "nil", + inputVal: nil, + expectedBytes: EncodeNullAscending(nil), + expectedBytesDesc: EncodeNullDescending(nil), + expectedDecodedVal: nil, + }, + { + name: "bool true", + inputVal: true, + expectedBytes: EncodeVarintAscending(nil, 1), + expectedBytesDesc: EncodeVarintDescending(nil, 1), + expectedDecodedVal: int64(1), + }, + { + name: "bool false", + inputVal: false, + expectedBytes: EncodeVarintAscending(nil, 0), + expectedBytesDesc: EncodeVarintDescending(nil, 0), + expectedDecodedVal: int64(0), + }, + { + name: "int", + inputVal: int64(55), + expectedBytes: EncodeVarintAscending(nil, 55), + expectedBytesDesc: EncodeVarintDescending(nil, 55), + expectedDecodedVal: int64(55), + }, + { + name: "float", + inputVal: 0.2, + expectedBytes: EncodeFloatAscending(nil, 0.2), + expectedBytesDesc: EncodeFloatDescending(nil, 0.2), + expectedDecodedVal: 0.2, + }, + { + name: "string", + inputVal: "str", + expectedBytes: EncodeBytesAscending(nil, []byte("str")), + expectedBytesDesc: EncodeBytesDescending(nil, []byte("str")), + expectedDecodedVal: []byte("str"), + }, + } + + for _, tt := range tests { + for _, descending := range []bool{false, true} { + label := " (ascending)" + if descending { + label = " (descending)" + } + t.Run(tt.name+label, func(t *testing.T) { + encoded := EncodeFieldValue(nil, tt.inputVal, descending) + expectedBytes := tt.expectedBytes + if descending { + expectedBytes = tt.expectedBytesDesc + } + if !reflect.DeepEqual(encoded, expectedBytes) { + t.Errorf("EncodeFieldValue() = %v, want %v", encoded, expectedBytes) + } + + _, decodedFieldVal, err := DecodeFieldValue(encoded, descending) + assert.NoError(t, err) + if !reflect.DeepEqual(decodedFieldVal, tt.expectedDecodedVal) { + t.Errorf("DecodeFieldValue() = %v, want %v", decodedFieldVal, tt.expectedDecodedVal) + } + }) + } + } +} + +func TestDecodeInvalidFieldValue(t *testing.T) { + tests := []struct { + name string + inputBytes []byte + inputBytesDesc []byte + }{ + { + name: "invalid int value", + inputBytes: []byte{IntMax, 2}, + inputBytesDesc: []byte{^byte(IntMax), 2}, + }, + { + name: "invalid float value", + inputBytes: []byte{floatPos, 2}, + inputBytesDesc: []byte{floatPos, 2}, + }, + { + name: "invalid bytes value", + inputBytes: []byte{bytesMarker, 2}, + inputBytesDesc: []byte{bytesMarker, 2}, + }, + { + name: "invalid data", + inputBytes: []byte{IntMin - 1, 2}, + inputBytesDesc: []byte{^byte(IntMin - 1), 2}, + }, + } + + for _, tt := range tests { + for _, descending := range []bool{false, true} { + label := " (ascending)" + if descending { + label = " (descending)" + } + t.Run(tt.name+label, func(t *testing.T) { + inputBytes := tt.inputBytes + if descending { + inputBytes = tt.inputBytesDesc + } + _, _, err := DecodeFieldValue(inputBytes, descending) + assert.ErrorIs(t, err, ErrCanNotDecodeFieldValue) + }) + } + } +} diff --git a/encoding/float.go b/encoding/float.go new file mode 100644 index 0000000000..322ea9f9b8 --- /dev/null +++ b/encoding/float.go @@ -0,0 +1,97 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +package encoding + +import ( + "math" +) + +// EncodeFloatAscending returns the resulting byte slice with the encoded float64 +// appended to b. The encoded format for a float64 value f is, for positive f, the +// encoding of the 64 bits (in IEEE 754 format) re-interpreted as an int64 and +// encoded using EncodeUint64Ascending. For negative f, we keep the sign bit and +// invert all other bits, encoding this value using EncodeUint64Descending. This +// approach was inspired by in github.com/google/orderedcode/orderedcode.go. +// +// One of five single-byte prefix tags are appended to the front of the encoding. +// These tags enforce logical ordering of keys for both ascending and descending +// encoding directions. The tags split the encoded floats into five categories: +// - NaN for an ascending encoding direction +// - Negative valued floats +// - Zero (positive and negative) +// - Positive valued floats +// - NaN for a descending encoding direction +// This ordering ensures that NaNs are always sorted first in either encoding +// direction, and that after them a logical ordering is followed. +func EncodeFloatAscending(b []byte, f float64) []byte { + // Handle the simplistic cases first. + switch { + case math.IsNaN(f): + return append(b, floatNaN) + case f == 0: + // This encodes both positive and negative zero the same. Negative zero uses + // composite indexes to decode itself correctly. + return append(b, floatZero) + } + u := math.Float64bits(f) + if u&(1<<63) != 0 { + u = ^u + b = append(b, floatNeg) + } else { + b = append(b, floatPos) + } + return EncodeUint64Ascending(b, u) +} + +// EncodeFloatDescending is the descending version of EncodeFloatAscending. +func EncodeFloatDescending(b []byte, f float64) []byte { + if math.IsNaN(f) { + return append(b, floatNaNDesc) + } + return EncodeFloatAscending(b, -f) +} + +// DecodeFloatAscending returns the remaining byte slice after decoding and the decoded +// float64 from buf. +func DecodeFloatAscending(buf []byte) ([]byte, float64, error) { + if PeekType(buf) != Float { + return buf, 0, NewErrMarkersNotFound(buf, floatNaN, floatNeg, floatZero, floatPos, floatNaNDesc) + } + switch buf[0] { + case floatNaN, floatNaNDesc: + return buf[1:], math.NaN(), nil + case floatNeg: + b, u, err := DecodeUint64Ascending(buf[1:]) + if err != nil { + return b, 0, err + } + u = ^u + return b, math.Float64frombits(u), nil + case floatZero: + return buf[1:], 0, nil + case floatPos: + b, u, err := DecodeUint64Ascending(buf[1:]) + if err != nil { + return b, 0, err + } + return b, math.Float64frombits(u), nil + default: + return nil, 0, NewErrMarkersNotFound(buf, floatNaN, floatNeg, floatZero, floatPos, floatNaNDesc) + } +} + +// DecodeFloatDescending decodes floats encoded with EncodeFloatDescending. +func DecodeFloatDescending(buf []byte) ([]byte, float64, error) { + b, r, err := DecodeFloatAscending(buf) + return b, -r, err +} diff --git a/encoding/float_test.go b/encoding/float_test.go new file mode 100644 index 0000000000..6fc610db24 --- /dev/null +++ b/encoding/float_test.go @@ -0,0 +1,127 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +package encoding + +import ( + "bytes" + "math" + "testing" +) + +func TestEncodeFloatOrdered(t *testing.T) { + testCases := []struct { + Value float64 + Encoding []byte + }{ + {math.NaN(), []byte{floatNaN}}, + {math.Inf(-1), []byte{floatNeg, 0x00, 0x0f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + {-math.MaxFloat64, []byte{floatNeg, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, + {-1e308, []byte{floatNeg, 0x00, 0x1e, 0x33, 0x0c, 0x7a, 0x14, 0x37, 0x5f}}, + {-10000.0, []byte{floatNeg, 0x3f, 0x3c, 0x77, 0xff, 0xff, 0xff, 0xff, 0xff}}, + {-9999.0, []byte{floatNeg, 0x3f, 0x3c, 0x78, 0x7f, 0xff, 0xff, 0xff, 0xff}}, + {-100.0, []byte{floatNeg, 0x3f, 0xa6, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + {-99.0, []byte{floatNeg, 0x3f, 0xa7, 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff}}, + {-1.0, []byte{floatNeg, 0x40, 0x0f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + {-0.00123, []byte{floatNeg, 0x40, 0xab, 0xd9, 0x01, 0x8e, 0x75, 0x79, 0x28}}, + {-1e-307, []byte{floatNeg, 0x7f, 0xce, 0x05, 0xe7, 0xd3, 0xbf, 0x39, 0xf2}}, + {-math.SmallestNonzeroFloat64, []byte{floatNeg, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}}, + {math.Copysign(0, -1), []byte{floatZero}}, + {0, []byte{floatZero}}, + {math.SmallestNonzeroFloat64, []byte{floatPos, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}}, + {1e-307, []byte{floatPos, 0x00, 0x31, 0xfa, 0x18, 0x2c, 0x40, 0xc6, 0x0d}}, + {0.00123, []byte{floatPos, 0x3f, 0x54, 0x26, 0xfe, 0x71, 0x8a, 0x86, 0xd7}}, + {0.0123, []byte{floatPos, 0x3f, 0x89, 0x30, 0xbe, 0x0d, 0xed, 0x28, 0x8d}}, + {0.123, []byte{floatPos, 0x3f, 0xbf, 0x7c, 0xed, 0x91, 0x68, 0x72, 0xb0}}, + {1.0, []byte{floatPos, 0x3f, 0xf0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, + {10.0, []byte{floatPos, 0x40, 0x24, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, + {12.345, []byte{floatPos, 0x40, 0x28, 0xb0, 0xa3, 0xd7, 0x0a, 0x3d, 0x71}}, + {99.0, []byte{floatPos, 0x40, 0x58, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00}}, + {99.0001, []byte{floatPos, 0x40, 0x58, 0xc0, 0x01, 0xa3, 0x6e, 0x2e, 0xb2}}, + {99.01, []byte{floatPos, 0x40, 0x58, 0xc0, 0xa3, 0xd7, 0x0a, 0x3d, 0x71}}, + {100.0, []byte{floatPos, 0x40, 0x59, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, + {100.01, []byte{floatPos, 0x40, 0x59, 0x00, 0xa3, 0xd7, 0x0a, 0x3d, 0x71}}, + {100.1, []byte{floatPos, 0x40, 0x59, 0x06, 0x66, 0x66, 0x66, 0x66, 0x66}}, + {1234, []byte{floatPos, 0x40, 0x93, 0x48, 0x00, 0x00, 0x00, 0x00, 0x00}}, + {1234.5, []byte{floatPos, 0x40, 0x93, 0x4a, 0x00, 0x00, 0x00, 0x00, 0x00}}, + {9999, []byte{floatPos, 0x40, 0xc3, 0x87, 0x80, 0x00, 0x00, 0x00, 0x00}}, + {9999.000001, []byte{floatPos, 0x40, 0xc3, 0x87, 0x80, 0x00, 0x08, 0x63, 0x7c}}, + {9999.000009, []byte{floatPos, 0x40, 0xc3, 0x87, 0x80, 0x00, 0x4b, 0x7f, 0x5a}}, + {9999.00001, []byte{floatPos, 0x40, 0xc3, 0x87, 0x80, 0x00, 0x53, 0xe2, 0xd6}}, + {9999.00009, []byte{floatPos, 0x40, 0xc3, 0x87, 0x80, 0x02, 0xf2, 0xf9, 0x87}}, + {9999.000099, []byte{floatPos, 0x40, 0xc3, 0x87, 0x80, 0x03, 0x3e, 0x78, 0xe2}}, + {9999.0001, []byte{floatPos, 0x40, 0xc3, 0x87, 0x80, 0x03, 0x46, 0xdc, 0x5d}}, + {9999.001, []byte{floatPos, 0x40, 0xc3, 0x87, 0x80, 0x20, 0xc4, 0x9b, 0xa6}}, + {9999.01, []byte{floatPos, 0x40, 0xc3, 0x87, 0x81, 0x47, 0xae, 0x14, 0x7b}}, + {9999.1, []byte{floatPos, 0x40, 0xc3, 0x87, 0x8c, 0xcc, 0xcc, 0xcc, 0xcd}}, + {10000, []byte{floatPos, 0x40, 0xc3, 0x88, 0x00, 0x00, 0x00, 0x00, 0x00}}, + {10001, []byte{floatPos, 0x40, 0xc3, 0x88, 0x80, 0x00, 0x00, 0x00, 0x00}}, + {12345, []byte{floatPos, 0x40, 0xc8, 0x1c, 0x80, 0x00, 0x00, 0x00, 0x00}}, + {123450, []byte{floatPos, 0x40, 0xfe, 0x23, 0xa0, 0x00, 0x00, 0x00, 0x00}}, + {1e308, []byte{floatPos, 0x7f, 0xe1, 0xcc, 0xf3, 0x85, 0xeb, 0xc8, 0xa0}}, + {math.MaxFloat64, []byte{floatPos, 0x7f, 0xef, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + {math.Inf(1), []byte{floatPos, 0x7f, 0xf0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, + } + + var lastEncoded []byte + for _, isAscending := range []bool{true, false} { + for i, c := range testCases { + var enc []byte + var err error + var dec float64 + if isAscending { + enc = EncodeFloatAscending(nil, c.Value) + _, dec, err = DecodeFloatAscending(enc) + } else { + enc = EncodeFloatDescending(nil, c.Value) + _, dec, err = DecodeFloatDescending(enc) + } + if isAscending && !bytes.Equal(enc, c.Encoding) { + t.Errorf("unexpected mismatch for %v. expected [% x], got [% x]", + c.Value, c.Encoding, enc) + } + if i > 0 { + if (bytes.Compare(lastEncoded, enc) > 0 && isAscending) || + (bytes.Compare(lastEncoded, enc) < 0 && !isAscending) { + t.Errorf("%v: expected [% x] to be less than or equal to [% x]", + c.Value, testCases[i-1].Encoding, enc) + } + } + if err != nil { + t.Error(err) + continue + } + if math.IsNaN(c.Value) { + if !math.IsNaN(dec) { + t.Errorf("unexpected mismatch for %v. got %v", c.Value, dec) + } + } else if dec != c.Value { + t.Errorf("unexpected mismatch for %v. got %v", c.Value, dec) + } + lastEncoded = enc + } + + // Test that appending the float to an existing buffer works. + var enc []byte + var dec float64 + if isAscending { + enc = EncodeFloatAscending([]byte("hello"), 1.23) + _, dec, _ = DecodeFloatAscending(enc[5:]) + } else { + enc = EncodeFloatDescending([]byte("hello"), 1.23) + _, dec, _ = DecodeFloatDescending(enc[5:]) + } + if dec != 1.23 { + t.Errorf("unexpected mismatch for %v. got %v", 1.23, dec) + } + } +} diff --git a/encoding/int.go b/encoding/int.go new file mode 100644 index 0000000000..733ed94b12 --- /dev/null +++ b/encoding/int.go @@ -0,0 +1,246 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +package encoding + +import ( + "encoding/binary" + "math" +) + +// EncodeUint64Ascending encodes the uint64 value using a big-endian 8 byte +// representation. The bytes are appended to the supplied buffer and +// the final buffer is returned. +func EncodeUint64Ascending(b []byte, v uint64) []byte { + return append(b, + byte(v>>56), byte(v>>48), byte(v>>40), byte(v>>32), + byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) +} + +// EncodeUint64Descending encodes the uint64 value so that it sorts in +// reverse order, from largest to smallest. +func EncodeUint64Descending(b []byte, v uint64) []byte { + return EncodeUint64Ascending(b, ^v) +} + +// DecodeUint64Ascending decodes a uint64 from the input buffer, treating +// the input as a big-endian 8 byte uint64 representation. The remainder +// of the input buffer and the decoded uint64 are returned. +func DecodeUint64Ascending(b []byte) ([]byte, uint64, error) { + if len(b) < 8 { + return nil, 0, NewErrInsufficientBytesToDecode(b, "uint64") + } + v := binary.BigEndian.Uint64(b) + return b[8:], v, nil +} + +// DecodeUint64Descending decodes a uint64 value which was encoded +// using EncodeUint64Descending. +func DecodeUint64Descending(b []byte) ([]byte, uint64, error) { + leftover, v, err := DecodeUint64Ascending(b) + return leftover, ^v, err +} + +// EncodeVarintAscending encodes the int64 value using a variable length +// (length-prefixed) representation. The length is encoded as a single +// byte. If the value to be encoded is negative the length is encoded +// as 8-numBytes. If the value is positive it is encoded as +// 8+numBytes. The encoded bytes are appended to the supplied buffer +// and the final buffer is returned. +func EncodeVarintAscending(b []byte, v int64) []byte { + if v < 0 { + switch { + case v >= -0xff: + return append(b, IntMin+7, byte(v)) + case v >= -0xffff: + return append(b, IntMin+6, byte(v>>8), byte(v)) + case v >= -0xffffff: + return append(b, IntMin+5, byte(v>>16), byte(v>>8), byte(v)) + case v >= -0xffffffff: + return append(b, IntMin+4, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) + case v >= -0xffffffffff: + return append(b, IntMin+3, byte(v>>32), byte(v>>24), byte(v>>16), byte(v>>8), + byte(v)) + case v >= -0xffffffffffff: + return append(b, IntMin+2, byte(v>>40), byte(v>>32), byte(v>>24), byte(v>>16), + byte(v>>8), byte(v)) + case v >= -0xffffffffffffff: + return append(b, IntMin+1, byte(v>>48), byte(v>>40), byte(v>>32), byte(v>>24), + byte(v>>16), byte(v>>8), byte(v)) + default: + return append(b, IntMin, byte(v>>56), byte(v>>48), byte(v>>40), byte(v>>32), + byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) + } + } + return EncodeUvarintAscending(b, uint64(v)) +} + +// EncodeVarintDescending encodes the int64 value so that it sorts in reverse +// order, from largest to smallest. +func EncodeVarintDescending(b []byte, v int64) []byte { + return EncodeVarintAscending(b, ^v) +} + +// DecodeVarintAscending decodes a value encoded by EncodeVarintAscending. +func DecodeVarintAscending(b []byte) ([]byte, int64, error) { + if len(b) == 0 { + return nil, 0, NewErrInsufficientBytesToDecode(b, "varint") + } + length := int(b[0]) - intZero + if length < 0 { + length = -length + remB := b[1:] + if len(remB) < length { + return nil, 0, NewErrInsufficientBytesToDecode(b, "varint") + } + var v int64 + // Use the ones-complement of each encoded byte in order to build + // up a positive number, then take the ones-complement again to + // arrive at our negative value. + for _, t := range remB[:length] { + v = (v << 8) | int64(^t) + } + return remB[length:], ^v, nil + } + + remB, v, err := DecodeUvarintAscending(b) + if err != nil { + return remB, 0, err + } + if v > math.MaxInt64 { + return nil, 0, NewErrVarintOverflow(b, v) + } + return remB, int64(v), nil +} + +// DecodeVarintDescending decodes a int64 value which was encoded +// using EncodeVarintDescending. +func DecodeVarintDescending(b []byte) ([]byte, int64, error) { + leftover, v, err := DecodeVarintAscending(b) + return leftover, ^v, err +} + +// EncodeUvarintAscending encodes the uint64 value using a variable length +// (length-prefixed) representation. The length is encoded as a single +// byte indicating the number of encoded bytes (-8) to follow. See +// EncodeVarintAscending for rationale. The encoded bytes are appended to the +// supplied buffer and the final buffer is returned. +func EncodeUvarintAscending(b []byte, v uint64) []byte { + switch { + case v <= intSmall: + return append(b, intZero+byte(v)) + case v <= 0xff: + return append(b, IntMax-7, byte(v)) + case v <= 0xffff: + return append(b, IntMax-6, byte(v>>8), byte(v)) + case v <= 0xffffff: + return append(b, IntMax-5, byte(v>>16), byte(v>>8), byte(v)) + case v <= 0xffffffff: + return append(b, IntMax-4, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) + case v <= 0xffffffffff: + return append(b, IntMax-3, byte(v>>32), byte(v>>24), byte(v>>16), byte(v>>8), + byte(v)) + case v <= 0xffffffffffff: + return append(b, IntMax-2, byte(v>>40), byte(v>>32), byte(v>>24), byte(v>>16), + byte(v>>8), byte(v)) + case v <= 0xffffffffffffff: + return append(b, IntMax-1, byte(v>>48), byte(v>>40), byte(v>>32), byte(v>>24), + byte(v>>16), byte(v>>8), byte(v)) + default: + return append(b, IntMax, byte(v>>56), byte(v>>48), byte(v>>40), byte(v>>32), + byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) + } +} + +// EncodeUvarintDescending encodes the uint64 value so that it sorts in +// reverse order, from largest to smallest. +func EncodeUvarintDescending(b []byte, v uint64) []byte { + switch { + case v == 0: + return append(b, IntMin+8) + case v <= 0xff: + v = ^v + return append(b, IntMin+7, byte(v)) + case v <= 0xffff: + v = ^v + return append(b, IntMin+6, byte(v>>8), byte(v)) + case v <= 0xffffff: + v = ^v + return append(b, IntMin+5, byte(v>>16), byte(v>>8), byte(v)) + case v <= 0xffffffff: + v = ^v + return append(b, IntMin+4, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) + case v <= 0xffffffffff: + v = ^v + return append(b, IntMin+3, byte(v>>32), byte(v>>24), byte(v>>16), byte(v>>8), + byte(v)) + case v <= 0xffffffffffff: + v = ^v + return append(b, IntMin+2, byte(v>>40), byte(v>>32), byte(v>>24), byte(v>>16), + byte(v>>8), byte(v)) + case v <= 0xffffffffffffff: + v = ^v + return append(b, IntMin+1, byte(v>>48), byte(v>>40), byte(v>>32), byte(v>>24), + byte(v>>16), byte(v>>8), byte(v)) + default: + v = ^v + return append(b, IntMin, byte(v>>56), byte(v>>48), byte(v>>40), byte(v>>32), + byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) + } +} + +// DecodeUvarintAscending decodes a uvarint encoded uint64 from the input +// buffer. The remainder of the input buffer and the decoded uint64 +// are returned. +func DecodeUvarintAscending(b []byte) ([]byte, uint64, error) { + if len(b) == 0 { + return nil, 0, NewErrInsufficientBytesToDecode(b, "uvarint") + } + length := int(b[0]) - intZero + b = b[1:] // skip length byte + if length <= intSmall { + return b, uint64(length), nil + } + length -= intSmall + if length < 0 || length > 8 { + return nil, 0, NewErrInvalidUvarintLength(b, length) + } else if len(b) < length { + return nil, 0, NewErrInsufficientBytesToDecode(b, "uvarint") + } + var v uint64 + // It is faster to range over the elements in a slice than to index + // into the slice on each loop iteration. + for _, t := range b[:length] { + v = (v << 8) | uint64(t) + } + return b[length:], v, nil +} + +// DecodeUvarintDescending decodes a uint64 value which was encoded +// using EncodeUvarintDescending. +func DecodeUvarintDescending(b []byte) ([]byte, uint64, error) { + if len(b) == 0 { + return nil, 0, NewErrInsufficientBytesToDecode(b, "uvarint") + } + length := intZero - int(b[0]) + b = b[1:] // skip length byte + if length < 0 || length > 8 { + return nil, 0, NewErrInvalidUvarintLength(b, length) + } else if len(b) < length { + return nil, 0, NewErrInsufficientBytesToDecode(b, "uvarint") + } + var x uint64 + for _, t := range b[:length] { + x = (x << 8) | uint64(^t) + } + return b[length:], x, nil +} diff --git a/encoding/int_test.go b/encoding/int_test.go new file mode 100644 index 0000000000..80c3f502c4 --- /dev/null +++ b/encoding/int_test.go @@ -0,0 +1,223 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +package encoding + +import ( + "bytes" + "math" + "testing" +) + +func testBasicEncodeDecodeUint64( + encFunc func([]byte, uint64) []byte, + decFunc func([]byte) ([]byte, uint64, error), + descending bool, + t *testing.T, +) { + testCases := []uint64{ + 0, 1, + 1<<8 - 1, 1 << 8, + 1<<16 - 1, 1 << 16, + 1<<24 - 1, 1 << 24, + 1<<32 - 1, 1 << 32, + 1<<40 - 1, 1 << 40, + 1<<48 - 1, 1 << 48, + 1<<56 - 1, 1 << 56, + math.MaxUint64 - 1, math.MaxUint64, + } + + var lastEnc []byte + for i, v := range testCases { + enc := encFunc(nil, v) + if i > 0 { + if (descending && bytes.Compare(enc, lastEnc) >= 0) || + (!descending && bytes.Compare(enc, lastEnc) < 0) { + t.Errorf("ordered constraint violated for %d: [% x] vs. [% x]", v, enc, lastEnc) + } + } + b, decode, err := decFunc(enc) + if err != nil { + t.Error(err) + continue + } + if len(b) != 0 { + t.Errorf("leftover bytes: [% x]", b) + } + if decode != v { + t.Errorf("decode yielded different value than input: %d vs. %d", decode, v) + } + lastEnc = enc + } +} + +var int64TestCases = [...]int64{ + math.MinInt64, math.MinInt64 + 1, + -1<<56 - 1, -1 << 56, + -1<<48 - 1, -1 << 48, + -1<<40 - 1, -1 << 40, + -1<<32 - 1, -1 << 32, + -1<<24 - 1, -1 << 24, + -1<<16 - 1, -1 << 16, + -1<<8 - 1, -1 << 8, + -1, 0, 1, + 1<<8 - 1, 1 << 8, + 1<<16 - 1, 1 << 16, + 1<<24 - 1, 1 << 24, + 1<<32 - 1, 1 << 32, + 1<<40 - 1, 1 << 40, + 1<<48 - 1, 1 << 48, + 1<<56 - 1, 1 << 56, + math.MaxInt64 - 1, math.MaxInt64, +} + +func testBasicEncodeDecodeInt64( + encFunc func([]byte, int64) []byte, + decFunc func([]byte) ([]byte, int64, error), + descending bool, + t *testing.T, +) { + var lastEnc []byte + for i, v := range int64TestCases { + enc := encFunc(nil, v) + if i > 0 { + if (descending && bytes.Compare(enc, lastEnc) >= 0) || + (!descending && bytes.Compare(enc, lastEnc) < 0) { + t.Errorf("ordered constraint violated for %d: [% x] vs. [% x]", v, enc, lastEnc) + } + } + b, decode, err := decFunc(enc) + if err != nil { + t.Errorf("%v: %d [%x]", err, v, enc) + continue + } + if len(b) != 0 { + t.Errorf("leftover bytes: [% x]", b) + } + if decode != v { + t.Errorf("decode yielded different value than input: %d vs. %d [%x]", decode, v, enc) + } + lastEnc = enc + } +} + +type testCaseInt64 struct { + value int64 + expEnc []byte +} + +func testCustomEncodeInt64( + testCases []testCaseInt64, encFunc func([]byte, int64) []byte, t *testing.T, +) { + for _, test := range testCases { + enc := encFunc(nil, test.value) + if !bytes.Equal(enc, test.expEnc) { + t.Errorf("expected [% x]; got [% x] (value: %d)", test.expEnc, enc, test.value) + } + } +} + +type testCaseUint64 struct { + value uint64 + expEnc []byte +} + +func testCustomEncodeUint64( + testCases []testCaseUint64, encFunc func([]byte, uint64) []byte, t *testing.T, +) { + for _, test := range testCases { + enc := encFunc(nil, test.value) + if !bytes.Equal(enc, test.expEnc) { + t.Errorf("expected [% x]; got [% x] (value: %d)", test.expEnc, enc, test.value) + } + } +} + +func TestEncodeDecodeUint64(t *testing.T) { + testBasicEncodeDecodeUint64(EncodeUint64Ascending, DecodeUint64Ascending, false, t) + testCases := []testCaseUint64{ + {0, []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, + {1, []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}}, + {1 << 8, []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00}}, + {math.MaxUint64, []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + } + testCustomEncodeUint64(testCases, EncodeUint64Ascending, t) +} + +func TestEncodeDecodeUint64Descending(t *testing.T) { + testBasicEncodeDecodeUint64(EncodeUint64Descending, DecodeUint64Descending, true, t) + testCases := []testCaseUint64{ + {0, []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + {1, []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}}, + {1 << 8, []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe, 0xff}}, + {math.MaxUint64, []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, + } + testCustomEncodeUint64(testCases, EncodeUint64Descending, t) +} + +func TestEncodeDecodeVarint(t *testing.T) { + testBasicEncodeDecodeInt64(EncodeVarintAscending, DecodeVarintAscending, false, t) + testCases := []testCaseInt64{ + {math.MinInt64, []byte{0x80, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, + {math.MinInt64 + 1, []byte{0x80, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}}, + {-1 << 8, []byte{0x86, 0xff, 0x00}}, + {-1, []byte{0x87, 0xff}}, + {0, []byte{0x88}}, + {1, []byte{0x89}}, + {109, []byte{0xf5}}, + {112, []byte{0xf6, 0x70}}, + {1 << 8, []byte{0xf7, 0x01, 0x00}}, + {math.MaxInt64, []byte{0xfd, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + } + testCustomEncodeInt64(testCases, EncodeVarintAscending, t) +} + +func TestEncodeDecodeVarintDescending(t *testing.T) { + testBasicEncodeDecodeInt64(EncodeVarintDescending, DecodeVarintDescending, true, t) + testCases := []testCaseInt64{ + {math.MinInt64, []byte{0xfd, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + {math.MinInt64 + 1, []byte{0xfd, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}}, + {-1 << 8, []byte{0xf6, 0xff}}, + {-110, []byte{0xf5}}, + {-1, []byte{0x88}}, + {0, []byte{0x87, 0xff}}, + {1, []byte{0x87, 0xfe}}, + {1 << 8, []byte{0x86, 0xfe, 0xff}}, + {math.MaxInt64, []byte{0x80, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, + } + testCustomEncodeInt64(testCases, EncodeVarintDescending, t) +} + +func TestEncodeDecodeUvarint(t *testing.T) { + testBasicEncodeDecodeUint64(EncodeUvarintAscending, DecodeUvarintAscending, false, t) + testCases := []testCaseUint64{ + {0, []byte{0x88}}, + {1, []byte{0x89}}, + {109, []byte{0xf5}}, + {110, []byte{0xf6, 0x6e}}, + {1 << 8, []byte{0xf7, 0x01, 0x00}}, + {math.MaxUint64, []byte{0xfd, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + } + testCustomEncodeUint64(testCases, EncodeUvarintAscending, t) +} + +func TestEncodeDecodeUvarintDescending(t *testing.T) { + testBasicEncodeDecodeUint64(EncodeUvarintDescending, DecodeUvarintDescending, true, t) + testCases := []testCaseUint64{ + {0, []byte{0x88}}, + {1, []byte{0x87, 0xfe}}, + {1 << 8, []byte{0x86, 0xfe, 0xff}}, + {math.MaxUint64 - 1, []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}}, + {math.MaxUint64, []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, + } + testCustomEncodeUint64(testCases, EncodeUvarintDescending, t) +} diff --git a/encoding/null.go b/encoding/null.go new file mode 100644 index 0000000000..067c348122 --- /dev/null +++ b/encoding/null.go @@ -0,0 +1,41 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +package encoding + +// EncodeNullAscending encodes a NULL value. The encodes bytes are appended to the +// supplied buffer and the final buffer is returned. The encoded value for a +// NULL is guaranteed to not be a prefix for the EncodeVarint, EncodeFloat, +// EncodeBytes and EncodeString encodings. +func EncodeNullAscending(b []byte) []byte { + return append(b, encodedNull) +} + +// EncodeNullDescending is the descending equivalent of EncodeNullAscending. +func EncodeNullDescending(b []byte) []byte { + return append(b, encodedNullDesc) +} + +// DecodeIfNull decodes a NULL value from the input buffer. If the input buffer +// contains a null at the start of the buffer then it is removed from the +// buffer and true is returned for the second result. Otherwise, the buffer is +// returned unchanged and false is returned for the second result. Since the +// NULL value encoding is guaranteed to never occur as the prefix for the +// EncodeVarint, EncodeFloat, EncodeBytes and EncodeString encodings, it is +// safe to call DecodeIfNull on their encoded values. +// This function handles both ascendingly and descendingly encoded NULLs. +func DecodeIfNull(b []byte) ([]byte, bool) { + if PeekType(b) == Null { + return b[1:], true + } + return b, false +} diff --git a/encoding/null_test.go b/encoding/null_test.go new file mode 100644 index 0000000000..fb18d2ac64 --- /dev/null +++ b/encoding/null_test.go @@ -0,0 +1,40 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +package encoding + +import ( + "bytes" + "testing" +) + +func TestEncodeDecodeNull(t *testing.T) { + const hello = "hello" + + buf := EncodeNullAscending([]byte(hello)) + expected := []byte(hello + "\x00") + if !bytes.Equal(expected, buf) { + t.Fatalf("expected %q, but found %q", expected, buf) + } + + if remaining, isNull := DecodeIfNull([]byte(hello)); isNull { + t.Fatalf("expected isNull=false, but found isNull=%v", isNull) + } else if hello != string(remaining) { + t.Fatalf("expected %q, but found %q", hello, remaining) + } + + if remaining, isNull := DecodeIfNull([]byte("\x00" + hello)); !isNull { + t.Fatalf("expected isNull=true, but found isNull=%v", isNull) + } else if hello != string(remaining) { + t.Fatalf("expected %q, but found %q", hello, remaining) + } +} diff --git a/encoding/string.go b/encoding/string.go new file mode 100644 index 0000000000..23b6d379ae --- /dev/null +++ b/encoding/string.go @@ -0,0 +1,52 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +package encoding + +import ( + "unsafe" +) + +// unsafeConvertStringToBytes converts a string to a byte array to be used with +// string encoding functions. Note that the output byte array should not be +// modified if the input string is expected to be used again - doing so could +// violate Go semantics. +func unsafeConvertStringToBytes(s string) []byte { + if len(s) == 0 { + return nil + } + return unsafe.Slice(unsafe.StringData(s), len(s)) +} + +// EncodeStringAscending encodes the string value using an escape-based encoding. See +// EncodeBytes for details. The encoded bytes are append to the supplied buffer +// and the resulting buffer is returned. +func EncodeStringAscending(b []byte, s string) []byte { + return encodeStringAscendingWithTerminatorAndPrefix(b, s, ascendingBytesEscapes.escapedTerm, bytesMarker) +} + +// encodeStringAscendingWithTerminatorAndPrefix encodes the string value using an escape-based encoding. See +// EncodeBytes for details. The encoded bytes are append to the supplied buffer +// and the resulting buffer is returned. We can also pass a terminator byte to be used with +// JSON key encoding. +func encodeStringAscendingWithTerminatorAndPrefix( + b []byte, s string, terminator byte, prefix byte, +) []byte { + unsafeString := unsafeConvertStringToBytes(s) + return encodeBytesAscendingWithTerminatorAndPrefix(b, unsafeString, terminator, prefix) +} + +// EncodeStringDescending is the descending version of EncodeStringAscending. +func EncodeStringDescending(b []byte, s string) []byte { + unsafeString := unsafeConvertStringToBytes(s) + return EncodeBytesDescending(b, unsafeString) +} diff --git a/encoding/string_test.go b/encoding/string_test.go new file mode 100644 index 0000000000..9304b8303e --- /dev/null +++ b/encoding/string_test.go @@ -0,0 +1,122 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +package encoding + +import ( + "bytes" + "testing" +) + +func TestEncodeDecodeUnsafeString(t *testing.T) { + testCases := []struct { + value string + encoded []byte + }{ + {"\x00\x01a", []byte{bytesMarker, 0x00, escaped00, 1, 'a', escape, escapedTerm}}, + {"\x00a", []byte{bytesMarker, 0x00, escaped00, 'a', escape, escapedTerm}}, + {"\x00\xffa", []byte{bytesMarker, 0x00, escaped00, 0xff, 'a', escape, escapedTerm}}, + {"a", []byte{bytesMarker, 'a', escape, escapedTerm}}, + {"b", []byte{bytesMarker, 'b', escape, escapedTerm}}, + {"b\x00", []byte{bytesMarker, 'b', 0x00, escaped00, escape, escapedTerm}}, + {"b\x00\x00", []byte{bytesMarker, 'b', 0x00, escaped00, 0x00, escaped00, escape, escapedTerm}}, + {"b\x00\x00a", []byte{bytesMarker, 'b', 0x00, escaped00, 0x00, escaped00, 'a', escape, escapedTerm}}, + {"b\xff", []byte{bytesMarker, 'b', 0xff, escape, escapedTerm}}, + {"hello", []byte{bytesMarker, 'h', 'e', 'l', 'l', 'o', escape, escapedTerm}}, + } + for i, c := range testCases { + enc := EncodeStringAscending(nil, c.value) + if !bytes.Equal(enc, c.encoded) { + t.Errorf("unexpected encoding mismatch for %v. expected [% x], got [% x]", + c.value, c.encoded, enc) + } + if i > 0 { + if bytes.Compare(testCases[i-1].encoded, enc) >= 0 { + t.Errorf("%v: expected [% x] to be less than [% x]", + c.value, testCases[i-1].encoded, enc) + } + } + remainder, dec, err := DecodeBytesAscending(enc) + if err != nil { + t.Error(err) + continue + } + if c.value != string(dec) { + t.Errorf("unexpected decoding mismatch for %v. got %v", c.value, string(dec)) + } + if len(remainder) != 0 { + t.Errorf("unexpected remaining bytes: %v", remainder) + } + + enc = append(enc, "remainder"...) + remainder, _, err = DecodeBytesAscending(enc) + if err != nil { + t.Error(err) + continue + } + if string(remainder) != "remainder" { + t.Errorf("unexpected remaining bytes: %v", remainder) + } + } +} + +func TestEncodeDecodeUnsafeStringDescending(t *testing.T) { + testCases := []struct { + value string + encoded []byte + }{ + {"hello", []byte{bytesDescMarker, ^byte('h'), ^byte('e'), ^byte('l'), ^byte('l'), ^byte('o'), escapeDesc, escapedTermDesc}}, + {"b\xff", []byte{bytesDescMarker, ^byte('b'), ^byte(0xff), escapeDesc, escapedTermDesc}}, + {"b\x00\x00a", []byte{bytesDescMarker, ^byte('b'), ^byte(0), escaped00Desc, ^byte(0), escaped00Desc, ^byte('a'), escapeDesc, escapedTermDesc}}, + {"b\x00\x00", []byte{bytesDescMarker, ^byte('b'), ^byte(0), escaped00Desc, ^byte(0), escaped00Desc, escapeDesc, escapedTermDesc}}, + {"b\x00", []byte{bytesDescMarker, ^byte('b'), ^byte(0), escaped00Desc, escapeDesc, escapedTermDesc}}, + {"b", []byte{bytesDescMarker, ^byte('b'), escapeDesc, escapedTermDesc}}, + {"a", []byte{bytesDescMarker, ^byte('a'), escapeDesc, escapedTermDesc}}, + {"\x00\xffa", []byte{bytesDescMarker, ^byte(0), escaped00Desc, ^byte(0xff), ^byte('a'), escapeDesc, escapedTermDesc}}, + {"\x00a", []byte{bytesDescMarker, ^byte(0), escaped00Desc, ^byte('a'), escapeDesc, escapedTermDesc}}, + {"\x00\x01a", []byte{bytesDescMarker, ^byte(0), escaped00Desc, ^byte(1), ^byte('a'), escapeDesc, escapedTermDesc}}, + } + for i, c := range testCases { + enc := EncodeStringDescending(nil, c.value) + if !bytes.Equal(enc, c.encoded) { + t.Errorf("unexpected encoding mismatch for %v. expected [% x], got [% x]", + c.value, c.encoded, enc) + } + if i > 0 { + if bytes.Compare(testCases[i-1].encoded, enc) >= 0 { + t.Errorf("%v: expected [% x] to be less than [% x]", + c.value, testCases[i-1].encoded, enc) + } + } + remainder, dec, err := DecodeBytesDescending(enc) + if err != nil { + t.Error(err) + continue + } + if c.value != string(dec) { + t.Errorf("unexpected decoding mismatch for %v. got [% x]", c.value, string(dec)) + } + if len(remainder) != 0 { + t.Errorf("unexpected remaining bytes: %v", remainder) + } + + enc = append(enc, "remainder"...) + remainder, _, err = DecodeBytesDescending(enc) + if err != nil { + t.Error(err) + continue + } + if string(remainder) != "remainder" { + t.Errorf("unexpected remaining bytes: %v", remainder) + } + } +} diff --git a/encoding/type.go b/encoding/type.go new file mode 100644 index 0000000000..b4b85cf7bf --- /dev/null +++ b/encoding/type.go @@ -0,0 +1,46 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +package encoding + +// Type represents the type of a value encoded by +// Encode{Null,Varint,Uvarint,Float,Bytes}. +type Type int + +const ( + Unknown Type = 0 + Null Type = 1 + Int Type = 3 + Float Type = 4 + Bytes Type = 6 + BytesDesc Type = 7 +) + +// PeekType peeks at the type of the value encoded at the start of b. +func PeekType(b []byte) Type { + if len(b) >= 1 { + m := b[0] + switch { + case m == encodedNull, m == encodedNullDesc: + return Null + case m == bytesMarker: + return Bytes + case m == bytesDescMarker: + return BytesDesc + case m >= IntMin && m <= IntMax: + return Int + case m >= floatNaN && m <= floatNaNDesc: + return Float + } + } + return Unknown +} diff --git a/encoding/type_test.go b/encoding/type_test.go new file mode 100644 index 0000000000..f3114858bd --- /dev/null +++ b/encoding/type_test.go @@ -0,0 +1,41 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +package encoding + +import ( + "testing" +) + +func TestPeekType(t *testing.T) { + testCases := []struct { + enc []byte + typ Type + }{ + {EncodeNullAscending(nil), Null}, + {EncodeNullDescending(nil), Null}, + {EncodeVarintAscending(nil, 0), Int}, + {EncodeVarintDescending(nil, 0), Int}, + {EncodeUvarintAscending(nil, 0), Int}, + {EncodeUvarintDescending(nil, 0), Int}, + {EncodeFloatAscending(nil, 0), Float}, + {EncodeFloatDescending(nil, 0), Float}, + {EncodeBytesAscending(nil, []byte("")), Bytes}, + {EncodeBytesDescending(nil, []byte("")), BytesDesc}, + } + for i, c := range testCases { + typ := PeekType(c.enc) + if c.typ != typ { + t.Fatalf("%d: expected %d, but found %d", i, c.typ, typ) + } + } +} diff --git a/go.mod b/go.mod index 42b089729f..8872605e0a 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( github.com/multiformats/go-multiaddr v0.12.2 github.com/multiformats/go-multibase v0.2.0 github.com/multiformats/go-multihash v0.2.3 + github.com/pkg/errors v0.9.1 github.com/sourcenetwork/badger/v4 v4.2.1-0.20231113215945-a63444ca5276 github.com/sourcenetwork/go-libp2p-pubsub-rpc v0.0.13 github.com/sourcenetwork/graphql-go v0.7.10-0.20231113214537-a9560c1898dd @@ -147,7 +148,6 @@ require ( github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 // indirect github.com/pelletier/go-toml/v2 v2.1.0 // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect - github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/polydawn/refmt v0.89.0 // indirect github.com/prometheus/client_golang v1.18.0 // indirect diff --git a/lens/fetcher.go b/lens/fetcher.go index 3e430dd377..1e093f3966 100644 --- a/lens/fetcher.go +++ b/lens/fetcher.go @@ -222,7 +222,7 @@ func (f *lensedFetcher) lensDocToEncodedDoc(docAsMap LensDoc) (fetcher.EncodedDo continue } - fieldValue, err := core.DecodeFieldValue(fieldDesc, fieldByteValue) + fieldValue, err := core.NormalizeFieldValue(fieldDesc, fieldByteValue) if err != nil { return nil, err } diff --git a/request/graphql/schema/collection.go b/request/graphql/schema/collection.go index 996b88495a..4680bdceff 100644 --- a/request/graphql/schema/collection.go +++ b/request/graphql/schema/collection.go @@ -44,11 +44,11 @@ func FromString(ctx context.Context, schemaString string) ( return nil, err } - return fromAst(ctx, doc) + return fromAst(doc) } // fromAst parses a GQL AST into a set of collection descriptions. -func fromAst(ctx context.Context, doc *ast.Document) ( +func fromAst(doc *ast.Document) ( []client.CollectionDefinition, error, ) { @@ -58,7 +58,7 @@ func fromAst(ctx context.Context, doc *ast.Document) ( for _, def := range doc.Definitions { switch defType := def.(type) { case *ast.ObjectDefinition: - description, err := collectionFromAstDefinition(ctx, relationManager, defType) + description, err := collectionFromAstDefinition(relationManager, defType) if err != nil { return nil, err } @@ -66,7 +66,7 @@ func fromAst(ctx context.Context, doc *ast.Document) ( definitions = append(definitions, description) case *ast.InterfaceDefinition: - description, err := schemaFromAstDefinition(ctx, relationManager, defType) + description, err := schemaFromAstDefinition(relationManager, defType) if err != nil { return nil, err } @@ -98,7 +98,6 @@ func fromAst(ctx context.Context, doc *ast.Document) ( // collectionFromAstDefinition parses a AST object definition into a set of collection descriptions. func collectionFromAstDefinition( - ctx context.Context, relationManager *RelationManager, def *ast.ObjectDefinition, ) (client.CollectionDefinition, error) { @@ -164,7 +163,6 @@ func collectionFromAstDefinition( } func schemaFromAstDefinition( - ctx context.Context, relationManager *RelationManager, def *ast.InterfaceDefinition, ) (client.SchemaDescription, error) { @@ -213,8 +211,7 @@ func fieldIndexFromAST(field *ast.FieldDefinition, directive *ast.Directive) (cl desc := client.IndexDescription{ Fields: []client.IndexedFieldDescription{ { - Name: field.Name.Value, - Direction: client.Ascending, + Name: field.Name.Value, }, }, } @@ -235,6 +232,14 @@ func fieldIndexFromAST(field *ast.FieldDefinition, directive *ast.Directive) (cl return client.IndexDescription{}, ErrIndexWithInvalidArg } desc.Unique = boolVal.Value + case types.IndexDirectivePropDirection: + dirVal, ok := arg.Value.(*ast.EnumValue) + if !ok { + return client.IndexDescription{}, ErrIndexWithInvalidArg + } + if dirVal.Value == types.FieldOrderDESC { + desc.Fields[0].Descending = true + } default: return client.IndexDescription{}, ErrIndexWithUnknownArg } @@ -298,16 +303,12 @@ func indexFromAST(directive *ast.Directive) (client.IndexDescription, error) { if !ok { return client.IndexDescription{}, ErrIndexWithInvalidArg } - if dirVal.Value == string(client.Ascending) { - desc.Fields[i].Direction = client.Ascending - } else if dirVal.Value == string(client.Descending) { - desc.Fields[i].Direction = client.Descending + if dirVal.Value == types.FieldOrderASC { + desc.Fields[i].Descending = false + } else if dirVal.Value == types.FieldOrderDESC { + desc.Fields[i].Descending = true } } - } else { - for i := range desc.Fields { - desc.Fields[i].Direction = client.Ascending - } } return desc, nil } diff --git a/request/graphql/schema/generate.go b/request/graphql/schema/generate.go index 32fe562cff..e4397e2e40 100644 --- a/request/graphql/schema/generate.go +++ b/request/graphql/schema/generate.go @@ -82,7 +82,7 @@ func (g *Generator) Generate(ctx context.Context, collections []client.Collectio // the given CollectionDescriptions. func (g *Generator) generate(ctx context.Context, collections []client.CollectionDefinition) ([]*gql.Object, error) { // build base types - defs, err := g.buildTypes(ctx, collections) + defs, err := g.buildTypes(collections) if err != nil { return nil, err } @@ -129,7 +129,7 @@ func (g *Generator) generate(ctx context.Context, collections []client.Collectio return nil, err } - if err := g.genAggregateFields(ctx); err != nil { + if err := g.genAggregateFields(); err != nil { return nil, err } // resolve types @@ -403,7 +403,6 @@ func (g *Generator) createExpandedFieldList( // Given a set of developer defined collection types // extract and return the correct gql.Object type(s) func (g *Generator) buildTypes( - ctx context.Context, collections []client.CollectionDefinition, ) ([]*gql.Object, error) { // @todo: Check for duplicate named defined types in the TypeMap @@ -594,7 +593,7 @@ func (g *Generator) buildMutationInputTypes(collections []client.CollectionDefin return nil } -func (g *Generator) genAggregateFields(ctx context.Context) error { +func (g *Generator) genAggregateFields() error { topLevelCountInputs := map[string]*gql.InputObject{} topLevelNumericAggInputs := map[string]*gql.InputObject{} @@ -1014,7 +1013,7 @@ func (g *Generator) GenerateQueryInputForGQLType( types.groupBy = g.genTypeFieldsEnum(obj) types.order = g.genTypeOrderArgInput(obj) - queryField := g.genTypeQueryableFieldList(ctx, obj, types) + queryField := g.genTypeQueryableFieldList(obj, types) return queryField, nil } @@ -1249,7 +1248,6 @@ type queryInputTypeConfig struct { } func (g *Generator) genTypeQueryableFieldList( - ctx context.Context, obj *gql.Object, config queryInputTypeConfig, ) *gql.Field { diff --git a/request/graphql/schema/index_parse_test.go b/request/graphql/schema/index_parse_test.go index ca1ce32696..8204c2d0ec 100644 --- a/request/graphql/schema/index_parse_test.go +++ b/request/graphql/schema/index_parse_test.go @@ -28,7 +28,7 @@ func TestParseIndexOnStruct(t *testing.T) { { Name: "", Fields: []client.IndexedFieldDescription{ - {Name: "name", Direction: client.Ascending}, + {Name: "name"}, }, Unique: false, }, @@ -41,7 +41,7 @@ func TestParseIndexOnStruct(t *testing.T) { { Name: "userIndex", Fields: []client.IndexedFieldDescription{ - {Name: "name", Direction: client.Ascending}, + {Name: "name"}, }, }, }, @@ -52,7 +52,7 @@ func TestParseIndexOnStruct(t *testing.T) { targetDescriptions: []client.IndexDescription{ { Fields: []client.IndexedFieldDescription{ - {Name: "name", Direction: client.Ascending}, + {Name: "name"}, }, Unique: true, }, @@ -64,7 +64,7 @@ func TestParseIndexOnStruct(t *testing.T) { targetDescriptions: []client.IndexDescription{ { Fields: []client.IndexedFieldDescription{ - {Name: "name", Direction: client.Ascending}, + {Name: "name"}, }, Unique: false, }, @@ -76,7 +76,7 @@ func TestParseIndexOnStruct(t *testing.T) { targetDescriptions: []client.IndexDescription{ { Fields: []client.IndexedFieldDescription{ - {Name: "name", Direction: client.Ascending}}, + {Name: "name"}}, }, }, }, @@ -86,7 +86,7 @@ func TestParseIndexOnStruct(t *testing.T) { targetDescriptions: []client.IndexDescription{ { Fields: []client.IndexedFieldDescription{ - {Name: "name", Direction: client.Descending}}, + {Name: "name", Descending: true}}, }, }, }, @@ -96,8 +96,8 @@ func TestParseIndexOnStruct(t *testing.T) { targetDescriptions: []client.IndexDescription{ { Fields: []client.IndexedFieldDescription{ - {Name: "name", Direction: client.Ascending}, - {Name: "age", Direction: client.Ascending}, + {Name: "name"}, + {Name: "age"}, }, }, }, @@ -108,8 +108,8 @@ func TestParseIndexOnStruct(t *testing.T) { targetDescriptions: []client.IndexDescription{ { Fields: []client.IndexedFieldDescription{ - {Name: "name", Direction: client.Ascending}, - {Name: "age", Direction: client.Descending}, + {Name: "name"}, + {Name: "age", Descending: true}, }, }, }, @@ -216,7 +216,7 @@ func TestParseIndexOnField(t *testing.T) { { Name: "", Fields: []client.IndexedFieldDescription{ - {Name: "name", Direction: client.Ascending}, + {Name: "name"}, }, Unique: false, }, @@ -231,7 +231,7 @@ func TestParseIndexOnField(t *testing.T) { { Name: "nameIndex", Fields: []client.IndexedFieldDescription{ - {Name: "name", Direction: client.Ascending}, + {Name: "name"}, }, Unique: false, }, @@ -245,7 +245,7 @@ func TestParseIndexOnField(t *testing.T) { targetDescriptions: []client.IndexDescription{ { Fields: []client.IndexedFieldDescription{ - {Name: "name", Direction: client.Ascending}, + {Name: "name"}, }, Unique: true, }, @@ -259,7 +259,35 @@ func TestParseIndexOnField(t *testing.T) { targetDescriptions: []client.IndexDescription{ { Fields: []client.IndexedFieldDescription{ - {Name: "name", Direction: client.Ascending}, + {Name: "name"}, + }, + Unique: false, + }, + }, + }, + { + description: "field index in ASC order", + sdl: `type user { + name: String @index(direction: ASC) + }`, + targetDescriptions: []client.IndexDescription{ + { + Fields: []client.IndexedFieldDescription{ + {Name: "name"}, + }, + Unique: false, + }, + }, + }, + { + description: "field index in DESC order", + sdl: `type user { + name: String @index(direction: DESC) + }`, + targetDescriptions: []client.IndexDescription{ + { + Fields: []client.IndexedFieldDescription{ + {Name: "name", Descending: true}, }, Unique: false, }, @@ -281,13 +309,6 @@ func TestParseInvalidIndexOnField(t *testing.T) { }`, expectedErr: errIndexUnknownArgument, }, - { - description: "forbidden 'direction' argument", - sdl: `type user { - name: String @index(direction: ASC) - }`, - expectedErr: errIndexUnknownArgument, - }, { description: "invalid field index name type", sdl: `type user { diff --git a/request/graphql/schema/types/types.go b/request/graphql/schema/types/types.go index 065dadaa6d..199f21deb3 100644 --- a/request/graphql/schema/types/types.go +++ b/request/graphql/schema/types/types.go @@ -28,7 +28,11 @@ const ( IndexDirectivePropName = "name" IndexDirectivePropUnique = "unique" IndexDirectivePropFields = "fields" + IndexDirectivePropDirection = "direction" IndexDirectivePropDirections = "directions" + + FieldOrderASC = "ASC" + FieldOrderDESC = "DESC" ) var ( @@ -111,6 +115,12 @@ var ( IndexDirectivePropName: &gql.ArgumentConfig{ Type: gql.String, }, + IndexDirectivePropUnique: &gql.ArgumentConfig{ + Type: gql.Boolean, + }, + IndexDirectivePropDirection: &gql.ArgumentConfig{ + Type: OrderingEnum, + }, }, Locations: []string{ gql.DirectiveLocationField, diff --git a/tests/integration/index/create_composite_test.go b/tests/integration/index/create_composite_test.go index c20b1b1240..e9a83f1d15 100644 --- a/tests/integration/index/create_composite_test.go +++ b/tests/integration/index/create_composite_test.go @@ -48,7 +48,7 @@ func TestCompositeIndexCreate_WhenCreated_CanRetrieve(t *testing.T) { testUtils.CreateIndex{ CollectionID: 0, IndexName: "name_age_index", - FieldsNames: []string{"name", "age"}, + Fields: []testUtils.IndexedField{{Name: "name"}, {Name: "age"}}, }, testUtils.GetIndexes{ CollectionID: 0, @@ -58,12 +58,10 @@ func TestCompositeIndexCreate_WhenCreated_CanRetrieve(t *testing.T) { ID: 1, Fields: []client.IndexedFieldDescription{ { - Name: "name", - Direction: client.Ascending, + Name: "name", }, { - Name: "age", - Direction: client.Ascending, + Name: "age", }, }, }, diff --git a/tests/integration/index/create_get_test.go b/tests/integration/index/create_get_test.go index 6ec0962c17..3ba27cfa9e 100644 --- a/tests/integration/index/create_get_test.go +++ b/tests/integration/index/create_get_test.go @@ -37,8 +37,7 @@ func TestIndexGet_ShouldReturnListOfExistingIndexes(t *testing.T) { ID: 1, Fields: []client.IndexedFieldDescription{ { - Name: "name", - Direction: client.Ascending, + Name: "name", }, }, }, @@ -47,8 +46,7 @@ func TestIndexGet_ShouldReturnListOfExistingIndexes(t *testing.T) { ID: 2, Fields: []client.IndexedFieldDescription{ { - Name: "age", - Direction: client.Ascending, + Name: "age", }, }, }, diff --git a/tests/integration/index/create_unique_composite_test.go b/tests/integration/index/create_unique_composite_test.go index 3d146eb591..2f0ed96488 100644 --- a/tests/integration/index/create_unique_composite_test.go +++ b/tests/integration/index/create_unique_composite_test.go @@ -52,7 +52,7 @@ func TestCreateUniqueCompositeIndex_IfFieldValuesAreNotUnique_ReturnError(t *tes }, testUtils.CreateIndex{ CollectionID: 0, - FieldsNames: []string{"name", "age"}, + Fields: []testUtils.IndexedField{{Name: "name"}, {Name: "age"}}, Unique: true, ExpectedError: db.NewErrCanNotIndexNonUniqueFields( "bae-cae3deac-d371-5a1f-93b4-ede69042f79b", @@ -151,7 +151,7 @@ func TestUniqueCompositeIndexCreate_IfFieldValuesAreUnique_Succeed(t *testing.T) }, testUtils.CreateIndex{ CollectionID: 0, - FieldsNames: []string{"name", "age"}, + Fields: []testUtils.IndexedField{{Name: "name"}, {Name: "age"}}, IndexName: "name_age_unique_index", Unique: true, }, @@ -164,12 +164,10 @@ func TestUniqueCompositeIndexCreate_IfFieldValuesAreUnique_Succeed(t *testing.T) Unique: true, Fields: []client.IndexedFieldDescription{ { - Name: "name", - Direction: client.Ascending, + Name: "name", }, { - Name: "age", - Direction: client.Ascending, + Name: "age", }, }, }, diff --git a/tests/integration/index/create_unique_test.go b/tests/integration/index/create_unique_test.go index a0ecf34482..e9b2d41753 100644 --- a/tests/integration/index/create_unique_test.go +++ b/tests/integration/index/create_unique_test.go @@ -121,8 +121,7 @@ func TestUniqueIndexCreate_UponAddingDocWithExistingFieldValue_ReturnError(t *te Unique: true, Fields: []client.IndexedFieldDescription{ { - Name: "age", - Direction: client.Ascending, + Name: "age", }, }, }, @@ -177,8 +176,7 @@ func TestUniqueIndexCreate_IfFieldValuesAreUnique_Succeed(t *testing.T) { Unique: true, Fields: []client.IndexedFieldDescription{ { - Name: "age", - Direction: client.Ascending, + Name: "age", }, }, }, @@ -239,8 +237,7 @@ func TestUniqueIndexCreate_WithMultipleNilFields_ShouldSucceed(t *testing.T) { Unique: true, Fields: []client.IndexedFieldDescription{ { - Name: "age", - Direction: client.Ascending, + Name: "age", }, }, }, diff --git a/tests/integration/index/query_with_composite_index_field_order_test.go b/tests/integration/index/query_with_composite_index_field_order_test.go new file mode 100644 index 0000000000..d3b1beee16 --- /dev/null +++ b/tests/integration/index/query_with_composite_index_field_order_test.go @@ -0,0 +1,435 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package index + +import ( + "testing" + + testUtils "github.com/sourcenetwork/defradb/tests/integration" +) + +func TestQueryWithCompositeIndex_WithDefaultOrder_ShouldFetchInDefaultOrder(t *testing.T) { + test := testUtils.TestCase{ + Description: "Test composite index in default order", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type User @index(fields: ["name", "age"]) { + name: String + age: Int + }`, + }, + testUtils.CreateDoc{ + CollectionID: 0, + Doc: ` + { + "name": "Alice", + "age": 22 + }`, + }, + testUtils.CreateDoc{ + CollectionID: 0, + Doc: ` + { + "name": "Alan", + "age": 29 + }`, + }, + testUtils.CreateDoc{ + CollectionID: 0, + Doc: ` + { + "name": "Alice", + "age": 38 + }`, + }, + testUtils.CreateDoc{ + CollectionID: 0, + Doc: ` + { + "name": "Alice", + "age": 24 + }`, + }, + testUtils.Request{ + Request: ` + query { + User(filter: {name: {_like: "Al%"}}) { + name + age + } + }`, + Results: []map[string]any{ + { + "name": "Alan", + "age": 29, + }, + { + "name": "Alice", + "age": 22, + }, + { + "name": "Alice", + "age": 24, + }, + { + "name": "Alice", + "age": 38, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQueryWithCompositeIndex_WithRevertedOrderOnFirstField_ShouldFetchInRevertedOrder(t *testing.T) { + test := testUtils.TestCase{ + Description: "Test composite index with reverted order on first field", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type User @index(fields: ["name", "age"], directions: [DESC, ASC]) { + name: String + age: Int + }`, + }, + testUtils.CreateDoc{ + CollectionID: 0, + Doc: ` + { + "name": "Alice", + "age": 22 + }`, + }, + testUtils.CreateDoc{ + CollectionID: 0, + Doc: ` + { + "name": "Alan", + "age": 29 + }`, + }, + testUtils.CreateDoc{ + CollectionID: 0, + Doc: ` + { + "name": "Alice", + "age": 38 + }`, + }, + testUtils.CreateDoc{ + CollectionID: 0, + Doc: ` + { + "name": "Andy", + "age": 24 + }`, + }, + testUtils.CreateDoc{ + CollectionID: 0, + Doc: ` + { + "name": "Alice", + "age": 24 + }`, + }, + testUtils.Request{ + Request: ` + query { + User(filter: {name: {_like: "A%"}}) { + name + age + } + }`, + Results: []map[string]any{ + { + "name": "Andy", + "age": 24, + }, + { + "name": "Alice", + "age": 22, + }, + { + "name": "Alice", + "age": 24, + }, + { + "name": "Alice", + "age": 38, + }, + { + "name": "Alan", + "age": 29, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQueryWithCompositeIndex_WithRevertedOrderOnSecondField_ShouldFetchInRevertedOrder(t *testing.T) { + test := testUtils.TestCase{ + Description: "Test composite index with reverted order on second field", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type User @index(fields: ["name", "age"], directions: [ASC, DESC]) { + name: String + age: Int + }`, + }, + testUtils.CreateDoc{ + CollectionID: 0, + Doc: ` + { + "name": "Alice", + "age": 22 + }`, + }, + testUtils.CreateDoc{ + CollectionID: 0, + Doc: ` + { + "name": "Alan", + "age": 29 + }`, + }, + testUtils.CreateDoc{ + CollectionID: 0, + Doc: ` + { + "name": "Alice", + "age": 38 + }`, + }, + testUtils.CreateDoc{ + CollectionID: 0, + Doc: ` + { + "name": "Alice", + "age": 24 + }`, + }, + testUtils.Request{ + Request: ` + query { + User(filter: {name: {_like: "Al%"}}) { + name + age + } + }`, + Results: []map[string]any{ + { + "name": "Alan", + "age": 29, + }, + { + "name": "Alice", + "age": 38, + }, + { + "name": "Alice", + "age": 24, + }, + { + "name": "Alice", + "age": 22, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQueryWithCompositeIndex_IfExactMatchWithRevertedOrderOnFirstField_ShouldFetch(t *testing.T) { + test := testUtils.TestCase{ + Description: "Test composite index with reverted order on first field and filter with exact match", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type User @index(fields: ["name", "age"], directions: [DESC, ASC]) { + name: String + age: Int + }`, + }, + testUtils.CreateDoc{ + CollectionID: 0, + Doc: ` + { + "name": "Alice", + "age": 38 + }`, + }, + testUtils.CreateDoc{ + CollectionID: 0, + Doc: ` + { + "name": "Alice", + "age": 22 + }`, + }, + testUtils.CreateDoc{ + CollectionID: 0, + Doc: ` + { + "name": "Alan", + "age": 29 + }`, + }, + testUtils.Request{ + Request: ` + query { + User(filter: {name: {_eq: "Alice"}, age: {_eq: 22}}) { + name + age + } + }`, + Results: []map[string]any{ + { + "name": "Alice", + "age": 22, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQueryWithCompositeIndex_IfExactMatchWithRevertedOrderOnSecondField_ShouldFetch(t *testing.T) { + test := testUtils.TestCase{ + Description: "Test composite index with reverted order on second field and filter with exact match", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type User @index(fields: ["name", "age"], directions: [ASC, DESC]) { + name: String + age: Int + }`, + }, + testUtils.CreateDoc{ + CollectionID: 0, + Doc: ` + { + "name": "Alice", + "age": 38 + }`, + }, + testUtils.CreateDoc{ + CollectionID: 0, + Doc: ` + { + "name": "Alice", + "age": 22 + }`, + }, + testUtils.CreateDoc{ + CollectionID: 0, + Doc: ` + { + "name": "Alan", + "age": 29 + }`, + }, + testUtils.Request{ + Request: ` + query { + User(filter: {name: {_eq: "Alice"}, age: {_eq: 22}}) { + name + age + } + }`, + Results: []map[string]any{ + { + "name": "Alice", + "age": 22, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQueryWithCompositeIndex_WithInFilterOnFirstFieldWithRevertedOrder_ShouldFetch(t *testing.T) { + test := testUtils.TestCase{ + Description: "Test composite index with reverted order on first field and filtering with _in filter", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type User @index(fields: ["name", "age"], directions: [DESC, ASC]) { + name: String + age: Int + email: String + }`, + }, + testUtils.CreatePredefinedDocs{ + Docs: getUserDocs(), + }, + testUtils.Request{ + Request: `query { + User(filter: {name: {_in: ["Addo", "Andy", "Fred"]}}) { + name + } + }`, + Results: []map[string]any{ + {"name": "Addo"}, + {"name": "Andy"}, + {"name": "Fred"}, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQueryWithCompositeIndex_WithInFilterOnSecondFieldWithRevertedOrder_ShouldFetch(t *testing.T) { + test := testUtils.TestCase{ + Description: "Test composite index with reverted order on second field and filtering with _in filter", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type User @index(fields: ["name", "age"], directions: [ASC, DESC]) { + name: String + age: Int + email: String + }`, + }, + testUtils.CreatePredefinedDocs{ + Docs: getUserDocs(), + }, + testUtils.Request{ + Request: `query { + User(filter: {age: {_in: [20, 28, 33]}}) { + name + } + }`, + Results: []map[string]any{ + {"name": "Shahzad"}, + {"name": "Andy"}, + {"name": "Fred"}, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} diff --git a/tests/integration/index/query_with_composite_index_only_filter_test.go b/tests/integration/index/query_with_composite_index_only_filter_test.go index dfcc7d1299..bf7e8b17c3 100644 --- a/tests/integration/index/query_with_composite_index_only_filter_test.go +++ b/tests/integration/index/query_with_composite_index_only_filter_test.go @@ -211,8 +211,8 @@ func TestQueryWithCompositeIndex_WithGreaterOrEqualFilterOnSecondField_ShouldFet testUtils.Request{ Request: req, Results: []map[string]any{ - {"name": "Roy"}, {"name": "Chris"}, + {"name": "Roy"}, }, }, testUtils.Request{ @@ -394,13 +394,13 @@ func TestQueryWithCompositeIndex_WithNotEqualFilter_ShouldFetch(t *testing.T) { testUtils.Request{ Request: req, Results: []map[string]any{ - {"name": "Roy"}, {"name": "Addo"}, {"name": "Andy"}, - {"name": "John"}, {"name": "Bruno"}, {"name": "Chris"}, + {"name": "John"}, {"name": "Keenan"}, + {"name": "Roy"}, {"name": "Shahzad"}, }, }, @@ -474,9 +474,9 @@ func TestQueryWithCompositeIndex_WithNotInFilter_ShouldFetch(t *testing.T) { testUtils.Request{ Request: req, Results: []map[string]any{ - {"name": "Roy"}, {"name": "Islam"}, {"name": "Keenan"}, + {"name": "Roy"}, }, }, testUtils.Request{ @@ -624,10 +624,10 @@ func TestQueryWithCompositeIndex_WithNotLikeFilter_ShouldFetch(t *testing.T) { testUtils.Request{ Request: req, Results: []map[string]any{ - {"name": "Roy"}, {"name": "Bruno"}, {"name": "Islam"}, {"name": "Keenan"}, + {"name": "Roy"}, }, }, testUtils.Request{ diff --git a/tests/integration/index/query_with_index_only_field_order_test.go b/tests/integration/index/query_with_index_only_field_order_test.go new file mode 100644 index 0000000000..ae46213533 --- /dev/null +++ b/tests/integration/index/query_with_index_only_field_order_test.go @@ -0,0 +1,180 @@ +// Copyright 2023 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package index + +import ( + "testing" + + testUtils "github.com/sourcenetwork/defradb/tests/integration" +) + +func TestQueryWithIndex_IfIntFieldInDescOrder_ShouldFetchInRevertedOrder(t *testing.T) { + test := testUtils.TestCase{ + Description: "If indexed int field is in DESC order, it should be fetched in reverted order", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type User { + name: String + age: Int @index(direction: DESC) + }`, + }, + testUtils.CreateDoc{ + Doc: ` + { + "name": "Alice", + "age": 22 + }`, + }, + testUtils.CreateDoc{ + Doc: ` + { + "name": "Bob", + "age": 24 + }`, + }, + testUtils.CreateDoc{ + Doc: ` + { + "name": "Kate", + "age": 23 + }`, + }, + testUtils.Request{ + Request: ` + query { + User(filter: {age: {_gt: 1}}) { + name + age + } + }`, + Results: []map[string]any{{ + "name": "Bob", + "age": 24, + }, { + "name": "Kate", + "age": 23, + }, { + "name": "Alice", + "age": 22, + }}, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQueryWithIndex_IfFloatFieldInDescOrder_ShouldFetchInRevertedOrder(t *testing.T) { + test := testUtils.TestCase{ + Description: "If indexed float field is in DESC order, it should be fetched in reverted order", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type User { + name: String + iq: Float @index(direction: DESC) + }`, + }, + testUtils.CreateDoc{ + Doc: ` + { + "name": "Alice", + "iq": 0.2 + }`, + }, + testUtils.CreateDoc{ + Doc: ` + { + "name": "Bob", + "iq": 0.4 + }`, + }, + testUtils.CreateDoc{ + Doc: ` + { + "name": "Kate", + "iq": 0.3 + }`, + }, + testUtils.Request{ + Request: ` + query { + User(filter: {iq: {_gt: 1}}) { + name + iq + } + }`, + Results: []map[string]any{{ + "name": "Bob", + "iq": 0.4, + }, { + "name": "Kate", + "iq": 0.3, + }, { + "name": "Alice", + "iq": 0.2, + }}, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQueryWithIndex_IfStringFieldInDescOrder_ShouldFetchInRevertedOrder(t *testing.T) { + test := testUtils.TestCase{ + Description: "If indexed string field is in DESC order, it should be fetched in reverted order", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type User { + name: String @index(direction: DESC) + }`, + }, + testUtils.CreateDoc{ + Doc: ` + { + "name": "Alice" + }`, + }, + testUtils.CreateDoc{ + Doc: ` + { + "name": "Aaron" + }`, + }, + testUtils.CreateDoc{ + Doc: ` + { + "name": "Andy" + }`, + }, + testUtils.Request{ + Request: ` + query { + User(filter: {name: {_like: "A%"}}) { + name + } + }`, + Results: []map[string]any{{ + "name": "Andy", + }, { + "name": "Alice", + }, { + "name": "Aaron", + }}, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} diff --git a/tests/integration/index/query_with_index_only_filter_test.go b/tests/integration/index/query_with_index_only_filter_test.go index 9615461757..0c2c337398 100644 --- a/tests/integration/index/query_with_index_only_filter_test.go +++ b/tests/integration/index/query_with_index_only_filter_test.go @@ -294,14 +294,14 @@ func TestQueryWithIndex_WithNotEqualFilter_ShouldFetch(t *testing.T) { testUtils.Request{ Request: req, Results: []map[string]any{ - {"name": "Roy"}, {"name": "Addo"}, {"name": "Andy"}, - {"name": "Fred"}, - {"name": "John"}, {"name": "Bruno"}, {"name": "Chris"}, + {"name": "Fred"}, + {"name": "John"}, {"name": "Keenan"}, + {"name": "Roy"}, {"name": "Shahzad"}, }, }, @@ -566,13 +566,13 @@ func TestQueryWithIndex_WithNotLikeFilter_ShouldFetch(t *testing.T) { testUtils.Request{ Request: req, Results: []map[string]any{ - {"name": "Roy"}, {"name": "Addo"}, {"name": "Andy"}, - {"name": "Fred"}, {"name": "Bruno"}, + {"name": "Fred"}, {"name": "Islam"}, {"name": "Keenan"}, + {"name": "Roy"}, }, }, testUtils.Request{ diff --git a/tests/integration/index/query_with_unique_composite_index_filter_test.go b/tests/integration/index/query_with_unique_composite_index_filter_test.go index 744cb76615..8d37c141d4 100644 --- a/tests/integration/index/query_with_unique_composite_index_filter_test.go +++ b/tests/integration/index/query_with_unique_composite_index_filter_test.go @@ -229,8 +229,8 @@ func TestQueryWithUniqueCompositeIndex_WithGreaterOrEqualFilterOnSecondField_Sho testUtils.Request{ Request: req, Results: []map[string]any{ - {"name": "Roy"}, {"name": "Chris"}, + {"name": "Roy"}, }, }, testUtils.Request{ @@ -412,13 +412,13 @@ func TestQueryWithUniqueCompositeIndex_WithNotEqualFilter_ShouldFetch(t *testing testUtils.Request{ Request: req, Results: []map[string]any{ - {"name": "Roy"}, {"name": "Addo"}, {"name": "Andy"}, - {"name": "John"}, {"name": "Bruno"}, {"name": "Chris"}, + {"name": "John"}, {"name": "Keenan"}, + {"name": "Roy"}, {"name": "Shahzad"}, }, }, @@ -583,9 +583,9 @@ func TestQueryWithUniqueCompositeIndex_WithNotInFilter_ShouldFetch(t *testing.T) testUtils.Request{ Request: req, Results: []map[string]any{ - {"name": "Roy"}, {"name": "Islam"}, {"name": "Keenan"}, + {"name": "Roy"}, }, }, testUtils.Request{ @@ -733,10 +733,10 @@ func TestQueryWithUniqueCompositeIndex_WithNotLikeFilter_ShouldFetch(t *testing. testUtils.Request{ Request: req, Results: []map[string]any{ - {"name": "Roy"}, {"name": "Bruno"}, {"name": "Islam"}, {"name": "Keenan"}, + {"name": "Roy"}, }, }, testUtils.Request{ @@ -1083,9 +1083,9 @@ func TestQueryWithUniqueCompositeIndex_WithMultipleNilOnBothFieldsAndNilFilter_S } }`, Results: []map[string]any{ - {"about": "nil_22"}, {"about": "nil_nil_2"}, {"about": "nil_nil_1"}, + {"about": "nil_22"}, }, }, testUtils.Request{ @@ -1218,8 +1218,8 @@ func TestQueryWithUniqueCompositeIndex_AfterUpdateOnNilFields_ShouldFetch(t *tes } }`, Results: []map[string]any{ - {"about": "nil_nil -> nil_22"}, {"about": "bob_nil -> nil_nil"}, + {"about": "nil_nil -> nil_22"}, }, }, testUtils.Request{ diff --git a/tests/integration/index/query_with_unique_index_only_filter_test.go b/tests/integration/index/query_with_unique_index_only_filter_test.go index 002022b058..c9bcd027a9 100644 --- a/tests/integration/index/query_with_unique_index_only_filter_test.go +++ b/tests/integration/index/query_with_unique_index_only_filter_test.go @@ -214,14 +214,14 @@ func TestQueryWithUniqueIndex_WithNotEqualFilter_ShouldFetch(t *testing.T) { testUtils.Request{ Request: req, Results: []map[string]any{ - {"name": "Roy"}, {"name": "Addo"}, {"name": "Andy"}, - {"name": "Fred"}, - {"name": "John"}, {"name": "Bruno"}, {"name": "Chris"}, + {"name": "Fred"}, + {"name": "John"}, {"name": "Keenan"}, + {"name": "Roy"}, {"name": "Shahzad"}, }, }, @@ -443,13 +443,13 @@ func TestQueryWithUniqueIndex_WithNotLikeFilter_ShouldFetch(t *testing.T) { testUtils.Request{ Request: req, Results: []map[string]any{ - {"name": "Roy"}, {"name": "Addo"}, {"name": "Andy"}, - {"name": "Fred"}, {"name": "Bruno"}, + {"name": "Fred"}, {"name": "Islam"}, {"name": "Keenan"}, + {"name": "Roy"}, }, }, testUtils.Request{ diff --git a/tests/integration/test_case.go b/tests/integration/test_case.go index 346f5abec7..ce6e456fbb 100644 --- a/tests/integration/test_case.go +++ b/tests/integration/test_case.go @@ -261,6 +261,14 @@ type UpdateDoc struct { DontSync bool } +// IndexField describes a field to be indexed. +type IndexedField struct { + // Name contains the name of the field. + Name string + // Descending indicates whether the field is indexed in descending order. + Descending bool +} + // CreateIndex will attempt to create the given secondary index for the given collection // using the collection api. type CreateIndex struct { @@ -278,11 +286,8 @@ type CreateIndex struct { // The name of the field to index. Used only for single field indexes. FieldName string - // The names of the fields to index. Used only for composite indexes. - FieldsNames []string - // The directions of the 'FieldsNames' to index. Used only for composite indexes. - // If not provided all fields will be indexed in ascending order. - Directions []client.IndexDirection + // The fields to index. Used only for composite indexes. + Fields []IndexedField // If Unique is true, the index will be created as a unique index. Unique bool diff --git a/tests/integration/utils2.go b/tests/integration/utils2.go index 016559280d..d5cdcbd01d 100644 --- a/tests/integration/utils2.go +++ b/tests/integration/utils2.go @@ -1380,15 +1380,11 @@ func createIndex( Name: action.FieldName, }, } - } else if len(action.FieldsNames) > 0 { - for i := range action.FieldsNames { - dir := client.Ascending - if len(action.Directions) > i { - dir = action.Directions[i] - } + } else if len(action.Fields) > 0 { + for i := range action.Fields { indexDesc.Fields = append(indexDesc.Fields, client.IndexedFieldDescription{ - Name: action.FieldsNames[i], - Direction: dir, + Name: action.Fields[i].Name, + Descending: action.Fields[i].Descending, }) } } diff --git a/tools/configs/golangci.yaml b/tools/configs/golangci.yaml index 561cfd7138..1b6b76aa37 100644 --- a/tools/configs/golangci.yaml +++ b/tools/configs/golangci.yaml @@ -149,7 +149,7 @@ issues: - errorlint # Exclude running header check in these paths - - path: "(net|datastore/badger/v4/compat_logger.go|datastore/badger/v4/datastore.go|connor)" + - path: "(net|datastore/badger/v4/compat_logger.go|datastore/badger/v4/datastore.go|connor|encoding)" linters: - goheader