Skip to content

Commit

Permalink
refactor: Change counters to support encryption (#2698)
Browse files Browse the repository at this point in the history
## Relevant issue(s)

Resolves #2696 

## Description
This PR changes the counter CRDTs to make them support the document
encryption feature. They previously stored their values as concrete
types (int64 and float64) instead of bytes. Storing them as bytes allow
them to be stored plainly or encrypted.
  • Loading branch information
fredcarle authored Jun 10, 2024
1 parent ea68087 commit 8d0c756
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 114 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Support encryption for counters
We changed the data format of counters from int64 and float64 to bytes to support encryption. This changes the generated CIDs for counters.
8 changes: 3 additions & 5 deletions internal/core/block/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func init() {
&crdt.CRDT{},
&crdt.LWWRegDelta{},
&crdt.CompositeDAGDelta{},
&crdt.CounterDelta[int64]{}, // Only need to call one of the CounterDelta types.
&crdt.CounterDelta{},
)
}

Expand Down Expand Up @@ -149,10 +149,8 @@ func New(delta core.Delta, links []DAGLink, heads ...cid.Cid) *Block {
crdtDelta = crdt.CRDT{LWWRegDelta: delta}
case *crdt.CompositeDAGDelta:
crdtDelta = crdt.CRDT{CompositeDAGDelta: delta}
case *crdt.CounterDelta[int64]:
crdtDelta = crdt.CRDT{CounterDeltaInt: delta}
case *crdt.CounterDelta[float64]:
crdtDelta = crdt.CRDT{CounterDeltaFloat: delta}
case *crdt.CounterDelta:
crdtDelta = crdt.CRDT{CounterDelta: delta}
}

return &Block{
Expand Down
129 changes: 78 additions & 51 deletions internal/core/crdt/counter.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type Incrementable interface {
}

// CounterDelta is a single delta operation for a Counter
type CounterDelta[T Incrementable] struct {
type CounterDelta struct {
DocID []byte
FieldName string
Priority uint64
Expand All @@ -44,69 +44,60 @@ type CounterDelta[T Incrementable] struct {
//
// It can be used to identify the collection datastructure state at the time of commit.
SchemaVersionID string
Data T
Data []byte
}

var _ core.Delta = (*CounterDelta[float64])(nil)
var _ core.Delta = (*CounterDelta[int64])(nil)
var _ core.Delta = (*CounterDelta)(nil)

// IPLDSchemaBytes returns the IPLD schema representation for the type.
//
// This needs to match the [CounterDelta[T]] struct or [coreblock.mustSetSchema] will panic on init.
func (delta *CounterDelta[T]) IPLDSchemaBytes() []byte {
// This needs to match the [CounterDelta] struct or [coreblock.mustSetSchema] will panic on init.
func (delta *CounterDelta) IPLDSchemaBytes() []byte {
return []byte(`
type CounterDeltaFloat struct {
type CounterDelta struct {
docID Bytes
fieldName String
priority Int
nonce Int
schemaVersionID String
data Float
}
type CounterDeltaInt struct {
docID Bytes
fieldName String
priority Int
nonce Int
schemaVersionID String
data Int
data Bytes
}`)
}

// GetPriority gets the current priority for this delta.
func (delta *CounterDelta[T]) GetPriority() uint64 {
func (delta *CounterDelta) GetPriority() uint64 {
return delta.Priority
}

// SetPriority will set the priority for this delta.
func (delta *CounterDelta[T]) SetPriority(prio uint64) {
func (delta *CounterDelta) SetPriority(prio uint64) {
delta.Priority = prio
}

// Counter, is a simple CRDT type that allows increment/decrement
// of an Int and Float data types that ensures convergence.
type Counter[T Incrementable] struct {
type Counter struct {
baseCRDT
AllowDecrement bool
Kind client.ScalarKind
}

var _ core.ReplicatedData = (*Counter[float64])(nil)
var _ core.ReplicatedData = (*Counter[int64])(nil)
var _ core.ReplicatedData = (*Counter)(nil)

// NewCounter returns a new instance of the Counter with the given ID.
func NewCounter[T Incrementable](
func NewCounter(
store datastore.DSReaderWriter,
schemaVersionKey core.CollectionSchemaVersionKey,
key core.DataStoreKey,
fieldName string,
allowDecrement bool,
) Counter[T] {
return Counter[T]{newBaseCRDT(store, key, schemaVersionKey, fieldName), allowDecrement}
kind client.ScalarKind,
) Counter {
return Counter{newBaseCRDT(store, key, schemaVersionKey, fieldName), allowDecrement, kind}
}

// Value gets the current counter value
func (c Counter[T]) Value(ctx context.Context) ([]byte, error) {
func (c Counter) Value(ctx context.Context) ([]byte, error) {
valueK := c.key.WithValueFlag()
buf, err := c.store.Get(ctx, valueK.ToDS())
if err != nil {
Expand All @@ -120,7 +111,7 @@ func (c Counter[T]) Value(ctx context.Context) ([]byte, error) {
// WARNING: Incrementing an integer and causing it to overflow the int64 max value
// will cause the value to roll over to the int64 min value. Incremeting a float and
// causing it to overflow the float64 max value will act like a no-op.
func (c Counter[T]) Increment(ctx context.Context, value T) (*CounterDelta[T], error) {
func (c Counter) Increment(ctx context.Context, value []byte) (*CounterDelta, error) {
// To ensure that the dag block is unique, we add a random number to the delta.
// This is done only on update (if the doc doesn't already exist) to ensure that the
// initial dag block of a document can be reproducible.
Expand All @@ -137,7 +128,7 @@ func (c Counter[T]) Increment(ctx context.Context, value T) (*CounterDelta[T], e
nonce = r.Int64()
}

return &CounterDelta[T]{
return &CounterDelta{
DocID: []byte(c.key.DocID),
FieldName: c.fieldName,
Data: value,
Expand All @@ -148,19 +139,20 @@ func (c Counter[T]) Increment(ctx context.Context, value T) (*CounterDelta[T], e

// Merge implements ReplicatedData interface.
// It merges two CounterRegisty by adding the values together.
func (c Counter[T]) Merge(ctx context.Context, delta core.Delta) error {
d, ok := delta.(*CounterDelta[T])
func (c Counter) Merge(ctx context.Context, delta core.Delta) error {
d, ok := delta.(*CounterDelta)
if !ok {
return ErrMismatchedMergeType
}

return c.incrementValue(ctx, d.Data, d.GetPriority())
}

func (c Counter[T]) incrementValue(ctx context.Context, value T, priority uint64) error {
if !c.AllowDecrement && value < 0 {
return NewErrNegativeValue(value)
}
func (c Counter) incrementValue(
ctx context.Context,
valueAsBytes []byte,
priority uint64,
) error {
key := c.key.WithValueFlag()
marker, err := c.store.Get(ctx, c.key.ToPrimaryDataStoreKey().ToDS())
if err != nil && !errors.Is(err, ds.ErrNotFound) {
Expand All @@ -170,27 +162,69 @@ func (c Counter[T]) incrementValue(ctx context.Context, value T, priority uint64
key = key.WithDeletedFlag()
}

curValue, err := c.getCurrentValue(ctx, key)
var resultAsBytes []byte

switch c.Kind {
case client.FieldKind_NILLABLE_INT:
resultAsBytes, err = validateAndIncrement[int64](ctx, c.store, key, valueAsBytes, c.AllowDecrement)
if err != nil {
return err
}
case client.FieldKind_NILLABLE_FLOAT:
resultAsBytes, err = validateAndIncrement[float64](ctx, c.store, key, valueAsBytes, c.AllowDecrement)
if err != nil {
return err
}
default:
return NewErrUnsupportedCounterType(c.Kind)
}

err = c.store.Put(ctx, key.ToDS(), resultAsBytes)
if err != nil {
return err
return NewErrFailedToStoreValue(err)
}

newValue := curValue + value
b, err := cbor.Marshal(newValue)
return c.setPriority(ctx, c.key, priority)
}

func (c Counter) CType() client.CType {
if c.AllowDecrement {
return client.PN_COUNTER
}
return client.P_COUNTER
}

func validateAndIncrement[T Incrementable](
ctx context.Context,
store datastore.DSReaderWriter,
key core.DataStoreKey,
valueAsBytes []byte,
allowDecrement bool,
) ([]byte, error) {
value, err := getNumericFromBytes[T](valueAsBytes)
if err != nil {
return err
return nil, err
}

err = c.store.Put(ctx, key.ToDS(), b)
if !allowDecrement && value < 0 {
return nil, NewErrNegativeValue(value)
}

curValue, err := getCurrentValue[T](ctx, store, key)
if err != nil {
return NewErrFailedToStoreValue(err)
return nil, err
}

return c.setPriority(ctx, c.key, priority)
newValue := curValue + value
return cbor.Marshal(newValue)
}

func (c Counter[T]) getCurrentValue(ctx context.Context, key core.DataStoreKey) (T, error) {
curValue, err := c.store.Get(ctx, key.ToDS())
func getCurrentValue[T Incrementable](
ctx context.Context,
store datastore.DSReaderWriter,
key core.DataStoreKey,
) (T, error) {
curValue, err := store.Get(ctx, key.ToDS())
if err != nil {
if errors.Is(err, ds.ErrNotFound) {
return 0, nil
Expand All @@ -201,13 +235,6 @@ func (c Counter[T]) getCurrentValue(ctx context.Context, key core.DataStoreKey)
return getNumericFromBytes[T](curValue)
}

func (c Counter[T]) CType() client.CType {
if c.AllowDecrement {
return client.PN_COUNTER
}
return client.P_COUNTER
}

func getNumericFromBytes[T Incrementable](b []byte) (T, error) {
var val T
err := cbor.Unmarshal(b, &val)
Expand Down
15 changes: 11 additions & 4 deletions internal/core/crdt/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
package crdt

import (
"github.com/sourcenetwork/defradb/client"
"github.com/sourcenetwork/defradb/errors"
)

const (
errFailedToGetPriority string = "failed to get priority"
errFailedToStoreValue string = "failed to store value"
errNegativeValue string = "value cannot be negative"
errFailedToGetPriority string = "failed to get priority"
errFailedToStoreValue string = "failed to store value"
errNegativeValue string = "value cannot be negative"
errUnsupportedCounterType string = "unsupported counter type. Valid types are int64 and float64"
)

// Errors returnable from this package.
Expand All @@ -31,7 +33,8 @@ var (
ErrEncodingPriority = errors.New("error encoding priority")
ErrDecodingPriority = errors.New("error decoding priority")
// ErrMismatchedMergeType - Tying to merge two ReplicatedData of different types
ErrMismatchedMergeType = errors.New("given type to merge does not match source")
ErrMismatchedMergeType = errors.New("given type to merge does not match source")
ErrUnsupportedCounterType = errors.New(errUnsupportedCounterType)
)

// NewErrFailedToGetPriority returns an error indicating that the priority could not be retrieved.
Expand All @@ -47,3 +50,7 @@ func NewErrFailedToStoreValue(inner error) error {
func NewErrNegativeValue[T Incrementable](value T) error {
return errors.New(errNegativeValue, errors.NewKV("Value", value))
}

func NewErrUnsupportedCounterType(valueType client.ScalarKind) error {
return errors.New(errUnsupportedCounterType, errors.NewKV("Type", valueType))
}
36 changes: 12 additions & 24 deletions internal/core/crdt/ipld_union.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ import "github.com/sourcenetwork/defradb/internal/core"
type CRDT struct {
LWWRegDelta *LWWRegDelta
CompositeDAGDelta *CompositeDAGDelta
CounterDeltaInt *CounterDelta[int64]
CounterDeltaFloat *CounterDelta[float64]
CounterDelta *CounterDelta
}

// IPLDSchemaBytes returns the IPLD schema representation for the CRDT.
Expand All @@ -28,8 +27,7 @@ func (c CRDT) IPLDSchemaBytes() []byte {
type CRDT union {
| LWWRegDelta "lww"
| CompositeDAGDelta "composite"
| CounterDeltaInt "counterInt"
| CounterDeltaFloat "counterFloat"
| CounterDelta "counter"
} representation keyed`)
}

Expand All @@ -40,10 +38,8 @@ func (c CRDT) GetDelta() core.Delta {
return c.LWWRegDelta
case c.CompositeDAGDelta != nil:
return c.CompositeDAGDelta
case c.CounterDeltaFloat != nil:
return c.CounterDeltaFloat
case c.CounterDeltaInt != nil:
return c.CounterDeltaInt
case c.CounterDelta != nil:
return c.CounterDelta
}
return nil
}
Expand All @@ -55,10 +51,8 @@ func (c CRDT) GetPriority() uint64 {
return c.LWWRegDelta.GetPriority()
case c.CompositeDAGDelta != nil:
return c.CompositeDAGDelta.GetPriority()
case c.CounterDeltaFloat != nil:
return c.CounterDeltaFloat.GetPriority()
case c.CounterDeltaInt != nil:
return c.CounterDeltaInt.GetPriority()
case c.CounterDelta != nil:
return c.CounterDelta.GetPriority()
}
return 0
}
Expand All @@ -70,10 +64,8 @@ func (c CRDT) GetFieldName() string {
return c.LWWRegDelta.FieldName
case c.CompositeDAGDelta != nil:
return c.CompositeDAGDelta.FieldName
case c.CounterDeltaFloat != nil:
return c.CounterDeltaFloat.FieldName
case c.CounterDeltaInt != nil:
return c.CounterDeltaInt.FieldName
case c.CounterDelta != nil:
return c.CounterDelta.FieldName
}
return ""
}
Expand All @@ -85,10 +77,8 @@ func (c CRDT) GetDocID() []byte {
return c.LWWRegDelta.DocID
case c.CompositeDAGDelta != nil:
return c.CompositeDAGDelta.DocID
case c.CounterDeltaFloat != nil:
return c.CounterDeltaFloat.DocID
case c.CounterDeltaInt != nil:
return c.CounterDeltaInt.DocID
case c.CounterDelta != nil:
return c.CounterDelta.DocID
}
return nil
}
Expand All @@ -100,10 +90,8 @@ func (c CRDT) GetSchemaVersionID() string {
return c.LWWRegDelta.SchemaVersionID
case c.CompositeDAGDelta != nil:
return c.CompositeDAGDelta.SchemaVersionID
case c.CounterDeltaFloat != nil:
return c.CounterDeltaFloat.SchemaVersionID
case c.CounterDeltaInt != nil:
return c.CounterDeltaInt.SchemaVersionID
case c.CounterDelta != nil:
return c.CounterDelta.SchemaVersionID
}
return ""
}
Expand Down
Loading

0 comments on commit 8d0c756

Please sign in to comment.