diff --git a/docs/data_format_changes/i2696-support-encryption-for-counters.md b/docs/data_format_changes/i2696-support-encryption-for-counters.md new file mode 100644 index 0000000000..dd53e57898 --- /dev/null +++ b/docs/data_format_changes/i2696-support-encryption-for-counters.md @@ -0,0 +1,2 @@ +# Support encryption for counters +We changed the data format of counters from int64 and float64 to bytes to support encryption. This changes the generated CIDs for counters. \ No newline at end of file diff --git a/internal/core/block/block.go b/internal/core/block/block.go index 6be17908be..c9a3f629c2 100644 --- a/internal/core/block/block.go +++ b/internal/core/block/block.go @@ -40,7 +40,7 @@ func init() { &crdt.CRDT{}, &crdt.LWWRegDelta{}, &crdt.CompositeDAGDelta{}, - &crdt.CounterDelta[int64]{}, // Only need to call one of the CounterDelta types. + &crdt.CounterDelta{}, ) } @@ -149,10 +149,8 @@ func New(delta core.Delta, links []DAGLink, heads ...cid.Cid) *Block { crdtDelta = crdt.CRDT{LWWRegDelta: delta} case *crdt.CompositeDAGDelta: crdtDelta = crdt.CRDT{CompositeDAGDelta: delta} - case *crdt.CounterDelta[int64]: - crdtDelta = crdt.CRDT{CounterDeltaInt: delta} - case *crdt.CounterDelta[float64]: - crdtDelta = crdt.CRDT{CounterDeltaFloat: delta} + case *crdt.CounterDelta: + crdtDelta = crdt.CRDT{CounterDelta: delta} } return &Block{ diff --git a/internal/core/crdt/counter.go b/internal/core/crdt/counter.go index c87c7d6da6..4aa9a40793 100644 --- a/internal/core/crdt/counter.go +++ b/internal/core/crdt/counter.go @@ -33,7 +33,7 @@ type Incrementable interface { } // CounterDelta is a single delta operation for a Counter -type CounterDelta[T Incrementable] struct { +type CounterDelta struct { DocID []byte FieldName string Priority uint64 @@ -44,69 +44,60 @@ type CounterDelta[T Incrementable] struct { // // It can be used to identify the collection datastructure state at the time of commit. SchemaVersionID string - Data T + Data []byte } -var _ core.Delta = (*CounterDelta[float64])(nil) -var _ core.Delta = (*CounterDelta[int64])(nil) +var _ core.Delta = (*CounterDelta)(nil) // IPLDSchemaBytes returns the IPLD schema representation for the type. // -// This needs to match the [CounterDelta[T]] struct or [coreblock.mustSetSchema] will panic on init. -func (delta *CounterDelta[T]) IPLDSchemaBytes() []byte { +// This needs to match the [CounterDelta] struct or [coreblock.mustSetSchema] will panic on init. +func (delta *CounterDelta) IPLDSchemaBytes() []byte { return []byte(` - type CounterDeltaFloat struct { + type CounterDelta struct { docID Bytes fieldName String priority Int nonce Int schemaVersionID String - data Float - } - - type CounterDeltaInt struct { - docID Bytes - fieldName String - priority Int - nonce Int - schemaVersionID String - data Int + data Bytes }`) } // GetPriority gets the current priority for this delta. -func (delta *CounterDelta[T]) GetPriority() uint64 { +func (delta *CounterDelta) GetPriority() uint64 { return delta.Priority } // SetPriority will set the priority for this delta. -func (delta *CounterDelta[T]) SetPriority(prio uint64) { +func (delta *CounterDelta) SetPriority(prio uint64) { delta.Priority = prio } // Counter, is a simple CRDT type that allows increment/decrement // of an Int and Float data types that ensures convergence. -type Counter[T Incrementable] struct { +type Counter struct { baseCRDT AllowDecrement bool + Kind client.ScalarKind } -var _ core.ReplicatedData = (*Counter[float64])(nil) -var _ core.ReplicatedData = (*Counter[int64])(nil) +var _ core.ReplicatedData = (*Counter)(nil) // NewCounter returns a new instance of the Counter with the given ID. -func NewCounter[T Incrementable]( +func NewCounter( store datastore.DSReaderWriter, schemaVersionKey core.CollectionSchemaVersionKey, key core.DataStoreKey, fieldName string, allowDecrement bool, -) Counter[T] { - return Counter[T]{newBaseCRDT(store, key, schemaVersionKey, fieldName), allowDecrement} + kind client.ScalarKind, +) Counter { + return Counter{newBaseCRDT(store, key, schemaVersionKey, fieldName), allowDecrement, kind} } // Value gets the current counter value -func (c Counter[T]) Value(ctx context.Context) ([]byte, error) { +func (c Counter) Value(ctx context.Context) ([]byte, error) { valueK := c.key.WithValueFlag() buf, err := c.store.Get(ctx, valueK.ToDS()) if err != nil { @@ -120,7 +111,7 @@ func (c Counter[T]) Value(ctx context.Context) ([]byte, error) { // WARNING: Incrementing an integer and causing it to overflow the int64 max value // will cause the value to roll over to the int64 min value. Incremeting a float and // causing it to overflow the float64 max value will act like a no-op. -func (c Counter[T]) Increment(ctx context.Context, value T) (*CounterDelta[T], error) { +func (c Counter) Increment(ctx context.Context, value []byte) (*CounterDelta, error) { // To ensure that the dag block is unique, we add a random number to the delta. // This is done only on update (if the doc doesn't already exist) to ensure that the // initial dag block of a document can be reproducible. @@ -137,7 +128,7 @@ func (c Counter[T]) Increment(ctx context.Context, value T) (*CounterDelta[T], e nonce = r.Int64() } - return &CounterDelta[T]{ + return &CounterDelta{ DocID: []byte(c.key.DocID), FieldName: c.fieldName, Data: value, @@ -148,8 +139,8 @@ func (c Counter[T]) Increment(ctx context.Context, value T) (*CounterDelta[T], e // Merge implements ReplicatedData interface. // It merges two CounterRegisty by adding the values together. -func (c Counter[T]) Merge(ctx context.Context, delta core.Delta) error { - d, ok := delta.(*CounterDelta[T]) +func (c Counter) Merge(ctx context.Context, delta core.Delta) error { + d, ok := delta.(*CounterDelta) if !ok { return ErrMismatchedMergeType } @@ -157,10 +148,11 @@ func (c Counter[T]) Merge(ctx context.Context, delta core.Delta) error { return c.incrementValue(ctx, d.Data, d.GetPriority()) } -func (c Counter[T]) incrementValue(ctx context.Context, value T, priority uint64) error { - if !c.AllowDecrement && value < 0 { - return NewErrNegativeValue(value) - } +func (c Counter) incrementValue( + ctx context.Context, + valueAsBytes []byte, + priority uint64, +) error { key := c.key.WithValueFlag() marker, err := c.store.Get(ctx, c.key.ToPrimaryDataStoreKey().ToDS()) if err != nil && !errors.Is(err, ds.ErrNotFound) { @@ -170,27 +162,69 @@ func (c Counter[T]) incrementValue(ctx context.Context, value T, priority uint64 key = key.WithDeletedFlag() } - curValue, err := c.getCurrentValue(ctx, key) + var resultAsBytes []byte + + switch c.Kind { + case client.FieldKind_NILLABLE_INT: + resultAsBytes, err = validateAndIncrement[int64](ctx, c.store, key, valueAsBytes, c.AllowDecrement) + if err != nil { + return err + } + case client.FieldKind_NILLABLE_FLOAT: + resultAsBytes, err = validateAndIncrement[float64](ctx, c.store, key, valueAsBytes, c.AllowDecrement) + if err != nil { + return err + } + default: + return NewErrUnsupportedCounterType(c.Kind) + } + + err = c.store.Put(ctx, key.ToDS(), resultAsBytes) if err != nil { - return err + return NewErrFailedToStoreValue(err) } - newValue := curValue + value - b, err := cbor.Marshal(newValue) + return c.setPriority(ctx, c.key, priority) +} + +func (c Counter) CType() client.CType { + if c.AllowDecrement { + return client.PN_COUNTER + } + return client.P_COUNTER +} + +func validateAndIncrement[T Incrementable]( + ctx context.Context, + store datastore.DSReaderWriter, + key core.DataStoreKey, + valueAsBytes []byte, + allowDecrement bool, +) ([]byte, error) { + value, err := getNumericFromBytes[T](valueAsBytes) if err != nil { - return err + return nil, err } - err = c.store.Put(ctx, key.ToDS(), b) + if !allowDecrement && value < 0 { + return nil, NewErrNegativeValue(value) + } + + curValue, err := getCurrentValue[T](ctx, store, key) if err != nil { - return NewErrFailedToStoreValue(err) + return nil, err } - return c.setPriority(ctx, c.key, priority) + newValue := curValue + value + return cbor.Marshal(newValue) } -func (c Counter[T]) getCurrentValue(ctx context.Context, key core.DataStoreKey) (T, error) { - curValue, err := c.store.Get(ctx, key.ToDS()) +func getCurrentValue[T Incrementable]( + ctx context.Context, + store datastore.DSReaderWriter, + key core.DataStoreKey, +) (T, error) { + curValue, err := store.Get(ctx, key.ToDS()) if err != nil { if errors.Is(err, ds.ErrNotFound) { return 0, nil @@ -201,13 +235,6 @@ func (c Counter[T]) getCurrentValue(ctx context.Context, key core.DataStoreKey) return getNumericFromBytes[T](curValue) } -func (c Counter[T]) CType() client.CType { - if c.AllowDecrement { - return client.PN_COUNTER - } - return client.P_COUNTER -} - func getNumericFromBytes[T Incrementable](b []byte) (T, error) { var val T err := cbor.Unmarshal(b, &val) diff --git a/internal/core/crdt/errors.go b/internal/core/crdt/errors.go index 75af579850..43bc9c565c 100644 --- a/internal/core/crdt/errors.go +++ b/internal/core/crdt/errors.go @@ -11,13 +11,15 @@ package crdt import ( + "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/errors" ) const ( - errFailedToGetPriority string = "failed to get priority" - errFailedToStoreValue string = "failed to store value" - errNegativeValue string = "value cannot be negative" + errFailedToGetPriority string = "failed to get priority" + errFailedToStoreValue string = "failed to store value" + errNegativeValue string = "value cannot be negative" + errUnsupportedCounterType string = "unsupported counter type. Valid types are int64 and float64" ) // Errors returnable from this package. @@ -31,7 +33,8 @@ var ( ErrEncodingPriority = errors.New("error encoding priority") ErrDecodingPriority = errors.New("error decoding priority") // ErrMismatchedMergeType - Tying to merge two ReplicatedData of different types - ErrMismatchedMergeType = errors.New("given type to merge does not match source") + ErrMismatchedMergeType = errors.New("given type to merge does not match source") + ErrUnsupportedCounterType = errors.New(errUnsupportedCounterType) ) // NewErrFailedToGetPriority returns an error indicating that the priority could not be retrieved. @@ -47,3 +50,7 @@ func NewErrFailedToStoreValue(inner error) error { func NewErrNegativeValue[T Incrementable](value T) error { return errors.New(errNegativeValue, errors.NewKV("Value", value)) } + +func NewErrUnsupportedCounterType(valueType client.ScalarKind) error { + return errors.New(errUnsupportedCounterType, errors.NewKV("Type", valueType)) +} diff --git a/internal/core/crdt/ipld_union.go b/internal/core/crdt/ipld_union.go index 5d4cfc9f9e..361a41b150 100644 --- a/internal/core/crdt/ipld_union.go +++ b/internal/core/crdt/ipld_union.go @@ -16,8 +16,7 @@ import "github.com/sourcenetwork/defradb/internal/core" type CRDT struct { LWWRegDelta *LWWRegDelta CompositeDAGDelta *CompositeDAGDelta - CounterDeltaInt *CounterDelta[int64] - CounterDeltaFloat *CounterDelta[float64] + CounterDelta *CounterDelta } // IPLDSchemaBytes returns the IPLD schema representation for the CRDT. @@ -28,8 +27,7 @@ func (c CRDT) IPLDSchemaBytes() []byte { type CRDT union { | LWWRegDelta "lww" | CompositeDAGDelta "composite" - | CounterDeltaInt "counterInt" - | CounterDeltaFloat "counterFloat" + | CounterDelta "counter" } representation keyed`) } @@ -40,10 +38,8 @@ func (c CRDT) GetDelta() core.Delta { return c.LWWRegDelta case c.CompositeDAGDelta != nil: return c.CompositeDAGDelta - case c.CounterDeltaFloat != nil: - return c.CounterDeltaFloat - case c.CounterDeltaInt != nil: - return c.CounterDeltaInt + case c.CounterDelta != nil: + return c.CounterDelta } return nil } @@ -55,10 +51,8 @@ func (c CRDT) GetPriority() uint64 { return c.LWWRegDelta.GetPriority() case c.CompositeDAGDelta != nil: return c.CompositeDAGDelta.GetPriority() - case c.CounterDeltaFloat != nil: - return c.CounterDeltaFloat.GetPriority() - case c.CounterDeltaInt != nil: - return c.CounterDeltaInt.GetPriority() + case c.CounterDelta != nil: + return c.CounterDelta.GetPriority() } return 0 } @@ -70,10 +64,8 @@ func (c CRDT) GetFieldName() string { return c.LWWRegDelta.FieldName case c.CompositeDAGDelta != nil: return c.CompositeDAGDelta.FieldName - case c.CounterDeltaFloat != nil: - return c.CounterDeltaFloat.FieldName - case c.CounterDeltaInt != nil: - return c.CounterDeltaInt.FieldName + case c.CounterDelta != nil: + return c.CounterDelta.FieldName } return "" } @@ -85,10 +77,8 @@ func (c CRDT) GetDocID() []byte { return c.LWWRegDelta.DocID case c.CompositeDAGDelta != nil: return c.CompositeDAGDelta.DocID - case c.CounterDeltaFloat != nil: - return c.CounterDeltaFloat.DocID - case c.CounterDeltaInt != nil: - return c.CounterDeltaInt.DocID + case c.CounterDelta != nil: + return c.CounterDelta.DocID } return nil } @@ -100,10 +90,8 @@ func (c CRDT) GetSchemaVersionID() string { return c.LWWRegDelta.SchemaVersionID case c.CompositeDAGDelta != nil: return c.CompositeDAGDelta.SchemaVersionID - case c.CounterDeltaFloat != nil: - return c.CounterDeltaFloat.SchemaVersionID - case c.CounterDeltaInt != nil: - return c.CounterDeltaInt.SchemaVersionID + case c.CounterDelta != nil: + return c.CounterDelta.SchemaVersionID } return "" } diff --git a/internal/merkle/crdt/counter.go b/internal/merkle/crdt/counter.go index 4501de326c..c5d3a7e8dd 100644 --- a/internal/merkle/crdt/counter.go +++ b/internal/merkle/crdt/counter.go @@ -22,37 +22,42 @@ import ( ) // MerkleCounter is a MerkleCRDT implementation of the Counter using MerkleClocks. -type MerkleCounter[T crdt.Incrementable] struct { +type MerkleCounter struct { *baseMerkleCRDT - reg crdt.Counter[T] + reg crdt.Counter } // NewMerkleCounter creates a new instance (or loaded from DB) of a MerkleCRDT // backed by a Counter CRDT. -func NewMerkleCounter[T crdt.Incrementable]( +func NewMerkleCounter( store Stores, schemaVersionKey core.CollectionSchemaVersionKey, key core.DataStoreKey, fieldName string, allowDecrement bool, -) *MerkleCounter[T] { - register := crdt.NewCounter[T](store.Datastore(), schemaVersionKey, key, fieldName, allowDecrement) + kind client.ScalarKind, +) *MerkleCounter { + register := crdt.NewCounter(store.Datastore(), schemaVersionKey, key, fieldName, allowDecrement, kind) clk := clock.NewMerkleClock(store.Headstore(), store.DAGstore(), key.ToHeadStoreKey(), register) base := &baseMerkleCRDT{clock: clk, crdt: register} - return &MerkleCounter[T]{ + return &MerkleCounter{ baseMerkleCRDT: base, reg: register, } } // Save the value of the Counter to the DAG. -func (mc *MerkleCounter[T]) Save(ctx context.Context, data any) (cidlink.Link, []byte, error) { +func (mc *MerkleCounter) Save(ctx context.Context, data any) (cidlink.Link, []byte, error) { value, ok := data.(*client.FieldValue) if !ok { return cidlink.Link{}, nil, NewErrUnexpectedValueType(mc.reg.CType(), &client.FieldValue{}, data) } - delta, err := mc.reg.Increment(ctx, value.Value().(T)) + bytes, err := value.Bytes() + if err != nil { + return cidlink.Link{}, nil, err + } + delta, err := mc.reg.Increment(ctx, bytes) if err != nil { return cidlink.Link{}, nil, err } diff --git a/internal/merkle/crdt/merklecrdt.go b/internal/merkle/crdt/merklecrdt.go index ed8452195f..abc0ffeb51 100644 --- a/internal/merkle/crdt/merklecrdt.go +++ b/internal/merkle/crdt/merklecrdt.go @@ -88,24 +88,14 @@ func InstanceWithStore( fieldName, ), nil case client.PN_COUNTER, client.P_COUNTER: - switch kind { - case client.FieldKind_NILLABLE_INT: - return NewMerkleCounter[int64]( - store, - schemaVersionKey, - key, - fieldName, - cType == client.PN_COUNTER, - ), nil - case client.FieldKind_NILLABLE_FLOAT: - return NewMerkleCounter[float64]( - store, - schemaVersionKey, - key, - fieldName, - cType == client.PN_COUNTER, - ), nil - } + return NewMerkleCounter( + store, + schemaVersionKey, + key, + fieldName, + cType == client.PN_COUNTER, + kind.(client.ScalarKind), + ), nil case client.COMPOSITE: return NewMerkleCompositeDAG( store, diff --git a/tests/integration/query/simple/with_cid_doc_id_test.go b/tests/integration/query/simple/with_cid_doc_id_test.go index 97791ce993..dcf0d1a1d1 100644 --- a/tests/integration/query/simple/with_cid_doc_id_test.go +++ b/tests/integration/query/simple/with_cid_doc_id_test.go @@ -324,7 +324,7 @@ func TestCidAndDocIDQuery_ContainsPNCounterWithIntKind_NoError(t *testing.T) { testUtils.Request{ Request: `query { Users ( - cid: "bafyreicsx7flfz4b6iwfmwgrnrnd2klxrbg6yojuffh4ia3lrrqcph5q7a", + cid: "bafyreienkinjn7cvsonvhs4tslqvmmcnezuu4aif57jn75cyp6i3vdvkpm", docID: "bae-d8cb53d4-ac5a-5c55-8306-64df633d400d" ) { name @@ -376,7 +376,7 @@ func TestCidAndDocIDQuery_ContainsPNCounterWithFloatKind_NoError(t *testing.T) { testUtils.Request{ Request: `query { Users ( - cid: "bafyreidwtowbnmdfshq3dptfdggzswtdftyh5374ohfcmqki4ad2wd4m64", + cid: "bafyreiceodj32fyhq3v7ryk6mmcjanwx3zr7ajl2k47w4setngmyx7nc3e", docID: "bae-d420ebcd-023a-5800-ae2e-8ea89442318e" ) { name @@ -423,7 +423,7 @@ func TestCidAndDocIDQuery_ContainsPCounterWithIntKind_NoError(t *testing.T) { testUtils.Request{ Request: `query { Users ( - cid: "bafyreifngcu76fxe3dtjee556hwymfjgsm3sqhxned4cykit5lcsyy3ope", + cid: "bafyreieypgt2mq43g4ute2hkzombdqw5v6wctleyxyy6vdkzitrfje636i", docID: "bae-d8cb53d4-ac5a-5c55-8306-64df633d400d" ) { name @@ -470,7 +470,7 @@ func TestCidAndDocIDQuery_ContainsPCounterWithFloatKind_NoError(t *testing.T) { testUtils.Request{ Request: `query { Users ( - cid: "bafyreigih3wl4ycq5lktczydbecvcvlmdsy5jzarx2l6hcqdcrqkoranny", + cid: "bafyreigb3ujvnxie7kwl53w4chiq6cjcyuhranchseo5gmx5i6vfje67da", docID: "bae-d420ebcd-023a-5800-ae2e-8ea89442318e" ) { name