diff --git a/cli/collection_create.go b/cli/collection_create.go index eecdfef2d8..002847d6ec 100644 --- a/cli/collection_create.go +++ b/cli/collection_create.go @@ -17,8 +17,6 @@ import ( "github.com/spf13/cobra" "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/datastore" - "github.com/sourcenetwork/defradb/internal/db" "github.com/sourcenetwork/defradb/internal/encryption" ) @@ -89,8 +87,7 @@ Example: create from stdin: return cmd.Usage() } - txn, _ := db.TryGetContextTxn(cmd.Context()) - setContextDocEncryption(cmd, shouldEncryptDoc, encryptedFields, txn) + setContextDocEncryption(cmd, shouldEncryptDoc, encryptedFields) if client.IsJSONArray(docData) { docs, err := client.NewDocsFromJSON(docData, col.Definition()) @@ -116,14 +113,11 @@ Example: create from stdin: } // setContextDocEncryption sets doc encryption for the current command context. -func setContextDocEncryption(cmd *cobra.Command, shouldEncryptDoc bool, encryptFields []string, txn datastore.Txn) { +func setContextDocEncryption(cmd *cobra.Command, shouldEncryptDoc bool, encryptFields []string) { if !shouldEncryptDoc && len(encryptFields) == 0 { return } ctx := cmd.Context() - if txn != nil { - ctx = encryption.ContextWithStore(ctx, txn) - } ctx = encryption.SetContextConfigFromParams(ctx, shouldEncryptDoc, encryptFields) cmd.SetContext(ctx) } diff --git a/client/db.go b/client/db.go index 50ee1f82dc..b8f5e91e35 100644 --- a/client/db.go +++ b/client/db.go @@ -52,6 +52,11 @@ type DB interface { // It sits within the rootstore returned by [Root]. Blockstore() datastore.Blockstore + // Encstore returns the store, that contains all known encryption keys for documents and their fields. + // + // It sits within the rootstore returned by [Root]. + Encstore() datastore.Blockstore + // Peerstore returns the peerstore where known host information is stored. // // It sits within the rootstore returned by [Root]. diff --git a/client/mocks/db.go b/client/mocks/db.go index b14aec5d05..8923e63d78 100644 --- a/client/mocks/db.go +++ b/client/mocks/db.go @@ -479,6 +479,53 @@ func (_c *DB_DeleteReplicator_Call) RunAndReturn(run func(context.Context, clien return _c } +// Encstore provides a mock function with given fields: +func (_m *DB) Encstore() datastore.Blockstore { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Encstore") + } + + var r0 datastore.Blockstore + if rf, ok := ret.Get(0).(func() datastore.Blockstore); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(datastore.Blockstore) + } + } + + return r0 +} + +// DB_Encstore_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Encstore' +type DB_Encstore_Call struct { + *mock.Call +} + +// Encstore is a helper method to define mock.On call +func (_e *DB_Expecter) Encstore() *DB_Encstore_Call { + return &DB_Encstore_Call{Call: _e.mock.On("Encstore")} +} + +func (_c *DB_Encstore_Call) Run(run func()) *DB_Encstore_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *DB_Encstore_Call) Return(_a0 datastore.Blockstore) *DB_Encstore_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *DB_Encstore_Call) RunAndReturn(run func() datastore.Blockstore) *DB_Encstore_Call { + _c.Call.Return(run) + return _c +} + // Events provides a mock function with given fields: func (_m *DB) Events() *event.Bus { ret := _m.Called() diff --git a/crypto/aes.go b/crypto/aes.go new file mode 100644 index 0000000000..9fa2bd8deb --- /dev/null +++ b/crypto/aes.go @@ -0,0 +1,94 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package crypto + +import ( + "crypto/aes" + "crypto/cipher" +) + +// EncryptAES encrypts data using AES-GCM with a provided key and additional data. +// It generates a nonce internally and optionally prepends it to the cipherText. +// +// Parameters: +// - plainText: The data to be encrypted +// - key: The AES encryption key +// - additionalData: Additional authenticated data (AAD) to be used in the encryption +// - prependNonce: If true, the nonce is prepended to the returned cipherText +// +// Returns: +// - cipherText: The encrypted data, with the nonce prepended if prependNonce is true +// - nonce: The generated nonce +// - error: Any error encountered during the encryption process +func EncryptAES(plainText, key, additionalData []byte, prependNonce bool) ([]byte, []byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, nil, err + } + + nonce, err := generateNonceFunc() + if err != nil { + return nil, nil, err + } + + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return nil, nil, err + } + + var cipherText []byte + if prependNonce { + cipherText = aesGCM.Seal(nonce, nonce, plainText, additionalData) + } else { + cipherText = aesGCM.Seal(nil, nonce, plainText, additionalData) + } + + return cipherText, nonce, nil +} + +// DecryptAES decrypts AES-GCM encrypted data with a provided key and additional data. +// If no separate nonce is provided, it assumes the nonce is prepended to the cipherText. +// +// Parameters: +// - nonce: The nonce used for decryption. If empty, it's assumed to be prepended to cipherText +// - cipherText: The data to be decrypted +// - key: The AES decryption key +// - additionalData: Additional authenticated data (AAD) used during encryption +// +// Returns: +// - plainText: The decrypted data +// - error: Any error encountered during the decryption process, including authentication failures +func DecryptAES(nonce, cipherText, key, additionalData []byte) ([]byte, error) { + if len(nonce) == 0 { + if len(cipherText) < AESNonceSize { + return nil, ErrCipherTextTooShort + } + nonce = cipherText[:AESNonceSize] + cipherText = cipherText[AESNonceSize:] + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + plainText, err := aesGCM.Open(nil, nonce, cipherText, additionalData) + if err != nil { + return nil, err + } + + return plainText, nil +} diff --git a/crypto/aes_test.go b/crypto/aes_test.go new file mode 100644 index 0000000000..7218ca24b2 --- /dev/null +++ b/crypto/aes_test.go @@ -0,0 +1,175 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package crypto + +import ( + "bytes" + "crypto/rand" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestEncryptAES(t *testing.T) { + validKey := make([]byte, 32) // AES-256 + _, err := rand.Read(validKey) + require.NoError(t, err) + validPlaintext := []byte("Hello, World!") + validAAD := []byte("Additional Authenticated Data") + + tests := []struct { + name string + plainText []byte + key []byte + additionalData []byte + prependNonce bool + expectError bool + errorContains string + }{ + { + name: "Valid encryption with prepended nonce", + plainText: validPlaintext, + key: validKey, + additionalData: validAAD, + prependNonce: true, + expectError: false, + }, + { + name: "Valid encryption without prepended nonce", + plainText: validPlaintext, + key: validKey, + additionalData: validAAD, + prependNonce: false, + expectError: false, + }, + { + name: "Invalid key size", + plainText: validPlaintext, + key: make([]byte, 31), // Invalid key size + additionalData: validAAD, + prependNonce: true, + expectError: true, + errorContains: "invalid key size", + }, + { + name: "Nil plaintext", + plainText: nil, + key: validKey, + additionalData: validAAD, + prependNonce: true, + expectError: false, // AES-GCM can encrypt nil/empty plaintext + }, + { + name: "Nil additional data", + plainText: validPlaintext, + key: validKey, + additionalData: nil, + prependNonce: true, + expectError: false, // Nil AAD is valid + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cipherText, nonce, err := EncryptAES(tt.plainText, tt.key, tt.additionalData, tt.prependNonce) + + if tt.expectError { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errorContains) + } else { + require.NoError(t, err) + if tt.prependNonce { + require.Greater(t, len(cipherText), len(nonce), "Ciphertext length not greater than nonce length") + } else { + require.Equal(t, AESNonceSize, len(nonce), "Nonce length != AESNonceSize") + } + } + }) + } +} + +func TestDecryptAES(t *testing.T) { + validKey := make([]byte, 32) // AES-256 + _, err := rand.Read(validKey) + require.NoError(t, err) + validPlaintext := []byte("Hello, World!") + validAAD := []byte("Additional Authenticated Data") + validCiphertext, validNonce, _ := EncryptAES(validPlaintext, validKey, validAAD, true) + + tests := []struct { + name string + nonce []byte + cipherText []byte + key []byte + additionalData []byte + expectError bool + errorContains string + }{ + { + name: "Valid decryption", + nonce: nil, // Should be extracted from cipherText + cipherText: validCiphertext, + key: validKey, + additionalData: validAAD, + expectError: false, + }, + { + name: "Invalid key size", + nonce: validNonce, + cipherText: validCiphertext[AESNonceSize:], + key: make([]byte, 31), // Invalid key size + additionalData: validAAD, + expectError: true, + errorContains: "invalid key size", + }, + { + name: "Ciphertext too short", + nonce: nil, + cipherText: make([]byte, AESNonceSize-1), // Too short to contain nonce + key: validKey, + additionalData: validAAD, + expectError: true, + errorContains: errCipherTextTooShort, + }, + { + name: "Invalid additional data", + nonce: validNonce, + cipherText: validCiphertext[AESNonceSize:], + key: validKey, + additionalData: []byte("Wrong AAD"), + expectError: true, + errorContains: "message authentication failed", + }, + { + name: "Tampered ciphertext", + nonce: validNonce, + cipherText: append([]byte{0}, validCiphertext[AESNonceSize+1:]...), + key: validKey, + additionalData: validAAD, + expectError: true, + errorContains: "message authentication failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + plainText, err := DecryptAES(tt.nonce, tt.cipherText, tt.key, tt.additionalData) + + if tt.expectError { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errorContains) + } else { + require.NoError(t, err) + require.True(t, bytes.Equal(plainText, validPlaintext), "Decrypted plaintext does not match original") + } + }) + } +} diff --git a/crypto/ecies.go b/crypto/ecies.go new file mode 100644 index 0000000000..f025e87823 --- /dev/null +++ b/crypto/ecies.go @@ -0,0 +1,274 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package crypto + +import ( + "crypto/ecdh" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + + "golang.org/x/crypto/hkdf" +) + +const X25519PublicKeySize = 32 +const HMACSize = 32 +const AESKeySize = 32 + +const minCipherTextSize = 16 + +// GenerateX25519 generates a new X25519 private key. +func GenerateX25519() (*ecdh.PrivateKey, error) { + return ecdh.X25519().GenerateKey(rand.Reader) +} + +// X25519PublicKeyFromBytes creates a new X25519 public key from the given bytes. +func X25519PublicKeyFromBytes(publicKeyBytes []byte) (*ecdh.PublicKey, error) { + return ecdh.X25519().NewPublicKey(publicKeyBytes) +} + +type ECIESOption func(*eciesOptions) + +type eciesOptions struct { + associatedData []byte + privateKey *ecdh.PrivateKey + publicKeyBytes []byte + noPubKeyPrepended bool +} + +// WithAAD sets the associated data to use for authentication. +func WithAAD(aad []byte) ECIESOption { + return func(o *eciesOptions) { + o.associatedData = aad + } +} + +// WithPrivKey sets the private key to use for encryption. +// +// If not set, a new ephemeral key will be generated. +// This option has no effect on decryption. +func WithPrivKey(privKey *ecdh.PrivateKey) ECIESOption { + return func(o *eciesOptions) { + o.privateKey = privKey + } +} + +// WithPubKeyBytes sets the public key bytes to use for decryption. +// +// If not set, the cipherText is assumed to have the public key X25519 prepended. +// This option has no effect on encryption. +func WithPubKeyBytes(pubKeyBytes []byte) ECIESOption { + return func(o *eciesOptions) { + o.publicKeyBytes = pubKeyBytes + } +} + +// WithPubKeyPrepended sets whether the public key should is prepended to the cipherText. +// +// Upon encryption, if set to true (default value), the public key is prepended to the cipherText. +// Otherwise it's not and in this case a private key should be provided with the WithPrivKey option. +// +// Upon decryption, if set to true (default value), the public key is expected to be prepended to the cipherText. +// Otherwise it's not and in this case the public key bytes should be provided with the WithPubKeyBytes option. +func WithPubKeyPrepended(prepended bool) ECIESOption { + return func(o *eciesOptions) { + o.noPubKeyPrepended = !prepended + } +} + +// EncryptECIES encrypts plaintext using a custom Elliptic Curve Integrated Encryption Scheme (ECIES) +// with X25519 for key agreement, HKDF for key derivation, AES for encryption, and HMAC for authentication. +// +// The function: +// - Uses or generates an ephemeral X25519 key pair +// - Performs ECDH with the provided public key +// - Derives encryption and HMAC keys using HKDF +// - Encrypts the plaintext using a custom AES encryption function +// - Computes an HMAC over the ciphertext +// +// The default output format is: [ephemeral public key | encrypted data (including nonce) | HMAC] +// This can be modified using options. +// +// Parameters: +// - plainText: The message to encrypt +// - publicKey: The recipient's X25519 public key +// - opts: Optional ECIESOption functions to customize the encryption process +// +// Available options: +// - WithAAD(aad []byte): Sets the associated data for additional authentication +// - WithPrivKey(privKey *ecdh.PrivateKey): Uses the provided private key instead of generating a new one +// - WithPubKeyPrepended(prepended bool): Controls whether the public key is prepended to the ciphertext +// +// Returns: +// - Byte slice containing the encrypted message and necessary metadata for decryption +// - Error if any step of the encryption process fails +// +// Example usage: +// +// cipherText, err := EncryptECIES(plainText, recipientPublicKey, +// WithAAD(additionalData), +// WithPrivKey(senderPrivateKey), +// WithPubKeyPrepended(false)) +func EncryptECIES(plainText []byte, publicKey *ecdh.PublicKey, opts ...ECIESOption) ([]byte, error) { + options := &eciesOptions{} + for _, opt := range opts { + opt(options) + } + + ourPrivateKey := options.privateKey + if ourPrivateKey == nil { + if options.noPubKeyPrepended { + return nil, ErrNoPublicKeyForDecryption + } + var err error + ourPrivateKey, err = GenerateX25519() + if err != nil { + return nil, NewErrFailedToGenerateEphemeralKey(err) + } + } + ourPublicKey := ourPrivateKey.PublicKey() + + sharedSecret, err := ourPrivateKey.ECDH(publicKey) + if err != nil { + return nil, NewErrFailedECDHOperation(err) + } + + kdf := hkdf.New(sha256.New, sharedSecret, nil, nil) + aesKey := make([]byte, AESKeySize) + hmacKey := make([]byte, HMACSize) + if _, err := kdf.Read(aesKey); err != nil { + return nil, NewErrFailedKDFOperationForAESKey(err) + } + if _, err := kdf.Read(hmacKey); err != nil { + return nil, NewErrFailedKDFOperationForHMACKey(err) + } + + cipherText, _, err := EncryptAES(plainText, aesKey, makeAAD(ourPublicKey.Bytes(), options.associatedData), true) + if err != nil { + return nil, NewErrFailedToEncrypt(err) + } + + mac := hmac.New(sha256.New, hmacKey) + mac.Write(cipherText) + macSum := mac.Sum(nil) + + var result []byte + if options.noPubKeyPrepended { + result = cipherText + } else { + result = append(ourPublicKey.Bytes(), cipherText...) + } + result = append(result, macSum...) + + return result, nil +} + +// DecryptECIES decrypts ciphertext encrypted with EncryptECIES using the provided private key. +// +// The function: +// - Extracts or uses the provided ephemeral public key +// - Performs ECDH with the provided private key +// - Derives decryption and HMAC keys using HKDF +// - Verifies the HMAC +// - Decrypts the message using a custom AES decryption function +// +// The default expected input format is: [ephemeral public key | encrypted data (including nonce) | HMAC] +// This can be modified using options. +// +// Parameters: +// - cipherText: The encrypted message, including all necessary metadata +// - privateKey: The recipient's X25519 private key +// - opts: Optional ECIESOption functions to customize the decryption process +// +// Available options: +// - WithAAD(aad []byte): Sets the associated data used during encryption for additional authentication +// - WithPubKeyBytes(pubKeyBytes []byte): Provides the public key bytes if not prepended to the ciphertext +// - WithPubKeyPrepended(prepended bool): Indicates whether the public key is prepended to the ciphertext +// +// Returns: +// - Byte slice containing the decrypted plaintext +// - Error if any step of the decryption process fails, including authentication failure +// +// Example usage: +// +// plainText, err := DecryptECIES(cipherText, recipientPrivateKey, +// WithAAD(additionalData), +// WithPubKeyBytes(senderPublicKeyBytes), +// WithPubKeyPrepended(false)) +func DecryptECIES(cipherText []byte, ourPrivateKey *ecdh.PrivateKey, opts ...ECIESOption) ([]byte, error) { + options := &eciesOptions{} + for _, opt := range opts { + opt(options) + } + + minLength := X25519PublicKeySize + AESNonceSize + HMACSize + minCipherTextSize + if options.noPubKeyPrepended { + minLength -= X25519PublicKeySize + } + + if len(cipherText) < minLength { + return nil, ErrCipherTextTooShort + } + + publicKeyBytes := options.publicKeyBytes + if options.publicKeyBytes == nil { + if options.noPubKeyPrepended { + return nil, ErrNoPublicKeyForDecryption + } + publicKeyBytes = cipherText[:X25519PublicKeySize] + cipherText = cipherText[X25519PublicKeySize:] + } + publicKey, err := ecdh.X25519().NewPublicKey(publicKeyBytes) + if err != nil { + return nil, NewErrFailedToParseEphemeralPublicKey(err) + } + + sharedSecret, err := ourPrivateKey.ECDH(publicKey) + if err != nil { + return nil, NewErrFailedECDHOperation(err) + } + + kdf := hkdf.New(sha256.New, sharedSecret, nil, nil) + aesKey := make([]byte, AESKeySize) + hmacKey := make([]byte, HMACSize) + if _, err := kdf.Read(aesKey); err != nil { + return nil, NewErrFailedKDFOperationForAESKey(err) + } + if _, err := kdf.Read(hmacKey); err != nil { + return nil, NewErrFailedKDFOperationForHMACKey(err) + } + + macSum := cipherText[len(cipherText)-HMACSize:] + cipherTextWithNonce := cipherText[:len(cipherText)-HMACSize] + + mac := hmac.New(sha256.New, hmacKey) + mac.Write(cipherTextWithNonce) + expectedMAC := mac.Sum(nil) + if !hmac.Equal(macSum, expectedMAC) { + return nil, ErrVerificationWithHMACFailed + } + + plainText, err := DecryptAES(nil, cipherTextWithNonce, aesKey, makeAAD(publicKeyBytes, options.associatedData)) + if err != nil { + return nil, NewErrFailedToDecrypt(err) + } + + return plainText, nil +} + +// makeAAD concatenates the ephemeral public key and associated data for use as additional authenticated data. +func makeAAD(ephemeralPublicBytes, associatedData []byte) []byte { + l := len(ephemeralPublicBytes) + len(associatedData) + aad := make([]byte, l) + copy(aad, ephemeralPublicBytes) + copy(aad[len(ephemeralPublicBytes):], associatedData) + return aad +} diff --git a/crypto/ecies_test.go b/crypto/ecies_test.go new file mode 100644 index 0000000000..f4ed463c26 --- /dev/null +++ b/crypto/ecies_test.go @@ -0,0 +1,201 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package crypto + +import ( + "crypto/ecdh" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEncryptECIES_Errors(t *testing.T) { + validAssociatedData := []byte("associated data") + validPrivateKey, _ := GenerateX25519() + + tests := []struct { + name string + plainText []byte + publicKey *ecdh.PublicKey + opts []ECIESOption + expectError string + }{ + { + name: "Invalid public key", + plainText: []byte("test data"), + publicKey: &ecdh.PublicKey{}, + opts: []ECIESOption{WithAAD(validAssociatedData)}, + expectError: errFailedECDHOperation, + }, + { + name: "No public key prepended and no private key provided", + plainText: []byte("test data"), + publicKey: validPrivateKey.PublicKey(), + opts: []ECIESOption{WithPubKeyPrepended(false)}, + expectError: errNoPublicKeyForDecryption, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := EncryptECIES(tt.plainText, tt.publicKey, tt.opts...) + if err == nil { + t.Errorf("Expected an error, but got nil") + } else if !strings.Contains(err.Error(), tt.expectError) { + t.Errorf("Expected error containing '%s', got '%v'", tt.expectError, err) + } + }) + } +} + +func TestDecryptECIES_Errors(t *testing.T) { + validPrivateKey, _ := GenerateX25519() + aad := []byte("associated data") + validCipherText, _ := EncryptECIES([]byte("test data test data"), validPrivateKey.PublicKey(), WithAAD(aad)) + + tests := []struct { + name string + cipherText []byte + privateKey *ecdh.PrivateKey + opts []ECIESOption + expectError string + }{ + { + name: "Ciphertext too short", + cipherText: []byte("short"), + privateKey: validPrivateKey, + opts: []ECIESOption{WithAAD(aad)}, + expectError: errCipherTextTooShort, + }, + { + name: "Invalid private key", + cipherText: validCipherText, + privateKey: &ecdh.PrivateKey{}, + opts: []ECIESOption{WithAAD(aad)}, + expectError: errFailedECDHOperation, + }, + { + name: "Tampered ciphertext", + cipherText: append(validCipherText, byte(0)), + privateKey: validPrivateKey, + opts: []ECIESOption{WithAAD(aad)}, + expectError: errVerificationWithHMACFailed, + }, + { + name: "Wrong associated data", + cipherText: validCipherText, + privateKey: validPrivateKey, + opts: []ECIESOption{WithAAD([]byte("wrong data"))}, + expectError: errFailedToDecrypt, + }, + { + name: "No public key prepended and no public key bytes provided", + cipherText: validCipherText[X25519PublicKeySize:], + privateKey: validPrivateKey, + opts: []ECIESOption{WithAAD(aad), WithPubKeyPrepended(false)}, + expectError: errNoPublicKeyForDecryption, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := DecryptECIES(tt.cipherText, tt.privateKey, tt.opts...) + if err == nil || !strings.Contains(err.Error(), tt.expectError) { + t.Errorf("Expected error containing '%s', got %v", tt.expectError, err) + } + }) + } +} + +func TestEncryptDecryptECIES_DefaultOptions_Succeeds(t *testing.T) { + plainText := []byte("Hello, World!") + recipientPrivateKey := mustGenerateX25519(t) + + cipherText, err := EncryptECIES(plainText, recipientPrivateKey.PublicKey()) + require.NoError(t, err) + + decryptedText, err := DecryptECIES(cipherText, recipientPrivateKey) + require.NoError(t, err) + + assert.Equal(t, plainText, decryptedText) +} + +func TestEncryptDecryptECIES_WithAAD_Succeeds(t *testing.T) { + plainText := []byte("Secret message") + aad := []byte("extra authentication data") + recipientPrivateKey := mustGenerateX25519(t) + + cipherText, err := EncryptECIES(plainText, recipientPrivateKey.PublicKey(), WithAAD(aad)) + require.NoError(t, err) + + decryptedText, err := DecryptECIES(cipherText, recipientPrivateKey, WithAAD(aad)) + require.NoError(t, err) + + assert.Equal(t, plainText, decryptedText) +} + +func TestEncryptDecryptECIES_WithCustomPrivateKey_Succeeds(t *testing.T) { + plainText := []byte("Custom key message") + recipientPrivateKey := mustGenerateX25519(t) + senderPrivateKey := mustGenerateX25519(t) + + cipherText, err := EncryptECIES(plainText, recipientPrivateKey.PublicKey(), WithPrivKey(senderPrivateKey)) + require.NoError(t, err) + + require.Equal(t, senderPrivateKey.PublicKey().Bytes(), cipherText[:X25519PublicKeySize]) + + decryptedText, err := DecryptECIES(cipherText, recipientPrivateKey) + require.NoError(t, err) + + assert.Equal(t, plainText, decryptedText) +} + +func TestEncryptDecryptECIES_WithoutPublicKeyPrepended_Succeeds(t *testing.T) { + plainText := []byte("No prepended key") + recipientPrivateKey := mustGenerateX25519(t) + senderPrivateKey := mustGenerateX25519(t) + + cipherText, err := EncryptECIES(plainText, recipientPrivateKey.PublicKey(), + WithPubKeyPrepended(false), + WithPrivKey(senderPrivateKey)) + require.NoError(t, err) + + // In a real scenario, the public key would be transmitted separately + senderPublicKeyBytes := senderPrivateKey.PublicKey().Bytes() + + decryptedText, err := DecryptECIES(cipherText, recipientPrivateKey, + WithPubKeyPrepended(false), + WithPubKeyBytes(senderPublicKeyBytes)) + require.NoError(t, err) + + assert.Equal(t, plainText, decryptedText) +} + +func TestEncryptDecryptECIES_DifferentAAD_FailsToDecrypt(t *testing.T) { + plainText := []byte("AAD test message") + encryptAAD := []byte("encryption AAD") + decryptAAD := []byte("decryption AAD") + recipientPrivateKey := mustGenerateX25519(t) + + cipherText, err := EncryptECIES(plainText, recipientPrivateKey.PublicKey(), WithAAD(encryptAAD)) + require.NoError(t, err) + + _, err = DecryptECIES(cipherText, recipientPrivateKey, WithAAD(decryptAAD)) + assert.Error(t, err, "Decryption should fail with different AAD") +} + +func mustGenerateX25519(t *testing.T) *ecdh.PrivateKey { + key, err := GenerateX25519() + require.NoError(t, err) + return key +} diff --git a/crypto/errors.go b/crypto/errors.go new file mode 100644 index 0000000000..a6128f9860 --- /dev/null +++ b/crypto/errors.go @@ -0,0 +1,62 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package crypto + +import ( + "github.com/sourcenetwork/defradb/errors" +) + +const ( + errFailedToGenerateEphemeralKey string = "failed to generate ephemeral key" + errFailedECDHOperation string = "failed ECDH operation" + errFailedKDFOperationForAESKey string = "failed KDF operation for AES key" + errFailedKDFOperationForHMACKey string = "failed KDF operation for HMAC key" + errFailedToEncrypt string = "failed to encrypt" + errCipherTextTooShort string = "cipherText too short" + errFailedToParseEphemeralPublicKey string = "failed to parse ephemeral public key" + errVerificationWithHMACFailed string = "verification with HMAC failed" + errFailedToDecrypt string = "failed to decrypt" + errNoPublicKeyForDecryption string = "no public key provided for decryption" +) + +var ( + ErrCipherTextTooShort = errors.New(errCipherTextTooShort) + ErrVerificationWithHMACFailed = errors.New(errVerificationWithHMACFailed) + ErrNoPublicKeyForDecryption = errors.New(errNoPublicKeyForDecryption) +) + +func NewErrFailedToGenerateEphemeralKey(inner error) error { + return errors.Wrap(errFailedToGenerateEphemeralKey, inner) +} + +func NewErrFailedECDHOperation(inner error) error { + return errors.Wrap(errFailedECDHOperation, inner) +} + +func NewErrFailedKDFOperationForAESKey(inner error) error { + return errors.Wrap(errFailedKDFOperationForAESKey, inner) +} + +func NewErrFailedKDFOperationForHMACKey(inner error) error { + return errors.Wrap(errFailedKDFOperationForHMACKey, inner) +} + +func NewErrFailedToEncrypt(inner error) error { + return errors.Wrap(errFailedToEncrypt, inner) +} + +func NewErrFailedToParseEphemeralPublicKey(inner error) error { + return errors.Wrap(errFailedToParseEphemeralPublicKey, inner) +} + +func NewErrFailedToDecrypt(inner error) error { + return errors.Wrap(errFailedToDecrypt, inner) +} diff --git a/internal/encryption/nonce.go b/crypto/nonce.go similarity index 86% rename from internal/encryption/nonce.go rename to crypto/nonce.go index 67a5467a4e..9c8f00b31f 100644 --- a/internal/encryption/nonce.go +++ b/crypto/nonce.go @@ -8,7 +8,7 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. -package encryption +package crypto import ( "crypto/rand" @@ -18,12 +18,12 @@ import ( "strings" ) -const nonceLength = 12 +const AESNonceSize = 12 var generateNonceFunc = generateNonce func generateNonce() ([]byte, error) { - nonce := make([]byte, nonceLength) + nonce := make([]byte, AESNonceSize) if _, err := io.ReadFull(rand.Reader, nonce); err != nil { return nil, err } @@ -35,11 +35,11 @@ func generateNonce() ([]byte, error) { func generateTestNonce() ([]byte, error) { nonce := []byte("deterministic nonce for testing") - if len(nonce) < nonceLength { + if len(nonce) < AESNonceSize { return nil, errors.New("nonce length is longer than available deterministic nonce") } - return nonce[:nonceLength], nil + return nonce[:AESNonceSize], nil } func init() { @@ -48,6 +48,5 @@ func init() { // TODO: We should try to find a better way to detect this https://github.com/sourcenetwork/defradb/issues/2801 if strings.HasSuffix(arg, ".test") || strings.Contains(arg, "/defradb/tests/") { generateNonceFunc = generateTestNonce - generateEncryptionKeyFunc = generateTestEncryptionKey } } diff --git a/datastore/mocks/blockstore.go b/datastore/mocks/blockstore.go new file mode 100644 index 0000000000..6dab79de7c --- /dev/null +++ b/datastore/mocks/blockstore.go @@ -0,0 +1,493 @@ +// Code generated by mockery. DO NOT EDIT. + +package mocks + +import ( + blocks "github.com/ipfs/go-block-format" + cid "github.com/ipfs/go-cid" + + context "context" + + datastore "github.com/sourcenetwork/defradb/datastore" + + mock "github.com/stretchr/testify/mock" +) + +// Blockstore is an autogenerated mock type for the Blockstore type +type Blockstore struct { + mock.Mock +} + +type Blockstore_Expecter struct { + mock *mock.Mock +} + +func (_m *Blockstore) EXPECT() *Blockstore_Expecter { + return &Blockstore_Expecter{mock: &_m.Mock} +} + +// AllKeysChan provides a mock function with given fields: ctx +func (_m *Blockstore) AllKeysChan(ctx context.Context) (<-chan cid.Cid, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for AllKeysChan") + } + + var r0 <-chan cid.Cid + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (<-chan cid.Cid, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) <-chan cid.Cid); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan cid.Cid) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Blockstore_AllKeysChan_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AllKeysChan' +type Blockstore_AllKeysChan_Call struct { + *mock.Call +} + +// AllKeysChan is a helper method to define mock.On call +// - ctx context.Context +func (_e *Blockstore_Expecter) AllKeysChan(ctx interface{}) *Blockstore_AllKeysChan_Call { + return &Blockstore_AllKeysChan_Call{Call: _e.mock.On("AllKeysChan", ctx)} +} + +func (_c *Blockstore_AllKeysChan_Call) Run(run func(ctx context.Context)) *Blockstore_AllKeysChan_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *Blockstore_AllKeysChan_Call) Return(_a0 <-chan cid.Cid, _a1 error) *Blockstore_AllKeysChan_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Blockstore_AllKeysChan_Call) RunAndReturn(run func(context.Context) (<-chan cid.Cid, error)) *Blockstore_AllKeysChan_Call { + _c.Call.Return(run) + return _c +} + +// AsIPLDStorage provides a mock function with given fields: +func (_m *Blockstore) AsIPLDStorage() datastore.IPLDStorage { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for AsIPLDStorage") + } + + var r0 datastore.IPLDStorage + if rf, ok := ret.Get(0).(func() datastore.IPLDStorage); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(datastore.IPLDStorage) + } + } + + return r0 +} + +// Blockstore_AsIPLDStorage_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AsIPLDStorage' +type Blockstore_AsIPLDStorage_Call struct { + *mock.Call +} + +// AsIPLDStorage is a helper method to define mock.On call +func (_e *Blockstore_Expecter) AsIPLDStorage() *Blockstore_AsIPLDStorage_Call { + return &Blockstore_AsIPLDStorage_Call{Call: _e.mock.On("AsIPLDStorage")} +} + +func (_c *Blockstore_AsIPLDStorage_Call) Run(run func()) *Blockstore_AsIPLDStorage_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Blockstore_AsIPLDStorage_Call) Return(_a0 datastore.IPLDStorage) *Blockstore_AsIPLDStorage_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Blockstore_AsIPLDStorage_Call) RunAndReturn(run func() datastore.IPLDStorage) *Blockstore_AsIPLDStorage_Call { + _c.Call.Return(run) + return _c +} + +// DeleteBlock provides a mock function with given fields: _a0, _a1 +func (_m *Blockstore) DeleteBlock(_a0 context.Context, _a1 cid.Cid) error { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for DeleteBlock") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, cid.Cid) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Blockstore_DeleteBlock_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteBlock' +type Blockstore_DeleteBlock_Call struct { + *mock.Call +} + +// DeleteBlock is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 cid.Cid +func (_e *Blockstore_Expecter) DeleteBlock(_a0 interface{}, _a1 interface{}) *Blockstore_DeleteBlock_Call { + return &Blockstore_DeleteBlock_Call{Call: _e.mock.On("DeleteBlock", _a0, _a1)} +} + +func (_c *Blockstore_DeleteBlock_Call) Run(run func(_a0 context.Context, _a1 cid.Cid)) *Blockstore_DeleteBlock_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(cid.Cid)) + }) + return _c +} + +func (_c *Blockstore_DeleteBlock_Call) Return(_a0 error) *Blockstore_DeleteBlock_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Blockstore_DeleteBlock_Call) RunAndReturn(run func(context.Context, cid.Cid) error) *Blockstore_DeleteBlock_Call { + _c.Call.Return(run) + return _c +} + +// Get provides a mock function with given fields: _a0, _a1 +func (_m *Blockstore) Get(_a0 context.Context, _a1 cid.Cid) (blocks.Block, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 blocks.Block + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, cid.Cid) (blocks.Block, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, cid.Cid) blocks.Block); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(blocks.Block) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, cid.Cid) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Blockstore_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type Blockstore_Get_Call struct { + *mock.Call +} + +// Get is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 cid.Cid +func (_e *Blockstore_Expecter) Get(_a0 interface{}, _a1 interface{}) *Blockstore_Get_Call { + return &Blockstore_Get_Call{Call: _e.mock.On("Get", _a0, _a1)} +} + +func (_c *Blockstore_Get_Call) Run(run func(_a0 context.Context, _a1 cid.Cid)) *Blockstore_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(cid.Cid)) + }) + return _c +} + +func (_c *Blockstore_Get_Call) Return(_a0 blocks.Block, _a1 error) *Blockstore_Get_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Blockstore_Get_Call) RunAndReturn(run func(context.Context, cid.Cid) (blocks.Block, error)) *Blockstore_Get_Call { + _c.Call.Return(run) + return _c +} + +// GetSize provides a mock function with given fields: _a0, _a1 +func (_m *Blockstore) GetSize(_a0 context.Context, _a1 cid.Cid) (int, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for GetSize") + } + + var r0 int + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, cid.Cid) (int, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, cid.Cid) int); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Get(0).(int) + } + + if rf, ok := ret.Get(1).(func(context.Context, cid.Cid) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Blockstore_GetSize_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSize' +type Blockstore_GetSize_Call struct { + *mock.Call +} + +// GetSize is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 cid.Cid +func (_e *Blockstore_Expecter) GetSize(_a0 interface{}, _a1 interface{}) *Blockstore_GetSize_Call { + return &Blockstore_GetSize_Call{Call: _e.mock.On("GetSize", _a0, _a1)} +} + +func (_c *Blockstore_GetSize_Call) Run(run func(_a0 context.Context, _a1 cid.Cid)) *Blockstore_GetSize_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(cid.Cid)) + }) + return _c +} + +func (_c *Blockstore_GetSize_Call) Return(_a0 int, _a1 error) *Blockstore_GetSize_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Blockstore_GetSize_Call) RunAndReturn(run func(context.Context, cid.Cid) (int, error)) *Blockstore_GetSize_Call { + _c.Call.Return(run) + return _c +} + +// Has provides a mock function with given fields: _a0, _a1 +func (_m *Blockstore) Has(_a0 context.Context, _a1 cid.Cid) (bool, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Has") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, cid.Cid) (bool, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, cid.Cid) bool); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(context.Context, cid.Cid) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Blockstore_Has_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Has' +type Blockstore_Has_Call struct { + *mock.Call +} + +// Has is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 cid.Cid +func (_e *Blockstore_Expecter) Has(_a0 interface{}, _a1 interface{}) *Blockstore_Has_Call { + return &Blockstore_Has_Call{Call: _e.mock.On("Has", _a0, _a1)} +} + +func (_c *Blockstore_Has_Call) Run(run func(_a0 context.Context, _a1 cid.Cid)) *Blockstore_Has_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(cid.Cid)) + }) + return _c +} + +func (_c *Blockstore_Has_Call) Return(_a0 bool, _a1 error) *Blockstore_Has_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Blockstore_Has_Call) RunAndReturn(run func(context.Context, cid.Cid) (bool, error)) *Blockstore_Has_Call { + _c.Call.Return(run) + return _c +} + +// HashOnRead provides a mock function with given fields: enabled +func (_m *Blockstore) HashOnRead(enabled bool) { + _m.Called(enabled) +} + +// Blockstore_HashOnRead_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HashOnRead' +type Blockstore_HashOnRead_Call struct { + *mock.Call +} + +// HashOnRead is a helper method to define mock.On call +// - enabled bool +func (_e *Blockstore_Expecter) HashOnRead(enabled interface{}) *Blockstore_HashOnRead_Call { + return &Blockstore_HashOnRead_Call{Call: _e.mock.On("HashOnRead", enabled)} +} + +func (_c *Blockstore_HashOnRead_Call) Run(run func(enabled bool)) *Blockstore_HashOnRead_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(bool)) + }) + return _c +} + +func (_c *Blockstore_HashOnRead_Call) Return() *Blockstore_HashOnRead_Call { + _c.Call.Return() + return _c +} + +func (_c *Blockstore_HashOnRead_Call) RunAndReturn(run func(bool)) *Blockstore_HashOnRead_Call { + _c.Call.Return(run) + return _c +} + +// Put provides a mock function with given fields: _a0, _a1 +func (_m *Blockstore) Put(_a0 context.Context, _a1 blocks.Block) error { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Put") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, blocks.Block) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Blockstore_Put_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Put' +type Blockstore_Put_Call struct { + *mock.Call +} + +// Put is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 blocks.Block +func (_e *Blockstore_Expecter) Put(_a0 interface{}, _a1 interface{}) *Blockstore_Put_Call { + return &Blockstore_Put_Call{Call: _e.mock.On("Put", _a0, _a1)} +} + +func (_c *Blockstore_Put_Call) Run(run func(_a0 context.Context, _a1 blocks.Block)) *Blockstore_Put_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(blocks.Block)) + }) + return _c +} + +func (_c *Blockstore_Put_Call) Return(_a0 error) *Blockstore_Put_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Blockstore_Put_Call) RunAndReturn(run func(context.Context, blocks.Block) error) *Blockstore_Put_Call { + _c.Call.Return(run) + return _c +} + +// PutMany provides a mock function with given fields: _a0, _a1 +func (_m *Blockstore) PutMany(_a0 context.Context, _a1 []blocks.Block) error { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for PutMany") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, []blocks.Block) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Blockstore_PutMany_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PutMany' +type Blockstore_PutMany_Call struct { + *mock.Call +} + +// PutMany is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 []blocks.Block +func (_e *Blockstore_Expecter) PutMany(_a0 interface{}, _a1 interface{}) *Blockstore_PutMany_Call { + return &Blockstore_PutMany_Call{Call: _e.mock.On("PutMany", _a0, _a1)} +} + +func (_c *Blockstore_PutMany_Call) Run(run func(_a0 context.Context, _a1 []blocks.Block)) *Blockstore_PutMany_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].([]blocks.Block)) + }) + return _c +} + +func (_c *Blockstore_PutMany_Call) Return(_a0 error) *Blockstore_PutMany_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Blockstore_PutMany_Call) RunAndReturn(run func(context.Context, []blocks.Block) error) *Blockstore_PutMany_Call { + _c.Call.Return(run) + return _c +} + +// NewBlockstore creates a new instance of Blockstore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewBlockstore(t interface { + mock.TestingT + Cleanup(func()) +}) *Blockstore { + mock := &Blockstore{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/datastore/mocks/txn.go b/datastore/mocks/txn.go index 41606260ea..ea923d5de4 100644 --- a/datastore/mocks/txn.go +++ b/datastore/mocks/txn.go @@ -196,19 +196,19 @@ func (_c *Txn_Discard_Call) RunAndReturn(run func(context.Context)) *Txn_Discard } // Encstore provides a mock function with given fields: -func (_m *Txn) Encstore() datastore.DSReaderWriter { +func (_m *Txn) Encstore() datastore.Blockstore { ret := _m.Called() if len(ret) == 0 { panic("no return value specified for Encstore") } - var r0 datastore.DSReaderWriter - if rf, ok := ret.Get(0).(func() datastore.DSReaderWriter); ok { + var r0 datastore.Blockstore + if rf, ok := ret.Get(0).(func() datastore.Blockstore); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(datastore.DSReaderWriter) + r0 = ret.Get(0).(datastore.Blockstore) } } @@ -232,12 +232,12 @@ func (_c *Txn_Encstore_Call) Run(run func()) *Txn_Encstore_Call { return _c } -func (_c *Txn_Encstore_Call) Return(_a0 datastore.DSReaderWriter) *Txn_Encstore_Call { +func (_c *Txn_Encstore_Call) Return(_a0 datastore.Blockstore) *Txn_Encstore_Call { _c.Call.Return(_a0) return _c } -func (_c *Txn_Encstore_Call) RunAndReturn(run func() datastore.DSReaderWriter) *Txn_Encstore_Call { +func (_c *Txn_Encstore_Call) RunAndReturn(run func() datastore.Blockstore) *Txn_Encstore_Call { _c.Call.Return(run) return _c } diff --git a/datastore/mocks/utils.go b/datastore/mocks/utils.go index af3c49fd0c..d6c69684be 100644 --- a/datastore/mocks/utils.go +++ b/datastore/mocks/utils.go @@ -24,6 +24,7 @@ type MultiStoreTxn struct { MockRootstore *DSReaderWriter MockDatastore *DSReaderWriter MockHeadstore *DSReaderWriter + MockEncstore *Blockstore MockDAGstore *DAGStore MockSystemstore *DSReaderWriter } @@ -36,6 +37,14 @@ func prepareDataStore(t *testing.T) *DSReaderWriter { return dataStore } +func prepareEncStore(t *testing.T) *Blockstore { + encStore := NewBlockstore(t) + encStore.EXPECT().Get(mock.Anything, mock.Anything).Return(nil, ds.ErrNotFound).Maybe() + encStore.EXPECT().Put(mock.Anything, mock.Anything).Return(nil).Maybe() + encStore.EXPECT().Has(mock.Anything, mock.Anything).Return(true, nil).Maybe() + return encStore +} + func prepareRootstore(t *testing.T) *DSReaderWriter { return NewDSReaderWriter(t) } @@ -75,6 +84,7 @@ func NewTxnWithMultistore(t *testing.T) *MultiStoreTxn { t: t, MockRootstore: prepareRootstore(t), MockDatastore: prepareDataStore(t), + MockEncstore: prepareEncStore(t), MockHeadstore: prepareHeadStore(t), MockDAGstore: prepareDAGStore(t), MockSystemstore: prepareSystemStore(t), @@ -82,6 +92,7 @@ func NewTxnWithMultistore(t *testing.T) *MultiStoreTxn { txn.EXPECT().Rootstore().Return(result.MockRootstore).Maybe() txn.EXPECT().Datastore().Return(result.MockDatastore).Maybe() + txn.EXPECT().Encstore().Return(result.MockEncstore).Maybe() txn.EXPECT().Headstore().Return(result.MockHeadstore).Maybe() txn.EXPECT().Blockstore().Return(result.MockDAGstore).Maybe() txn.EXPECT().Systemstore().Return(result.MockSystemstore).Maybe() diff --git a/datastore/multi.go b/datastore/multi.go index f863924d5d..cbbf80e23f 100644 --- a/datastore/multi.go +++ b/datastore/multi.go @@ -29,12 +29,11 @@ var ( type multistore struct { root DSReaderWriter data DSReaderWriter - enc DSReaderWriter + enc Blockstore head DSReaderWriter peer DSBatching system DSReaderWriter - // block DSReaderWriter - dag Blockstore + dag Blockstore } var _ MultiStore = (*multistore)(nil) @@ -45,7 +44,7 @@ func MultiStoreFrom(rootstore ds.Datastore) MultiStore { ms := &multistore{ root: rootRW, data: prefix(rootRW, dataStoreKey), - enc: prefix(rootRW, encStoreKey), + enc: newBlockstore(prefix(rootRW, encStoreKey)), head: prefix(rootRW, headStoreKey), peer: namespace.Wrap(rootstore, peerStoreKey), system: prefix(rootRW, systemStoreKey), @@ -61,7 +60,7 @@ func (ms multistore) Datastore() DSReaderWriter { } // Encstore implements MultiStore. -func (ms multistore) Encstore() DSReaderWriter { +func (ms multistore) Encstore() Blockstore { return ms.enc } diff --git a/datastore/store.go b/datastore/store.go index 516bfe0b65..641cd10b1a 100644 --- a/datastore/store.go +++ b/datastore/store.go @@ -40,7 +40,7 @@ type MultiStore interface { // Encstore is a wrapped root DSReaderWriter under the /enc namespace // This store is used for storing symmetric encryption keys for doc encryption. // The store keys are comprised of docID + field name. - Encstore() DSReaderWriter + Encstore() Blockstore // Headstore is a wrapped root DSReaderWriter under the /head namespace Headstore() DSReaderWriter diff --git a/docs/data_format_changes/i2891-no-change-tests-updated.md b/docs/data_format_changes/i2891-no-change-tests-updated.md new file mode 100644 index 0000000000..8d22b94c15 --- /dev/null +++ b/docs/data_format_changes/i2891-no-change-tests-updated.md @@ -0,0 +1,3 @@ +# Doc encryption key exchange + +For the key exchange mechanism we changed slightly the structure of DAG block to hold an additional information. diff --git a/go.mod b/go.mod index 4be484b96e..4788e9667c 100644 --- a/go.mod +++ b/go.mod @@ -44,6 +44,7 @@ require ( github.com/multiformats/go-multihash v0.2.3 github.com/pelletier/go-toml v1.9.5 github.com/pkg/errors v0.9.1 + github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 github.com/sourcenetwork/acp_core v0.0.0-20240607160510-47a5306b2ad2 github.com/sourcenetwork/badger/v4 v4.2.1-0.20231113215945-a63444ca5276 github.com/sourcenetwork/corelog v0.0.8 @@ -62,6 +63,7 @@ require ( go.opentelemetry.io/otel/metric v1.30.0 go.opentelemetry.io/otel/sdk/metric v1.30.0 go.uber.org/zap v1.27.0 + golang.org/x/crypto v0.26.0 golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa google.golang.org/grpc v1.66.2 google.golang.org/protobuf v1.34.2 @@ -86,7 +88,7 @@ require ( cosmossdk.io/x/feegrant v0.1.0 // indirect cosmossdk.io/x/tx v0.13.4 // indirect cosmossdk.io/x/upgrade v0.1.1 // indirect - filippo.io/edwards25519 v1.0.0 // indirect + filippo.io/edwards25519 v1.1.0 // indirect github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect github.com/99designs/keyring v1.2.2 // indirect github.com/DataDog/datadog-go v3.2.0+incompatible // indirect @@ -359,7 +361,6 @@ require ( go.uber.org/fx v1.22.2 // indirect go.uber.org/mock v0.4.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/crypto v0.26.0 // indirect golang.org/x/mod v0.20.0 // indirect golang.org/x/net v0.28.0 // indirect golang.org/x/oauth2 v0.21.0 // indirect diff --git a/go.sum b/go.sum index eaf23755d5..374d490ecf 100644 --- a/go.sum +++ b/go.sum @@ -219,8 +219,8 @@ dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7 dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBrvjyP0v+ecvNYvCpyZgu5/xkfAUhi6wJj28eUfSU= dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= -filippo.io/edwards25519 v1.0.0 h1:0wAIcmJUqRdI8IJ/3eGi5/HwXZWPujYXXlkrQogz0Ek= -filippo.io/edwards25519 v1.0.0/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 h1:/vQbFIOMbk2FiG/kXiLl8BRyzTWDw7gX/Hz7Dd5eDMs= github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4/go.mod h1:hN7oaIRCjzsZ2dE+yG5k+rsdt3qcwykqK6HVGcKwsw4= @@ -1265,6 +1265,8 @@ github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/profile v1.2.1/go.mod h1:hJw3o1OdXxsrSjjVksARp5W95eeEaEfptyVZyv6JUPA= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= diff --git a/http/client.go b/http/client.go index e98eac7d07..777cf4a733 100644 --- a/http/client.go +++ b/http/client.go @@ -489,6 +489,10 @@ func (c *Client) Blockstore() datastore.Blockstore { panic("client side database") } +func (c *Client) Encstore() datastore.Blockstore { + panic("client side database") +} + func (c *Client) Peerstore() datastore.DSBatching { panic("client side database") } diff --git a/http/client_tx.go b/http/client_tx.go index 5b99f5aaad..daacb4128e 100644 --- a/http/client_tx.go +++ b/http/client_tx.go @@ -91,7 +91,7 @@ func (c *Transaction) Datastore() datastore.DSReaderWriter { panic("client side transaction") } -func (c *Transaction) Encstore() datastore.DSReaderWriter { +func (c *Transaction) Encstore() datastore.Blockstore { panic("client side transaction") } diff --git a/internal/core/block/block.go b/internal/core/block/block.go index d2caa610f7..1ec62fe939 100644 --- a/internal/core/block/block.go +++ b/internal/core/block/block.go @@ -29,12 +29,15 @@ import ( // Schema is the IPLD schema type that represents a `Block`. var ( - Schema schema.Type - SchemaPrototype ipld.NodePrototype + Schema schema.Type + SchemaPrototype ipld.NodePrototype + EncryptionSchema schema.Type + EncryptionSchemaPrototype ipld.NodePrototype ) func init() { Schema, SchemaPrototype = mustSetSchema( + "Block", &Block{}, &DAGLink{}, &crdt.CRDT{}, @@ -42,6 +45,11 @@ func init() { &crdt.CompositeDAGDelta{}, &crdt.CounterDelta{}, ) + + EncryptionSchema, EncryptionSchemaPrototype = mustSetSchema( + "Encryption", + &Encryption{}, + ) } type schemaDefinition interface { @@ -49,7 +57,7 @@ type schemaDefinition interface { IPLDSchemaBytes() []byte } -func mustSetSchema(schemas ...schemaDefinition) (schema.Type, ipld.NodePrototype) { +func mustSetSchema(schemaName string, schemas ...schemaDefinition) (schema.Type, ipld.NodePrototype) { schemaBytes := make([][]byte, 0, len(schemas)) for _, s := range schemas { schemaBytes = append(schemaBytes, s.IPLDSchemaBytes()) @@ -59,12 +67,12 @@ func mustSetSchema(schemas ...schemaDefinition) (schema.Type, ipld.NodePrototype if err != nil { panic(err) } - blockSchemaType := ts.TypeByName("Block") + blockSchemaType := ts.TypeByName(schemaName) // Calling bindnode.Prototype here ensure that [Block] and all the types it contains // are compatible with the IPLD schema defined by blockSchemaType. // If [Block] and `blockSchematype` do not match, this will panic. - proto := bindnode.Prototype(&Block{}, blockSchemaType) + proto := bindnode.Prototype(schemas[0], blockSchemaType) return blockSchemaType, proto.Representation() } @@ -97,27 +105,78 @@ func NewDAGLink(name string, link cidlink.Link) DAGLink { } } +// Encryption contains the encryption information for the block's delta. +type Encryption struct { + // DocID is the ID of the document that is encrypted with the associated encryption key. + DocID []byte + // FieldName is the name of the field that is encrypted with the associated encryption key. + // It is set if encryption is applied to a field instead of the whole doc. + // It needs to be a pointer so that it can be translated from and to `optional` in the IPLD schema. + FieldName *string + // Encryption key. + Key []byte +} + // Block is a block that contains a CRDT delta and links to other blocks. type Block struct { // Delta is the CRDT delta that is stored in the block. Delta crdt.CRDT // Links are the links to other blocks in the DAG. Links []DAGLink - // IsEncrypted is a flag that indicates if the block's delta is encrypted. - // It needs to be a pointer so that it can be translated from and to `optional Bool` in the IPLD schema. - IsEncrypted *bool + // Encryption contains the encryption information for the block's delta. + // It needs to be a pointer so that it can be translated from and to `optional` in the IPLD schema. + Encryption *cidlink.Link +} + +// IsEncrypted returns true if the block is encrypted. +func (block *Block) IsEncrypted() bool { + return block.Encryption != nil +} + +// Clone returns a shallow copy of the block with cloned delta. +func (block *Block) Clone() *Block { + return &Block{ + Delta: block.Delta.Clone(), + Links: block.Links, + Encryption: block.Encryption, + } +} + +// GetHeadLinks returns the CIDs of the previous blocks. There can be more than 1 with multiple heads. +func (block *Block) GetHeadLinks() []cid.Cid { + var heads []cid.Cid + for _, link := range block.Links { + if link.Name == core.HEAD { + heads = append(heads, link.Cid) + } + } + return heads } // IPLDSchemaBytes returns the IPLD schema representation for the block. // // This needs to match the [Block] struct or [mustSetSchema] will panic on init. -func (b Block) IPLDSchemaBytes() []byte { +func (block *Block) IPLDSchemaBytes() []byte { return []byte(` - type Block struct { - delta CRDT - links [ DAGLink ] - isEncrypted optional Bool - }`) + type Block struct { + delta CRDT + links [DAGLink] + encryption optional Link + } + `) +} + +// IPLDSchemaBytes returns the IPLD schema representation for the encryption block. +// +// This needs to match the [Encryption] struct or [mustSetSchema] will panic on init. +func (enc *Encryption) IPLDSchemaBytes() []byte { + return []byte(` + type Encryption struct { + docID Bytes + fieldName optional String + key Bytes + } + `) } // New creates a new block with the given delta and links. @@ -153,6 +212,16 @@ func New(delta core.Delta, links []DAGLink, heads ...cid.Cid) *Block { } } +// GetFromBytes returns a block from encoded bytes. +func GetEncryptionBlockFromBytes(b []byte) (*Encryption, error) { + enc := &Encryption{} + err := enc.Unmarshal(b) + if err != nil { + return nil, err + } + return enc, nil +} + // GetFromBytes returns a block from encoded bytes. func GetFromBytes(b []byte) (*Block, error) { block := &Block{} @@ -172,8 +241,17 @@ func GetFromNode(node ipld.Node) (*Block, error) { return block, nil } +// GetFromNode returns a block from a node. +func GetEncryptionBlockFromNode(node ipld.Node) (*Encryption, error) { + encBlock, ok := bindnode.Unwrap(node).(*Encryption) + if !ok { + return nil, NewErrNodeToBlock(node) + } + return encBlock, nil +} + // Marshal encodes the delta using CBOR encoding. -func (block *Block) Marshal() (data []byte, err error) { +func (block *Block) Marshal() ([]byte, error) { b, err := ipld.Marshal(dagcbor.Encode, block, Schema) if err != nil { return nil, NewErrEncodingBlock(err) @@ -183,12 +261,25 @@ func (block *Block) Marshal() (data []byte, err error) { // Unmarshal decodes the delta from CBOR encoding. func (block *Block) Unmarshal(b []byte) error { - _, err := ipld.Unmarshal( - b, - dagcbor.Decode, - block, - Schema, - ) + _, err := ipld.Unmarshal(b, dagcbor.Decode, block, Schema) + if err != nil { + return NewErrUnmarshallingBlock(err) + } + return nil +} + +// Marshal encodes the delta using CBOR encoding. +func (enc *Encryption) Marshal() ([]byte, error) { + b, err := ipld.Marshal(dagcbor.Encode, enc, EncryptionSchema) + if err != nil { + return nil, NewErrEncodingBlock(err) + } + return b, nil +} + +// Unmarshal decodes the delta from CBOR encoding. +func (enc *Encryption) Unmarshal(b []byte) error { + _, err := ipld.Unmarshal(b, dagcbor.Decode, enc, EncryptionSchema) if err != nil { return NewErrUnmarshallingBlock(err) } @@ -196,10 +287,15 @@ func (block *Block) Unmarshal(b []byte) error { } // GenerateNode generates an IPLD node from the block in its representation form. -func (block *Block) GenerateNode() (node ipld.Node) { +func (block *Block) GenerateNode() ipld.Node { return bindnode.Wrap(block, Schema).Representation() } +// GenerateNode generates an IPLD node from the encryption block in its representation form. +func (enc *Encryption) GenerateNode() ipld.Node { + return bindnode.Wrap(enc, EncryptionSchema).Representation() +} + // GetLinkByName returns the link by name. It will return false if the link does not exist. func (block *Block) GetLinkByName(name string) (cidlink.Link, bool) { for _, link := range block.Links { diff --git a/internal/core/block/block_test.go b/internal/core/block/block_test.go index 5b68cf9067..d7fe2d1bf0 100644 --- a/internal/core/block/block_test.go +++ b/internal/core/block/block_test.go @@ -180,13 +180,23 @@ func TestBlockDeltaPriority(t *testing.T) { require.Equal(t, uint64(2), block.Delta.GetPriority()) } -func TestBlockMarshal_IsEncryptedNotSet_ShouldNotContainIsEcryptedField(t *testing.T) { +func TestBlockMarshal_IfEncryptedNotSet_ShouldNotContainIsEncryptedField(t *testing.T) { lsys := cidlink.DefaultLinkSystem() store := memstore.Store{} lsys.SetReadStorage(&store) lsys.SetWriteStorage(&store) - fieldBlock := Block{ + encBlock := Encryption{ + DocID: []byte("docID"), + Key: []byte("keyID"), + } + + encBlockLink, err := lsys.Store(ipld.LinkContext{}, GetLinkPrototype(), encBlock.GenerateNode()) + require.NoError(t, err) + + link := encBlockLink.(cidlink.Link) + + block := Block{ Delta: crdt.CRDT{ LWWRegDelta: &crdt.LWWRegDelta{ DocID: []byte("docID"), @@ -196,11 +206,27 @@ func TestBlockMarshal_IsEncryptedNotSet_ShouldNotContainIsEcryptedField(t *testi Data: []byte("John"), }, }, + Encryption: &link, } - b, err := fieldBlock.Marshal() + blockLink, err := lsys.Store(ipld.LinkContext{}, GetLinkPrototype(), block.GenerateNode()) require.NoError(t, err) - require.NotContains(t, string(b), "isEncrypted") + + nd, err := lsys.Load(ipld.LinkContext{}, blockLink, SchemaPrototype) + require.NoError(t, err) + + loadedBlock, err := GetFromNode(nd) + require.NoError(t, err) + + require.NotNil(t, loadedBlock.Encryption) + + nd, err = lsys.Load(ipld.LinkContext{}, loadedBlock.Encryption, EncryptionSchemaPrototype) + require.NoError(t, err) + + loadedEncBlock, err := GetEncryptionBlockFromNode(nd) + require.NoError(t, err) + + require.Equal(t, encBlock, *loadedEncBlock) } func TestBlockMarshal_IsEncryptedNotSetWithLinkSystem_ShouldLoadWithNoError(t *testing.T) { @@ -228,3 +254,58 @@ func TestBlockMarshal_IsEncryptedNotSetWithLinkSystem_ShouldLoadWithNoError(t *t _, err = GetFromNode(nd) require.NoError(t, err) } + +func TestBlockUnmarshal_ValidInput_Succeed(t *testing.T) { + validBlock := Block{ + Delta: crdt.CRDT{ + LWWRegDelta: &crdt.LWWRegDelta{ + DocID: []byte("docID"), + FieldName: "name", + Priority: 1, + SchemaVersionID: "schemaVersionID", + Data: []byte("John"), + }, + }, + } + + marshaledData, err := validBlock.Marshal() + require.NoError(t, err) + + var unmarshaledBlock Block + err = unmarshaledBlock.Unmarshal(marshaledData) + require.NoError(t, err) + + require.Equal(t, validBlock, unmarshaledBlock) +} + +func TestBlockUnmarshal_InvalidCBOR_Error(t *testing.T) { + invalidData := []byte("invalid CBOR data") + var block Block + err := block.Unmarshal(invalidData) + require.Error(t, err) +} + +func TestEncryptionBlockUnmarshal_InvalidCBOR_Error(t *testing.T) { + invalidData := []byte("invalid CBOR data") + var encBlock Encryption + err := encBlock.Unmarshal(invalidData) + require.Error(t, err) +} + +func TestEncryptionBlockUnmarshal_ValidInput_Succeed(t *testing.T) { + fieldName := "fieldName" + encBlock := Encryption{ + DocID: []byte("docID"), + Key: []byte("keyID"), + FieldName: &fieldName, + } + + marshaledData, err := encBlock.Marshal() + require.NoError(t, err) + + var unmarshaledBlock Encryption + err = unmarshaledBlock.Unmarshal(marshaledData) + require.NoError(t, err) + + require.Equal(t, encBlock, unmarshaledBlock) +} diff --git a/internal/core/block/errors.go b/internal/core/block/errors.go index 9b6b0e8a95..ced4c4d6a1 100644 --- a/internal/core/block/errors.go +++ b/internal/core/block/errors.go @@ -17,10 +17,12 @@ import ( ) const ( - errNodeToBlock string = "failed to convert node to block" - errEncodingBlock string = "failed to encode block" - errUnmarshallingBlock string = "failed to unmarshal block" - errGeneratingLink string = "failed to generate link" + errNodeToBlock string = "failed to convert node to block" + errEncodingBlock string = "failed to encode block" + errUnmarshallingBlock string = "failed to unmarshal block" + errGeneratingLink string = "failed to generate link" + errInvalidBlockEncryptionType string = "invalid block encryption type" + errInvalidBlockEncryptionKeyID string = "invalid block encryption key id" ) // Errors returnable from this package. @@ -28,10 +30,12 @@ const ( // This list is incomplete and undefined errors may also be returned. // Errors returned from this package may be tested against these errors with errors.Is. var ( - ErrNodeToBlock = errors.New(errNodeToBlock) - ErrEncodingBlock = errors.New(errEncodingBlock) - ErrUnmarshallingBlock = errors.New(errUnmarshallingBlock) - ErrGeneratingLink = errors.New(errGeneratingLink) + ErrNodeToBlock = errors.New(errNodeToBlock) + ErrEncodingBlock = errors.New(errEncodingBlock) + ErrUnmarshallingBlock = errors.New(errUnmarshallingBlock) + ErrGeneratingLink = errors.New(errGeneratingLink) + ErrInvalidBlockEncryptionType = errors.New(errInvalidBlockEncryptionType) + ErrInvalidBlockEncryptionKeyID = errors.New(errInvalidBlockEncryptionKeyID) ) // NewErrFailedToGetPriority returns an error indicating that the priority could not be retrieved. diff --git a/internal/core/crdt/composite.go b/internal/core/crdt/composite.go index 58372cfb49..c730badcb6 100644 --- a/internal/core/crdt/composite.go +++ b/internal/core/crdt/composite.go @@ -106,13 +106,13 @@ func (c CompositeDAG) Merge(ctx context.Context, delta core.Delta) error { if err != nil { return err } - return c.deleteWithPrefix(ctx, c.key.WithValueFlag().WithFieldId("")) + return c.deleteWithPrefix(ctx, c.key.WithValueFlag().WithFieldID("")) } // We cannot rely on the dagDelta.Status here as it may have been deleted locally, this is not // reflected in `dagDelta.Status` if sourced via P2P. Updates synced via P2P should not undelete - // the local reperesentation of the document. - versionKey := c.key.WithValueFlag().WithFieldId(core.DATASTORE_DOC_VERSION_FIELD_ID) + // the local representation of the document. + versionKey := c.key.WithValueFlag().WithFieldID(core.DATASTORE_DOC_VERSION_FIELD_ID) objectMarker, err := c.store.Get(ctx, c.key.ToPrimaryDataStoreKey().ToDS()) hasObjectMarker := !errors.Is(err, ds.ErrNotFound) if err != nil && hasObjectMarker { diff --git a/internal/core/crdt/lwwreg_test.go b/internal/core/crdt/lwwreg_test.go index 136d5cd09d..5b56df7636 100644 --- a/internal/core/crdt/lwwreg_test.go +++ b/internal/core/crdt/lwwreg_test.go @@ -16,6 +16,7 @@ import ( "testing" ds "github.com/ipfs/go-datastore" + "github.com/stretchr/testify/require" "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/internal/core" @@ -31,11 +32,12 @@ func setupLWWRegister() LWWRegister { return NewLWWRegister(store, core.CollectionSchemaVersionKey{}, key, "") } -func setupLoadedLWWRegister(ctx context.Context) LWWRegister { +func setupLoadedLWWRegister(t *testing.T, ctx context.Context) LWWRegister { lww := setupLWWRegister() addDelta := lww.Set([]byte("test")) addDelta.SetPriority(1) - lww.Merge(ctx, addDelta) + err := lww.Merge(ctx, addDelta) + require.NoError(t, err) return lww } @@ -71,12 +73,13 @@ func TestLWWRegisterInitialMerge(t *testing.T) { } } -func TestLWWReisterFollowupMerge(t *testing.T) { +func TestLWWRegisterFollowupMerge(t *testing.T) { ctx := context.Background() - lww := setupLoadedLWWRegister(ctx) + lww := setupLoadedLWWRegister(t, ctx) addDelta := lww.Set([]byte("test2")) addDelta.SetPriority(2) - lww.Merge(ctx, addDelta) + err := lww.Merge(ctx, addDelta) + require.NoError(t, err) val, err := lww.Value(ctx) if err != nil { @@ -90,10 +93,11 @@ func TestLWWReisterFollowupMerge(t *testing.T) { func TestLWWRegisterOldMerge(t *testing.T) { ctx := context.Background() - lww := setupLoadedLWWRegister(ctx) + lww := setupLoadedLWWRegister(t, ctx) addDelta := lww.Set([]byte("test-1")) addDelta.SetPriority(0) - lww.Merge(ctx, addDelta) + err := lww.Merge(ctx, addDelta) + require.NoError(t, err) val, err := lww.Value(ctx) if err != nil { @@ -106,9 +110,7 @@ func TestLWWRegisterOldMerge(t *testing.T) { } func TestLWWRegisterDeltaInit(t *testing.T) { - delta := &LWWRegDelta{ - Data: []byte("test"), - } + delta := &LWWRegDelta{} var _ core.Delta = delta // checks if LWWRegDelta implements core.Delta (also checked in the implementation code, but w.e) } diff --git a/internal/core/key.go b/internal/core/key.go index b913a75f54..0e7942411d 100644 --- a/internal/core/key.go +++ b/internal/core/key.go @@ -127,7 +127,7 @@ var _ Key = (*PrimaryDataStoreKey)(nil) type HeadStoreKey struct { DocID string - FieldId string //can be 'C' + FieldID string //can be 'C' Cid cid.Cid } @@ -285,7 +285,7 @@ func NewHeadStoreKey(key string) (HeadStoreKey, error) { return HeadStoreKey{ // elements[0] is empty (key has leading '/') DocID: elements[1], - FieldId: elements[2], + FieldID: elements[2], Cid: cid, }, nil } @@ -471,16 +471,16 @@ func (k DataStoreKey) WithInstanceInfo(key DataStoreKey) DataStoreKey { return newKey } -func (k DataStoreKey) WithFieldId(fieldId string) DataStoreKey { +func (k DataStoreKey) WithFieldID(fieldID string) DataStoreKey { newKey := k - newKey.FieldID = fieldId + newKey.FieldID = fieldID return newKey } func (k DataStoreKey) ToHeadStoreKey() HeadStoreKey { return HeadStoreKey{ DocID: k.DocID, - FieldId: k.FieldID, + FieldID: k.FieldID, } } @@ -496,9 +496,9 @@ func (k HeadStoreKey) WithCid(c cid.Cid) HeadStoreKey { return newKey } -func (k HeadStoreKey) WithFieldId(fieldId string) HeadStoreKey { +func (k HeadStoreKey) WithFieldID(fieldID string) HeadStoreKey { newKey := k - newKey.FieldId = fieldId + newKey.FieldID = fieldID return newKey } @@ -858,8 +858,8 @@ func (k HeadStoreKey) ToString() string { if k.DocID != "" { result = result + "/" + k.DocID } - if k.FieldId != "" { - result = result + "/" + k.FieldId + if k.FieldID != "" { + result = result + "/" + k.FieldID } if k.Cid.Defined() { result = result + "/" + k.Cid.String() @@ -927,34 +927,3 @@ func bytesPrefixEnd(b []byte) []byte { // maximal byte string (i.e. already \xff...). return b } - -// EncStoreDocKey is a key for the encryption store. -type EncStoreDocKey struct { - DocID string - FieldName string -} - -var _ Key = (*EncStoreDocKey)(nil) - -// NewEncStoreDocKey creates a new EncStoreDocKey from a docID and fieldID. -func NewEncStoreDocKey(docID string, fieldName string) EncStoreDocKey { - return EncStoreDocKey{ - DocID: docID, - FieldName: fieldName, - } -} - -func (k EncStoreDocKey) ToString() string { - if k.FieldName == "" { - return k.DocID - } - return fmt.Sprintf("%s/%s", k.DocID, k.FieldName) -} - -func (k EncStoreDocKey) Bytes() []byte { - return []byte(k.ToString()) -} - -func (k EncStoreDocKey) ToDS() ds.Key { - return ds.NewKey(k.ToString()) -} diff --git a/internal/db/base/collection_keys.go b/internal/db/base/collection_keys.go index e23707285c..8878d50b13 100644 --- a/internal/db/base/collection_keys.go +++ b/internal/db/base/collection_keys.go @@ -45,7 +45,7 @@ func MakePrimaryIndexKeyForCRDT( case client.COMPOSITE: return MakeDataStoreKeyWithCollectionDescription(c.Description). WithInstanceInfo(key). - WithFieldId(core.COMPOSITE_NAMESPACE), + WithFieldID(core.COMPOSITE_NAMESPACE), nil case client.LWW_REGISTER, client.PN_COUNTER, client.P_COUNTER: field, ok := c.GetFieldByName(fieldName) @@ -55,7 +55,7 @@ func MakePrimaryIndexKeyForCRDT( return MakeDataStoreKeyWithCollectionDescription(c.Description). WithInstanceInfo(key). - WithFieldId(fmt.Sprint(field.ID)), + WithFieldID(fmt.Sprint(field.ID)), nil } return core.DataStoreKey{}, ErrInvalidCrdtType diff --git a/internal/db/collection.go b/internal/db/collection.go index e6205fecd9..6165218f78 100644 --- a/internal/db/collection.go +++ b/internal/db/collection.go @@ -915,7 +915,7 @@ func (c *collection) saveCompositeToMerkleCRDT( status client.DocumentStatus, ) (cidlink.Link, []byte, error) { txn := mustGetContextTxn(ctx) - dsKey = dsKey.WithFieldId(core.COMPOSITE_NAMESPACE) + dsKey = dsKey.WithFieldID(core.COMPOSITE_NAMESPACE) merkleCRDT := merklecrdt.NewMerkleCompositeDAG( txn, core.NewCollectionSchemaVersionKey(c.Schema().VersionID, c.ID()), diff --git a/internal/db/collection_delete.go b/internal/db/collection_delete.go index 082a53caf2..9ccca92ed5 100644 --- a/internal/db/collection_delete.go +++ b/internal/db/collection_delete.go @@ -144,7 +144,7 @@ func (c *collection) applyDelete( dsKey := primaryKey.ToDataStoreKey() headset := clock.NewHeadSet( txn.Headstore(), - dsKey.WithFieldId(core.COMPOSITE_NAMESPACE).ToHeadStoreKey(), + dsKey.WithFieldID(core.COMPOSITE_NAMESPACE).ToHeadStoreKey(), ) cids, _, err := headset.List(ctx) if err != nil { diff --git a/internal/db/config.go b/internal/db/config.go index 8ce725ebd0..3d69e833c4 100644 --- a/internal/db/config.go +++ b/internal/db/config.go @@ -19,12 +19,16 @@ const ( updateEventBufferSize = 100 ) +type dbOptions struct { + maxTxnRetries immutable.Option[int] +} + // Option is a funtion that sets a config value on the db. -type Option func(*db) +type Option func(*dbOptions) // WithMaxRetries sets the maximum number of retries per transaction. func WithMaxRetries(num int) Option { - return func(db *db) { - db.maxTxnRetries = immutable.Some(num) + return func(opts *dbOptions) { + opts.maxTxnRetries = immutable.Some(num) } } diff --git a/internal/db/config_test.go b/internal/db/config_test.go index 405e192598..a52d494a21 100644 --- a/internal/db/config_test.go +++ b/internal/db/config_test.go @@ -17,8 +17,8 @@ import ( ) func TestWithMaxRetries(t *testing.T) { - d := &db{} - WithMaxRetries(10)(d) + d := dbOptions{} + WithMaxRetries(10)(&d) assert.True(t, d.maxTxnRetries.HasValue()) assert.Equal(t, 10, d.maxTxnRetries.Value()) } diff --git a/internal/db/context.go b/internal/db/context.go index 8ad51c86ce..a2fa50507f 100644 --- a/internal/db/context.go +++ b/internal/db/context.go @@ -17,7 +17,6 @@ import ( acpIdentity "github.com/sourcenetwork/defradb/acp/identity" "github.com/sourcenetwork/defradb/datastore" - "github.com/sourcenetwork/defradb/internal/encryption" ) // txnContextKey is the key type for transaction context values. @@ -58,12 +57,10 @@ func ensureContextTxn(ctx context.Context, db transactionDB, readOnly bool) (con if ok { return SetContextTxn(ctx, &explicitTxn{txn}), &explicitTxn{txn}, nil } - // implicit transaction txn, err := db.NewTxn(ctx, readOnly) if err != nil { return nil, txn, err } - ctx = encryption.ContextWithStore(ctx, txn) return SetContextTxn(ctx, txn), txn, nil } diff --git a/internal/db/db.go b/internal/db/db.go index 81ec48e199..d88c5920bc 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -117,8 +117,13 @@ func newDB( } // apply options + var opts dbOptions for _, opt := range options { - opt(db) + opt(&opts) + } + + if opts.maxTxnRetries.HasValue() { + db.maxTxnRetries = opts.maxTxnRetries } if lens != nil { @@ -161,6 +166,11 @@ func (db *db) Blockstore() datastore.Blockstore { return db.multistore.Blockstore() } +// Encstore returns the internal enc store which contains encryption key for documents and their fields. +func (db *db) Encstore() datastore.Blockstore { + return db.multistore.Encstore() +} + // Peerstore returns the internal DAG store which contains IPLD blocks. func (db *db) Peerstore() datastore.DSBatching { return db.multistore.Peerstore() diff --git a/internal/db/errors.go b/internal/db/errors.go index 2da8c9c734..612d5ddb40 100644 --- a/internal/db/errors.go +++ b/internal/db/errors.go @@ -101,6 +101,7 @@ const ( errReplicatorCollections string = "failed to get collections for replicator" errReplicatorNotFound string = "replicator not found" errCanNotEncryptBuiltinField string = "can not encrypt build-in field" + errFailedToHandleEncKeysReceivedEvent string = "failed to handle encryption-keys-received event" errSelfReferenceWithoutSelf string = "must specify 'Self' kind for self referencing relations" errColNotMaterialized string = "non-materialized collections are not supported" errMaterializedViewAndACPNotSupported string = "materialized views do not support ACP" diff --git a/internal/db/fetcher/dag.go b/internal/db/fetcher/dag.go index cec1121827..3d3a6dd85e 100644 --- a/internal/db/fetcher/dag.go +++ b/internal/db/fetcher/dag.go @@ -92,7 +92,7 @@ func (hf *HeadFetcher) FetchNext() (*cid.Cid, error) { return nil, err } - if hf.fieldId.HasValue() && hf.fieldId.Value() != headStoreKey.FieldId { + if hf.fieldId.HasValue() && hf.fieldId.Value() != headStoreKey.FieldID { // FieldIds do not match, continue to next row return hf.FetchNext() } diff --git a/internal/db/fetcher/versioned.go b/internal/db/fetcher/versioned.go index 0ff58c4eeb..80b71cdd88 100644 --- a/internal/db/fetcher/versioned.go +++ b/internal/db/fetcher/versioned.go @@ -415,8 +415,7 @@ func (vf *VersionedFetcher) processBlock( vf.mCRDTs[crdtIndex] = mcrdt } - err = mcrdt.Clock().ProcessBlock(vf.ctx, block, blockLink, false) - return err + return mcrdt.Clock().ProcessBlock(vf.ctx, block, blockLink) } func (vf *VersionedFetcher) getDAGBlock(c cid.Cid) (*coreblock.Block, error) { diff --git a/internal/db/indexed_docs_test.go b/internal/db/indexed_docs_test.go index fad45aa11f..9f4ea3fe72 100644 --- a/internal/db/indexed_docs_test.go +++ b/internal/db/indexed_docs_test.go @@ -309,7 +309,7 @@ func TestNonUnique_IfDocWithDescendingOrderIsAdded_ShouldBeIndexed(t *testing.T) assert.Len(t, data, 0) } -func TestNonUnique_IfFailsToStoredIndexedDoc_Error(t *testing.T) { +func TestNonUnique_IfFailsToStoreIndexedDoc_Error(t *testing.T) { f := newIndexTestFixture(t) defer f.db.Close() f.createUserCollectionIndexOnName() @@ -698,7 +698,7 @@ func TestNonUniqueCreate_IfDatastoreFailsToStoreIndex_ReturnError(t *testing.T) fieldKeyString := core.DataStoreKey{ CollectionRootID: f.users.Description().RootID, }.WithDocID(doc.ID().String()). - WithFieldId("1"). + WithFieldID("1"). WithValueFlag(). ToString() diff --git a/internal/db/merge.go b/internal/db/merge.go index e588cb60a4..58c89cfc4e 100644 --- a/internal/db/merge.go +++ b/internal/db/merge.go @@ -16,6 +16,7 @@ import ( "sync" "github.com/ipfs/go-cid" + ipld "github.com/ipfs/go-ipld-format" "github.com/ipld/go-ipld-prime/linking" cidlink "github.com/ipld/go-ipld-prime/linking/cid" @@ -28,6 +29,7 @@ import ( "github.com/sourcenetwork/defradb/internal/core" coreblock "github.com/sourcenetwork/defradb/internal/core/block" "github.com/sourcenetwork/defradb/internal/db/base" + "github.com/sourcenetwork/defradb/internal/encryption" "github.com/sourcenetwork/defradb/internal/merkle/clock" merklecrdt "github.com/sourcenetwork/defradb/internal/merkle/crdt" ) @@ -44,21 +46,18 @@ func (db *db) executeMerge(ctx context.Context, dagMerge event.Merge) error { return err } - ls := cidlink.DefaultLinkSystem() - ls.SetReadStorage(txn.Blockstore().AsIPLDStorage()) - docID, err := client.NewDocIDFromString(dagMerge.DocID) if err != nil { return err } dsKey := base.MakeDataStoreKeyWithCollectionAndDocID(col.Description(), docID.String()) - mp, err := db.newMergeProcessor(txn, ls, col, dsKey) + mp, err := db.newMergeProcessor(txn, col, dsKey) if err != nil { return err } - mt, err := getHeadsAsMergeTarget(ctx, txn, dsKey) + mt, err := getHeadsAsMergeTarget(ctx, txn, dsKey.WithFieldID(core.COMPOSITE_NAMESPACE)) if err != nil { return err } @@ -130,26 +129,40 @@ func (m *mergeQueue) done(docID string) { type mergeProcessor struct { txn datastore.Txn - lsys linking.LinkSystem + blockLS linking.LinkSystem + encBlockLS linking.LinkSystem mCRDTs map[string]merklecrdt.MerkleCRDT col *collection dsKey core.DataStoreKey + // composites is a list of composites that need to be merged. composites *list.List + // missingEncryptionBlocks is a list of blocks that we failed to fetch + missingEncryptionBlocks map[cidlink.Link]struct{} + // availableEncryptionBlocks is a list of blocks that we have successfully fetched + availableEncryptionBlocks map[cidlink.Link]*coreblock.Encryption } func (db *db) newMergeProcessor( txn datastore.Txn, - lsys linking.LinkSystem, col *collection, dsKey core.DataStoreKey, ) (*mergeProcessor, error) { + blockLS := cidlink.DefaultLinkSystem() + blockLS.SetReadStorage(txn.Blockstore().AsIPLDStorage()) + + encBlockLS := cidlink.DefaultLinkSystem() + encBlockLS.SetReadStorage(txn.Encstore().AsIPLDStorage()) + return &mergeProcessor{ - txn: txn, - lsys: lsys, - mCRDTs: make(map[string]merklecrdt.MerkleCRDT), - col: col, - dsKey: dsKey, - composites: list.New(), + txn: txn, + blockLS: blockLS, + encBlockLS: encBlockLS, + mCRDTs: make(map[string]merklecrdt.MerkleCRDT), + col: col, + dsKey: dsKey, + composites: list.New(), + missingEncryptionBlocks: make(map[cidlink.Link]struct{}), + availableEncryptionBlocks: make(map[cidlink.Link]*coreblock.Encryption), }, nil } @@ -165,7 +178,7 @@ func newMergeTarget() mergeTarget { } // loadComposites retrieves and stores into the merge processor the composite blocks for the given -// document until it reaches a block that has already been merged or until we reach the genesis block. +// CID until it reaches a block that has already been merged or until we reach the genesis block. func (mp *mergeProcessor) loadComposites( ctx context.Context, blockCid cid.Cid, @@ -176,7 +189,7 @@ func (mp *mergeProcessor) loadComposites( return nil } - nd, err := mp.lsys.Load(linking.LinkContext{Ctx: ctx}, cidlink.Link{Cid: blockCid}, coreblock.SchemaPrototype) + nd, err := mp.blockLS.Load(linking.LinkContext{Ctx: ctx}, cidlink.Link{Cid: blockCid}, coreblock.SchemaPrototype) if err != nil { return err } @@ -191,12 +204,10 @@ func (mp *mergeProcessor) loadComposites( // In this case, we also need to walk back the merge target's DAG until we reach a common block. if block.Delta.GetPriority() >= mt.headHeight { mp.composites.PushFront(block) - for _, link := range block.Links { - if link.Name == core.HEAD { - err := mp.loadComposites(ctx, link.Cid, mt) - if err != nil { - return err - } + for _, prevCid := range block.GetHeadLinks() { + err := mp.loadComposites(ctx, prevCid, mt) + if err != nil { + return err } } } else { @@ -204,7 +215,7 @@ func (mp *mergeProcessor) loadComposites( for _, b := range mt.heads { for _, link := range b.Links { if link.Name == core.HEAD { - nd, err := mp.lsys.Load(linking.LinkContext{Ctx: ctx}, link.Link, coreblock.SchemaPrototype) + nd, err := mp.blockLS.Load(linking.LinkContext{Ctx: ctx}, link.Link, coreblock.SchemaPrototype) if err != nil { return err } @@ -227,15 +238,50 @@ func (mp *mergeProcessor) loadComposites( func (mp *mergeProcessor) mergeComposites(ctx context.Context) error { for e := mp.composites.Front(); e != nil; e = e.Next() { block := e.Value.(*coreblock.Block) - var onlyHeads bool - if block.IsEncrypted != nil && *block.IsEncrypted { - onlyHeads = true - } link, err := block.GenerateLink() if err != nil { return err } - err = mp.processBlock(ctx, block, link, onlyHeads) + err = mp.processBlock(ctx, block, link) + if err != nil { + return err + } + } + + return mp.tryFetchMissingBlocksAndMerge(ctx) +} + +func (mp *mergeProcessor) tryFetchMissingBlocksAndMerge(ctx context.Context) error { + for len(mp.missingEncryptionBlocks) > 0 { + links := make([]cidlink.Link, 0, len(mp.missingEncryptionBlocks)) + for link := range mp.missingEncryptionBlocks { + links = append(links, link) + } + msg, results := encryption.NewRequestKeysMessage(links) + mp.col.db.events.Publish(msg) + + res := <-results.Get() + if res.Error != nil { + return res.Error + } + + clear(mp.missingEncryptionBlocks) + + for i := range res.Items { + _, link, err := cid.CidFromBytes(res.Items[i].Link) + if err != nil { + return err + } + var encBlock coreblock.Encryption + err = encBlock.Unmarshal(res.Items[i].Block) + if err != nil { + return err + } + + mp.availableEncryptionBlocks[cidlink.Link{Cid: link}] = &encBlock + } + + err := mp.mergeComposites(ctx) if err != nil { return err } @@ -243,36 +289,109 @@ func (mp *mergeProcessor) mergeComposites(ctx context.Context) error { return nil } +func (mp *mergeProcessor) loadEncryptionBlock( + ctx context.Context, + encLink cidlink.Link, +) (*coreblock.Encryption, error) { + nd, err := mp.encBlockLS.Load(linking.LinkContext{Ctx: ctx}, encLink, coreblock.EncryptionSchemaPrototype) + if err != nil { + if errors.Is(err, ipld.ErrNotFound{}) { + mp.missingEncryptionBlocks[encLink] = struct{}{} + return nil, nil + } + return nil, err + } + + return coreblock.GetEncryptionBlockFromNode(nd) +} + +func (mp *mergeProcessor) tryGetEncryptionBlock( + ctx context.Context, + encLink cidlink.Link, +) (*coreblock.Encryption, error) { + if encBlock, ok := mp.availableEncryptionBlocks[encLink]; ok { + return encBlock, nil + } + if _, ok := mp.missingEncryptionBlocks[encLink]; ok { + return nil, nil + } + + encBlock, err := mp.loadEncryptionBlock(ctx, encLink) + if err != nil { + return nil, err + } + + if encBlock != nil { + mp.availableEncryptionBlocks[encLink] = encBlock + } + + return encBlock, nil +} + +// processEncryptedBlock decrypts the block if it is encrypted and returns the decrypted block. +// If the block is encrypted and we were not able to decrypt it, it returns false as the second return value +// which indicates that the we can't read the block. +// If we were able to decrypt the block, we return the decrypted block and true as the second return value. +func (mp *mergeProcessor) processEncryptedBlock( + ctx context.Context, + dagBlock *coreblock.Block, +) (*coreblock.Block, bool, error) { + if dagBlock.IsEncrypted() { + encBlock, err := mp.tryGetEncryptionBlock(ctx, *dagBlock.Encryption) + if err != nil { + return nil, false, err + } + + if encBlock == nil { + return dagBlock, false, nil + } + + plainTextBlock, err := decryptBlock(ctx, dagBlock, encBlock) + if err != nil { + return nil, false, err + } + if plainTextBlock != nil { + return plainTextBlock, true, nil + } + } + return dagBlock, true, nil +} + // processBlock merges the block and its children to the datastore and sets the head accordingly. -// If onlyHeads is true, it will skip merging and update only the heads. func (mp *mergeProcessor) processBlock( ctx context.Context, - block *coreblock.Block, + dagBlock *coreblock.Block, blockLink cidlink.Link, - onlyHeads bool, ) error { - crdt, err := mp.initCRDTForType(block.Delta.GetFieldName()) + block, canRead, err := mp.processEncryptedBlock(ctx, dagBlock) if err != nil { return err } - // If the CRDT is nil, it means the field is not part - // of the schema and we can safely ignore it. - if crdt == nil { - return nil - } + if canRead { + crdt, err := mp.initCRDTForType(dagBlock.Delta.GetFieldName()) + if err != nil { + return err + } - err = crdt.Clock().ProcessBlock(ctx, block, blockLink, onlyHeads) - if err != nil { - return err + // If the CRDT is nil, it means the field is not part + // of the schema and we can safely ignore it. + if crdt == nil { + return nil + } + + err = crdt.Clock().ProcessBlock(ctx, block, blockLink) + if err != nil { + return err + } } - for _, link := range block.Links { + for _, link := range dagBlock.Links { if link.Name == core.HEAD { continue } - nd, err := mp.lsys.Load(linking.LinkContext{Ctx: ctx}, link.Link, coreblock.SchemaPrototype) + nd, err := mp.blockLS.Load(linking.LinkContext{Ctx: ctx}, link.Link, coreblock.SchemaPrototype) if err != nil { return err } @@ -282,7 +401,7 @@ func (mp *mergeProcessor) processBlock( return err } - if err := mp.processBlock(ctx, childBlock, link.Link, onlyHeads); err != nil { + if err := mp.processBlock(ctx, childBlock, link.Link); err != nil { return err } } @@ -290,9 +409,31 @@ func (mp *mergeProcessor) processBlock( return nil } -func (mp *mergeProcessor) initCRDTForType( - field string, -) (merklecrdt.MerkleCRDT, error) { +func decryptBlock( + ctx context.Context, + block *coreblock.Block, + encBlock *coreblock.Encryption, +) (*coreblock.Block, error) { + _, encryptor := encryption.EnsureContextWithEncryptor(ctx) + + if block.Delta.IsComposite() { + // for composite blocks there is nothing to decrypt + return block, nil + } + + bytes, err := encryptor.Decrypt(block.Delta.GetData(), encBlock.Key) + if err != nil { + return nil, err + } + if len(bytes) == 0 { + return nil, nil + } + newBlock := block.Clone() + newBlock.Delta.SetData(bytes) + return newBlock, nil +} + +func (mp *mergeProcessor) initCRDTForType(field string) (merklecrdt.MerkleCRDT, error) { mcrdt, exists := mp.mCRDTs[field] if exists { return mcrdt, nil @@ -307,7 +448,7 @@ func (mp *mergeProcessor) initCRDTForType( mcrdt = merklecrdt.NewMerkleCompositeDAG( mp.txn, schemaVersionKey, - mp.dsKey.WithFieldId(core.COMPOSITE_NAMESPACE), + mp.dsKey.WithFieldID(core.COMPOSITE_NAMESPACE), "", ) mp.mCRDTs[field] = mcrdt @@ -325,7 +466,7 @@ func (mp *mergeProcessor) initCRDTForType( schemaVersionKey, fd.Typ, fd.Kind, - mp.dsKey.WithFieldId(fd.ID.String()), + mp.dsKey.WithFieldID(fd.ID.String()), field, ) if err != nil { @@ -357,24 +498,15 @@ func getCollectionFromRootSchema(ctx context.Context, db *db, rootSchema string) // getHeadsAsMergeTarget retrieves the heads of the composite DAG for the given document // and returns them as a merge target. func getHeadsAsMergeTarget(ctx context.Context, txn datastore.Txn, dsKey core.DataStoreKey) (mergeTarget, error) { - headset := clock.NewHeadSet( - txn.Headstore(), - dsKey.WithFieldId(core.COMPOSITE_NAMESPACE).ToHeadStoreKey(), - ) + cids, err := getHeads(ctx, txn, dsKey) - cids, _, err := headset.List(ctx) if err != nil { return mergeTarget{}, err } mt := newMergeTarget() for _, cid := range cids { - b, err := txn.Blockstore().Get(ctx, cid) - if err != nil { - return mergeTarget{}, err - } - - block, err := coreblock.GetFromBytes(b.RawData()) + block, err := loadBlockFromBlockStore(ctx, txn, cid) if err != nil { return mergeTarget{}, err } @@ -386,6 +518,33 @@ func getHeadsAsMergeTarget(ctx context.Context, txn datastore.Txn, dsKey core.Da return mt, nil } +// getHeads retrieves the heads associated with the given datastore key. +func getHeads(ctx context.Context, txn datastore.Txn, dsKey core.DataStoreKey) ([]cid.Cid, error) { + headset := clock.NewHeadSet(txn.Headstore(), dsKey.ToHeadStoreKey()) + + cids, _, err := headset.List(ctx) + if err != nil { + return nil, err + } + + return cids, nil +} + +// loadBlockFromBlockStore loads a block from the blockstore. +func loadBlockFromBlockStore(ctx context.Context, txn datastore.Txn, cid cid.Cid) (*coreblock.Block, error) { + b, err := txn.Blockstore().Get(ctx, cid) + if err != nil { + return nil, err + } + + block, err := coreblock.GetFromBytes(b.RawData()) + if err != nil { + return nil, err + } + + return block, nil +} + func syncIndexedDoc( ctx context.Context, docID client.DocID, @@ -406,10 +565,10 @@ func syncIndexedDoc( return err } - if isDeletedDoc { - return col.deleteIndexedDoc(ctx, oldDoc) - } else if isNewDoc { + if isNewDoc { return col.indexNewDoc(ctx, doc) + } else if isDeletedDoc { + return col.deleteIndexedDoc(ctx, oldDoc) } else { return col.updateDocIndex(ctx, oldDoc, doc) } diff --git a/internal/db/p2p_replicator.go b/internal/db/p2p_replicator.go index 409419ea3f..b66ab4f2cf 100644 --- a/internal/db/p2p_replicator.go +++ b/internal/db/p2p_replicator.go @@ -161,7 +161,7 @@ func (db *db) getDocsHeads( docID := core.DataStoreKeyFromDocID(docIDResult.ID) headset := clock.NewHeadSet( txn.Headstore(), - docID.WithFieldId(core.COMPOSITE_NAMESPACE).ToHeadStoreKey(), + docID.WithFieldID(core.COMPOSITE_NAMESPACE).ToHeadStoreKey(), ) cids, _, err := headset.List(ctx) if err != nil { diff --git a/internal/encryption/aes.go b/internal/encryption/aes.go deleted file mode 100644 index e3a7feb563..0000000000 --- a/internal/encryption/aes.go +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright 2024 Democratized Data Foundation -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package encryption - -import ( - "crypto/aes" - "crypto/cipher" - "encoding/base64" - "fmt" -) - -// EncryptAES encrypts data using AES-GCM with a provided key. -func EncryptAES(plainText, key []byte) ([]byte, error) { - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - - nonce, err := generateNonceFunc() - if err != nil { - return nil, err - } - - aesGCM, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - - cipherText := aesGCM.Seal(nonce, nonce, plainText, nil) - - buf := make([]byte, base64.StdEncoding.EncodedLen(len(cipherText))) - base64.StdEncoding.Encode(buf, cipherText) - - return buf, nil -} - -// DecryptAES decrypts AES-GCM encrypted data with a provided key. -func DecryptAES(cipherTextBase64, key []byte) ([]byte, error) { - cipherText := make([]byte, base64.StdEncoding.DecodedLen(len(cipherTextBase64))) - n, err := base64.StdEncoding.Decode(cipherText, []byte(cipherTextBase64)) - - if err != nil { - return nil, err - } - - cipherText = cipherText[:n] - - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - - if len(cipherText) < nonceLength { - return nil, fmt.Errorf("cipherText too short") - } - - nonce := cipherText[:nonceLength] - cipherText = cipherText[nonceLength:] - - aesGCM, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - - plainText, err := aesGCM.Open(nil, nonce, cipherText, nil) - if err != nil { - return nil, err - } - - return plainText, nil -} diff --git a/internal/encryption/context.go b/internal/encryption/context.go index 96e90a7e0c..422bd97697 100644 --- a/internal/encryption/context.go +++ b/internal/encryption/context.go @@ -14,8 +14,6 @@ import ( "context" "github.com/sourcenetwork/immutable" - - "github.com/sourcenetwork/defradb/datastore" ) // docEncContextKey is the key type for document encryption context values. @@ -24,37 +22,34 @@ type docEncContextKey struct{} // configContextKey is the key type for encryption context values. type configContextKey struct{} -// TryGetContextDocEnc returns a document encryption and a bool indicating if -// it was retrieved from the given context. -func TryGetContextEncryptor(ctx context.Context) (*DocEncryptor, bool) { +// GetEncryptorFromContext returns a document encryptor from the given context. +// It returns nil if no encryptor exists in the context. +func GetEncryptorFromContext(ctx context.Context) *DocEncryptor { enc, ok := ctx.Value(docEncContextKey{}).(*DocEncryptor) if ok { setConfig(ctx, enc) } - return enc, ok + return enc } func setConfig(ctx context.Context, enc *DocEncryptor) { enc.SetConfig(GetContextConfig(ctx)) + enc.ctx = ctx } -func ensureContextWithDocEnc(ctx context.Context) (context.Context, *DocEncryptor) { - enc, ok := TryGetContextEncryptor(ctx) - if !ok { +// EnsureContextWithEncryptor returns a context with a document encryptor and the +// document encryptor itself. If the context already has an encryptor, it +// returns the context and encryptor as is. Otherwise, it creates a new +// document encryptor and stores it in the context. +func EnsureContextWithEncryptor(ctx context.Context) (context.Context, *DocEncryptor) { + enc := GetEncryptorFromContext(ctx) + if enc == nil { enc = newDocEncryptor(ctx) ctx = context.WithValue(ctx, docEncContextKey{}, enc) } return ctx, enc } -// ContextWithStore sets the store on the doc encryptor in the context. -// If the doc encryptor is not present, it will be created. -func ContextWithStore(ctx context.Context, txn datastore.Txn) context.Context { - ctx, encryptor := ensureContextWithDocEnc(ctx) - encryptor.SetStore(txn.Encstore()) - return ctx -} - // GetContextConfig returns the doc encryption config from the given context. func GetContextConfig(ctx context.Context) immutable.Option[DocEncConfig] { encConfig, ok := ctx.Value(configContextKey{}).(DocEncConfig) @@ -66,6 +61,7 @@ func GetContextConfig(ctx context.Context) immutable.Option[DocEncConfig] { // SetContextConfig returns a new context with the doc encryption config set. func SetContextConfig(ctx context.Context, encConfig DocEncConfig) context.Context { + ctx, _ = EnsureContextWithEncryptor(ctx) return context.WithValue(ctx, configContextKey{}, encConfig) } diff --git a/internal/encryption/encryptor.go b/internal/encryption/encryptor.go index 9a6cb8f6f0..fdd2efec3d 100644 --- a/internal/encryption/encryptor.go +++ b/internal/encryption/encryptor.go @@ -13,15 +13,13 @@ package encryption import ( "context" "crypto/rand" - "errors" "io" - - ds "github.com/ipfs/go-datastore" + "os" + "strings" "github.com/sourcenetwork/immutable" - "github.com/sourcenetwork/defradb/datastore" - "github.com/sourcenetwork/defradb/internal/core" + "github.com/sourcenetwork/defradb/crypto" ) var generateEncryptionKeyFunc = generateEncryptionKey @@ -31,7 +29,7 @@ const keyLength = 32 // 32 bytes for AES-256 const testEncryptionKey = "examplekey1234567890examplekey12" // generateEncryptionKey generates a random AES key. -func generateEncryptionKey(_, _ string) ([]byte, error) { +func generateEncryptionKey(_ string, _ immutable.Option[string]) ([]byte, error) { key := make([]byte, keyLength) if _, err := io.ReadFull(rand.Reader, key); err != nil { return nil, err @@ -42,21 +40,28 @@ func generateEncryptionKey(_, _ string) ([]byte, error) { // generateTestEncryptionKey generates a deterministic encryption key for testing. // While testing, we also want to make sure different keys are generated for different docs and fields // and that's why we use the docID and fieldName to generate the key. -func generateTestEncryptionKey(docID, fieldName string) ([]byte, error) { - return []byte(fieldName + docID + testEncryptionKey)[0:keyLength], nil +func generateTestEncryptionKey(docID string, fieldName immutable.Option[string]) ([]byte, error) { + return []byte(fieldName.Value() + docID + testEncryptionKey)[0:keyLength], nil } // DocEncryptor is a document encryptor that encrypts and decrypts individual document fields. // It acts based on the configuration [DocEncConfig] provided and data stored in the provided store. -// It uses [core.EncStoreDocKey] to store and retrieve encryption keys. +// DocEncryptor is a session-bound, i.e. once a user requests to create (or update) a document or a node +// receives an UpdateEvent on a document (or any other event) a new DocEncryptor is created and stored +// in the context, so that the same DocEncryptor can be used by other object down the call chain. type DocEncryptor struct { - conf immutable.Option[DocEncConfig] - ctx context.Context - store datastore.DSReaderWriter + conf immutable.Option[DocEncConfig] + ctx context.Context + generatedKeys map[genK][]byte +} + +type genK struct { + docID string + fieldName immutable.Option[string] } func newDocEncryptor(ctx context.Context) *DocEncryptor { - return &DocEncryptor{ctx: ctx} + return &DocEncryptor{ctx: ctx, generatedKeys: make(map[genK][]byte)} } // SetConfig sets the configuration for the document encryptor. @@ -64,141 +69,124 @@ func (d *DocEncryptor) SetConfig(conf immutable.Option[DocEncConfig]) { d.conf = conf } -// SetStore sets the store for the document encryptor. -func (d *DocEncryptor) SetStore(store datastore.DSReaderWriter) { - d.store = store -} - -func shouldEncryptIndividualField(conf immutable.Option[DocEncConfig], fieldName string) bool { - if !conf.HasValue() || fieldName == "" { +func shouldEncryptIndividualField(conf immutable.Option[DocEncConfig], fieldName immutable.Option[string]) bool { + if !conf.HasValue() || !fieldName.HasValue() { return false } for _, field := range conf.Value().EncryptedFields { - if field == fieldName { + if field == fieldName.Value() { return true } } return false } -func shouldEncryptField(conf immutable.Option[DocEncConfig], fieldName string) bool { +func shouldEncryptDocField(conf immutable.Option[DocEncConfig], fieldName immutable.Option[string]) bool { if !conf.HasValue() { return false } if conf.Value().IsDocEncrypted { return true } - if fieldName == "" { + if !fieldName.HasValue() { return false } for _, field := range conf.Value().EncryptedFields { - if field == fieldName { + if field == fieldName.Value() { return true } } return false } -// Encrypt encrypts the given plainText that is associated with the given docID and fieldName. -// If the current configuration is set to encrypt the given key individually, it will encrypt it with a new key. -// Otherwise, it will use document-level encryption key. -func (d *DocEncryptor) Encrypt(docID, fieldName string, plainText []byte) ([]byte, error) { - encryptionKey, err := d.fetchEncryptionKey(docID, fieldName) - if err != nil { - return nil, err +// Encrypt encrypts the given plainText with the encryption key that is associated with the given docID, +// fieldName and key id. +func (d *DocEncryptor) Encrypt( + plainText, encryptionKey []byte, +) ([]byte, error) { + var cipherText []byte + var err error + if len(plainText) > 0 { + cipherText, _, err = crypto.EncryptAES(plainText, encryptionKey, nil, true) } - if len(encryptionKey) == 0 { - if !shouldEncryptIndividualField(d.conf, fieldName) { - fieldName = "" - } - - if !shouldEncryptField(d.conf, fieldName) { - return plainText, nil - } - - encryptionKey, err = generateEncryptionKeyFunc(docID, fieldName) - if err != nil { - return nil, err - } - - storeKey := core.NewEncStoreDocKey(docID, fieldName) - err = d.store.Put(d.ctx, storeKey.ToDS(), encryptionKey) - if err != nil { - return nil, err - } - } - return EncryptAES(plainText, encryptionKey) + return cipherText, err } // Decrypt decrypts the given cipherText that is associated with the given docID and fieldName. // If the corresponding encryption key is not found, it returns nil. -func (d *DocEncryptor) Decrypt(docID, fieldName string, cipherText []byte) ([]byte, error) { - encKey, err := d.fetchEncryptionKey(docID, fieldName) - if err != nil { - return nil, err - } +func (d *DocEncryptor) Decrypt( + cipherText, encKey []byte, +) ([]byte, error) { if len(encKey) == 0 { return nil, nil } - return DecryptAES(cipherText, encKey) + return crypto.DecryptAES(nil, cipherText, encKey, nil) +} + +// getGeneratedKeyFor returns the generated key for the given docID and fieldName. +func (d *DocEncryptor) getGeneratedKeyFor( + docID string, + fieldName immutable.Option[string], +) []byte { + return d.generatedKeys[genK{docID, fieldName}] +} + +// GetOrGenerateEncryptionKey returns the generated encryption key for the given docID, (optional) fieldName. +// If the key is not generated before, it generates a new key and stores it. +func (d *DocEncryptor) GetOrGenerateEncryptionKey( + docID string, + fieldName immutable.Option[string], +) ([]byte, error) { + encryptionKey := d.getGeneratedKeyFor(docID, fieldName) + if len(encryptionKey) > 0 { + return encryptionKey, nil + } + + return d.generateEncryptionKey(docID, fieldName) } -// fetchEncryptionKey fetches the encryption key for the given docID and fieldName. -// If the key is not found, it returns an empty key. -func (d *DocEncryptor) fetchEncryptionKey(docID string, fieldName string) ([]byte, error) { - if d.store == nil { - return nil, ErrNoStorageProvided +// generateEncryptionKey generates a new encryption key for the given docID and fieldName. +func (d *DocEncryptor) generateEncryptionKey( + docID string, + fieldName immutable.Option[string], +) ([]byte, error) { + if !shouldEncryptIndividualField(d.conf, fieldName) { + fieldName = immutable.None[string]() + } + + if !shouldEncryptDocField(d.conf, fieldName) { + return nil, nil } - // first we try to find field-level key - storeKey := core.NewEncStoreDocKey(docID, fieldName) - encryptionKey, err := d.store.Get(d.ctx, storeKey.ToDS()) - isNotFound := errors.Is(err, ds.ErrNotFound) + + encryptionKey, err := generateEncryptionKeyFunc(docID, fieldName) if err != nil { - if !isNotFound { - return nil, err - } - // if previous fetch was for doc-level, there is nothing else to look for - if fieldName == "" { - return nil, nil - } - if shouldEncryptIndividualField(d.conf, fieldName) { - return nil, nil - } - // try to find doc-level key - storeKey.FieldName = "" - encryptionKey, err = d.store.Get(d.ctx, storeKey.ToDS()) - isNotFound = errors.Is(err, ds.ErrNotFound) - if err != nil && !isNotFound { - return nil, err - } + return nil, err } + + d.generatedKeys[genK{docID, fieldName}] = encryptionKey + return encryptionKey, nil } -// EncryptDoc encrypts the given plainText that is associated with the given docID and fieldName with -// encryptor in the context. -// If the current configuration is set to encrypt the given key individually, it will encrypt it with a new key. -// Otherwise, it will use document-level encryption key. -func EncryptDoc(ctx context.Context, docID string, fieldName string, plainText []byte) ([]byte, error) { - enc, ok := TryGetContextEncryptor(ctx) - if !ok { - return nil, nil - } - return enc.Encrypt(docID, fieldName, plainText) +// ShouldEncryptDocField returns true if the given field should be encrypted based on the context config. +func ShouldEncryptDocField(ctx context.Context, fieldName immutable.Option[string]) bool { + return shouldEncryptDocField(GetContextConfig(ctx), fieldName) } -// DecryptDoc decrypts the given cipherText that is associated with the given docID and fieldName with -// encryptor in the context. -func DecryptDoc(ctx context.Context, docID string, fieldName string, cipherText []byte) ([]byte, error) { - enc, ok := TryGetContextEncryptor(ctx) - if !ok { - return nil, nil - } - return enc.Decrypt(docID, fieldName, cipherText) +// ShouldEncryptIndividualField returns true if the given field should be encrypted individually based on +// the context config. +func ShouldEncryptIndividualField(ctx context.Context, fieldName immutable.Option[string]) bool { + return shouldEncryptIndividualField(GetContextConfig(ctx), fieldName) } -// ShouldEncryptField returns true if the given field should be encrypted based on the context config. -func ShouldEncryptField(ctx context.Context, fieldName string) bool { - return shouldEncryptField(GetContextConfig(ctx), fieldName) +func init() { + arg := os.Args[0] + // If the binary is a test binary, use a deterministic nonce. + // TODO: We should try to find a better way to detect this https://github.com/sourcenetwork/defradb/issues/2801 + if strings.HasSuffix(arg, ".test") || + strings.Contains(arg, "/defradb/tests/") || + strings.Contains(arg, "/__debug_bin") { + generateEncryptionKeyFunc = generateTestEncryptionKey + } } diff --git a/internal/encryption/encryptor_test.go b/internal/encryption/encryptor_test.go index 76888ed4f1..3c34ed819d 100644 --- a/internal/encryption/encryptor_test.go +++ b/internal/encryption/encryptor_test.go @@ -12,211 +12,192 @@ package encryption import ( "context" - "errors" "testing" - ds "github.com/ipfs/go-datastore" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/sourcenetwork/immutable" - - "github.com/sourcenetwork/defradb/datastore/mocks" - "github.com/sourcenetwork/defradb/internal/core" ) -var testErr = errors.New("test error") - -const docID = "bae-c9fb0fa4-1195-589c-aa54-e68333fb90b3" - -const fieldName = "name" - -func getPlainText() []byte { - return []byte("test") +func TestContext_NoEncryptor_ReturnsNil(t *testing.T) { + ctx := context.Background() + enc := GetEncryptorFromContext(ctx) + assert.Nil(t, enc) } -func getEncKey(fieldName string) []byte { - key, _ := generateTestEncryptionKey(docID, fieldName) - return key -} +func TestContext_WithEncryptor_ReturnsEncryptor(t *testing.T) { + ctx := context.Background() + enc := newDocEncryptor(ctx) + ctx = context.WithValue(ctx, docEncContextKey{}, enc) -func getCipherText(t *testing.T, fieldName string) []byte { - cipherText, err := EncryptAES(getPlainText(), getEncKey(fieldName)) - assert.NoError(t, err) - return cipherText + retrievedEnc := GetEncryptorFromContext(ctx) + assert.NotNil(t, retrievedEnc) + assert.Equal(t, enc, retrievedEnc) } -func newDefaultEncryptor(t *testing.T) (*DocEncryptor, *mocks.DSReaderWriter) { - return newEncryptorWithConfig(t, DocEncConfig{IsDocEncrypted: true}) -} - -func newEncryptorWithConfig(t *testing.T, conf DocEncConfig) (*DocEncryptor, *mocks.DSReaderWriter) { - enc := newDocEncryptor(context.Background()) - st := mocks.NewDSReaderWriter(t) - enc.SetConfig(immutable.Some(conf)) - enc.SetStore(st) - return enc, st -} - -func TestEncryptorEncrypt_IfStorageReturnsError_Error(t *testing.T) { - enc, st := newDefaultEncryptor(t) - - st.EXPECT().Get(mock.Anything, mock.Anything).Return(nil, testErr) +func TestContext_EnsureEncryptor_CreatesNew(t *testing.T) { + ctx := context.Background() + newCtx, enc := EnsureContextWithEncryptor(ctx) - _, err := enc.Encrypt(docID, fieldName, []byte("test")) + assert.NotNil(t, enc) + assert.NotEqual(t, ctx, newCtx) - assert.ErrorIs(t, err, testErr) + retrievedEnc := GetEncryptorFromContext(newCtx) + assert.Equal(t, enc, retrievedEnc) } -func TestEncryptorEncrypt_IfStorageReturnsErrorOnSecondCall_Error(t *testing.T) { - enc, st := newDefaultEncryptor(t) +func TestContext_EnsureEncryptor_ReturnsExisting(t *testing.T) { + ctx := context.Background() + enc := newDocEncryptor(ctx) + ctx = context.WithValue(ctx, docEncContextKey{}, enc) - st.EXPECT().Get(mock.Anything, mock.Anything).Return(nil, ds.ErrNotFound).Once() - st.EXPECT().Get(mock.Anything, mock.Anything).Return(nil, testErr) - - _, err := enc.Encrypt(docID, fieldName, []byte("test")) - - assert.ErrorIs(t, err, testErr) + newCtx, retrievedEnc := EnsureContextWithEncryptor(ctx) + assert.Equal(t, ctx, newCtx) + assert.Equal(t, enc, retrievedEnc) } -func TestEncryptorEncrypt_WithEmptyFieldNameIfNoKeyFoundInStorage_ShouldGenerateKeyStoreItAndReturnCipherText(t *testing.T) { - enc, st := newDefaultEncryptor(t) - - storeKey := core.NewEncStoreDocKey(docID, "") - - st.EXPECT().Get(mock.Anything, storeKey.ToDS()).Return(nil, ds.ErrNotFound) - st.EXPECT().Put(mock.Anything, storeKey.ToDS(), getEncKey("")).Return(nil) +func TestConfig_GetFromContext_NoConfig(t *testing.T) { + ctx := context.Background() + config := GetContextConfig(ctx) + assert.False(t, config.HasValue()) +} - cipherText, err := enc.Encrypt(docID, "", getPlainText()) +func TestConfig_GetFromContext_ReturnCurrentConfig(t *testing.T) { + ctx := context.Background() + expectedConfig := DocEncConfig{IsDocEncrypted: true, EncryptedFields: []string{"field1", "field2"}} + ctx = context.WithValue(ctx, configContextKey{}, expectedConfig) - assert.NoError(t, err) - assert.Equal(t, getCipherText(t, ""), cipherText) + config := GetContextConfig(ctx) + assert.True(t, config.HasValue()) + assert.Equal(t, expectedConfig, config.Value()) } -func TestEncryptorEncrypt_IfNoFieldEncRequestedAndNoKeyInStorage_GenerateKeyStoreItAndReturnCipherText(t *testing.T) { - enc, st := newDefaultEncryptor(t) +func TestConfig_SetContextConfig_StoreConfig(t *testing.T) { + ctx := context.Background() + config := DocEncConfig{IsDocEncrypted: true, EncryptedFields: []string{"field1", "field2"}} - docStoreKey := core.NewEncStoreDocKey(docID, "").ToDS() - fieldStoreKey := core.NewEncStoreDocKey(docID, fieldName).ToDS() + newCtx := SetContextConfig(ctx, config) + retrievedConfig := GetContextConfig(newCtx) - st.EXPECT().Get(mock.Anything, fieldStoreKey).Return(nil, ds.ErrNotFound) - st.EXPECT().Get(mock.Anything, docStoreKey).Return(nil, ds.ErrNotFound) - st.EXPECT().Put(mock.Anything, docStoreKey, getEncKey("")).Return(nil) + assert.True(t, retrievedConfig.HasValue()) + assert.Equal(t, config, retrievedConfig.Value()) +} - cipherText, err := enc.Encrypt(docID, fieldName, getPlainText()) +func TestConfig_SetFromParamsWithDocEncryption_StoreConfig(t *testing.T) { + ctx := context.Background() + newCtx := SetContextConfigFromParams(ctx, true, []string{"field1", "field2"}) - assert.NoError(t, err) - assert.Equal(t, getCipherText(t, ""), cipherText) + config := GetContextConfig(newCtx) + assert.True(t, config.HasValue()) + assert.True(t, config.Value().IsDocEncrypted) + assert.Equal(t, []string{"field1", "field2"}, config.Value().EncryptedFields) } -func TestEncryptorEncrypt_IfNoKeyWithFieldFoundInStorage_ShouldGenerateKeyStoreItAndReturnCipherText(t *testing.T) { - enc, st := newEncryptorWithConfig(t, DocEncConfig{EncryptedFields: []string{fieldName}}) +func TestConfig_SetFromParamsWithFields_StoreConfig(t *testing.T) { + ctx := context.Background() + newCtx := SetContextConfigFromParams(ctx, false, []string{"field1", "field2"}) - storeKey := core.NewEncStoreDocKey(docID, fieldName) - - st.EXPECT().Get(mock.Anything, storeKey.ToDS()).Return(nil, ds.ErrNotFound) - st.EXPECT().Put(mock.Anything, storeKey.ToDS(), getEncKey(fieldName)).Return(nil) + config := GetContextConfig(newCtx) + assert.True(t, config.HasValue()) + assert.False(t, config.Value().IsDocEncrypted) + assert.Equal(t, []string{"field1", "field2"}, config.Value().EncryptedFields) +} - cipherText, err := enc.Encrypt(docID, fieldName, getPlainText()) +func TestConfig_SetFromParamsWithNoEncryptionSetting_NoConfig(t *testing.T) { + ctx := context.Background() + newCtx := SetContextConfigFromParams(ctx, false, nil) - assert.NoError(t, err) - assert.Equal(t, getCipherText(t, fieldName), cipherText) + config := GetContextConfig(newCtx) + assert.False(t, config.HasValue()) } -func TestEncryptorEncrypt_IfKeyWithFieldFoundInStorage_ShouldUseItToReturnCipherText(t *testing.T) { - enc, st := newEncryptorWithConfig(t, DocEncConfig{EncryptedFields: []string{fieldName}}) - - storeKey := core.NewEncStoreDocKey(docID, fieldName) - st.EXPECT().Get(mock.Anything, storeKey.ToDS()).Return(getEncKey(fieldName), nil) +func TestEncryptor_EncryptDecrypt_SuccessfulRoundTrip(t *testing.T) { + ctx := context.Background() + enc := newDocEncryptor(ctx) + enc.SetConfig(immutable.Some(DocEncConfig{EncryptedFields: []string{"field1"}})) - cipherText, err := enc.Encrypt(docID, fieldName, getPlainText()) + plainText := []byte("Hello, World!") + docID := "doc1" + fieldName := immutable.Some("field1") + key, err := enc.GetOrGenerateEncryptionKey(docID, fieldName) assert.NoError(t, err) - assert.Equal(t, getCipherText(t, fieldName), cipherText) -} - -func TestEncryptorEncrypt_IfKeyFoundInStorage_ShouldUseItToReturnCipherText(t *testing.T) { - enc, st := newDefaultEncryptor(t) - - st.EXPECT().Get(mock.Anything, mock.Anything).Return(getEncKey(""), nil) + assert.NotNil(t, key) - cipherText, err := enc.Encrypt(docID, "", getPlainText()) + cipherText, err := enc.Encrypt(plainText, key) + assert.NoError(t, err) + assert.NotEqual(t, plainText, cipherText) + decryptedText, err := enc.Decrypt(cipherText, key) assert.NoError(t, err) - assert.Equal(t, getCipherText(t, ""), cipherText) + assert.Equal(t, plainText, decryptedText) } -func TestEncryptorEncrypt_IfStorageFailsToStoreEncryptionKey_ReturnError(t *testing.T) { - enc, st := newDefaultEncryptor(t) - - st.EXPECT().Get(mock.Anything, mock.Anything).Return(nil, ds.ErrNotFound) +func TestEncryptor_GetOrGenerateKey_ReturnsExistingKey(t *testing.T) { + ctx := context.Background() + enc := newDocEncryptor(ctx) + enc.SetConfig(immutable.Some(DocEncConfig{EncryptedFields: []string{"field1"}})) - st.EXPECT().Put(mock.Anything, mock.Anything, mock.Anything).Return(testErr) + docID := "doc1" + fieldName := immutable.Some("field1") - _, err := enc.Encrypt(docID, fieldName, getPlainText()) + key1, err := enc.GetOrGenerateEncryptionKey(docID, fieldName) + assert.NoError(t, err) + assert.NotNil(t, key1) - assert.ErrorIs(t, err, testErr) + key2, err := enc.GetOrGenerateEncryptionKey(docID, fieldName) + assert.NoError(t, err) + assert.Equal(t, key1, key2) } -func TestEncryptorEncrypt_IfKeyGenerationIsNotEnabled_ShouldReturnPlainText(t *testing.T) { - enc, st := newDefaultEncryptor(t) - enc.SetConfig(immutable.None[DocEncConfig]()) - - st.EXPECT().Get(mock.Anything, mock.Anything).Return(nil, ds.ErrNotFound) +func TestEncryptor_GenerateKey_DifferentKeysForDifferentFields(t *testing.T) { + ctx := context.Background() + enc := newDocEncryptor(ctx) + enc.SetConfig(immutable.Some(DocEncConfig{EncryptedFields: []string{"field1", "field2"}})) - cipherText, err := enc.Encrypt(docID, fieldName, getPlainText()) + docID := "doc1" + fieldName1 := immutable.Some("field1") + fieldName2 := immutable.Some("field2") + key1, err := enc.GetOrGenerateEncryptionKey(docID, fieldName1) assert.NoError(t, err) - assert.Equal(t, getPlainText(), cipherText) -} - -func TestEncryptorEncrypt_IfNoStorageProvided_Error(t *testing.T) { - enc, _ := newDefaultEncryptor(t) - enc.SetStore(nil) + assert.NotNil(t, key1) - _, err := enc.Encrypt(docID, fieldName, getPlainText()) + key2, err := enc.GetOrGenerateEncryptionKey(docID, fieldName2) + assert.NoError(t, err) + assert.NotNil(t, key2) - assert.ErrorIs(t, err, ErrNoStorageProvided) + assert.NotEqual(t, key1, key2) } -func TestEncryptorDecrypt_IfNoStorageProvided_Error(t *testing.T) { - enc, _ := newDefaultEncryptor(t) - enc.SetStore(nil) +func TestShouldEncryptField_WithDocEncryption_True(t *testing.T) { + config := DocEncConfig{IsDocEncrypted: true} + ctx := SetContextConfig(context.Background(), config) - _, err := enc.Decrypt(docID, fieldName, getPlainText()) - - assert.ErrorIs(t, err, ErrNoStorageProvided) + assert.True(t, ShouldEncryptDocField(ctx, immutable.Some("field1"))) + assert.True(t, ShouldEncryptDocField(ctx, immutable.Some("field2"))) } -func TestEncryptorDecrypt_IfStorageReturnsError_Error(t *testing.T) { - enc, st := newDefaultEncryptor(t) - - st.EXPECT().Get(mock.Anything, mock.Anything).Return(nil, testErr) +func TestShouldEncryptField_WithFieldEncryption_TrueForMatchingField(t *testing.T) { + config := DocEncConfig{EncryptedFields: []string{"field1"}} + ctx := SetContextConfig(context.Background(), config) - _, err := enc.Decrypt(docID, fieldName, []byte("test")) - - assert.ErrorIs(t, err, testErr) + assert.True(t, ShouldEncryptDocField(ctx, immutable.Some("field1"))) + assert.False(t, ShouldEncryptDocField(ctx, immutable.Some("field2"))) } -func TestEncryptorDecrypt_IfKeyFoundInStorage_ShouldUseItToReturnPlainText(t *testing.T) { - enc, st := newDefaultEncryptor(t) - - st.EXPECT().Get(mock.Anything, mock.Anything).Return(getEncKey(""), nil) - - plainText, err := enc.Decrypt(docID, fieldName, getCipherText(t, "")) +func TestShouldEncryptIndividualField_WithDocEncryption_False(t *testing.T) { + config := DocEncConfig{IsDocEncrypted: true} + ctx := SetContextConfig(context.Background(), config) - assert.NoError(t, err) - assert.Equal(t, getPlainText(), plainText) + assert.False(t, ShouldEncryptIndividualField(ctx, immutable.Some("field1"))) + assert.False(t, ShouldEncryptIndividualField(ctx, immutable.Some("field2"))) } -func TestEncryptDoc_IfContextHasNoEncryptor_ReturnNil(t *testing.T) { - data, err := EncryptDoc(context.Background(), docID, fieldName, getPlainText()) - assert.Nil(t, data, "data should be nil") - assert.NoError(t, err, "error should be nil") -} +func TestShouldEncryptIndividualField_WithFieldEncryption_TrueForMatchingField(t *testing.T) { + config := DocEncConfig{EncryptedFields: []string{"field1"}} + ctx := SetContextConfig(context.Background(), config) -func TestDecryptDoc_IfContextHasNoEncryptor_ReturnNil(t *testing.T) { - data, err := DecryptDoc(context.Background(), docID, fieldName, getCipherText(t, fieldName)) - assert.Nil(t, data, "data should be nil") - assert.NoError(t, err, "error should be nil") + assert.True(t, ShouldEncryptIndividualField(ctx, immutable.Some("field1"))) + assert.False(t, ShouldEncryptIndividualField(ctx, immutable.Some("field2"))) } diff --git a/internal/encryption/errors.go b/internal/encryption/errors.go index 6a443ad834..a068c20fae 100644 --- a/internal/encryption/errors.go +++ b/internal/encryption/errors.go @@ -15,9 +15,11 @@ import ( ) const ( - errNoStorageProvided string = "no storage provided" + errNoStorageProvided string = "no storage provided" + errContextHasNoEncryptor string = "context has no encryptor" ) var ( - ErrNoStorageProvided = errors.New(errNoStorageProvided) + ErrNoStorageProvided = errors.New(errNoStorageProvided) + ErrContextHasNoEncryptor = errors.New(errContextHasNoEncryptor) ) diff --git a/internal/encryption/event.go b/internal/encryption/event.go new file mode 100644 index 0000000000..16c1442b64 --- /dev/null +++ b/internal/encryption/event.go @@ -0,0 +1,75 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package encryption + +import ( + cidlink "github.com/ipld/go-ipld-prime/linking/cid" + + "github.com/sourcenetwork/defradb/event" +) + +const RequestKeysEventName = event.Name("enc-keys-request") + +// RequestKeysEvent represents a request of a node to fetch an encryption key for a specific +// docID/field +// +// It must only contain public elements not protected by ACP. +type RequestKeysEvent struct { + // Keys is a list of the keys that are being requested. + Keys []cidlink.Link + + Resp chan<- Result +} + +// RequestedKeyEventData represents the data that was retrieved for a specific key. +type RequestedKeyEventData struct { + // Key is the encryption key that was retrieved. + Key []byte +} + +// KeyRetrievedEvent represents a key that was retrieved. +type Item struct { + Link []byte + Block []byte +} + +type Result struct { + Items []Item + Error error +} + +type Results struct { + output chan Result +} + +func (r *Results) Get() <-chan Result { + return r.output +} + +// NewResults creates a new Results object and a channel that can be used to send results to it. +// The Results object can be used to wait on the results, and the channel can be used to send results. +func NewResults() (*Results, chan<- Result) { + ch := make(chan Result, 1) + return &Results{ + output: ch, + }, ch +} + +// NewRequestKeysMessage creates a new event message for a request of a node to fetch an encryption key +// for a specific docID/field +// It returns the message and the results that that can be waited on. +func NewRequestKeysMessage(keys []cidlink.Link) (event.Message, *Results) { + res, ch := NewResults() + return event.NewMessage(RequestKeysEventName, RequestKeysEvent{ + Keys: keys, + Resp: ch, + }), res +} diff --git a/internal/kms/enc_store.go b/internal/kms/enc_store.go new file mode 100644 index 0000000000..bd60592f26 --- /dev/null +++ b/internal/kms/enc_store.go @@ -0,0 +1,66 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package kms + +import ( + "context" + + "github.com/ipfs/go-cid" + "github.com/ipld/go-ipld-prime/linking" + cidlink "github.com/ipld/go-ipld-prime/linking/cid" + + "github.com/sourcenetwork/defradb/datastore" + coreblock "github.com/sourcenetwork/defradb/internal/core/block" +) + +type ipldEncStorage struct { + encstore datastore.Blockstore +} + +func newIPLDEncryptionStorage(encstore datastore.Blockstore) *ipldEncStorage { + return &ipldEncStorage{encstore: encstore} +} + +func (s *ipldEncStorage) get(ctx context.Context, cidBytes []byte) (*coreblock.Encryption, error) { + lsys := cidlink.DefaultLinkSystem() + lsys.SetReadStorage(s.encstore.AsIPLDStorage()) + + _, blockCid, err := cid.CidFromBytes(cidBytes) + if err != nil { + return nil, err + } + + nd, err := lsys.Load(linking.LinkContext{Ctx: ctx}, cidlink.Link{Cid: blockCid}, + coreblock.EncryptionSchemaPrototype) + if err != nil { + return nil, err + } + + return coreblock.GetEncryptionBlockFromNode(nd) +} + +func (s *ipldEncStorage) put(ctx context.Context, blockBytes []byte) ([]byte, error) { + lsys := cidlink.DefaultLinkSystem() + lsys.SetWriteStorage(s.encstore.AsIPLDStorage()) + + var encBlock coreblock.Encryption + err := encBlock.Unmarshal(blockBytes) + if err != nil { + return nil, err + } + + link, err := lsys.Store(linking.LinkContext{Ctx: ctx}, coreblock.GetLinkPrototype(), encBlock.GenerateNode()) + if err != nil { + return nil, err + } + + return []byte(link.String()), nil +} diff --git a/internal/kms/errors.go b/internal/kms/errors.go new file mode 100644 index 0000000000..603d3c3232 --- /dev/null +++ b/internal/kms/errors.go @@ -0,0 +1,27 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package kms + +import ( + "github.com/sourcenetwork/defradb/errors" +) + +const ( + errUnknownKMSType string = "unknown KMS type" +) + +var ( + ErrUnknownKMSType = errors.New(errUnknownKMSType) +) + +func NewErrUnknownKMSType(t ServiceType) error { + return errors.New(errUnknownKMSType, errors.NewKV("Type", t)) +} diff --git a/internal/kms/pubsub.go b/internal/kms/pubsub.go new file mode 100644 index 0000000000..ca67603a7c --- /dev/null +++ b/internal/kms/pubsub.go @@ -0,0 +1,339 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package kms + +import ( + "bytes" + "context" + "crypto/ecdh" + "encoding/base64" + + cidlink "github.com/ipld/go-ipld-prime/linking/cid" + libpeer "github.com/libp2p/go-libp2p/core/peer" + rpc "github.com/sourcenetwork/go-libp2p-pubsub-rpc" + grpcpeer "google.golang.org/grpc/peer" + "google.golang.org/protobuf/proto" + + "github.com/sourcenetwork/defradb/crypto" + "github.com/sourcenetwork/defradb/datastore" + "github.com/sourcenetwork/defradb/errors" + "github.com/sourcenetwork/defradb/event" + "github.com/sourcenetwork/defradb/internal/encryption" + pb "github.com/sourcenetwork/defradb/net/pb" +) + +const pubsubTopic = "encryption" + +type PubSubServer interface { + AddPubSubTopic(string, rpc.MessageHandler) error + SendPubSubMessage(context.Context, string, []byte) (<-chan rpc.Response, error) +} + +type pubSubService struct { + ctx context.Context + peerID libpeer.ID + pubsub PubSubServer + keyRequestedSub *event.Subscription + eventBus *event.Bus + encStore *ipldEncStorage +} + +var _ Service = (*pubSubService)(nil) + +func (s *pubSubService) GetKeys(ctx context.Context, cids ...cidlink.Link) (*encryption.Results, error) { + res, ch := encryption.NewResults() + + err := s.requestEncryptionKeyFromPeers(ctx, cids, ch) + if err != nil { + return nil, err + } + + return res, nil +} + +// NewPubSubService creates a new instance of the KMS service that is connected to the given PubSubServer, +// event bus and encryption storage. +// +// The service will subscribe to the "encryption" topic on the PubSubServer and to the +// "enc-keys-request" event on the event bus. +func NewPubSubService( + ctx context.Context, + peerID libpeer.ID, + pubsub PubSubServer, + eventBus *event.Bus, + encstore datastore.Blockstore, +) (*pubSubService, error) { + s := &pubSubService{ + ctx: ctx, + peerID: peerID, + pubsub: pubsub, + eventBus: eventBus, + encStore: newIPLDEncryptionStorage(encstore), + } + err := pubsub.AddPubSubTopic(pubsubTopic, s.handleRequestFromPeer) + if err != nil { + return nil, err + } + s.keyRequestedSub, err = eventBus.Subscribe(encryption.RequestKeysEventName) + if err != nil { + return nil, err + } + go s.handleKeyRequestedEvent() + return s, nil +} + +func (s *pubSubService) handleKeyRequestedEvent() { + for { + msg, isOpen := <-s.keyRequestedSub.Message() + if !isOpen { + return + } + + if keyReqEvent, ok := msg.Data.(encryption.RequestKeysEvent); ok { + go func() { + results, err := s.GetKeys(s.ctx, keyReqEvent.Keys...) + if err != nil { + log.ErrorContextE(s.ctx, "Failed to get encryption keys", err) + } + + defer close(keyReqEvent.Resp) + + select { + case <-s.ctx.Done(): + return + case encResult := <-results.Get(): + for _, encItem := range encResult.Items { + _, err = s.encStore.put(s.ctx, encItem.Block) + if err != nil { + log.ErrorContextE(s.ctx, "Failed to save encryption key", err) + return + } + } + + keyReqEvent.Resp <- encResult + } + }() + } else { + log.ErrorContext(s.ctx, "Failed to cast event data to RequestKeysEvent") + } + } +} + +// handleEncryptionMessage handles incoming FetchEncryptionKeyRequest messages from the pubsub network. +func (s *pubSubService) handleRequestFromPeer(peerID libpeer.ID, topic string, msg []byte) ([]byte, error) { + req := new(pb.FetchEncryptionKeyRequest) + if err := proto.Unmarshal(msg, req); err != nil { + log.ErrorContextE(s.ctx, "Failed to unmarshal pubsub message %s", err) + return nil, err + } + + ctx := grpcpeer.NewContext(s.ctx, newGRPCPeer(peerID)) + res, err := s.tryGenEncryptionKeyLocally(ctx, req) + if err != nil { + log.ErrorContextE(s.ctx, "failed attempt to get encryption key", err) + return nil, errors.Wrap("failed attempt to get encryption key", err) + } + return res.MarshalVT() +} + +func (s *pubSubService) prepareFetchEncryptionKeyRequest( + cids []cidlink.Link, + ephemeralPublicKey []byte, +) (*pb.FetchEncryptionKeyRequest, error) { + req := &pb.FetchEncryptionKeyRequest{ + EphemeralPublicKey: ephemeralPublicKey, + } + + req.Links = make([][]byte, len(cids)) + for i, cid := range cids { + req.Links[i] = cid.Bytes() + } + + return req, nil +} + +// requestEncryptionKeyFromPeers publishes the given FetchEncryptionKeyRequest object on the PubSub network +func (s *pubSubService) requestEncryptionKeyFromPeers( + ctx context.Context, + cids []cidlink.Link, + result chan<- encryption.Result, +) error { + ephPrivKey, err := crypto.GenerateX25519() + if err != nil { + return err + } + + ephPubKeyBytes := ephPrivKey.PublicKey().Bytes() + req, err := s.prepareFetchEncryptionKeyRequest(cids, ephPubKeyBytes) + if err != nil { + return err + } + + data, err := req.MarshalVT() + if err != nil { + return errors.Wrap("failed to marshal pubsub message", err) + } + + respChan, err := s.pubsub.SendPubSubMessage(ctx, pubsubTopic, data) + if err != nil { + return errors.Wrap("failed publishing to encryption thread", err) + } + + go func() { + s.handleFetchEncryptionKeyResponse(<-respChan, req, ephPrivKey, result) + }() + + return nil +} + +// handleFetchEncryptionKeyResponse handles incoming FetchEncryptionKeyResponse messages +func (s *pubSubService) handleFetchEncryptionKeyResponse( + resp rpc.Response, + req *pb.FetchEncryptionKeyRequest, + privateKey *ecdh.PrivateKey, + result chan<- encryption.Result, +) { + defer close(result) + + var keyResp pb.FetchEncryptionKeyReply + if err := proto.Unmarshal(resp.Data, &keyResp); err != nil { + log.ErrorContextE(s.ctx, "Failed to unmarshal encryption key response", err) + result <- encryption.Result{Error: err} + return + } + + resultEncItems := make([]encryption.Item, 0, len(keyResp.Blocks)) + for i, block := range keyResp.Blocks { + decryptedData, err := crypto.DecryptECIES( + block, + privateKey, + crypto.WithAAD(makeAssociatedData(req, resp.From)), + crypto.WithPubKeyBytes(keyResp.EphemeralPublicKey), + crypto.WithPubKeyPrepended(false), + ) + + if err != nil { + log.ErrorContextE(s.ctx, "Failed to decrypt encryption key", err) + result <- encryption.Result{Error: err} + return + } + + resultEncItems = append(resultEncItems, encryption.Item{ + Link: keyResp.Links[i], + Block: decryptedData, + }) + } + + result <- encryption.Result{ + Items: resultEncItems, + } +} + +// makeAssociatedData creates the associated data for the encryption key request +func makeAssociatedData(req *pb.FetchEncryptionKeyRequest, peerID libpeer.ID) []byte { + return encodeToBase64(bytes.Join([][]byte{ + req.EphemeralPublicKey, + []byte(peerID), + }, []byte{})) +} + +func (s *pubSubService) tryGenEncryptionKeyLocally( + ctx context.Context, + req *pb.FetchEncryptionKeyRequest, +) (*pb.FetchEncryptionKeyReply, error) { + blocks, err := s.getEncryptionKeysLocally(ctx, req) + if err != nil || len(blocks) == 0 { + return nil, err + } + + reqEphPubKey, err := crypto.X25519PublicKeyFromBytes(req.EphemeralPublicKey) + if err != nil { + return nil, errors.Wrap("failed to unmarshal ephemeral public key", err) + } + + privKey, err := crypto.GenerateX25519() + if err != nil { + return nil, err + } + + res := &pb.FetchEncryptionKeyReply{ + Links: req.Links, + EphemeralPublicKey: privKey.PublicKey().Bytes(), + } + + res.Blocks = make([][]byte, 0, len(blocks)) + + for _, block := range blocks { + encryptedBlock, err := crypto.EncryptECIES( + block, + reqEphPubKey, + crypto.WithAAD(makeAssociatedData(req, s.peerID)), + crypto.WithPrivKey(privKey), + crypto.WithPubKeyPrepended(false), + ) + if err != nil { + return nil, errors.Wrap("failed to encrypt key for requester", err) + } + + res.Blocks = append(res.Blocks, encryptedBlock) + } + + return res, nil +} + +// getEncryptionKeys retrieves the encryption keys for the given targets. +// It returns the encryption keys and the targets for which the keys were found. +func (s *pubSubService) getEncryptionKeysLocally( + ctx context.Context, + req *pb.FetchEncryptionKeyRequest, +) ([][]byte, error) { + blocks := make([][]byte, 0, len(req.Links)) + for _, link := range req.Links { + encBlock, err := s.encStore.get(ctx, link) + if err != nil { + return nil, err + } + // TODO: we should test it somehow. For this this one peer should have some keys and + // another one should have the others. https://github.com/sourcenetwork/defradb/issues/2895 + if encBlock == nil { + continue + } + + encBlockBytes, err := encBlock.Marshal() + if err != nil { + return nil, err + } + + blocks = append(blocks, encBlockBytes) + } + return blocks, nil +} + +func encodeToBase64(data []byte) []byte { + encoded := make([]byte, base64.StdEncoding.EncodedLen(len(data))) + base64.StdEncoding.Encode(encoded, data) + return encoded +} + +func newGRPCPeer(peerID libpeer.ID) *grpcpeer.Peer { + return &grpcpeer.Peer{ + Addr: addr{peerID}, + } +} + +// addr implements net.Addr and holds a libp2p peer ID. +type addr struct{ id libpeer.ID } + +// Network returns the name of the network that this address belongs to (libp2p). +func (a addr) Network() string { return "libp2p" } + +// String returns the peer ID of this address in string form (B58-encoded). +func (a addr) String() string { return a.id.String() } diff --git a/internal/kms/service.go b/internal/kms/service.go new file mode 100644 index 0000000000..97985c9a43 --- /dev/null +++ b/internal/kms/service.go @@ -0,0 +1,40 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package kms + +import ( + "context" + + cidlink "github.com/ipld/go-ipld-prime/linking/cid" + "github.com/sourcenetwork/corelog" + + "github.com/sourcenetwork/defradb/internal/encryption" +) + +var ( + log = corelog.NewLogger("kms") +) + +type ServiceType string + +const ( + // PubSubServiceType is the type of KMS that uses PubSub mechanism to exchange keys + // between peers. + PubSubServiceType ServiceType = "pubsub" +) + +// Service is interface for key management service (KMS) +type Service interface { + // GetKeys retrieves the encryption blocks containing encryption keys for the given links. + // Blocks are fetched asynchronously, so the method returns an [encryption.Results] object + // that can be used to wait for the results. + GetKeys(ctx context.Context, cids ...cidlink.Link) (*encryption.Results, error) +} diff --git a/internal/lens/fetcher.go b/internal/lens/fetcher.go index 357bbe9677..bbe0c45a0d 100644 --- a/internal/lens/fetcher.go +++ b/internal/lens/fetcher.go @@ -307,7 +307,7 @@ func (f *lensedFetcher) updateDataStore(ctx context.Context, original map[string // in which case we have to skip them for now. continue } - fieldKey := datastoreKeyBase.WithFieldId(fieldDesc.ID.String()) + fieldKey := datastoreKeyBase.WithFieldID(fieldDesc.ID.String()) bytes, err := cbor.Marshal(value) if err != nil { @@ -320,7 +320,7 @@ func (f *lensedFetcher) updateDataStore(ctx context.Context, original map[string } } - versionKey := datastoreKeyBase.WithFieldId(core.DATASTORE_DOC_VERSION_FIELD_ID) + versionKey := datastoreKeyBase.WithFieldID(core.DATASTORE_DOC_VERSION_FIELD_ID) err := f.txn.Datastore().Put(ctx, versionKey.ToDS(), []byte(f.targetVersionID)) if err != nil { return err diff --git a/internal/merkle/clock/clock.go b/internal/merkle/clock/clock.go index b5b5f2631c..b5b55e1374 100644 --- a/internal/merkle/clock/clock.go +++ b/internal/merkle/clock/clock.go @@ -21,6 +21,7 @@ import ( cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/sourcenetwork/corelog" + "github.com/sourcenetwork/immutable" "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/internal/core" @@ -36,6 +37,7 @@ var ( type MerkleClock struct { headstore datastore.DSReaderWriter blockstore datastore.Blockstore + encstore datastore.Blockstore headset *heads crdt core.ReplicatedData } @@ -44,12 +46,14 @@ type MerkleClock struct { func NewMerkleClock( headstore datastore.DSReaderWriter, blockstore datastore.Blockstore, + encstore datastore.Blockstore, namespace core.HeadStoreKey, crdt core.ReplicatedData, ) *MerkleClock { return &MerkleClock{ headstore: headstore, blockstore: blockstore, + encstore: encstore, headset: NewHeadSet(headstore, namespace), crdt: crdt, } @@ -59,10 +63,23 @@ func (mc *MerkleClock) putBlock( ctx context.Context, block *coreblock.Block, ) (cidlink.Link, error) { - nd := block.GenerateNode() lsys := cidlink.DefaultLinkSystem() lsys.SetWriteStorage(mc.blockstore.AsIPLDStorage()) - link, err := lsys.Store(linking.LinkContext{Ctx: ctx}, coreblock.GetLinkPrototype(), nd) + link, err := lsys.Store(linking.LinkContext{Ctx: ctx}, coreblock.GetLinkPrototype(), block.GenerateNode()) + if err != nil { + return cidlink.Link{}, NewErrWritingBlock(err) + } + + return link.(cidlink.Link), nil +} + +func (mc *MerkleClock) putEncBlock( + ctx context.Context, + encBlock *coreblock.Encryption, +) (cidlink.Link, error) { + lsys := cidlink.DefaultLinkSystem() + lsys.SetWriteStorage(mc.encstore.AsIPLDStorage()) + link, err := lsys.Store(linking.LinkContext{Ctx: ctx}, coreblock.GetLinkPrototype(), encBlock.GenerateNode()) if err != nil { return cidlink.Link{}, NewErrWritingBlock(err) } @@ -86,21 +103,22 @@ func (mc *MerkleClock) AddDelta( delta.SetPriority(height) block := coreblock.New(delta, links, heads...) - isEncrypted, err := mc.checkIfBlockEncryptionEnabled(ctx, block.Delta.GetFieldName(), heads) + fieldName := immutable.None[string]() + if block.Delta.GetFieldName() != "" { + fieldName = immutable.Some(block.Delta.GetFieldName()) + } + encBlock, encLink, err := mc.determineBlockEncryption(ctx, string(block.Delta.GetDocID()), fieldName, heads) if err != nil { return cidlink.Link{}, nil, err } dagBlock := block - if isEncrypted { - if !block.Delta.IsComposite() { - dagBlock, err = encryptBlock(ctx, block) - if err != nil { - return cidlink.Link{}, nil, err - } - } else { - dagBlock.IsEncrypted = &isEncrypted + if encBlock != nil { + dagBlock, err = encryptBlock(ctx, block, encBlock) + if err != nil { + return cidlink.Link{}, nil, err } + dagBlock.Encryption = &encLink } link, err := mc.putBlock(ctx, dagBlock) @@ -109,12 +127,7 @@ func (mc *MerkleClock) AddDelta( } // merge the delta and update the state - err = mc.ProcessBlock( - ctx, - block, - link, - false, - ) + err = mc.ProcessBlock(ctx, block, link) if err != nil { return cidlink.Link{}, nil, err } @@ -127,57 +140,95 @@ func (mc *MerkleClock) AddDelta( return link, b, err } -func (mc *MerkleClock) checkIfBlockEncryptionEnabled( +func (mc *MerkleClock) determineBlockEncryption( ctx context.Context, - fieldName string, + docID string, + fieldName immutable.Option[string], heads []cid.Cid, -) (bool, error) { - if encryption.ShouldEncryptField(ctx, fieldName) { - return true, nil +) (*coreblock.Encryption, cidlink.Link, error) { + // if new encryption was requested by the user + if encryption.ShouldEncryptDocField(ctx, fieldName) { + encBlock := &coreblock.Encryption{DocID: []byte(docID)} + if encryption.ShouldEncryptIndividualField(ctx, fieldName) { + f := fieldName.Value() + encBlock.FieldName = &f + } + encryptor := encryption.GetEncryptorFromContext(ctx) + if encryptor != nil { + encKey, err := encryptor.GetOrGenerateEncryptionKey(docID, fieldName) + if err != nil { + return nil, cidlink.Link{}, err + } + if len(encKey) > 0 { + encBlock.Key = encKey + } + + link, err := mc.putEncBlock(ctx, encBlock) + if err != nil { + return nil, cidlink.Link{}, err + } + return encBlock, link, nil + } } + // otherwise we use the same encryption as the previous block for _, headCid := range heads { - bytes, err := mc.blockstore.AsIPLDStorage().Get(ctx, headCid.KeyString()) + prevBlockBytes, err := mc.blockstore.AsIPLDStorage().Get(ctx, headCid.KeyString()) if err != nil { - return false, NewErrCouldNotFindBlock(headCid, err) + return nil, cidlink.Link{}, NewErrCouldNotFindBlock(headCid, err) } - prevBlock, err := coreblock.GetFromBytes(bytes) + prevBlock, err := coreblock.GetFromBytes(prevBlockBytes) if err != nil { - return false, err + return nil, cidlink.Link{}, err } - if prevBlock.IsEncrypted != nil && *prevBlock.IsEncrypted { - return true, nil + if prevBlock.Encryption != nil { + prevBlockEncBytes, err := mc.encstore.AsIPLDStorage().Get(ctx, prevBlock.Encryption.Cid.KeyString()) + if err != nil { + return nil, cidlink.Link{}, NewErrCouldNotFindBlock(headCid, err) + } + prevEncBlock, err := coreblock.GetEncryptionBlockFromBytes(prevBlockEncBytes) + if err != nil { + return nil, cidlink.Link{}, err + } + return &coreblock.Encryption{ + DocID: prevEncBlock.DocID, + FieldName: prevEncBlock.FieldName, + Key: prevEncBlock.Key, + }, *prevBlock.Encryption, nil } } - return false, nil + return nil, cidlink.Link{}, nil } -func encryptBlock(ctx context.Context, block *coreblock.Block) (*coreblock.Block, error) { +func encryptBlock( + ctx context.Context, + block *coreblock.Block, + encBlock *coreblock.Encryption, +) (*coreblock.Block, error) { + if block.Delta.IsComposite() { + return block, nil + } + clonedCRDT := block.Delta.Clone() - bytes, err := encryption.EncryptDoc(ctx, string(clonedCRDT.GetDocID()), - clonedCRDT.GetFieldName(), clonedCRDT.GetData()) + _, encryptor := encryption.EnsureContextWithEncryptor(ctx) + bytes, err := encryptor.Encrypt(clonedCRDT.GetData(), encBlock.Key) if err != nil { return nil, err } clonedCRDT.SetData(bytes) - isEncrypted := true - return &coreblock.Block{Delta: clonedCRDT, Links: block.Links, IsEncrypted: &isEncrypted}, nil + return &coreblock.Block{Delta: clonedCRDT, Links: block.Links}, nil } // ProcessBlock merges the delta CRDT and updates the state accordingly. -// If onlyHeads is true, it will skip merging and update only the heads. func (mc *MerkleClock) ProcessBlock( ctx context.Context, block *coreblock.Block, blockLink cidlink.Link, - onlyHeads bool, ) error { - if !onlyHeads { - err := mc.crdt.Merge(ctx, block.Delta.GetDelta()) - if err != nil { - return NewErrMergingDelta(blockLink.Cid, err) - } + err := mc.crdt.Merge(ctx, block.Delta.GetDelta()) + if err != nil { + return NewErrMergingDelta(blockLink.Cid, err) } return mc.updateHeads(ctx, block, blockLink) diff --git a/internal/merkle/clock/clock_test.go b/internal/merkle/clock/clock_test.go index 7effc02fef..fe008971e4 100644 --- a/internal/merkle/clock/clock_test.go +++ b/internal/merkle/clock/clock_test.go @@ -37,7 +37,8 @@ func newTestMerkleClock() *MerkleClock { return NewMerkleClock( multistore.Headstore(), multistore.Blockstore(), - core.HeadStoreKey{DocID: request.DocIDArgName, FieldId: "1"}, + multistore.Encstore(), + core.HeadStoreKey{DocID: request.DocIDArgName, FieldID: "1"}, reg, ) } @@ -46,7 +47,7 @@ func TestNewMerkleClock(t *testing.T) { s := newDS() multistore := datastore.MultiStoreFrom(s) reg := crdt.NewLWWRegister(multistore.Rootstore(), core.CollectionSchemaVersionKey{}, core.DataStoreKey{}, "") - clk := NewMerkleClock(multistore.Headstore(), multistore.Blockstore(), core.HeadStoreKey{}, reg) + clk := NewMerkleClock(multistore.Headstore(), multistore.Blockstore(), multistore.Encstore(), core.HeadStoreKey{}, reg) if clk.headstore != multistore.Headstore() { t.Error("MerkleClock store not correctly set") diff --git a/internal/merkle/clock/errors.go b/internal/merkle/clock/errors.go index 9903f777a9..a20ce30731 100644 --- a/internal/merkle/clock/errors.go +++ b/internal/merkle/clock/errors.go @@ -26,6 +26,7 @@ const ( errReplacingHead = "error replacing head" errCouldNotFindBlock = "error checking for known block " errFailedToGetNextQResult = "failed to get next query result" + errCouldNotGetEncKey = "could not get encryption key" ) var ( @@ -39,6 +40,7 @@ var ( ErrCouldNotFindBlock = errors.New(errCouldNotFindBlock) ErrFailedToGetNextQResult = errors.New(errFailedToGetNextQResult) ErrDecodingHeight = errors.New("error decoding height") + ErrCouldNotGetEncKey = errors.New(errCouldNotGetEncKey) ) func NewErrCreatingBlock(inner error) error { diff --git a/internal/merkle/clock/heads_test.go b/internal/merkle/clock/heads_test.go index 94680569a8..0eb7acdd0e 100644 --- a/internal/merkle/clock/heads_test.go +++ b/internal/merkle/clock/heads_test.go @@ -45,7 +45,7 @@ func newHeadSet() *heads { return NewHeadSet( datastore.AsDSReaderWriter(s), - core.HeadStoreKey{}.WithDocID("myDocID").WithFieldId("1"), + core.HeadStoreKey{}.WithDocID("myDocID").WithFieldID("1"), ) } diff --git a/internal/merkle/crdt/composite.go b/internal/merkle/crdt/composite.go index 26ab4134e5..f8211b9f0a 100644 --- a/internal/merkle/crdt/composite.go +++ b/internal/merkle/crdt/composite.go @@ -44,7 +44,8 @@ func NewMerkleCompositeDAG( fieldName, ) - clock := clock.NewMerkleClock(store.Headstore(), store.Blockstore(), key.ToHeadStoreKey(), compositeDag) + clock := clock.NewMerkleClock(store.Headstore(), store.Blockstore(), store.Encstore(), + key.ToHeadStoreKey(), compositeDag) base := &baseMerkleCRDT{clock: clock, crdt: compositeDag} return &MerkleCompositeDAG{ diff --git a/internal/merkle/crdt/counter.go b/internal/merkle/crdt/counter.go index 1ff6874b08..21b26785b6 100644 --- a/internal/merkle/crdt/counter.go +++ b/internal/merkle/crdt/counter.go @@ -39,7 +39,7 @@ func NewMerkleCounter( kind client.ScalarKind, ) *MerkleCounter { register := crdt.NewCounter(store.Datastore(), schemaVersionKey, key, fieldName, allowDecrement, kind) - clk := clock.NewMerkleClock(store.Headstore(), store.Blockstore(), key.ToHeadStoreKey(), register) + clk := clock.NewMerkleClock(store.Headstore(), store.Blockstore(), store.Encstore(), key.ToHeadStoreKey(), register) base := &baseMerkleCRDT{clock: clk, crdt: register} return &MerkleCounter{ baseMerkleCRDT: base, diff --git a/internal/merkle/crdt/lwwreg.go b/internal/merkle/crdt/lwwreg.go index 11e73089bf..00c70dc4a9 100644 --- a/internal/merkle/crdt/lwwreg.go +++ b/internal/merkle/crdt/lwwreg.go @@ -37,7 +37,7 @@ func NewMerkleLWWRegister( fieldName string, ) *MerkleLWWRegister { register := corecrdt.NewLWWRegister(store.Datastore(), schemaVersionKey, key, fieldName) - clk := clock.NewMerkleClock(store.Headstore(), store.Blockstore(), key.ToHeadStoreKey(), register) + clk := clock.NewMerkleClock(store.Headstore(), store.Blockstore(), store.Encstore(), key.ToHeadStoreKey(), register) base := &baseMerkleCRDT{clock: clk, crdt: register} return &MerkleLWWRegister{ baseMerkleCRDT: base, diff --git a/internal/merkle/crdt/merklecrdt.go b/internal/merkle/crdt/merklecrdt.go index c7733be778..457ba0f200 100644 --- a/internal/merkle/crdt/merklecrdt.go +++ b/internal/merkle/crdt/merklecrdt.go @@ -27,6 +27,7 @@ import ( type Stores interface { Datastore() datastore.DSReaderWriter Blockstore() datastore.Blockstore + Encstore() datastore.Blockstore Headstore() datastore.DSReaderWriter } @@ -48,9 +49,7 @@ type MerkleClock interface { links ...coreblock.DAGLink, ) (cidlink.Link, []byte, error) // ProcessBlock processes a block and updates the CRDT state. - // The bool argument indicates whether only heads need to be updated. It is needed in case - // merge should be skipped for example if the block is encrypted. - ProcessBlock(context.Context, *coreblock.Block, cidlink.Link, bool) error + ProcessBlock(ctx context.Context, block *coreblock.Block, cid cidlink.Link) error } // baseMerkleCRDT handles the MerkleCRDT overhead functions that aren't CRDT specific like the mutations and state diff --git a/internal/merkle/crdt/merklecrdt_test.go b/internal/merkle/crdt/merklecrdt_test.go index 29482b28bf..74f4814ca3 100644 --- a/internal/merkle/crdt/merklecrdt_test.go +++ b/internal/merkle/crdt/merklecrdt_test.go @@ -32,7 +32,7 @@ func newTestBaseMerkleCRDT() (*baseMerkleCRDT, datastore.DSReaderWriter) { multistore := datastore.MultiStoreFrom(s) reg := crdt.NewLWWRegister(multistore.Datastore(), core.CollectionSchemaVersionKey{}, core.DataStoreKey{}, "") - clk := clock.NewMerkleClock(multistore.Headstore(), multistore.Blockstore(), core.HeadStoreKey{}, reg) + clk := clock.NewMerkleClock(multistore.Headstore(), multistore.Blockstore(), multistore.Encstore(), core.HeadStoreKey{}, reg) return &baseMerkleCRDT{clock: clk, crdt: reg}, multistore.Rootstore() } diff --git a/internal/planner/commit.go b/internal/planner/commit.go index 1e6a1f7b92..76825afe15 100644 --- a/internal/planner/commit.go +++ b/internal/planner/commit.go @@ -73,7 +73,7 @@ func (n *dagScanNode) Init() error { if n.commitSelect.FieldID.HasValue() { field := n.commitSelect.FieldID.Value() - dsKey = dsKey.WithFieldId(field) + dsKey = dsKey.WithFieldID(field) } n.spans = core.NewSpans(core.NewSpan(dsKey, dsKey.PrefixEnd())) @@ -104,16 +104,16 @@ func (n *dagScanNode) Spans(spans core.Spans) { } copy(headSetSpans.Value, spans.Value) - var fieldId string + var fieldID string if n.commitSelect.FieldID.HasValue() { - fieldId = n.commitSelect.FieldID.Value() + fieldID = n.commitSelect.FieldID.Value() } else { - fieldId = core.COMPOSITE_NAMESPACE + fieldID = core.COMPOSITE_NAMESPACE } for i, span := range headSetSpans.Value { - if span.Start().FieldID != fieldId { - headSetSpans.Value[i] = core.NewSpan(span.Start().WithFieldId(fieldId), core.DataStoreKey{}) + if span.Start().FieldID != fieldID { + headSetSpans.Value[i] = core.NewSpan(span.Start().WithFieldID(fieldID), core.DataStoreKey{}) } } diff --git a/net/client.go b/net/client.go index 77eb28d4d6..9d11a968d4 100644 --- a/net/client.go +++ b/net/client.go @@ -37,7 +37,7 @@ func (s *server) pushLog(evt event.Update, pid peer.ID) error { Cid: evt.Cid.Bytes(), SchemaRoot: []byte(evt.SchemaRoot), Creator: s.peer.host.ID().String(), - Log: &pb.Document_Log{ + Log: &pb.Log{ Block: evt.Block, }, } diff --git a/net/dialer_test.go b/net/dialer_test.go index 64060f2660..4ed8bcf68b 100644 --- a/net/dialer_test.go +++ b/net/dialer_test.go @@ -27,6 +27,7 @@ func TestDial_WithConnectedPeer_NoError(t *testing.T) { n1, err := NewPeer( ctx, db1.Blockstore(), + db1.Encstore(), db1.Events(), WithListenAddresses("/ip4/127.0.0.1/tcp/0"), ) @@ -35,6 +36,7 @@ func TestDial_WithConnectedPeer_NoError(t *testing.T) { n2, err := NewPeer( ctx, db2.Blockstore(), + db1.Encstore(), db2.Events(), WithListenAddresses("/ip4/127.0.0.1/tcp/0"), ) @@ -57,6 +59,7 @@ func TestDial_WithConnectedPeerAndSecondConnection_NoError(t *testing.T) { n1, err := NewPeer( ctx, db1.Blockstore(), + db1.Encstore(), db1.Events(), WithListenAddresses("/ip4/127.0.0.1/tcp/0"), ) @@ -65,6 +68,7 @@ func TestDial_WithConnectedPeerAndSecondConnection_NoError(t *testing.T) { n2, err := NewPeer( ctx, db2.Blockstore(), + db1.Encstore(), db2.Events(), WithListenAddresses("/ip4/127.0.0.1/tcp/0"), ) @@ -90,6 +94,7 @@ func TestDial_WithConnectedPeerAndSecondConnectionWithConnectionShutdown_Closing n1, err := NewPeer( ctx, db1.Blockstore(), + db1.Encstore(), db1.Events(), WithListenAddresses("/ip4/127.0.0.1/tcp/0"), ) @@ -98,6 +103,7 @@ func TestDial_WithConnectedPeerAndSecondConnectionWithConnectionShutdown_Closing n2, err := NewPeer( ctx, db2.Blockstore(), + db1.Encstore(), db2.Events(), WithListenAddresses("/ip4/127.0.0.1/tcp/0"), ) diff --git a/net/errors.go b/net/errors.go index 615f1088ef..3a21c8e5c1 100644 --- a/net/errors.go +++ b/net/errors.go @@ -22,6 +22,9 @@ const ( errPublishingToDocIDTopic = "can't publish log %s for docID %s" errPublishingToSchemaTopic = "can't publish log %s for schema %s" errCheckingForExistingBlock = "failed to check for existing block" + errRequestingEncryptionKeys = "failed to request encryption keys with %v" + errTopicAlreadyExist = "topic with name \"%s\" already exists" + errTopicDoesNotExist = "topic with name \"%s\" does not exists" ) var ( @@ -49,6 +52,10 @@ func NewErrPublishingToSchemaTopic(inner error, cid, docID string, kv ...errors. return errors.Wrap(fmt.Sprintf(errPublishingToSchemaTopic, cid, docID), inner, kv...) } -func NewErrCheckingForExistingBlock(inner error, cid string) error { - return errors.Wrap(errCheckingForExistingBlock, inner, errors.NewKV("cid", cid)) +func NewErrTopicAlreadyExist(topic string) error { + return errors.New(fmt.Sprintf(errTopicAlreadyExist, topic)) +} + +func NewErrTopicDoesNotExist(topic string) error { + return errors.New(fmt.Sprintf(errTopicDoesNotExist, topic)) } diff --git a/net/pb/Makefile b/net/pb/Makefile index 233665c334..30b0e92dfa 100644 --- a/net/pb/Makefile +++ b/net/pb/Makefile @@ -1,8 +1,13 @@ PB = $(wildcard *.proto) GO = $(PB:.proto=.pb.go) +PROTOC_GEN_GO := $(shell which protoc-gen-go) +PROTOC_GEN_GO_GRPC := $(shell which protoc-gen-go-grpc) +PROTOC_GEN_GO_VTPROTO := $(shell which protoc-gen-go-vtproto) + all: $(GO) +.PHONY: deps deps: go install google.golang.org/protobuf/cmd/protoc-gen-go@latest go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest @@ -10,14 +15,14 @@ deps: %.pb.go: %.proto protoc \ - --go_out=. --plugin protoc-gen-go="${GOBIN}/protoc-gen-go" \ - --go-grpc_out=. --plugin protoc-gen-go-grpc="${GOBIN}/protoc-gen-go-grpc" \ - --go-vtproto_out=. --plugin protoc-gen-go-vtproto="${GOBIN}/protoc-gen-go-vtproto" \ + -I. \ + --go_out=. --plugin protoc-gen-go="$(PROTOC_GEN_GO)" \ + --go-grpc_out=. --plugin protoc-gen-go-grpc="$(PROTOC_GEN_GO_GRPC)" \ + --go-vtproto_out=. --plugin protoc-gen-go-vtproto="$(PROTOC_GEN_GO_VTPROTO)" \ --go-vtproto_opt=features=marshal+unmarshal+size \ - $< + $< # This line specifies the input file +.PHONY: clean clean: rm -f *.pb.go rm -f *pb_test.go - -.PHONY: clean \ No newline at end of file diff --git a/net/pb/net.pb.go b/net/pb/net.pb.go index a9b5a2162d..dbac6829d0 100644 --- a/net/pb/net.pb.go +++ b/net/pb/net.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.31.0 -// protoc v4.25.1 +// protoc-gen-go v1.34.2 +// protoc v5.27.1 // source: net.proto package net_pb @@ -21,19 +21,17 @@ const ( ) // Log represents a thread log. -type Document struct { +type Log struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - // ID of the document. - DocID []byte `protobuf:"bytes,1,opt,name=docID,proto3" json:"docID,omitempty"` - // head of the log. - Head []byte `protobuf:"bytes,4,opt,name=head,proto3" json:"head,omitempty"` + // block is the top-level node's raw data as an ipld.Block. + Block []byte `protobuf:"bytes,1,opt,name=block,proto3" json:"block,omitempty"` } -func (x *Document) Reset() { - *x = Document{} +func (x *Log) Reset() { + *x = Log{} if protoimpl.UnsafeEnabled { mi := &file_net_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -41,13 +39,13 @@ func (x *Document) Reset() { } } -func (x *Document) String() string { +func (x *Log) String() string { return protoimpl.X.MessageStringOf(x) } -func (*Document) ProtoMessage() {} +func (*Log) ProtoMessage() {} -func (x *Document) ProtoReflect() protoreflect.Message { +func (x *Log) ProtoReflect() protoreflect.Message { mi := &file_net_proto_msgTypes[0] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -59,21 +57,14 @@ func (x *Document) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use Document.ProtoReflect.Descriptor instead. -func (*Document) Descriptor() ([]byte, []int) { +// Deprecated: Use Log.ProtoReflect.Descriptor instead. +func (*Log) Descriptor() ([]byte, []int) { return file_net_proto_rawDescGZIP(), []int{0} } -func (x *Document) GetDocID() []byte { - if x != nil { - return x.DocID - } - return nil -} - -func (x *Document) GetHead() []byte { +func (x *Log) GetBlock() []byte { if x != nil { - return x.Head + return x.Block } return nil } @@ -353,14 +344,21 @@ func (x *PushLogRequest) GetBody() *PushLogRequest_Body { return nil } -type GetHeadLogRequest struct { +// FetchEncryptionKeyRequest is a request to receive a doc encryption key +// from a peer that holds it. +type FetchEncryptionKeyRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields + + // links is the list of cid links of the blocks containing encryption keys. + Links [][]byte `protobuf:"bytes,1,rep,name=links,proto3" json:"links,omitempty"` + // ephemeralPublicKey is an ephemeral public of the requesting peer for deriving shared secret + EphemeralPublicKey []byte `protobuf:"bytes,2,opt,name=ephemeralPublicKey,proto3" json:"ephemeralPublicKey,omitempty"` } -func (x *GetHeadLogRequest) Reset() { - *x = GetHeadLogRequest{} +func (x *FetchEncryptionKeyRequest) Reset() { + *x = FetchEncryptionKeyRequest{} if protoimpl.UnsafeEnabled { mi := &file_net_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -368,13 +366,13 @@ func (x *GetHeadLogRequest) Reset() { } } -func (x *GetHeadLogRequest) String() string { +func (x *FetchEncryptionKeyRequest) String() string { return protoimpl.X.MessageStringOf(x) } -func (*GetHeadLogRequest) ProtoMessage() {} +func (*FetchEncryptionKeyRequest) ProtoMessage() {} -func (x *GetHeadLogRequest) ProtoReflect() protoreflect.Message { +func (x *FetchEncryptionKeyRequest) ProtoReflect() protoreflect.Message { mi := &file_net_proto_msgTypes[8] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -386,19 +384,43 @@ func (x *GetHeadLogRequest) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use GetHeadLogRequest.ProtoReflect.Descriptor instead. -func (*GetHeadLogRequest) Descriptor() ([]byte, []int) { +// Deprecated: Use FetchEncryptionKeyRequest.ProtoReflect.Descriptor instead. +func (*FetchEncryptionKeyRequest) Descriptor() ([]byte, []int) { return file_net_proto_rawDescGZIP(), []int{8} } -type PushLogReply struct { +func (x *FetchEncryptionKeyRequest) GetLinks() [][]byte { + if x != nil { + return x.Links + } + return nil +} + +func (x *FetchEncryptionKeyRequest) GetEphemeralPublicKey() []byte { + if x != nil { + return x.EphemeralPublicKey + } + return nil +} + +// FetchEncryptionKeyReply is a response to FetchEncryptionKeyRequest request +// by a peer that holds the requested doc encryption key. +type FetchEncryptionKeyReply struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields + + // links is the list of cid links of the blocks containing encryption keys. + Links [][]byte `protobuf:"bytes,1,rep,name=links,proto3" json:"links,omitempty"` + // blocks is the list of blocks containing encryption keys. The order of blocks should match the order of links. + // Every block is encrypted and contains a nonce. + Blocks [][]byte `protobuf:"bytes,2,rep,name=blocks,proto3" json:"blocks,omitempty"` + // ephemeralPublicKey is an ephemeral public of the responding peer for deriving shared secret + EphemeralPublicKey []byte `protobuf:"bytes,3,opt,name=ephemeralPublicKey,proto3" json:"ephemeralPublicKey,omitempty"` } -func (x *PushLogReply) Reset() { - *x = PushLogReply{} +func (x *FetchEncryptionKeyReply) Reset() { + *x = FetchEncryptionKeyReply{} if protoimpl.UnsafeEnabled { mi := &file_net_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -406,13 +428,13 @@ func (x *PushLogReply) Reset() { } } -func (x *PushLogReply) String() string { +func (x *FetchEncryptionKeyReply) String() string { return protoimpl.X.MessageStringOf(x) } -func (*PushLogReply) ProtoMessage() {} +func (*FetchEncryptionKeyReply) ProtoMessage() {} -func (x *PushLogReply) ProtoReflect() protoreflect.Message { +func (x *FetchEncryptionKeyReply) ProtoReflect() protoreflect.Message { mi := &file_net_proto_msgTypes[9] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -424,19 +446,40 @@ func (x *PushLogReply) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use PushLogReply.ProtoReflect.Descriptor instead. -func (*PushLogReply) Descriptor() ([]byte, []int) { +// Deprecated: Use FetchEncryptionKeyReply.ProtoReflect.Descriptor instead. +func (*FetchEncryptionKeyReply) Descriptor() ([]byte, []int) { return file_net_proto_rawDescGZIP(), []int{9} } -type GetHeadLogReply struct { +func (x *FetchEncryptionKeyReply) GetLinks() [][]byte { + if x != nil { + return x.Links + } + return nil +} + +func (x *FetchEncryptionKeyReply) GetBlocks() [][]byte { + if x != nil { + return x.Blocks + } + return nil +} + +func (x *FetchEncryptionKeyReply) GetEphemeralPublicKey() []byte { + if x != nil { + return x.EphemeralPublicKey + } + return nil +} + +type GetHeadLogRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields } -func (x *GetHeadLogReply) Reset() { - *x = GetHeadLogReply{} +func (x *GetHeadLogRequest) Reset() { + *x = GetHeadLogRequest{} if protoimpl.UnsafeEnabled { mi := &file_net_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -444,13 +487,13 @@ func (x *GetHeadLogReply) Reset() { } } -func (x *GetHeadLogReply) String() string { +func (x *GetHeadLogRequest) String() string { return protoimpl.X.MessageStringOf(x) } -func (*GetHeadLogReply) ProtoMessage() {} +func (*GetHeadLogRequest) ProtoMessage() {} -func (x *GetHeadLogReply) ProtoReflect() protoreflect.Message { +func (x *GetHeadLogRequest) ProtoReflect() protoreflect.Message { mi := &file_net_proto_msgTypes[10] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -462,23 +505,19 @@ func (x *GetHeadLogReply) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use GetHeadLogReply.ProtoReflect.Descriptor instead. -func (*GetHeadLogReply) Descriptor() ([]byte, []int) { +// Deprecated: Use GetHeadLogRequest.ProtoReflect.Descriptor instead. +func (*GetHeadLogRequest) Descriptor() ([]byte, []int) { return file_net_proto_rawDescGZIP(), []int{10} } -// Record is a thread record containing link data. -type Document_Log struct { +type PushLogReply struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - - // block is the top-level node's raw data as an ipld.Block. - Block []byte `protobuf:"bytes,1,opt,name=block,proto3" json:"block,omitempty"` } -func (x *Document_Log) Reset() { - *x = Document_Log{} +func (x *PushLogReply) Reset() { + *x = PushLogReply{} if protoimpl.UnsafeEnabled { mi := &file_net_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -486,13 +525,13 @@ func (x *Document_Log) Reset() { } } -func (x *Document_Log) String() string { +func (x *PushLogReply) String() string { return protoimpl.X.MessageStringOf(x) } -func (*Document_Log) ProtoMessage() {} +func (*PushLogReply) ProtoMessage() {} -func (x *Document_Log) ProtoReflect() protoreflect.Message { +func (x *PushLogReply) ProtoReflect() protoreflect.Message { mi := &file_net_proto_msgTypes[11] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -504,16 +543,47 @@ func (x *Document_Log) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use Document_Log.ProtoReflect.Descriptor instead. -func (*Document_Log) Descriptor() ([]byte, []int) { - return file_net_proto_rawDescGZIP(), []int{0, 0} +// Deprecated: Use PushLogReply.ProtoReflect.Descriptor instead. +func (*PushLogReply) Descriptor() ([]byte, []int) { + return file_net_proto_rawDescGZIP(), []int{11} } -func (x *Document_Log) GetBlock() []byte { - if x != nil { - return x.Block +type GetHeadLogReply struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *GetHeadLogReply) Reset() { + *x = GetHeadLogReply{} + if protoimpl.UnsafeEnabled { + mi := &file_net_proto_msgTypes[12] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } - return nil +} + +func (x *GetHeadLogReply) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetHeadLogReply) ProtoMessage() {} + +func (x *GetHeadLogReply) ProtoReflect() protoreflect.Message { + mi := &file_net_proto_msgTypes[12] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetHeadLogReply.ProtoReflect.Descriptor instead. +func (*GetHeadLogReply) Descriptor() ([]byte, []int) { + return file_net_proto_rawDescGZIP(), []int{12} } type PushLogRequest_Body struct { @@ -530,13 +600,13 @@ type PushLogRequest_Body struct { // creator is the PeerID of the peer that created the log. Creator string `protobuf:"bytes,4,opt,name=creator,proto3" json:"creator,omitempty"` // log hold the block that represent version of the document. - Log *Document_Log `protobuf:"bytes,6,opt,name=log,proto3" json:"log,omitempty"` + Log *Log `protobuf:"bytes,6,opt,name=log,proto3" json:"log,omitempty"` } func (x *PushLogRequest_Body) Reset() { *x = PushLogRequest_Body{} if protoimpl.UnsafeEnabled { - mi := &file_net_proto_msgTypes[12] + mi := &file_net_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -549,7 +619,7 @@ func (x *PushLogRequest_Body) String() string { func (*PushLogRequest_Body) ProtoMessage() {} func (x *PushLogRequest_Body) ProtoReflect() protoreflect.Message { - mi := &file_net_proto_msgTypes[12] + mi := &file_net_proto_msgTypes[13] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -593,7 +663,7 @@ func (x *PushLogRequest_Body) GetCreator() string { return "" } -func (x *PushLogRequest_Body) GetLog() *Document_Log { +func (x *PushLogRequest_Body) GetLog() *Log { if x != nil { return x.Log } @@ -604,59 +674,68 @@ var File_net_proto protoreflect.FileDescriptor var file_net_proto_rawDesc = []byte{ 0x0a, 0x09, 0x6e, 0x65, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x06, 0x6e, 0x65, 0x74, - 0x2e, 0x70, 0x62, 0x22, 0x51, 0x0a, 0x08, 0x44, 0x6f, 0x63, 0x75, 0x6d, 0x65, 0x6e, 0x74, 0x12, - 0x14, 0x0a, 0x05, 0x64, 0x6f, 0x63, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, - 0x64, 0x6f, 0x63, 0x49, 0x44, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x65, 0x61, 0x64, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x0c, 0x52, 0x04, 0x68, 0x65, 0x61, 0x64, 0x1a, 0x1b, 0x0a, 0x03, 0x4c, 0x6f, 0x67, - 0x12, 0x14, 0x0a, 0x05, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, - 0x05, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65, 0x74, 0x44, 0x6f, 0x63, - 0x47, 0x72, 0x61, 0x70, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x12, 0x0a, 0x10, - 0x47, 0x65, 0x74, 0x44, 0x6f, 0x63, 0x47, 0x72, 0x61, 0x70, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, - 0x22, 0x15, 0x0a, 0x13, 0x50, 0x75, 0x73, 0x68, 0x44, 0x6f, 0x63, 0x47, 0x72, 0x61, 0x70, 0x68, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x13, 0x0a, 0x11, 0x50, 0x75, 0x73, 0x68, 0x44, - 0x6f, 0x63, 0x47, 0x72, 0x61, 0x70, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x0f, 0x0a, 0x0d, - 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0d, 0x0a, - 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0xd4, 0x01, 0x0a, - 0x0e, 0x50, 0x75, 0x73, 0x68, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, - 0x2f, 0x0a, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1b, 0x2e, - 0x6e, 0x65, 0x74, 0x2e, 0x70, 0x62, 0x2e, 0x50, 0x75, 0x73, 0x68, 0x4c, 0x6f, 0x67, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52, 0x04, 0x62, 0x6f, 0x64, 0x79, - 0x1a, 0x90, 0x01, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x64, 0x6f, 0x63, - 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x64, 0x6f, 0x63, 0x49, 0x44, 0x12, - 0x10, 0x0a, 0x03, 0x63, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x03, 0x63, 0x69, - 0x64, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x52, 0x6f, 0x6f, 0x74, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x73, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x52, 0x6f, 0x6f, - 0x74, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x72, 0x65, 0x61, 0x74, 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x07, 0x63, 0x72, 0x65, 0x61, 0x74, 0x6f, 0x72, 0x12, 0x26, 0x0a, 0x03, 0x6c, - 0x6f, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x70, - 0x62, 0x2e, 0x44, 0x6f, 0x63, 0x75, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4c, 0x6f, 0x67, 0x52, 0x03, - 0x6c, 0x6f, 0x67, 0x22, 0x13, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x48, 0x65, 0x61, 0x64, 0x4c, 0x6f, - 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0e, 0x0a, 0x0c, 0x50, 0x75, 0x73, 0x68, - 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x11, 0x0a, 0x0f, 0x47, 0x65, 0x74, 0x48, - 0x65, 0x61, 0x64, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x32, 0xd1, 0x02, 0x0a, 0x07, - 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x44, 0x6f, - 0x63, 0x47, 0x72, 0x61, 0x70, 0x68, 0x12, 0x1a, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x70, 0x62, 0x2e, - 0x47, 0x65, 0x74, 0x44, 0x6f, 0x63, 0x47, 0x72, 0x61, 0x70, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x18, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x70, 0x62, 0x2e, 0x47, 0x65, 0x74, 0x44, - 0x6f, 0x63, 0x47, 0x72, 0x61, 0x70, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x48, - 0x0a, 0x0c, 0x50, 0x75, 0x73, 0x68, 0x44, 0x6f, 0x63, 0x47, 0x72, 0x61, 0x70, 0x68, 0x12, 0x1b, - 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x70, 0x62, 0x2e, 0x50, 0x75, 0x73, 0x68, 0x44, 0x6f, 0x63, 0x47, - 0x72, 0x61, 0x70, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x6e, 0x65, + 0x2e, 0x70, 0x62, 0x22, 0x1b, 0x0a, 0x03, 0x4c, 0x6f, 0x67, 0x12, 0x14, 0x0a, 0x05, 0x62, 0x6c, + 0x6f, 0x63, 0x6b, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x62, 0x6c, 0x6f, 0x63, 0x6b, + 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65, 0x74, 0x44, 0x6f, 0x63, 0x47, 0x72, 0x61, 0x70, 0x68, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x44, 0x6f, 0x63, + 0x47, 0x72, 0x61, 0x70, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x15, 0x0a, 0x13, 0x50, 0x75, + 0x73, 0x68, 0x44, 0x6f, 0x63, 0x47, 0x72, 0x61, 0x70, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x22, 0x13, 0x0a, 0x11, 0x50, 0x75, 0x73, 0x68, 0x44, 0x6f, 0x63, 0x47, 0x72, 0x61, 0x70, + 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x0f, 0x0a, 0x0d, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0d, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, + 0x67, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0xcb, 0x01, 0x0a, 0x0e, 0x50, 0x75, 0x73, 0x68, 0x4c, + 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2f, 0x0a, 0x04, 0x62, 0x6f, 0x64, + 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x70, 0x62, + 0x2e, 0x50, 0x75, 0x73, 0x68, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, + 0x42, 0x6f, 0x64, 0x79, 0x52, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x1a, 0x87, 0x01, 0x0a, 0x04, 0x42, + 0x6f, 0x64, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x64, 0x6f, 0x63, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x0c, 0x52, 0x05, 0x64, 0x6f, 0x63, 0x49, 0x44, 0x12, 0x10, 0x0a, 0x03, 0x63, 0x69, 0x64, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x03, 0x63, 0x69, 0x64, 0x12, 0x1e, 0x0a, 0x0a, 0x73, + 0x63, 0x68, 0x65, 0x6d, 0x61, 0x52, 0x6f, 0x6f, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, + 0x0a, 0x73, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x52, 0x6f, 0x6f, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x63, + 0x72, 0x65, 0x61, 0x74, 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x63, 0x72, + 0x65, 0x61, 0x74, 0x6f, 0x72, 0x12, 0x1d, 0x0a, 0x03, 0x6c, 0x6f, 0x67, 0x18, 0x06, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x0b, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x70, 0x62, 0x2e, 0x4c, 0x6f, 0x67, 0x52, + 0x03, 0x6c, 0x6f, 0x67, 0x22, 0x61, 0x0a, 0x19, 0x46, 0x65, 0x74, 0x63, 0x68, 0x45, 0x6e, 0x63, + 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x14, 0x0a, 0x05, 0x6c, 0x69, 0x6e, 0x6b, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0c, + 0x52, 0x05, 0x6c, 0x69, 0x6e, 0x6b, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x65, 0x70, 0x68, 0x65, 0x6d, + 0x65, 0x72, 0x61, 0x6c, 0x50, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x0c, 0x52, 0x12, 0x65, 0x70, 0x68, 0x65, 0x6d, 0x65, 0x72, 0x61, 0x6c, 0x50, 0x75, + 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x22, 0x77, 0x0a, 0x17, 0x46, 0x65, 0x74, 0x63, 0x68, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x70, + 0x6c, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x6c, 0x69, 0x6e, 0x6b, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x0c, 0x52, 0x05, 0x6c, 0x69, 0x6e, 0x6b, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x62, 0x6c, 0x6f, 0x63, + 0x6b, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x06, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x73, + 0x12, 0x2e, 0x0a, 0x12, 0x65, 0x70, 0x68, 0x65, 0x6d, 0x65, 0x72, 0x61, 0x6c, 0x50, 0x75, 0x62, + 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x12, 0x65, 0x70, + 0x68, 0x65, 0x6d, 0x65, 0x72, 0x61, 0x6c, 0x50, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, + 0x22, 0x13, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x48, 0x65, 0x61, 0x64, 0x4c, 0x6f, 0x67, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0e, 0x0a, 0x0c, 0x50, 0x75, 0x73, 0x68, 0x4c, 0x6f, 0x67, + 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x11, 0x0a, 0x0f, 0x47, 0x65, 0x74, 0x48, 0x65, 0x61, 0x64, + 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x32, 0xd1, 0x02, 0x0a, 0x07, 0x53, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x44, 0x6f, 0x63, 0x47, 0x72, + 0x61, 0x70, 0x68, 0x12, 0x1a, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x70, 0x62, 0x2e, 0x47, 0x65, 0x74, + 0x44, 0x6f, 0x63, 0x47, 0x72, 0x61, 0x70, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x18, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x70, 0x62, 0x2e, 0x47, 0x65, 0x74, 0x44, 0x6f, 0x63, 0x47, + 0x72, 0x61, 0x70, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0c, 0x50, + 0x75, 0x73, 0x68, 0x44, 0x6f, 0x63, 0x47, 0x72, 0x61, 0x70, 0x68, 0x12, 0x1b, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x70, 0x62, 0x2e, 0x50, 0x75, 0x73, 0x68, 0x44, 0x6f, 0x63, 0x47, 0x72, 0x61, 0x70, - 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x36, 0x0a, 0x06, 0x47, 0x65, 0x74, 0x4c, - 0x6f, 0x67, 0x12, 0x15, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x70, 0x62, 0x2e, 0x47, 0x65, 0x74, 0x4c, - 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x6e, 0x65, 0x74, 0x2e, - 0x70, 0x62, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, - 0x12, 0x39, 0x0a, 0x07, 0x50, 0x75, 0x73, 0x68, 0x4c, 0x6f, 0x67, 0x12, 0x16, 0x2e, 0x6e, 0x65, - 0x74, 0x2e, 0x70, 0x62, 0x2e, 0x50, 0x75, 0x73, 0x68, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x70, 0x62, 0x2e, 0x50, 0x75, 0x73, - 0x68, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x0a, 0x47, - 0x65, 0x74, 0x48, 0x65, 0x61, 0x64, 0x4c, 0x6f, 0x67, 0x12, 0x19, 0x2e, 0x6e, 0x65, 0x74, 0x2e, - 0x70, 0x62, 0x2e, 0x47, 0x65, 0x74, 0x48, 0x65, 0x61, 0x64, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x70, 0x62, 0x2e, 0x47, 0x65, - 0x74, 0x48, 0x65, 0x61, 0x64, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42, - 0x0a, 0x5a, 0x08, 0x2f, 0x3b, 0x6e, 0x65, 0x74, 0x5f, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x33, + 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x70, + 0x62, 0x2e, 0x50, 0x75, 0x73, 0x68, 0x44, 0x6f, 0x63, 0x47, 0x72, 0x61, 0x70, 0x68, 0x52, 0x65, + 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x36, 0x0a, 0x06, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x12, + 0x15, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x70, 0x62, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x70, 0x62, 0x2e, + 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x39, 0x0a, + 0x07, 0x50, 0x75, 0x73, 0x68, 0x4c, 0x6f, 0x67, 0x12, 0x16, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x70, + 0x62, 0x2e, 0x50, 0x75, 0x73, 0x68, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x14, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x70, 0x62, 0x2e, 0x50, 0x75, 0x73, 0x68, 0x4c, 0x6f, + 0x67, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x0a, 0x47, 0x65, 0x74, 0x48, + 0x65, 0x61, 0x64, 0x4c, 0x6f, 0x67, 0x12, 0x19, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x70, 0x62, 0x2e, + 0x47, 0x65, 0x74, 0x48, 0x65, 0x61, 0x64, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x17, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x70, 0x62, 0x2e, 0x47, 0x65, 0x74, 0x48, 0x65, + 0x61, 0x64, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42, 0x0a, 0x5a, 0x08, + 0x2f, 0x3b, 0x6e, 0x65, 0x74, 0x5f, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -671,35 +750,36 @@ func file_net_proto_rawDescGZIP() []byte { return file_net_proto_rawDescData } -var file_net_proto_msgTypes = make([]protoimpl.MessageInfo, 13) -var file_net_proto_goTypes = []interface{}{ - (*Document)(nil), // 0: net.pb.Document - (*GetDocGraphRequest)(nil), // 1: net.pb.GetDocGraphRequest - (*GetDocGraphReply)(nil), // 2: net.pb.GetDocGraphReply - (*PushDocGraphRequest)(nil), // 3: net.pb.PushDocGraphRequest - (*PushDocGraphReply)(nil), // 4: net.pb.PushDocGraphReply - (*GetLogRequest)(nil), // 5: net.pb.GetLogRequest - (*GetLogReply)(nil), // 6: net.pb.GetLogReply - (*PushLogRequest)(nil), // 7: net.pb.PushLogRequest - (*GetHeadLogRequest)(nil), // 8: net.pb.GetHeadLogRequest - (*PushLogReply)(nil), // 9: net.pb.PushLogReply - (*GetHeadLogReply)(nil), // 10: net.pb.GetHeadLogReply - (*Document_Log)(nil), // 11: net.pb.Document.Log - (*PushLogRequest_Body)(nil), // 12: net.pb.PushLogRequest.Body +var file_net_proto_msgTypes = make([]protoimpl.MessageInfo, 14) +var file_net_proto_goTypes = []any{ + (*Log)(nil), // 0: net.pb.Log + (*GetDocGraphRequest)(nil), // 1: net.pb.GetDocGraphRequest + (*GetDocGraphReply)(nil), // 2: net.pb.GetDocGraphReply + (*PushDocGraphRequest)(nil), // 3: net.pb.PushDocGraphRequest + (*PushDocGraphReply)(nil), // 4: net.pb.PushDocGraphReply + (*GetLogRequest)(nil), // 5: net.pb.GetLogRequest + (*GetLogReply)(nil), // 6: net.pb.GetLogReply + (*PushLogRequest)(nil), // 7: net.pb.PushLogRequest + (*FetchEncryptionKeyRequest)(nil), // 8: net.pb.FetchEncryptionKeyRequest + (*FetchEncryptionKeyReply)(nil), // 9: net.pb.FetchEncryptionKeyReply + (*GetHeadLogRequest)(nil), // 10: net.pb.GetHeadLogRequest + (*PushLogReply)(nil), // 11: net.pb.PushLogReply + (*GetHeadLogReply)(nil), // 12: net.pb.GetHeadLogReply + (*PushLogRequest_Body)(nil), // 13: net.pb.PushLogRequest.Body } var file_net_proto_depIdxs = []int32{ - 12, // 0: net.pb.PushLogRequest.body:type_name -> net.pb.PushLogRequest.Body - 11, // 1: net.pb.PushLogRequest.Body.log:type_name -> net.pb.Document.Log + 13, // 0: net.pb.PushLogRequest.body:type_name -> net.pb.PushLogRequest.Body + 0, // 1: net.pb.PushLogRequest.Body.log:type_name -> net.pb.Log 1, // 2: net.pb.Service.GetDocGraph:input_type -> net.pb.GetDocGraphRequest 3, // 3: net.pb.Service.PushDocGraph:input_type -> net.pb.PushDocGraphRequest 5, // 4: net.pb.Service.GetLog:input_type -> net.pb.GetLogRequest 7, // 5: net.pb.Service.PushLog:input_type -> net.pb.PushLogRequest - 8, // 6: net.pb.Service.GetHeadLog:input_type -> net.pb.GetHeadLogRequest + 10, // 6: net.pb.Service.GetHeadLog:input_type -> net.pb.GetHeadLogRequest 2, // 7: net.pb.Service.GetDocGraph:output_type -> net.pb.GetDocGraphReply 4, // 8: net.pb.Service.PushDocGraph:output_type -> net.pb.PushDocGraphReply 6, // 9: net.pb.Service.GetLog:output_type -> net.pb.GetLogReply - 9, // 10: net.pb.Service.PushLog:output_type -> net.pb.PushLogReply - 10, // 11: net.pb.Service.GetHeadLog:output_type -> net.pb.GetHeadLogReply + 11, // 10: net.pb.Service.PushLog:output_type -> net.pb.PushLogReply + 12, // 11: net.pb.Service.GetHeadLog:output_type -> net.pb.GetHeadLogReply 7, // [7:12] is the sub-list for method output_type 2, // [2:7] is the sub-list for method input_type 2, // [2:2] is the sub-list for extension type_name @@ -713,8 +793,8 @@ func file_net_proto_init() { return } if !protoimpl.UnsafeEnabled { - file_net_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Document); i { + file_net_proto_msgTypes[0].Exporter = func(v any, i int) any { + switch v := v.(*Log); i { case 0: return &v.state case 1: @@ -725,7 +805,7 @@ func file_net_proto_init() { return nil } } - file_net_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + file_net_proto_msgTypes[1].Exporter = func(v any, i int) any { switch v := v.(*GetDocGraphRequest); i { case 0: return &v.state @@ -737,7 +817,7 @@ func file_net_proto_init() { return nil } } - file_net_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + file_net_proto_msgTypes[2].Exporter = func(v any, i int) any { switch v := v.(*GetDocGraphReply); i { case 0: return &v.state @@ -749,7 +829,7 @@ func file_net_proto_init() { return nil } } - file_net_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + file_net_proto_msgTypes[3].Exporter = func(v any, i int) any { switch v := v.(*PushDocGraphRequest); i { case 0: return &v.state @@ -761,7 +841,7 @@ func file_net_proto_init() { return nil } } - file_net_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + file_net_proto_msgTypes[4].Exporter = func(v any, i int) any { switch v := v.(*PushDocGraphReply); i { case 0: return &v.state @@ -773,7 +853,7 @@ func file_net_proto_init() { return nil } } - file_net_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + file_net_proto_msgTypes[5].Exporter = func(v any, i int) any { switch v := v.(*GetLogRequest); i { case 0: return &v.state @@ -785,7 +865,7 @@ func file_net_proto_init() { return nil } } - file_net_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { + file_net_proto_msgTypes[6].Exporter = func(v any, i int) any { switch v := v.(*GetLogReply); i { case 0: return &v.state @@ -797,7 +877,7 @@ func file_net_proto_init() { return nil } } - file_net_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { + file_net_proto_msgTypes[7].Exporter = func(v any, i int) any { switch v := v.(*PushLogRequest); i { case 0: return &v.state @@ -809,8 +889,8 @@ func file_net_proto_init() { return nil } } - file_net_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetHeadLogRequest); i { + file_net_proto_msgTypes[8].Exporter = func(v any, i int) any { + switch v := v.(*FetchEncryptionKeyRequest); i { case 0: return &v.state case 1: @@ -821,8 +901,8 @@ func file_net_proto_init() { return nil } } - file_net_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PushLogReply); i { + file_net_proto_msgTypes[9].Exporter = func(v any, i int) any { + switch v := v.(*FetchEncryptionKeyReply); i { case 0: return &v.state case 1: @@ -833,8 +913,20 @@ func file_net_proto_init() { return nil } } - file_net_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetHeadLogReply); i { + file_net_proto_msgTypes[10].Exporter = func(v any, i int) any { + switch v := v.(*GetHeadLogRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_net_proto_msgTypes[11].Exporter = func(v any, i int) any { + switch v := v.(*PushLogReply); i { case 0: return &v.state case 1: @@ -845,8 +937,8 @@ func file_net_proto_init() { return nil } } - file_net_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Document_Log); i { + file_net_proto_msgTypes[12].Exporter = func(v any, i int) any { + switch v := v.(*GetHeadLogReply); i { case 0: return &v.state case 1: @@ -857,7 +949,7 @@ func file_net_proto_init() { return nil } } - file_net_proto_msgTypes[12].Exporter = func(v interface{}, i int) interface{} { + file_net_proto_msgTypes[13].Exporter = func(v any, i int) any { switch v := v.(*PushLogRequest_Body); i { case 0: return &v.state @@ -876,7 +968,7 @@ func file_net_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_net_proto_rawDesc, NumEnums: 0, - NumMessages: 13, + NumMessages: 14, NumExtensions: 0, NumServices: 1, }, diff --git a/net/pb/net.proto b/net/pb/net.proto index 5b0ee35dfb..8dc8fe8a46 100644 --- a/net/pb/net.proto +++ b/net/pb/net.proto @@ -4,17 +4,9 @@ package net.pb; option go_package = "/;net_pb"; // Log represents a thread log. -message Document { - // ID of the document. - bytes docID = 1; - // head of the log. - bytes head = 4; - - // Record is a thread record containing link data. - message Log { - // block is the top-level node's raw data as an ipld.Block. - bytes block = 1; - } +message Log { + // block is the top-level node's raw data as an ipld.Block. + bytes block = 1; } message GetDocGraphRequest {} @@ -42,10 +34,31 @@ message PushLogRequest { // creator is the PeerID of the peer that created the log. string creator = 4; // log hold the block that represent version of the document. - Document.Log log = 6; + Log log = 6; } } +// FetchEncryptionKeyRequest is a request to receive a doc encryption key +// from a peer that holds it. +message FetchEncryptionKeyRequest { + // links is the list of cid links of the blocks containing encryption keys. + repeated bytes links = 1; + // ephemeralPublicKey is an ephemeral public of the requesting peer for deriving shared secret + bytes ephemeralPublicKey = 2; +} + +// FetchEncryptionKeyReply is a response to FetchEncryptionKeyRequest request +// by a peer that holds the requested doc encryption key. +message FetchEncryptionKeyReply { + // links is the list of cid links of the blocks containing encryption keys. + repeated bytes links = 1; + // blocks is the list of blocks containing encryption keys. The order of blocks should match the order of links. + // Every block is encrypted and contains a nonce. + repeated bytes blocks = 2; + // ephemeralPublicKey is an ephemeral public of the responding peer for deriving shared secret + bytes ephemeralPublicKey = 3; +} + message GetHeadLogRequest {} message PushLogReply {} diff --git a/net/pb/net_grpc.pb.go b/net/pb/net_grpc.pb.go index 75ae790ab6..84564d6bec 100644 --- a/net/pb/net_grpc.pb.go +++ b/net/pb/net_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.3.0 -// - protoc v4.25.1 +// - protoc-gen-go-grpc v1.4.0 +// - protoc v5.27.1 // source: net.proto package net_pb @@ -15,8 +15,8 @@ import ( // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.32.0 or later. -const _ = grpc.SupportPackageIsVersion7 +// Requires gRPC-Go v1.62.0 or later. +const _ = grpc.SupportPackageIsVersion8 const ( Service_GetDocGraph_FullMethodName = "/net.pb.Service/GetDocGraph" @@ -29,6 +29,8 @@ const ( // ServiceClient is the client API for Service service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +// +// Service is the peer-to-peer network API for document sync type ServiceClient interface { // GetDocGraph from this peer. GetDocGraph(ctx context.Context, in *GetDocGraphRequest, opts ...grpc.CallOption) (*GetDocGraphReply, error) @@ -51,8 +53,9 @@ func NewServiceClient(cc grpc.ClientConnInterface) ServiceClient { } func (c *serviceClient) GetDocGraph(ctx context.Context, in *GetDocGraphRequest, opts ...grpc.CallOption) (*GetDocGraphReply, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetDocGraphReply) - err := c.cc.Invoke(ctx, Service_GetDocGraph_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, Service_GetDocGraph_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -60,8 +63,9 @@ func (c *serviceClient) GetDocGraph(ctx context.Context, in *GetDocGraphRequest, } func (c *serviceClient) PushDocGraph(ctx context.Context, in *PushDocGraphRequest, opts ...grpc.CallOption) (*PushDocGraphReply, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(PushDocGraphReply) - err := c.cc.Invoke(ctx, Service_PushDocGraph_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, Service_PushDocGraph_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -69,8 +73,9 @@ func (c *serviceClient) PushDocGraph(ctx context.Context, in *PushDocGraphReques } func (c *serviceClient) GetLog(ctx context.Context, in *GetLogRequest, opts ...grpc.CallOption) (*GetLogReply, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetLogReply) - err := c.cc.Invoke(ctx, Service_GetLog_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, Service_GetLog_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -78,8 +83,9 @@ func (c *serviceClient) GetLog(ctx context.Context, in *GetLogRequest, opts ...g } func (c *serviceClient) PushLog(ctx context.Context, in *PushLogRequest, opts ...grpc.CallOption) (*PushLogReply, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(PushLogReply) - err := c.cc.Invoke(ctx, Service_PushLog_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, Service_PushLog_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -87,8 +93,9 @@ func (c *serviceClient) PushLog(ctx context.Context, in *PushLogRequest, opts .. } func (c *serviceClient) GetHeadLog(ctx context.Context, in *GetHeadLogRequest, opts ...grpc.CallOption) (*GetHeadLogReply, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetHeadLogReply) - err := c.cc.Invoke(ctx, Service_GetHeadLog_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, Service_GetHeadLog_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -98,6 +105,8 @@ func (c *serviceClient) GetHeadLog(ctx context.Context, in *GetHeadLogRequest, o // ServiceServer is the server API for Service service. // All implementations must embed UnimplementedServiceServer // for forward compatibility +// +// Service is the peer-to-peer network API for document sync type ServiceServer interface { // GetDocGraph from this peer. GetDocGraph(context.Context, *GetDocGraphRequest) (*GetDocGraphReply, error) diff --git a/net/pb/net_vtproto.pb.go b/net/pb/net_vtproto.pb.go index 2bae8f83f3..bf1c93e8e8 100644 --- a/net/pb/net_vtproto.pb.go +++ b/net/pb/net_vtproto.pb.go @@ -1,14 +1,14 @@ // Code generated by protoc-gen-go-vtproto. DO NOT EDIT. -// protoc-gen-go-vtproto version: v0.5.0 +// protoc-gen-go-vtproto version: v0.6.0 // source: net.proto package net_pb import ( fmt "fmt" + protohelpers "github.com/planetscale/vtprotobuf/protohelpers" protoimpl "google.golang.org/protobuf/runtime/protoimpl" io "io" - bits "math/bits" ) const ( @@ -18,7 +18,7 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -func (m *Document_Log) MarshalVT() (dAtA []byte, err error) { +func (m *Log) MarshalVT() (dAtA []byte, err error) { if m == nil { return nil, nil } @@ -31,12 +31,12 @@ func (m *Document_Log) MarshalVT() (dAtA []byte, err error) { return dAtA[:n], nil } -func (m *Document_Log) MarshalToVT(dAtA []byte) (int, error) { +func (m *Log) MarshalToVT(dAtA []byte) (int, error) { size := m.SizeVT() return m.MarshalToSizedBufferVT(dAtA[:size]) } -func (m *Document_Log) MarshalToSizedBufferVT(dAtA []byte) (int, error) { +func (m *Log) MarshalToSizedBufferVT(dAtA []byte) (int, error) { if m == nil { return 0, nil } @@ -51,54 +51,7 @@ func (m *Document_Log) MarshalToSizedBufferVT(dAtA []byte) (int, error) { if len(m.Block) > 0 { i -= len(m.Block) copy(dAtA[i:], m.Block) - i = encodeVarint(dAtA, i, uint64(len(m.Block))) - i-- - dAtA[i] = 0xa - } - return len(dAtA) - i, nil -} - -func (m *Document) MarshalVT() (dAtA []byte, err error) { - if m == nil { - return nil, nil - } - size := m.SizeVT() - dAtA = make([]byte, size) - n, err := m.MarshalToSizedBufferVT(dAtA[:size]) - if err != nil { - return nil, err - } - return dAtA[:n], nil -} - -func (m *Document) MarshalToVT(dAtA []byte) (int, error) { - size := m.SizeVT() - return m.MarshalToSizedBufferVT(dAtA[:size]) -} - -func (m *Document) MarshalToSizedBufferVT(dAtA []byte) (int, error) { - if m == nil { - return 0, nil - } - i := len(dAtA) - _ = i - var l int - _ = l - if m.unknownFields != nil { - i -= len(m.unknownFields) - copy(dAtA[i:], m.unknownFields) - } - if len(m.Head) > 0 { - i -= len(m.Head) - copy(dAtA[i:], m.Head) - i = encodeVarint(dAtA, i, uint64(len(m.Head))) - i-- - dAtA[i] = 0x22 - } - if len(m.DocID) > 0 { - i -= len(m.DocID) - copy(dAtA[i:], m.DocID) - i = encodeVarint(dAtA, i, uint64(len(m.DocID))) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.Block))) i-- dAtA[i] = 0xa } @@ -339,35 +292,35 @@ func (m *PushLogRequest_Body) MarshalToSizedBufferVT(dAtA []byte) (int, error) { return 0, err } i -= size - i = encodeVarint(dAtA, i, uint64(size)) + i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x32 } if len(m.Creator) > 0 { i -= len(m.Creator) copy(dAtA[i:], m.Creator) - i = encodeVarint(dAtA, i, uint64(len(m.Creator))) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.Creator))) i-- dAtA[i] = 0x22 } if len(m.SchemaRoot) > 0 { i -= len(m.SchemaRoot) copy(dAtA[i:], m.SchemaRoot) - i = encodeVarint(dAtA, i, uint64(len(m.SchemaRoot))) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.SchemaRoot))) i-- dAtA[i] = 0x1a } if len(m.Cid) > 0 { i -= len(m.Cid) copy(dAtA[i:], m.Cid) - i = encodeVarint(dAtA, i, uint64(len(m.Cid))) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.Cid))) i-- dAtA[i] = 0x12 } if len(m.DocID) > 0 { i -= len(m.DocID) copy(dAtA[i:], m.DocID) - i = encodeVarint(dAtA, i, uint64(len(m.DocID))) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.DocID))) i-- dAtA[i] = 0xa } @@ -410,13 +363,120 @@ func (m *PushLogRequest) MarshalToSizedBufferVT(dAtA []byte) (int, error) { return 0, err } i -= size - i = encodeVarint(dAtA, i, uint64(size)) + i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0xa } return len(dAtA) - i, nil } +func (m *FetchEncryptionKeyRequest) MarshalVT() (dAtA []byte, err error) { + if m == nil { + return nil, nil + } + size := m.SizeVT() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBufferVT(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *FetchEncryptionKeyRequest) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *FetchEncryptionKeyRequest) MarshalToSizedBufferVT(dAtA []byte) (int, error) { + if m == nil { + return 0, nil + } + i := len(dAtA) + _ = i + var l int + _ = l + if m.unknownFields != nil { + i -= len(m.unknownFields) + copy(dAtA[i:], m.unknownFields) + } + if len(m.EphemeralPublicKey) > 0 { + i -= len(m.EphemeralPublicKey) + copy(dAtA[i:], m.EphemeralPublicKey) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.EphemeralPublicKey))) + i-- + dAtA[i] = 0x12 + } + if len(m.Links) > 0 { + for iNdEx := len(m.Links) - 1; iNdEx >= 0; iNdEx-- { + i -= len(m.Links[iNdEx]) + copy(dAtA[i:], m.Links[iNdEx]) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.Links[iNdEx]))) + i-- + dAtA[i] = 0xa + } + } + return len(dAtA) - i, nil +} + +func (m *FetchEncryptionKeyReply) MarshalVT() (dAtA []byte, err error) { + if m == nil { + return nil, nil + } + size := m.SizeVT() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBufferVT(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *FetchEncryptionKeyReply) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *FetchEncryptionKeyReply) MarshalToSizedBufferVT(dAtA []byte) (int, error) { + if m == nil { + return 0, nil + } + i := len(dAtA) + _ = i + var l int + _ = l + if m.unknownFields != nil { + i -= len(m.unknownFields) + copy(dAtA[i:], m.unknownFields) + } + if len(m.EphemeralPublicKey) > 0 { + i -= len(m.EphemeralPublicKey) + copy(dAtA[i:], m.EphemeralPublicKey) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.EphemeralPublicKey))) + i-- + dAtA[i] = 0x1a + } + if len(m.Blocks) > 0 { + for iNdEx := len(m.Blocks) - 1; iNdEx >= 0; iNdEx-- { + i -= len(m.Blocks[iNdEx]) + copy(dAtA[i:], m.Blocks[iNdEx]) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.Blocks[iNdEx]))) + i-- + dAtA[i] = 0x12 + } + } + if len(m.Links) > 0 { + for iNdEx := len(m.Links) - 1; iNdEx >= 0; iNdEx-- { + i -= len(m.Links[iNdEx]) + copy(dAtA[i:], m.Links[iNdEx]) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.Links[iNdEx]))) + i-- + dAtA[i] = 0xa + } + } + return len(dAtA) - i, nil +} + func (m *GetHeadLogRequest) MarshalVT() (dAtA []byte, err error) { if m == nil { return nil, nil @@ -516,18 +576,7 @@ func (m *GetHeadLogReply) MarshalToSizedBufferVT(dAtA []byte) (int, error) { return len(dAtA) - i, nil } -func encodeVarint(dAtA []byte, offset int, v uint64) int { - offset -= sov(v) - base := offset - for v >= 1<<7 { - dAtA[offset] = uint8(v&0x7f | 0x80) - v >>= 7 - offset++ - } - dAtA[offset] = uint8(v) - return base -} -func (m *Document_Log) SizeVT() (n int) { +func (m *Log) SizeVT() (n int) { if m == nil { return 0 } @@ -535,25 +584,7 @@ func (m *Document_Log) SizeVT() (n int) { _ = l l = len(m.Block) if l > 0 { - n += 1 + l + sov(uint64(l)) - } - n += len(m.unknownFields) - return n -} - -func (m *Document) SizeVT() (n int) { - if m == nil { - return 0 - } - var l int - _ = l - l = len(m.DocID) - if l > 0 { - n += 1 + l + sov(uint64(l)) - } - l = len(m.Head) - if l > 0 { - n += 1 + l + sov(uint64(l)) + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) } n += len(m.unknownFields) return n @@ -627,23 +658,23 @@ func (m *PushLogRequest_Body) SizeVT() (n int) { _ = l l = len(m.DocID) if l > 0 { - n += 1 + l + sov(uint64(l)) + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) } l = len(m.Cid) if l > 0 { - n += 1 + l + sov(uint64(l)) + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) } l = len(m.SchemaRoot) if l > 0 { - n += 1 + l + sov(uint64(l)) + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) } l = len(m.Creator) if l > 0 { - n += 1 + l + sov(uint64(l)) + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) } if m.Log != nil { l = m.Log.SizeVT() - n += 1 + l + sov(uint64(l)) + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) } n += len(m.unknownFields) return n @@ -657,7 +688,53 @@ func (m *PushLogRequest) SizeVT() (n int) { _ = l if m.Body != nil { l = m.Body.SizeVT() - n += 1 + l + sov(uint64(l)) + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + n += len(m.unknownFields) + return n +} + +func (m *FetchEncryptionKeyRequest) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if len(m.Links) > 0 { + for _, b := range m.Links { + l = len(b) + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + } + l = len(m.EphemeralPublicKey) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + n += len(m.unknownFields) + return n +} + +func (m *FetchEncryptionKeyReply) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if len(m.Links) > 0 { + for _, b := range m.Links { + l = len(b) + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + } + if len(m.Blocks) > 0 { + for _, b := range m.Blocks { + l = len(b) + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + } + l = len(m.EphemeralPublicKey) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) } n += len(m.unknownFields) return n @@ -693,13 +770,7 @@ func (m *GetHeadLogReply) SizeVT() (n int) { return n } -func sov(x uint64) (n int) { - return (bits.Len64(x|1) + 6) / 7 -} -func soz(x uint64) (n int) { - return sov(uint64((x << 1) ^ uint64((int64(x) >> 63)))) -} -func (m *Document_Log) UnmarshalVT(dAtA []byte) error { +func (m *Log) UnmarshalVT(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -707,7 +778,7 @@ func (m *Document_Log) UnmarshalVT(dAtA []byte) error { var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { - return ErrIntOverflow + return protohelpers.ErrIntOverflow } if iNdEx >= l { return io.ErrUnexpectedEOF @@ -722,10 +793,10 @@ func (m *Document_Log) UnmarshalVT(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: Document_Log: wiretype end group for non-group") + return fmt.Errorf("proto: Log: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: Document_Log: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: Log: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: @@ -735,7 +806,7 @@ func (m *Document_Log) UnmarshalVT(dAtA []byte) error { var byteLen int for shift := uint(0); ; shift += 7 { if shift >= 64 { - return ErrIntOverflow + return protohelpers.ErrIntOverflow } if iNdEx >= l { return io.ErrUnexpectedEOF @@ -748,11 +819,11 @@ func (m *Document_Log) UnmarshalVT(dAtA []byte) error { } } if byteLen < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } postIndex := iNdEx + byteLen if postIndex < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } if postIndex > l { return io.ErrUnexpectedEOF @@ -764,131 +835,12 @@ func (m *Document_Log) UnmarshalVT(dAtA []byte) error { iNdEx = postIndex default: iNdEx = preIndex - skippy, err := skip(dAtA[iNdEx:]) + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLength - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) - iNdEx += skippy - } - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} -func (m *Document) UnmarshalVT(dAtA []byte) error { - l := len(dAtA) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflow - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - wire |= uint64(b&0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: Document: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: Document: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field DocID", wireType) - } - var byteLen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflow - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - byteLen |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - if byteLen < 0 { - return ErrInvalidLength - } - postIndex := iNdEx + byteLen - if postIndex < 0 { - return ErrInvalidLength - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.DocID = append(m.DocID[:0], dAtA[iNdEx:postIndex]...) - if m.DocID == nil { - m.DocID = []byte{} - } - iNdEx = postIndex - case 4: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Head", wireType) - } - var byteLen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflow - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - byteLen |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - if byteLen < 0 { - return ErrInvalidLength - } - postIndex := iNdEx + byteLen - if postIndex < 0 { - return ErrInvalidLength - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.Head = append(m.Head[:0], dAtA[iNdEx:postIndex]...) - if m.Head == nil { - m.Head = []byte{} - } - iNdEx = postIndex - default: - iNdEx = preIndex - skippy, err := skip(dAtA[iNdEx:]) - if err != nil { - return err - } - if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF @@ -911,7 +863,7 @@ func (m *GetDocGraphRequest) UnmarshalVT(dAtA []byte) error { var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { - return ErrIntOverflow + return protohelpers.ErrIntOverflow } if iNdEx >= l { return io.ErrUnexpectedEOF @@ -934,12 +886,12 @@ func (m *GetDocGraphRequest) UnmarshalVT(dAtA []byte) error { switch fieldNum { default: iNdEx = preIndex - skippy, err := skip(dAtA[iNdEx:]) + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF @@ -962,7 +914,7 @@ func (m *GetDocGraphReply) UnmarshalVT(dAtA []byte) error { var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { - return ErrIntOverflow + return protohelpers.ErrIntOverflow } if iNdEx >= l { return io.ErrUnexpectedEOF @@ -985,12 +937,12 @@ func (m *GetDocGraphReply) UnmarshalVT(dAtA []byte) error { switch fieldNum { default: iNdEx = preIndex - skippy, err := skip(dAtA[iNdEx:]) + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF @@ -1013,7 +965,7 @@ func (m *PushDocGraphRequest) UnmarshalVT(dAtA []byte) error { var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { - return ErrIntOverflow + return protohelpers.ErrIntOverflow } if iNdEx >= l { return io.ErrUnexpectedEOF @@ -1036,12 +988,12 @@ func (m *PushDocGraphRequest) UnmarshalVT(dAtA []byte) error { switch fieldNum { default: iNdEx = preIndex - skippy, err := skip(dAtA[iNdEx:]) + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF @@ -1064,7 +1016,7 @@ func (m *PushDocGraphReply) UnmarshalVT(dAtA []byte) error { var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { - return ErrIntOverflow + return protohelpers.ErrIntOverflow } if iNdEx >= l { return io.ErrUnexpectedEOF @@ -1087,12 +1039,12 @@ func (m *PushDocGraphReply) UnmarshalVT(dAtA []byte) error { switch fieldNum { default: iNdEx = preIndex - skippy, err := skip(dAtA[iNdEx:]) + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF @@ -1115,7 +1067,7 @@ func (m *GetLogRequest) UnmarshalVT(dAtA []byte) error { var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { - return ErrIntOverflow + return protohelpers.ErrIntOverflow } if iNdEx >= l { return io.ErrUnexpectedEOF @@ -1138,12 +1090,12 @@ func (m *GetLogRequest) UnmarshalVT(dAtA []byte) error { switch fieldNum { default: iNdEx = preIndex - skippy, err := skip(dAtA[iNdEx:]) + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF @@ -1166,7 +1118,7 @@ func (m *GetLogReply) UnmarshalVT(dAtA []byte) error { var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { - return ErrIntOverflow + return protohelpers.ErrIntOverflow } if iNdEx >= l { return io.ErrUnexpectedEOF @@ -1189,12 +1141,12 @@ func (m *GetLogReply) UnmarshalVT(dAtA []byte) error { switch fieldNum { default: iNdEx = preIndex - skippy, err := skip(dAtA[iNdEx:]) + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF @@ -1217,7 +1169,7 @@ func (m *PushLogRequest_Body) UnmarshalVT(dAtA []byte) error { var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { - return ErrIntOverflow + return protohelpers.ErrIntOverflow } if iNdEx >= l { return io.ErrUnexpectedEOF @@ -1245,7 +1197,7 @@ func (m *PushLogRequest_Body) UnmarshalVT(dAtA []byte) error { var byteLen int for shift := uint(0); ; shift += 7 { if shift >= 64 { - return ErrIntOverflow + return protohelpers.ErrIntOverflow } if iNdEx >= l { return io.ErrUnexpectedEOF @@ -1258,11 +1210,11 @@ func (m *PushLogRequest_Body) UnmarshalVT(dAtA []byte) error { } } if byteLen < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } postIndex := iNdEx + byteLen if postIndex < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } if postIndex > l { return io.ErrUnexpectedEOF @@ -1279,7 +1231,7 @@ func (m *PushLogRequest_Body) UnmarshalVT(dAtA []byte) error { var byteLen int for shift := uint(0); ; shift += 7 { if shift >= 64 { - return ErrIntOverflow + return protohelpers.ErrIntOverflow } if iNdEx >= l { return io.ErrUnexpectedEOF @@ -1292,11 +1244,11 @@ func (m *PushLogRequest_Body) UnmarshalVT(dAtA []byte) error { } } if byteLen < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } postIndex := iNdEx + byteLen if postIndex < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } if postIndex > l { return io.ErrUnexpectedEOF @@ -1313,7 +1265,7 @@ func (m *PushLogRequest_Body) UnmarshalVT(dAtA []byte) error { var byteLen int for shift := uint(0); ; shift += 7 { if shift >= 64 { - return ErrIntOverflow + return protohelpers.ErrIntOverflow } if iNdEx >= l { return io.ErrUnexpectedEOF @@ -1326,11 +1278,11 @@ func (m *PushLogRequest_Body) UnmarshalVT(dAtA []byte) error { } } if byteLen < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } postIndex := iNdEx + byteLen if postIndex < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } if postIndex > l { return io.ErrUnexpectedEOF @@ -1347,7 +1299,7 @@ func (m *PushLogRequest_Body) UnmarshalVT(dAtA []byte) error { var stringLen uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { - return ErrIntOverflow + return protohelpers.ErrIntOverflow } if iNdEx >= l { return io.ErrUnexpectedEOF @@ -1361,11 +1313,11 @@ func (m *PushLogRequest_Body) UnmarshalVT(dAtA []byte) error { } intStringLen := int(stringLen) if intStringLen < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } postIndex := iNdEx + intStringLen if postIndex < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } if postIndex > l { return io.ErrUnexpectedEOF @@ -1379,7 +1331,7 @@ func (m *PushLogRequest_Body) UnmarshalVT(dAtA []byte) error { var msglen int for shift := uint(0); ; shift += 7 { if shift >= 64 { - return ErrIntOverflow + return protohelpers.ErrIntOverflow } if iNdEx >= l { return io.ErrUnexpectedEOF @@ -1392,17 +1344,17 @@ func (m *PushLogRequest_Body) UnmarshalVT(dAtA []byte) error { } } if msglen < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } postIndex := iNdEx + msglen if postIndex < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } if postIndex > l { return io.ErrUnexpectedEOF } if m.Log == nil { - m.Log = &Document_Log{} + m.Log = &Log{} } if err := m.Log.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { return err @@ -1410,12 +1362,12 @@ func (m *PushLogRequest_Body) UnmarshalVT(dAtA []byte) error { iNdEx = postIndex default: iNdEx = preIndex - skippy, err := skip(dAtA[iNdEx:]) + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF @@ -1438,7 +1390,7 @@ func (m *PushLogRequest) UnmarshalVT(dAtA []byte) error { var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { - return ErrIntOverflow + return protohelpers.ErrIntOverflow } if iNdEx >= l { return io.ErrUnexpectedEOF @@ -1466,7 +1418,7 @@ func (m *PushLogRequest) UnmarshalVT(dAtA []byte) error { var msglen int for shift := uint(0); ; shift += 7 { if shift >= 64 { - return ErrIntOverflow + return protohelpers.ErrIntOverflow } if iNdEx >= l { return io.ErrUnexpectedEOF @@ -1479,11 +1431,11 @@ func (m *PushLogRequest) UnmarshalVT(dAtA []byte) error { } } if msglen < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } postIndex := iNdEx + msglen if postIndex < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } if postIndex > l { return io.ErrUnexpectedEOF @@ -1497,12 +1449,12 @@ func (m *PushLogRequest) UnmarshalVT(dAtA []byte) error { iNdEx = postIndex default: iNdEx = preIndex - skippy, err := skip(dAtA[iNdEx:]) + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF @@ -1517,7 +1469,7 @@ func (m *PushLogRequest) UnmarshalVT(dAtA []byte) error { } return nil } -func (m *GetHeadLogRequest) UnmarshalVT(dAtA []byte) error { +func (m *FetchEncryptionKeyRequest) UnmarshalVT(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -1525,7 +1477,7 @@ func (m *GetHeadLogRequest) UnmarshalVT(dAtA []byte) error { var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { - return ErrIntOverflow + return protohelpers.ErrIntOverflow } if iNdEx >= l { return io.ErrUnexpectedEOF @@ -1540,20 +1492,86 @@ func (m *GetHeadLogRequest) UnmarshalVT(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: GetHeadLogRequest: wiretype end group for non-group") + return fmt.Errorf("proto: FetchEncryptionKeyRequest: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: GetHeadLogRequest: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: FetchEncryptionKeyRequest: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Links", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Links = append(m.Links, make([]byte, postIndex-iNdEx)) + copy(m.Links[len(m.Links)-1], dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field EphemeralPublicKey", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.EphemeralPublicKey = append(m.EphemeralPublicKey[:0], dAtA[iNdEx:postIndex]...) + if m.EphemeralPublicKey == nil { + m.EphemeralPublicKey = []byte{} + } + iNdEx = postIndex default: iNdEx = preIndex - skippy, err := skip(dAtA[iNdEx:]) + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF @@ -1568,7 +1586,7 @@ func (m *GetHeadLogRequest) UnmarshalVT(dAtA []byte) error { } return nil } -func (m *PushLogReply) UnmarshalVT(dAtA []byte) error { +func (m *FetchEncryptionKeyReply) UnmarshalVT(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -1576,7 +1594,7 @@ func (m *PushLogReply) UnmarshalVT(dAtA []byte) error { var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { - return ErrIntOverflow + return protohelpers.ErrIntOverflow } if iNdEx >= l { return io.ErrUnexpectedEOF @@ -1591,20 +1609,118 @@ func (m *PushLogReply) UnmarshalVT(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: PushLogReply: wiretype end group for non-group") + return fmt.Errorf("proto: FetchEncryptionKeyReply: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: PushLogReply: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: FetchEncryptionKeyReply: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Links", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Links = append(m.Links, make([]byte, postIndex-iNdEx)) + copy(m.Links[len(m.Links)-1], dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Blocks", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Blocks = append(m.Blocks, make([]byte, postIndex-iNdEx)) + copy(m.Blocks[len(m.Blocks)-1], dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field EphemeralPublicKey", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.EphemeralPublicKey = append(m.EphemeralPublicKey[:0], dAtA[iNdEx:postIndex]...) + if m.EphemeralPublicKey == nil { + m.EphemeralPublicKey = []byte{} + } + iNdEx = postIndex default: iNdEx = preIndex - skippy, err := skip(dAtA[iNdEx:]) + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF @@ -1619,7 +1735,7 @@ func (m *PushLogReply) UnmarshalVT(dAtA []byte) error { } return nil } -func (m *GetHeadLogReply) UnmarshalVT(dAtA []byte) error { +func (m *GetHeadLogRequest) UnmarshalVT(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -1627,7 +1743,7 @@ func (m *GetHeadLogReply) UnmarshalVT(dAtA []byte) error { var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { - return ErrIntOverflow + return protohelpers.ErrIntOverflow } if iNdEx >= l { return io.ErrUnexpectedEOF @@ -1642,20 +1758,20 @@ func (m *GetHeadLogReply) UnmarshalVT(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: GetHeadLogReply: wiretype end group for non-group") + return fmt.Errorf("proto: GetHeadLogRequest: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: GetHeadLogReply: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: GetHeadLogRequest: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { default: iNdEx = preIndex - skippy, err := skip(dAtA[iNdEx:]) + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLength + return protohelpers.ErrInvalidLength } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF @@ -1670,88 +1786,105 @@ func (m *GetHeadLogReply) UnmarshalVT(dAtA []byte) error { } return nil } - -func skip(dAtA []byte) (n int, err error) { +func (m *PushLogReply) UnmarshalVT(dAtA []byte) error { l := len(dAtA) iNdEx := 0 - depth := 0 for iNdEx < l { + preIndex := iNdEx var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { - return 0, ErrIntOverflow + return protohelpers.ErrIntOverflow } if iNdEx >= l { - return 0, io.ErrUnexpectedEOF + return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ - wire |= (uint64(b) & 0x7F) << shift + wire |= uint64(b&0x7F) << shift if b < 0x80 { break } } + fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) - switch wireType { - case 0: - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflow - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - iNdEx++ - if dAtA[iNdEx-1] < 0x80 { - break - } + if wireType == 4 { + return fmt.Errorf("proto: PushLogReply: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: PushLogReply: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + default: + iNdEx = preIndex + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) + if err != nil { + return err } - case 1: - iNdEx += 8 - case 2: - var length int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflow - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - length |= (int(b) & 0x7F) << shift - if b < 0x80 { - break - } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return protohelpers.ErrInvalidLength } - if length < 0 { - return 0, ErrInvalidLength + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF } - iNdEx += length - case 3: - depth++ - case 4: - if depth == 0 { - return 0, ErrUnexpectedEndOfGroup + m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *GetHeadLogReply) UnmarshalVT(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break } - depth-- - case 5: - iNdEx += 4 - default: - return 0, fmt.Errorf("proto: illegal wireType %d", wireType) } - if iNdEx < 0 { - return 0, ErrInvalidLength + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: GetHeadLogReply: wiretype end group for non-group") } - if depth == 0 { - return iNdEx, nil + if fieldNum <= 0 { + return fmt.Errorf("proto: GetHeadLogReply: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + default: + iNdEx = preIndex + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return protohelpers.ErrInvalidLength + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy } } - return 0, io.ErrUnexpectedEOF -} -var ( - ErrInvalidLength = fmt.Errorf("proto: negative length found during unmarshaling") - ErrIntOverflow = fmt.Errorf("proto: integer overflow") - ErrUnexpectedEndOfGroup = fmt.Errorf("proto: unexpected end of group") -) + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} diff --git a/net/peer.go b/net/peer.go index 301c080edb..24976ed388 100644 --- a/net/peer.go +++ b/net/peer.go @@ -21,7 +21,6 @@ import ( "github.com/ipfs/boxo/bitswap/network" "github.com/ipfs/boxo/blockservice" "github.com/ipfs/boxo/bootstrap" - exchange "github.com/ipfs/boxo/exchange" blocks "github.com/ipfs/go-block-format" "github.com/ipfs/go-cid" gostream "github.com/libp2p/go-libp2p-gostream" @@ -46,6 +45,7 @@ import ( // to the underlying DefraDB instance. type Peer struct { blockstore datastore.Blockstore + encstore datastore.Blockstore bus *event.Bus updateSub *event.Subscription @@ -61,7 +61,6 @@ type Peer struct { p2pRPC *grpc.Server // rpc server over the P2P network // peer DAG service - exch exchange.Interface bserv blockservice.BlockService bootCloser io.Closer @@ -71,6 +70,7 @@ type Peer struct { func NewPeer( ctx context.Context, blockstore datastore.Blockstore, + encstore datastore.Blockstore, bus *event.Bus, opts ...NodeOpt, ) (p *Peer, err error) { @@ -83,7 +83,7 @@ func NewPeer( } }() - if blockstore == nil { + if blockstore == nil || encstore == nil { return nil, ErrNilDB } @@ -120,12 +120,12 @@ func NewPeer( host: h, dht: ddht, blockstore: blockstore, + encstore: encstore, ctx: ctx, cancel: cancel, bus: bus, p2pRPC: grpc.NewServer(options.GRPCServerOptions...), bserv: blockservice.New(blockstore, bswap), - exch: bswap, } if options.EnablePubSub { @@ -151,7 +151,7 @@ func NewPeer( return nil, err } - p2plistener, err := gostream.Listen(h, corenet.Protocol) + p2pListener, err := gostream.Listen(h, corenet.Protocol) if err != nil { return nil, err } @@ -164,7 +164,7 @@ func NewPeer( // register the P2P gRPC server go func() { pb.RegisterServiceServer(p.p2pRPC, p.server) - if err := p.p2pRPC.Serve(p2plistener); err != nil && + if err := p.p2pRPC.Serve(p2pListener); err != nil && !errors.Is(err, grpc.ErrServerStopped) { log.ErrorE("Fatal P2P RPC server error", err) } @@ -270,7 +270,7 @@ func (p *Peer) RegisterNewDocument( schemaRoot string, ) error { // register topic - err := p.server.addPubSubTopic(docID.String(), !p.server.hasPubSubTopic(schemaRoot)) + err := p.server.addPubSubTopic(docID.String(), !p.server.hasPubSubTopic(schemaRoot), nil) if err != nil { log.ErrorE( "Failed to create new pubsub topic", @@ -287,7 +287,7 @@ func (p *Peer) RegisterNewDocument( Cid: c.Bytes(), SchemaRoot: []byte(schemaRoot), Creator: p.host.ID().String(), - Log: &pb.Document_Log{ + Log: &pb.Log{ Block: rawBlock, }, }, @@ -325,7 +325,7 @@ func (p *Peer) handleDocUpdateLog(evt event.Update) error { Cid: evt.Cid.Bytes(), SchemaRoot: []byte(evt.SchemaRoot), Creator: p.host.ID().String(), - Log: &pb.Document_Log{ + Log: &pb.Log{ Block: evt.Block, }, } @@ -408,3 +408,7 @@ func (p *Peer) PeerInfo() peer.AddrInfo { Addrs: p.host.Network().ListenAddresses(), } } + +func (p *Peer) Server() *server { + return p.server +} diff --git a/net/peer_test.go b/net/peer_test.go index 5322d32f6e..10af3a3ab4 100644 --- a/net/peer_test.go +++ b/net/peer_test.go @@ -73,6 +73,7 @@ func newTestPeer(ctx context.Context, t *testing.T) (client.DB, *Peer) { n, err := NewPeer( ctx, db.Blockstore(), + db.Encstore(), db.Events(), WithListenAddresses(randomMultiaddr), ) @@ -87,14 +88,14 @@ func TestNewPeer_NoError(t *testing.T) { db, err := db.NewDB(ctx, store, acp.NoACP, nil) require.NoError(t, err) defer db.Close() - p, err := NewPeer(ctx, db.Blockstore(), db.Events()) + p, err := NewPeer(ctx, db.Blockstore(), db.Encstore(), db.Events()) require.NoError(t, err) p.Close() } func TestNewPeer_NoDB_NilDBError(t *testing.T) { ctx := context.Background() - _, err := NewPeer(ctx, nil, nil, nil) + _, err := NewPeer(ctx, nil, nil, nil, nil) require.ErrorIs(t, err, ErrNilDB) } @@ -113,6 +114,7 @@ func TestStart_WithKnownPeer_NoError(t *testing.T) { n1, err := NewPeer( ctx, db1.Blockstore(), + db1.Encstore(), db1.Events(), WithListenAddresses("/ip4/127.0.0.1/tcp/0"), ) @@ -121,6 +123,7 @@ func TestStart_WithKnownPeer_NoError(t *testing.T) { n2, err := NewPeer( ctx, db2.Blockstore(), + db1.Encstore(), db2.Events(), WithListenAddresses("/ip4/127.0.0.1/tcp/0"), ) @@ -385,6 +388,7 @@ func TestNewPeer_WithEnableRelay_NoError(t *testing.T) { n, err := NewPeer( context.Background(), db.Blockstore(), + db.Encstore(), db.Events(), WithEnableRelay(true), ) @@ -402,6 +406,7 @@ func TestNewPeer_NoPubSub_NoError(t *testing.T) { n, err := NewPeer( context.Background(), db.Blockstore(), + db.Encstore(), db.Events(), WithEnablePubSub(false), ) @@ -420,6 +425,7 @@ func TestNewPeer_WithEnablePubSub_NoError(t *testing.T) { n, err := NewPeer( ctx, db.Blockstore(), + db.Encstore(), db.Events(), WithEnablePubSub(true), ) @@ -439,6 +445,7 @@ func TestNodeClose_NoError(t *testing.T) { n, err := NewPeer( context.Background(), db.Blockstore(), + db.Encstore(), db.Events(), ) require.NoError(t, err) @@ -455,6 +462,7 @@ func TestListenAddrs_WithListenAddresses_NoError(t *testing.T) { n, err := NewPeer( context.Background(), db.Blockstore(), + db.Encstore(), db.Events(), WithListenAddresses("/ip4/127.0.0.1/tcp/0"), ) @@ -473,6 +481,7 @@ func TestPeer_WithBootstrapPeers_NoError(t *testing.T) { n, err := NewPeer( context.Background(), db.Blockstore(), + db.Encstore(), db.Events(), WithBootstrapPeers("/ip4/127.0.0.1/tcp/6666/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ"), ) diff --git a/net/server.go b/net/server.go index 2f129d19cf..42ff15f5fb 100644 --- a/net/server.go +++ b/net/server.go @@ -34,7 +34,7 @@ import ( pb "github.com/sourcenetwork/defradb/net/pb" ) -// Server is the request/response instance for all P2P RPC communication. +// server is the request/response instance for all P2P RPC communication. // Implements gRPC server. See net/pb/net.proto for corresponding service definitions. // // Specifically, server handles the push/get request/response aspects of the RPC service @@ -144,9 +144,9 @@ func (s *server) PushLog(ctx context.Context, req *pb.PushLogRequest) (*pb.PushL corelog.Any("DocID", docID.String())) // Once processed, subscribe to the DocID topic on the pubsub network unless we already - // suscribe to the collection. + // subscribed to the collection. if !s.hasPubSubTopic(string(req.Body.SchemaRoot)) { - err = s.addPubSubTopic(docID.String(), true) + err = s.addPubSubTopic(docID.String(), true, nil) if err != nil { return nil, err } @@ -172,7 +172,9 @@ func (s *server) GetHeadLog( } // addPubSubTopic subscribes to a topic on the pubsub network -func (s *server) addPubSubTopic(topic string, subscribe bool) error { +// A custom message handler can be provided to handle incoming messages. If not provided, +// the default message handler will be used. +func (s *server) addPubSubTopic(topic string, subscribe bool, handler rpc.MessageHandler) error { if s.peer.ps == nil { return nil } @@ -200,8 +202,12 @@ func (s *server) addPubSubTopic(topic string, subscribe bool) error { return err } + if handler == nil { + handler = s.pubSubMessageHandler + } + t.SetEventHandler(s.pubSubEventHandler) - t.SetMessageHandler(s.pubSubMessageHandler) + t.SetMessageHandler(handler) s.topics[topic] = pubsubTopic{ Topic: t, subscribed: subscribe, @@ -209,6 +215,10 @@ func (s *server) addPubSubTopic(topic string, subscribe bool) error { return nil } +func (s *server) AddPubSubTopic(topicName string, handler rpc.MessageHandler) error { + return s.addPubSubTopic(topicName, true, handler) +} + // hasPubSubTopic checks if we are subscribed to a topic. func (s *server) hasPubSubTopic(topic string) bool { s.mu.Lock() @@ -269,7 +279,7 @@ func (s *server) publishLog(ctx context.Context, topic string, req *pb.PushLogRe t, ok := s.topics[topic] s.mu.Unlock() if !ok { - err := s.addPubSubTopic(topic, false) + err := s.addPubSubTopic(topic, false, nil) if err != nil { return errors.Wrap(fmt.Sprintf("failed to created single use topic %s", topic), err) } @@ -278,7 +288,7 @@ func (s *server) publishLog(ctx context.Context, topic string, req *pb.PushLogRe data, err := req.MarshalVT() if err != nil { - return errors.Wrap("failed marshling pubsub message", err) + return errors.Wrap("failed to marshal pubsub message", err) } _, err = t.Publish(ctx, data, rpc.WithIgnoreResponse(true)) @@ -347,7 +357,7 @@ func peerIDFromContext(ctx context.Context) (libpeer.ID, error) { func (s *server) updatePubSubTopics(evt event.P2PTopic) { for _, topic := range evt.ToAdd { - err := s.addPubSubTopic(topic, true) + err := s.addPubSubTopic(topic, true, nil) if err != nil { log.ErrorE("Failed to add pubsub topic.", err) } @@ -409,3 +419,17 @@ func (s *server) updateReplicators(evt event.Replicator) { } s.peer.bus.Publish(event.NewMessage(event.ReplicatorCompletedName, nil)) } + +func (s *server) SendPubSubMessage( + ctx context.Context, + topic string, + data []byte, +) (<-chan rpc.Response, error) { + s.mu.Lock() + t, ok := s.topics[topic] + s.mu.Unlock() + if !ok { + return nil, NewErrTopicDoesNotExist(topic) + } + return t.Publish(ctx, data) +} diff --git a/net/server_test.go b/net/server_test.go index 0e23e3b019..11a13604b1 100644 --- a/net/server_test.go +++ b/net/server_test.go @@ -75,7 +75,7 @@ func TestGetHeadLog(t *testing.T) { } func getHead(ctx context.Context, db client.DB, docID client.DocID) (cid.Cid, error) { - prefix := core.DataStoreKeyFromDocID(docID).ToHeadStoreKey().WithFieldId(core.COMPOSITE_NAMESPACE).ToString() + prefix := core.DataStoreKeyFromDocID(docID).ToHeadStoreKey().WithFieldID(core.COMPOSITE_NAMESPACE).ToString() results, err := db.Headstore().Query(ctx, query.Query{Prefix: prefix}) if err != nil { return cid.Undef, err @@ -132,7 +132,7 @@ func TestPushLog(t *testing.T) { Cid: headCID.Bytes(), SchemaRoot: []byte(col.SchemaRoot()), Creator: p.PeerID().String(), - Log: &net_pb.Document_Log{ + Log: &net_pb.Log{ Block: b, }, }, diff --git a/node/node.go b/node/node.go index 5660d0d77c..d5e62bc1bb 100644 --- a/node/node.go +++ b/node/node.go @@ -17,10 +17,12 @@ import ( gohttp "net/http" "github.com/sourcenetwork/corelog" + "github.com/sourcenetwork/immutable" "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/http" "github.com/sourcenetwork/defradb/internal/db" + "github.com/sourcenetwork/defradb/internal/kms" "github.com/sourcenetwork/defradb/net" ) @@ -42,6 +44,7 @@ type Options struct { disableP2P bool disableAPI bool enableDevelopment bool + kmsType immutable.Option[kms.ServiceType] } // DefaultOptions returns options with default settings. @@ -66,6 +69,12 @@ func WithDisableAPI(disable bool) NodeOpt { } } +func WithKMS(kms kms.ServiceType) NodeOpt { + return func(o *Options) { + o.kmsType = immutable.Some(kms) + } +} + // WithEnableDevelopment sets the enable development mode flag. func WithEnableDevelopment(enable bool) NodeOpt { return func(o *Options) { @@ -75,9 +84,10 @@ func WithEnableDevelopment(enable bool) NodeOpt { // Node is a DefraDB instance with optional sub-systems. type Node struct { - DB client.DB - Peer *net.Peer - Server *http.Server + DB client.DB + Peer *net.Peer + Server *http.Server + kmsService kms.Service options *Options dbOpts []db.Option @@ -141,10 +151,25 @@ func (n *Node) Start(ctx context.Context) error { if !n.options.disableP2P { // setup net node - n.Peer, err = net.NewPeer(ctx, n.DB.Blockstore(), n.DB.Events(), n.netOpts...) + n.Peer, err = net.NewPeer(ctx, n.DB.Blockstore(), n.DB.Encstore(), n.DB.Events(), n.netOpts...) if err != nil { return err } + if n.options.kmsType.HasValue() { + switch n.options.kmsType.Value() { + case kms.PubSubServiceType: + n.kmsService, err = kms.NewPubSubService( + ctx, + n.Peer.PeerID(), + n.Peer.Server(), + n.DB.Events(), + n.DB.Encstore(), + ) + } + if err != nil { + return err + } + } } if !n.options.disableAPI { diff --git a/tests/bench/query/planner/utils.go b/tests/bench/query/planner/utils.go index 43669fa53f..b9e077867b 100644 --- a/tests/bench/query/planner/utils.go +++ b/tests/bench/query/planner/utils.go @@ -135,7 +135,7 @@ type dummyTxn struct{} func (*dummyTxn) Rootstore() datastore.DSReaderWriter { return nil } func (*dummyTxn) Datastore() datastore.DSReaderWriter { return nil } -func (*dummyTxn) Encstore() datastore.DSReaderWriter { return nil } +func (*dummyTxn) Encstore() datastore.Blockstore { return nil } func (*dummyTxn) Headstore() datastore.DSReaderWriter { return nil } func (*dummyTxn) Peerstore() datastore.DSBatching { return nil } func (*dummyTxn) Blockstore() datastore.Blockstore { return nil } diff --git a/tests/clients/cli/wrapper.go b/tests/clients/cli/wrapper.go index fbfc0e5e6a..7a2f28fd4a 100644 --- a/tests/clients/cli/wrapper.go +++ b/tests/clients/cli/wrapper.go @@ -539,6 +539,10 @@ func (w *Wrapper) Rootstore() datastore.Rootstore { return w.node.DB.Rootstore() } +func (w *Wrapper) Encstore() datastore.Blockstore { + return w.node.DB.Encstore() +} + func (w *Wrapper) Blockstore() datastore.Blockstore { return w.node.DB.Blockstore() } diff --git a/tests/clients/cli/wrapper_tx.go b/tests/clients/cli/wrapper_tx.go index 46aefd000d..e3bf41d818 100644 --- a/tests/clients/cli/wrapper_tx.go +++ b/tests/clients/cli/wrapper_tx.go @@ -75,7 +75,7 @@ func (w *Transaction) Datastore() datastore.DSReaderWriter { return w.tx.Datastore() } -func (w *Transaction) Encstore() datastore.DSReaderWriter { +func (w *Transaction) Encstore() datastore.Blockstore { return w.tx.Encstore() } diff --git a/tests/clients/http/wrapper.go b/tests/clients/http/wrapper.go index f931732f09..2b84bfc701 100644 --- a/tests/clients/http/wrapper.go +++ b/tests/clients/http/wrapper.go @@ -208,6 +208,10 @@ func (w *Wrapper) Rootstore() datastore.Rootstore { return w.node.DB.Rootstore() } +func (w *Wrapper) Encstore() datastore.Blockstore { + return w.node.DB.Encstore() +} + func (w *Wrapper) Blockstore() datastore.Blockstore { return w.node.DB.Blockstore() } diff --git a/tests/clients/http/wrapper_tx.go b/tests/clients/http/wrapper_tx.go index e4b838a2e9..baf841871a 100644 --- a/tests/clients/http/wrapper_tx.go +++ b/tests/clients/http/wrapper_tx.go @@ -69,7 +69,7 @@ func (w *TxWrapper) Datastore() datastore.DSReaderWriter { return w.server.Datastore() } -func (w *TxWrapper) Encstore() datastore.DSReaderWriter { +func (w *TxWrapper) Encstore() datastore.Blockstore { return w.server.Encstore() } diff --git a/tests/integration/acp.go b/tests/integration/acp.go index 44ac023bce..a6efd64110 100644 --- a/tests/integration/acp.go +++ b/tests/integration/acp.go @@ -57,6 +57,20 @@ var ( acpType ACPType ) +// KMSType is the type of KMS to use. +type KMSType string + +const ( + // NoneKMSType is the none KMS type. It is used to indicate that no KMS should be used. + NoneKMSType KMSType = "none" + // PubSubKMSType is the PubSub KMS type. + PubSubKMSType KMSType = "pubsub" +) + +func getKMSTypes() []KMSType { + return []KMSType{PubSubKMSType} +} + func init() { acpType = ACPType(os.Getenv(acpTypeEnvName)) if acpType == "" { diff --git a/tests/integration/assert_stack.go b/tests/integration/assert_stack.go new file mode 100644 index 0000000000..a341c96a31 --- /dev/null +++ b/tests/integration/assert_stack.go @@ -0,0 +1,57 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package tests + +import ( + "strconv" + "strings" +) + +// assertStack keeps track of the current assertion path. +// GraphQL response can be traversed by a key of a map and/or an index of an array. +// So whenever we have a mismatch in a large response, we can use this stack to find the exact path. +// Example output: "commits[2].links[1].cid" +type assertStack struct { + stack []string + isMap []bool +} + +func (a *assertStack) pushMap(key string) { + a.stack = append(a.stack, key) + a.isMap = append(a.isMap, true) +} + +func (a *assertStack) pushArray(index int) { + a.stack = append(a.stack, strconv.Itoa(index)) + a.isMap = append(a.isMap, false) +} + +func (a *assertStack) pop() { + a.stack = a.stack[:len(a.stack)-1] + a.isMap = a.isMap[:len(a.isMap)-1] +} + +func (a *assertStack) String() string { + var b strings.Builder + for i, key := range a.stack { + if a.isMap[i] { + if i > 0 { + b.WriteString(".") + } + b.WriteString(key) + } else { + b.WriteString("[") + b.WriteString(key) + b.WriteString("]") + } + } + return b.String() +} diff --git a/tests/integration/db.go b/tests/integration/db.go index 06737318d7..b9c1b3791d 100644 --- a/tests/integration/db.go +++ b/tests/integration/db.go @@ -21,6 +21,7 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/crypto" + "github.com/sourcenetwork/defradb/internal/kms" "github.com/sourcenetwork/defradb/node" changeDetector "github.com/sourcenetwork/defradb/tests/change_detector" ) @@ -109,10 +110,7 @@ func NewBadgerFileDB(ctx context.Context, t testing.TB) (client.DB, error) { return node.DB, err } -// setupNode returns the database implementation for the current -// testing state. The database type on the test state is used to -// select the datastore implementation to use. -func setupNode(s *state) (*node.Node, string, error) { +func getDefaultNodeOpts() []node.Option { opts := []node.Option{ node.WithLensPoolSize(lensPoolSize), // The test framework sets this up elsewhere when required so that it may be wrapped @@ -127,7 +125,7 @@ func setupNode(s *state) (*node.Node, string, error) { if badgerEncryption && encryptionKey == nil { key, err := crypto.GenerateAES256() if err != nil { - return nil, "", err + return nil } encryptionKey = key } @@ -136,6 +134,15 @@ func setupNode(s *state) (*node.Node, string, error) { opts = append(opts, node.WithBadgerEncryptionKey(encryptionKey)) } + return opts +} + +// setupNode returns the database implementation for the current +// testing state. The database type on the test state is used to +// select the datastore implementation to use. +func setupNode(s *state, opts ...node.Option) (*node.Node, string, error) { + opts = append(getDefaultNodeOpts(), opts...) + switch acpType { case LocalACPType: opts = append(opts, node.WithACPType(node.LocalACPType)) @@ -185,6 +192,10 @@ func setupNode(s *state) (*node.Node, string, error) { return nil, "", fmt.Errorf("invalid database type: %v", s.dbt) } + if s.kms == PubSubKMSType { + opts = append(opts, node.WithKMS(kms.PubSubServiceType)) + } + node, err := node.New(s.ctx, opts...) if err != nil { return nil, "", err diff --git a/tests/integration/encryption/commit_test.go b/tests/integration/encryption/commit_test.go index 86ea5d88df..da493e097f 100644 --- a/tests/integration/encryption/commit_test.go +++ b/tests/integration/encryption/commit_test.go @@ -48,7 +48,7 @@ func TestDocEncryption_WithEncryptionOnLWWCRDT_ShouldStoreCommitsDeltaEncrypted( Results: map[string]any{ "commits": []map[string]any{ { - "cid": "bafyreibdjepzhhiez4o27srv33xcd52yr336tpzqtkv36rdf3h3oue2l5m", + "cid": "bafyreidkuvcdxxkyoeapnmttu6l2vk43qnm3zuzpxegbifpj6w24jrvrxq", "collectionID": int64(1), "delta": encrypt(testUtils.CBORValue(21), john21DocID, ""), "docID": john21DocID, @@ -58,7 +58,7 @@ func TestDocEncryption_WithEncryptionOnLWWCRDT_ShouldStoreCommitsDeltaEncrypted( "links": []map[string]any{}, }, { - "cid": "bafyreihkiua7jpwkye3xlex6s5hh2azckcaljfi2h3iscgub5sikacyrbu", + "cid": "bafyreihdlv4fvvptetghxzyerxt4jc4zgprecybhoijrfjuyxqe55qw3x4", "collectionID": int64(1), "delta": encrypt(testUtils.CBORValue("John"), john21DocID, ""), "docID": john21DocID, @@ -68,7 +68,7 @@ func TestDocEncryption_WithEncryptionOnLWWCRDT_ShouldStoreCommitsDeltaEncrypted( "links": []map[string]any{}, }, { - "cid": "bafyreidxdhzhwjrv5s4x6cho5drz6xq2tc7oymzupf4p4gfk6eelsnc7ke", + "cid": "bafyreie5jegw4c2hg56bbiv6cgxmfz336jruukjakbjuyapockfnn6b5le", "collectionID": int64(1), "delta": nil, "docID": john21DocID, @@ -77,11 +77,11 @@ func TestDocEncryption_WithEncryptionOnLWWCRDT_ShouldStoreCommitsDeltaEncrypted( "height": int64(1), "links": []map[string]any{ { - "cid": "bafyreibdjepzhhiez4o27srv33xcd52yr336tpzqtkv36rdf3h3oue2l5m", + "cid": "bafyreidkuvcdxxkyoeapnmttu6l2vk43qnm3zuzpxegbifpj6w24jrvrxq", "name": "age", }, { - "cid": "bafyreihkiua7jpwkye3xlex6s5hh2azckcaljfi2h3iscgub5sikacyrbu", + "cid": "bafyreihdlv4fvvptetghxzyerxt4jc4zgprecybhoijrfjuyxqe55qw3x4", "name": "name", }, }, diff --git a/tests/integration/encryption/peer_sec_index_test.go b/tests/integration/encryption/peer_sec_index_test.go new file mode 100644 index 0000000000..e6fd3548cf --- /dev/null +++ b/tests/integration/encryption/peer_sec_index_test.go @@ -0,0 +1,161 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package encryption + +import ( + "testing" + + "github.com/sourcenetwork/immutable" + + testUtils "github.com/sourcenetwork/defradb/tests/integration" +) + +func TestDocEncryptionPeer_IfEncryptedDocHasIndexedField_ShouldIndexAfterDecryption(t *testing.T) { + test := testUtils.TestCase{ + KMS: testUtils.KMS{Activated: true}, + Actions: []any{ + testUtils.RandomNetworkingConfig(), + testUtils.RandomNetworkingConfig(), + testUtils.SchemaUpdate{ + Schema: ` + type User { + name: String + age: Int @index + } + `, + }, + testUtils.ConnectPeers{ + SourceNodeID: 1, + TargetNodeID: 0, + }, + testUtils.SubscribeToCollection{ + NodeID: 1, + CollectionIDs: []int{0}, + }, + testUtils.CreateDoc{ + NodeID: immutable.Some(0), + Doc: `{ + "name": "Shahzad", + "age": 25 + }`, + }, + testUtils.CreateDoc{ + NodeID: immutable.Some(0), + Doc: islam33Doc, + IsDocEncrypted: true, + }, + testUtils.CreateDoc{ + NodeID: immutable.Some(0), + Doc: `{ + "name": "Andy", + "age": 21 + }`, + }, + testUtils.CreateDoc{ + NodeID: immutable.Some(0), + Doc: john21Doc, + IsDocEncrypted: true, + }, + testUtils.WaitForSync{}, + testUtils.Request{ + Request: ` + query @explain(type: execute) { + User(filter: {age: {_eq: 21}}) { + age + } + }`, + Asserter: testUtils.NewExplainAsserter().WithIndexFetches(2), + }, + testUtils.Request{ + Request: ` + query { + User(filter: {age: {_eq: 21}}) { + name + } + }`, + Results: map[string]any{ + "User": []map[string]any{ + { + "name": "Andy", + }, + { + "name": "John", + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestDocEncryptionPeer_IfDocDocHasEncryptedIndexedField_ShouldIndexAfterDecryption(t *testing.T) { + test := testUtils.TestCase{ + KMS: testUtils.KMS{Activated: true}, + Actions: []any{ + testUtils.RandomNetworkingConfig(), + testUtils.RandomNetworkingConfig(), + testUtils.SchemaUpdate{ + Schema: ` + type User { + name: String + age: Int @index + } + `, + }, + testUtils.ConnectPeers{ + SourceNodeID: 1, + TargetNodeID: 0, + }, + testUtils.SubscribeToCollection{ + NodeID: 1, + CollectionIDs: []int{0}, + }, + testUtils.CreateDoc{ + NodeID: immutable.Some(0), + Doc: `{ + "name": "Shahzad", + "age": 25 + }`, + }, + testUtils.CreateDoc{ + NodeID: immutable.Some(0), + Doc: islam33Doc, + EncryptedFields: []string{"age"}, + }, + testUtils.CreateDoc{ + NodeID: immutable.Some(0), + Doc: `{ + "name": "Andy", + "age": 21 + }`, + }, + testUtils.CreateDoc{ + NodeID: immutable.Some(0), + Doc: john21Doc, + EncryptedFields: []string{"age"}, + }, + testUtils.WaitForSync{}, + testUtils.Request{ + Request: ` + query @explain(type: execute) { + User(filter: {age: {_eq: 21}}) { + age + } + }`, + Asserter: testUtils.NewExplainAsserter().WithIndexFetches(2), + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} diff --git a/tests/integration/encryption/peer_share_test.go b/tests/integration/encryption/peer_share_test.go new file mode 100644 index 0000000000..c04d204a84 --- /dev/null +++ b/tests/integration/encryption/peer_share_test.go @@ -0,0 +1,530 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package encryption + +import ( + "testing" + + "github.com/sourcenetwork/immutable" + + testUtils "github.com/sourcenetwork/defradb/tests/integration" +) + +func TestDocEncryptionPeer_IfDocIsPublic_ShouldFetchKeyAndDecrypt(t *testing.T) { + test := testUtils.TestCase{ + KMS: testUtils.KMS{Activated: true}, + Actions: []any{ + testUtils.RandomNetworkingConfig(), + testUtils.RandomNetworkingConfig(), + testUtils.SchemaUpdate{ + Schema: ` + type User { + name: String + age: Int + } + `, + }, + testUtils.ConnectPeers{ + SourceNodeID: 1, + TargetNodeID: 0, + }, + testUtils.SubscribeToCollection{ + NodeID: 1, + CollectionIDs: []int{0}, + }, + testUtils.CreateDoc{ + NodeID: immutable.Some(0), + Doc: john21Doc, + IsDocEncrypted: true, + }, + testUtils.WaitForSync{}, + testUtils.Request{ + NodeID: immutable.Some(1), + Request: `query { + User { + age + } + }`, + Results: map[string]any{ + "User": []map[string]any{ + { + "age": int64(21), + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestDocEncryptionPeer_IfPublicDocHasEncryptedField_ShouldFetchKeyAndDecrypt(t *testing.T) { + test := testUtils.TestCase{ + KMS: testUtils.KMS{Activated: true}, + Actions: []any{ + testUtils.RandomNetworkingConfig(), + testUtils.RandomNetworkingConfig(), + testUtils.SchemaUpdate{ + Schema: ` + type User { + name: String + age: Int + } + `, + }, + testUtils.ConnectPeers{ + SourceNodeID: 1, + TargetNodeID: 0, + }, + testUtils.SubscribeToCollection{ + NodeID: 1, + CollectionIDs: []int{0}, + }, + testUtils.CreateDoc{ + NodeID: immutable.Some(0), + Doc: john21Doc, + EncryptedFields: []string{"age"}, + }, + testUtils.WaitForSync{}, + testUtils.Request{ + NodeID: immutable.Some(1), + Request: `query { + User { + name + age + } + }`, + Results: map[string]any{ + "User": []map[string]any{ + { + "name": "John", + "age": int64(21), + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestDocEncryptionPeer_IfEncryptedPublicDocHasEncryptedField_ShouldFetchKeysAndDecrypt(t *testing.T) { + test := testUtils.TestCase{ + KMS: testUtils.KMS{Activated: true}, + Actions: []any{ + testUtils.RandomNetworkingConfig(), + testUtils.RandomNetworkingConfig(), + testUtils.SchemaUpdate{ + Schema: ` + type User { + name: String + age: Int + } + `, + }, + testUtils.ConnectPeers{ + SourceNodeID: 1, + TargetNodeID: 0, + }, + testUtils.SubscribeToCollection{ + NodeID: 1, + CollectionIDs: []int{0}, + }, + testUtils.CreateDoc{ + NodeID: immutable.Some(0), + Doc: john21Doc, + IsDocEncrypted: true, + EncryptedFields: []string{"age"}, + }, + testUtils.WaitForSync{}, + testUtils.Request{ + NodeID: immutable.Some(1), + Request: `query { + User { + name + age + } + }`, + Results: map[string]any{ + "User": []map[string]any{ + { + "name": "John", + "age": int64(21), + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestDocEncryptionPeer_IfAllFieldsOfEncryptedPublicDocAreIndividuallyEncrypted_ShouldFetchKeysAndDecrypt(t *testing.T) { + test := testUtils.TestCase{ + KMS: testUtils.KMS{Activated: true}, + Actions: []any{ + testUtils.RandomNetworkingConfig(), + testUtils.RandomNetworkingConfig(), + testUtils.SchemaUpdate{ + Schema: ` + type User { + name: String + age: Int + } + `, + }, + testUtils.ConnectPeers{ + SourceNodeID: 1, + TargetNodeID: 0, + }, + testUtils.SubscribeToCollection{ + NodeID: 1, + CollectionIDs: []int{0}, + }, + testUtils.CreateDoc{ + NodeID: immutable.Some(0), + Doc: john21Doc, + IsDocEncrypted: true, + EncryptedFields: []string{"name", "age"}, + }, + testUtils.WaitForSync{}, + testUtils.Request{ + NodeID: immutable.Some(1), + Request: `query { + User { + name + age + } + }`, + Results: map[string]any{ + "User": []map[string]any{ + { + "name": "John", + "age": int64(21), + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestDocEncryptionPeer_IfAllFieldsOfPublicDocAreIndividuallyEncrypted_ShouldFetchKeysAndDecrypt(t *testing.T) { + test := testUtils.TestCase{ + KMS: testUtils.KMS{Activated: true}, + Actions: []any{ + testUtils.RandomNetworkingConfig(), + testUtils.RandomNetworkingConfig(), + testUtils.SchemaUpdate{ + Schema: ` + type User { + name: String + age: Int + } + `, + }, + testUtils.ConnectPeers{ + SourceNodeID: 1, + TargetNodeID: 0, + }, + testUtils.SubscribeToCollection{ + NodeID: 1, + CollectionIDs: []int{0}, + }, + testUtils.CreateDoc{ + NodeID: immutable.Some(0), + Doc: john21Doc, + EncryptedFields: []string{"name", "age"}, + }, + testUtils.WaitForSync{}, + testUtils.Request{ + NodeID: immutable.Some(1), + Request: `query { + User { + name + age + } + }`, + Results: map[string]any{ + "User": []map[string]any{ + { + "name": "John", + "age": int64(21), + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestDocEncryptionPeer_WithUpdatesOnEncryptedDeltaBasedCRDTField_ShouldDecryptAndCorrectlyMerge(t *testing.T) { + test := testUtils.TestCase{ + KMS: testUtils.KMS{Activated: true}, + Actions: []any{ + testUtils.RandomNetworkingConfig(), + testUtils.RandomNetworkingConfig(), + testUtils.SchemaUpdate{ + Schema: ` + type User { + name: String + age: Int @crdt(type: "pcounter") + } + `, + }, + testUtils.ConnectPeers{ + SourceNodeID: 1, + TargetNodeID: 0, + }, + testUtils.CreateDoc{ + NodeID: immutable.Some(0), + Doc: john21Doc, + EncryptedFields: []string{"age"}, + }, + testUtils.UpdateDoc{ + NodeID: immutable.Some(0), + Doc: `{"age": 3}`, + }, + testUtils.SubscribeToCollection{ + NodeID: 1, + CollectionIDs: []int{0}, + }, + testUtils.UpdateDoc{ + NodeID: immutable.Some(0), + Doc: `{"age": 2}`, + }, + testUtils.WaitForSync{}, + testUtils.Request{ + NodeID: immutable.Some(1), + Request: `query { + User { + name + age + } + }`, + Results: map[string]any{ + "User": []map[string]any{ + { + "name": "John", + "age": int64(26), + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestDocEncryptionPeer_WithUpdatesOnDeltaBasedCRDTFieldOfEncryptedDoc_ShouldDecryptAndCorrectlyMerge(t *testing.T) { + test := testUtils.TestCase{ + KMS: testUtils.KMS{Activated: true}, + Actions: []any{ + testUtils.RandomNetworkingConfig(), + testUtils.RandomNetworkingConfig(), + testUtils.SchemaUpdate{ + Schema: ` + type User { + name: String + age: Int @crdt(type: "pcounter") + } + `, + }, + testUtils.ConnectPeers{ + SourceNodeID: 1, + TargetNodeID: 0, + }, + testUtils.CreateDoc{ + NodeID: immutable.Some(0), + Doc: john21Doc, + IsDocEncrypted: true, + }, + testUtils.UpdateDoc{ + NodeID: immutable.Some(0), + Doc: `{"age": 3}`, + }, + testUtils.SubscribeToCollection{ + NodeID: 1, + CollectionIDs: []int{0}, + }, + testUtils.UpdateDoc{ + NodeID: immutable.Some(0), + Doc: `{"age": 2}`, + }, + testUtils.WaitForSync{}, + testUtils.Request{ + NodeID: immutable.Some(1), + Request: `query { + User { + name + age + } + }`, + Results: map[string]any{ + "User": []map[string]any{ + { + "name": "John", + "age": int64(26), + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestDocEncryptionPeer_WithUpdatesThatSetsEmptyString_ShouldDecryptAndCorrectlyMerge(t *testing.T) { + test := testUtils.TestCase{ + KMS: testUtils.KMS{Activated: true}, + Actions: []any{ + testUtils.RandomNetworkingConfig(), + testUtils.RandomNetworkingConfig(), + testUtils.SchemaUpdate{ + Schema: ` + type User { + name: String + age: Int + } + `, + }, + testUtils.ConnectPeers{ + SourceNodeID: 1, + TargetNodeID: 0, + }, + testUtils.CreateDoc{ + NodeID: immutable.Some(0), + Doc: john21Doc, + IsDocEncrypted: true, + }, + testUtils.SubscribeToCollection{ + NodeID: 1, + CollectionIDs: []int{0}, + }, + testUtils.UpdateDoc{ + NodeID: immutable.Some(0), + Doc: `{"name": ""}`, + }, + testUtils.WaitForSync{}, + testUtils.Request{ + NodeID: immutable.Some(1), + Request: `query { + User { + name + } + }`, + Results: map[string]any{ + "User": []map[string]any{ + {"name": ""}, + }, + }, + }, + testUtils.UpdateDoc{ + NodeID: immutable.Some(0), + Doc: `{"name": "John"}`, + }, + testUtils.WaitForSync{}, + testUtils.Request{ + NodeID: immutable.Some(1), + Request: `query { + User { + name + } + }`, + Results: map[string]any{ + "User": []map[string]any{ + {"name": "John"}, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestDocEncryptionPeer_WithUpdatesThatSetsStringToNull_ShouldDecryptAndCorrectlyMerge(t *testing.T) { + test := testUtils.TestCase{ + KMS: testUtils.KMS{Activated: true}, + Actions: []any{ + testUtils.RandomNetworkingConfig(), + testUtils.RandomNetworkingConfig(), + testUtils.SchemaUpdate{ + Schema: ` + type User { + name: String + age: Int + } + `, + }, + testUtils.ConnectPeers{ + SourceNodeID: 1, + TargetNodeID: 0, + }, + testUtils.CreateDoc{ + NodeID: immutable.Some(0), + Doc: john21Doc, + IsDocEncrypted: true, + }, + testUtils.SubscribeToCollection{ + NodeID: 1, + CollectionIDs: []int{0}, + }, + testUtils.UpdateDoc{ + NodeID: immutable.Some(0), + Doc: `{"name": null}`, + }, + testUtils.WaitForSync{}, + testUtils.Request{ + NodeID: immutable.Some(1), + Request: `query { + User { + name + } + }`, + Results: map[string]any{ + "User": []map[string]any{ + {"name": nil}, + }, + }, + }, + testUtils.UpdateDoc{ + NodeID: immutable.Some(0), + Doc: `{"name": "John"}`, + }, + testUtils.WaitForSync{}, + testUtils.Request{ + NodeID: immutable.Some(1), + Request: `query { + User { + name + } + }`, + Results: map[string]any{ + "User": []map[string]any{ + {"name": "John"}, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} diff --git a/tests/integration/encryption/peer_test.go b/tests/integration/encryption/peer_test.go index 7a94c22e13..9f5b875586 100644 --- a/tests/integration/encryption/peer_test.go +++ b/tests/integration/encryption/peer_test.go @@ -18,45 +18,9 @@ import ( testUtils "github.com/sourcenetwork/defradb/tests/integration" ) -func TestDocEncryptionPeer_IfPeerHasNoKey_ShouldNotFetch(t *testing.T) { - test := testUtils.TestCase{ - Actions: []any{ - testUtils.RandomNetworkingConfig(), - testUtils.RandomNetworkingConfig(), - updateUserCollectionSchema(), - testUtils.ConnectPeers{ - SourceNodeID: 1, - TargetNodeID: 0, - }, - testUtils.SubscribeToCollection{ - NodeID: 1, - CollectionIDs: []int{0}, - }, - testUtils.CreateDoc{ - NodeID: immutable.Some(0), - Doc: john21Doc, - IsDocEncrypted: true, - }, - testUtils.WaitForSync{}, - testUtils.Request{ - NodeID: immutable.Some(1), - Request: `query { - Users { - age - } - }`, - Results: map[string]any{ - "Users": []map[string]any{}, - }, - }, - }, - } - - testUtils.ExecuteTestCase(t, test) -} - func TestDocEncryptionPeer_UponSync_ShouldSyncEncryptedDAG(t *testing.T) { test := testUtils.TestCase{ + KMS: testUtils.KMS{Activated: true}, Actions: []any{ testUtils.RandomNetworkingConfig(), testUtils.RandomNetworkingConfig(), @@ -97,7 +61,7 @@ func TestDocEncryptionPeer_UponSync_ShouldSyncEncryptedDAG(t *testing.T) { Results: map[string]any{ "commits": []map[string]any{ { - "cid": "bafyreibdjepzhhiez4o27srv33xcd52yr336tpzqtkv36rdf3h3oue2l5m", + "cid": "bafyreidkuvcdxxkyoeapnmttu6l2vk43qnm3zuzpxegbifpj6w24jrvrxq", "collectionID": int64(1), "delta": encrypt(testUtils.CBORValue(21), john21DocID, ""), "docID": john21DocID, @@ -107,7 +71,7 @@ func TestDocEncryptionPeer_UponSync_ShouldSyncEncryptedDAG(t *testing.T) { "links": []map[string]any{}, }, { - "cid": "bafyreihkiua7jpwkye3xlex6s5hh2azckcaljfi2h3iscgub5sikacyrbu", + "cid": "bafyreihdlv4fvvptetghxzyerxt4jc4zgprecybhoijrfjuyxqe55qw3x4", "collectionID": int64(1), "delta": encrypt(testUtils.CBORValue("John"), john21DocID, ""), "docID": john21DocID, @@ -117,7 +81,7 @@ func TestDocEncryptionPeer_UponSync_ShouldSyncEncryptedDAG(t *testing.T) { "links": []map[string]any{}, }, { - "cid": "bafyreidxdhzhwjrv5s4x6cho5drz6xq2tc7oymzupf4p4gfk6eelsnc7ke", + "cid": "bafyreie5jegw4c2hg56bbiv6cgxmfz336jruukjakbjuyapockfnn6b5le", "collectionID": int64(1), "delta": nil, "docID": john21DocID, @@ -126,11 +90,11 @@ func TestDocEncryptionPeer_UponSync_ShouldSyncEncryptedDAG(t *testing.T) { "height": int64(1), "links": []map[string]any{ { - "cid": "bafyreibdjepzhhiez4o27srv33xcd52yr336tpzqtkv36rdf3h3oue2l5m", + "cid": "bafyreidkuvcdxxkyoeapnmttu6l2vk43qnm3zuzpxegbifpj6w24jrvrxq", "name": "age", }, { - "cid": "bafyreihkiua7jpwkye3xlex6s5hh2azckcaljfi2h3iscgub5sikacyrbu", + "cid": "bafyreihdlv4fvvptetghxzyerxt4jc4zgprecybhoijrfjuyxqe55qw3x4", "name": "name", }, }, @@ -143,3 +107,53 @@ func TestDocEncryptionPeer_UponSync_ShouldSyncEncryptedDAG(t *testing.T) { testUtils.ExecuteTestCase(t, test) } + +func TestDocEncryptionPeer_IfPeerDidNotReceiveKey_ShouldNotFetch(t *testing.T) { + test := testUtils.TestCase{ + KMS: testUtils.KMS{Activated: true}, + Actions: []any{ + testUtils.RandomNetworkingConfig(), + testUtils.RandomNetworkingConfig(), + updateUserCollectionSchema(), + testUtils.ConnectPeers{ + SourceNodeID: 1, + TargetNodeID: 0, + }, + testUtils.SubscribeToCollection{ + NodeID: 1, + CollectionIDs: []int{0}, + }, + testUtils.CreateDoc{ + NodeID: immutable.Some(0), + Doc: john21Doc, + IsDocEncrypted: true, + }, + testUtils.WaitForSync{}, + // Do not wait for the key sync and request the document as soon as the dag has synced + // The document will be returned if the key-sync has taken place already, if not, the set will + // be empty. + testUtils.Request{ + NodeID: immutable.Some(1), + Request: `query { + Users { + age + } + }`, + Results: map[string]any{ + "Users": testUtils.AnyOf{ + // The key-sync has not yet completed + []map[string]any{}, + // The key-sync has completed + []map[string]any{ + { + "age": int64(21), + }, + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} diff --git a/tests/integration/encryption/utils.go b/tests/integration/encryption/utils.go index fd9c1d17c0..685a2567fb 100644 --- a/tests/integration/encryption/utils.go +++ b/tests/integration/encryption/utils.go @@ -11,7 +11,7 @@ package encryption import ( - "github.com/sourcenetwork/defradb/internal/encryption" + "github.com/sourcenetwork/defradb/crypto" testUtils "github.com/sourcenetwork/defradb/tests/integration" ) @@ -50,6 +50,6 @@ func updateUserCollectionSchema() testUtils.SchemaUpdate { func encrypt(plaintext []byte, docID, fieldName string) []byte { const keyLength = 32 const testEncKey = "examplekey1234567890examplekey12" - val, _ := encryption.EncryptAES(plaintext, []byte(fieldName + docID + testEncKey)[0:keyLength]) + val, _, _ := crypto.EncryptAES(plaintext, []byte(fieldName + docID + testEncKey)[0:keyLength], nil, true) return val } diff --git a/tests/integration/events.go b/tests/integration/events.go index 87b157a662..bf004b99aa 100644 --- a/tests/integration/events.go +++ b/tests/integration/events.go @@ -349,6 +349,10 @@ func getEventsForCreateDoc(s *state, action CreateDoc) map[string]struct{} { return expect } +func waitForSync(s *state) { + waitForMergeEvents(s) +} + // getEventsForUpdateWithFilter returns a map of docIDs that should be // published to the local event bus after a UpdateWithFilter action. // diff --git a/tests/integration/state.go b/tests/integration/state.go index e594285318..9e65458531 100644 --- a/tests/integration/state.go +++ b/tests/integration/state.go @@ -116,6 +116,8 @@ type state struct { // The TestCase currently being executed. testCase TestCase + kms KMSType + // The type of database currently being tested. dbt DatabaseType @@ -191,6 +193,7 @@ func newState( ctx context.Context, t testing.TB, testCase TestCase, + kms KMSType, dbt DatabaseType, clientType ClientType, collectionNames []string, @@ -199,6 +202,7 @@ func newState( ctx: ctx, t: t, testCase: testCase, + kms: kms, dbt: dbt, clientType: clientType, txns: []datastore.Txn{}, diff --git a/tests/integration/test_case.go b/tests/integration/test_case.go index 389de2af35..9b0bce913b 100644 --- a/tests/integration/test_case.go +++ b/tests/integration/test_case.go @@ -59,6 +59,18 @@ type TestCase struct { // This is to only be used in the very rare cases where we really do want behavioural // differences between view types, or we need to temporarily document a bug. SupportedViewTypes immutable.Option[[]ViewType] + + // Configuration for KMS to be used in the test + KMS KMS +} + +// KMS contains the configuration for KMS to be used in the test +type KMS struct { + // Activated indicates if the KMS should be used in the test + Activated bool + // ExcludedTypes specifies the KMS types that should be excluded from the test. + // If none are specified all types will be used. + ExcludedTypes []KMSType } // SetupComplete is a flag to explicitly notify the change detector at which point diff --git a/tests/integration/utils.go b/tests/integration/utils.go index fd6758929d..62e27e0b73 100644 --- a/tests/integration/utils.go +++ b/tests/integration/utils.go @@ -14,14 +14,15 @@ import ( "context" "encoding/json" "fmt" + "log/slog" "os" "reflect" + "slices" "strconv" "strings" "testing" "time" - "github.com/bxcodec/faker/support/slice" "github.com/fxamacker/cbor/v2" "github.com/sourcenetwork/corelog" "github.com/sourcenetwork/immutable" @@ -38,6 +39,7 @@ import ( "github.com/sourcenetwork/defradb/internal/request/graphql" "github.com/sourcenetwork/defradb/internal/request/graphql/schema/types" "github.com/sourcenetwork/defradb/net" + "github.com/sourcenetwork/defradb/node" changeDetector "github.com/sourcenetwork/defradb/tests/change_detector" "github.com/sourcenetwork/defradb/tests/clients" "github.com/sourcenetwork/defradb/tests/gen" @@ -185,6 +187,17 @@ func ExecuteTestCase( databases = append(databases, defraIMType) } + var kmsList []KMSType + if testCase.KMS.Activated { + kmsList = getKMSTypes() + for _, excluded := range testCase.KMS.ExcludedTypes { + kmsList = slices.DeleteFunc(kmsList, func(t KMSType) bool { return t == excluded }) + } + } + if len(kmsList) == 0 { + kmsList = []KMSType{NoneKMSType} + } + // Assert that these are not empty to protect against accidental mis-configurations, // otherwise an empty set would silently pass all the tests. require.NotEmpty(t, databases) @@ -195,7 +208,9 @@ func ExecuteTestCase( ctx := context.Background() for _, ct := range clients { for _, dbt := range databases { - executeTestCase(ctx, t, collectionNames, testCase, dbt, ct) + for _, kms := range kmsList { + executeTestCase(ctx, t, collectionNames, testCase, kms, dbt, ct) + } } } } @@ -205,12 +220,11 @@ func executeTestCase( t testing.TB, collectionNames []string, testCase TestCase, + kms KMSType, dbt DatabaseType, clientType ClientType, ) { - log.InfoContext( - ctx, - testCase.Description, + logAttrs := []slog.Attr{ corelog.Any("database", dbt), corelog.Any("client", clientType), corelog.Any("mutationType", mutationType), @@ -222,11 +236,17 @@ func executeTestCase( corelog.String("changeDetector.SourceBranch", changeDetector.SourceBranch), corelog.String("changeDetector.TargetBranch", changeDetector.TargetBranch), corelog.String("changeDetector.Repository", changeDetector.Repository), - ) + } + + if kms != NoneKMSType { + logAttrs = append(logAttrs, corelog.Any("KMS", kms)) + } + + log.InfoContext(ctx, testCase.Description, logAttrs...) startActionIndex, endActionIndex := getActionRange(t, testCase) - s := newState(ctx, t, testCase, dbt, clientType, collectionNames) + s := newState(ctx, t, testCase, kms, dbt, clientType, collectionNames) setStartingNodes(s) // It is very important that the databases are always closed, otherwise resources will leak @@ -366,7 +386,7 @@ func performAction( assertClientIntrospectionResults(s, action) case WaitForSync: - waitForMergeEvents(s) + waitForSync(s) case Benchmark: benchmarkAction(s, actionIndex, action) @@ -403,7 +423,7 @@ func generateDocs(s *state, action GenerateDocs) { collections := getNodeCollections(action.NodeID, s.collections) defs := make([]client.CollectionDefinition, 0, len(collections[0])) for _, col := range collections[0] { - if len(action.ForCollections) == 0 || slice.Contains(action.ForCollections, col.Name().Value()) { + if len(action.ForCollections) == 0 || slices.Contains(action.ForCollections, col.Name().Value()) { defs = append(defs, col.Definition()) } } @@ -730,7 +750,7 @@ func restartNodes( nodeOpts := s.nodeConfigs[i] nodeOpts = append(nodeOpts, net.WithListenAddresses(addresses...)) - node.Peer, err = net.NewPeer(s.ctx, node.DB.Blockstore(), node.DB.Events(), nodeOpts...) + node.Peer, err = net.NewPeer(s.ctx, node.DB.Blockstore(), node.DB.Encstore(), node.DB.Events(), nodeOpts...) require.NoError(s.t, err) c, err := setupClient(s, node) @@ -802,20 +822,21 @@ func configureNode( return } - node, path, err := setupNode(s) //disable change dector, or allow it? - require.NoError(s.t, err) - privateKey, err := crypto.GenerateEd25519() require.NoError(s.t, err) - nodeOpts := action() - nodeOpts = append(nodeOpts, net.WithPrivateKey(privateKey)) + netNodeOpts := action() + netNodeOpts = append(netNodeOpts, net.WithPrivateKey(privateKey)) - node.Peer, err = net.NewPeer(s.ctx, node.DB.Blockstore(), node.DB.Events(), nodeOpts...) + nodeOpts := []node.Option{node.WithDisableP2P(false)} + for _, opt := range netNodeOpts { + nodeOpts = append(nodeOpts, opt) + } + node, path, err := setupNode(s, nodeOpts...) //disable change dector, or allow it? require.NoError(s.t, err) s.nodeAddresses = append(s.nodeAddresses, node.Peer.PeerInfo()) - s.nodeConfigs = append(s.nodeConfigs, nodeOpts) + s.nodeConfigs = append(s.nodeConfigs, netNodeOpts) c, err := setupClient(s, node) require.NoError(s.t, err) @@ -1767,7 +1788,6 @@ func executeRequest( result := node.ExecRequest(ctx, action.Request, options...) - anyOfByFieldKey := map[docFieldKey][]any{} expectedErrorRaised = assertRequestResults( s, &result.GQL, @@ -1775,7 +1795,6 @@ func executeRequest( action.ExpectedError, action.Asserter, nodeID, - anyOfByFieldKey, ) } @@ -1825,9 +1844,7 @@ func executeSubscriptionRequest( r, action.ExpectedError, nil, - // anyof is not yet supported by subscription requests 0, - map[docFieldKey][]any{}, ) assertExpectedErrorRaised(s.t, s.testCase.Description, action.ExpectedError, expectedErrorRaised) @@ -1884,12 +1901,6 @@ func AssertErrors( return false } -// docFieldKey is an internal key type that wraps docIndex and fieldName -type docFieldKey struct { - docIndex int - fieldName string -} - func assertRequestResults( s *state, result *client.GQLResult, @@ -1897,7 +1908,6 @@ func assertRequestResults( expectedError string, asserter ResultAsserter, nodeID int, - anyOfByField map[docFieldKey][]any, ) bool { // we skip assertion benchmark because you don't specify expected result for benchmark. if AssertErrors(s.t, s.testCase.Description, result.Errors, expectedError) || s.isBench { @@ -1926,31 +1936,37 @@ func assertRequestResults( keys[key] = struct{}{} } + stack := &assertStack{} for key := range keys { + stack.pushMap(key) expect, ok := expectedResults[key] require.True(s.t, ok, "expected key not found: %s", key) actual, ok := resultantData[key] require.True(s.t, ok, "result key not found: %s", key) - expectDocs, ok := expect.([]map[string]any) - if ok { + switch exp := expect.(type) { + case []map[string]any: actualDocs := ConvertToArrayOfMaps(s.t, actual) assertRequestResultDocs( s, nodeID, - expectDocs, + exp, actualDocs, - anyOfByField) - } else { + stack, + ) + case AnyOf: + assertResultsAnyOf(s.t, s.clientType, exp, actual) + default: assertResultsEqual( s.t, s.clientType, expect, actual, - fmt.Sprintf("node: %v, key: %v", nodeID, key), + fmt.Sprintf("node: %v, path: %s", nodeID, stack), ) } + stack.pop() } return false @@ -1961,13 +1977,14 @@ func assertRequestResultDocs( nodeID int, expectedResults []map[string]any, actualResults []map[string]any, - anyOfByField map[docFieldKey][]any, + stack *assertStack, ) bool { // compare results require.Equal(s.t, len(expectedResults), len(actualResults), s.testCase.Description+" \n(number of results don't match)") for actualDocIndex, actualDoc := range actualResults { + stack.pushArray(actualDocIndex) expectedDoc := expectedResults[actualDocIndex] require.Equal( @@ -1982,14 +1999,10 @@ func assertRequestResultDocs( ) for field, actualValue := range actualDoc { + stack.pushMap(field) switch expectedValue := expectedDoc[field].(type) { case AnyOf: assertResultsAnyOf(s.t, s.clientType, expectedValue, actualValue) - - dfk := docFieldKey{actualDocIndex, field} - valueSet := anyOfByField[dfk] - valueSet = append(valueSet, actualValue) - anyOfByField[dfk] = valueSet case DocIndex: expectedDocID := s.docIDs[expectedValue.CollectionIndex][expectedValue.Index].String() assertResultsEqual( @@ -1997,7 +2010,7 @@ func assertRequestResultDocs( s.clientType, expectedDocID, actualValue, - fmt.Sprintf("node: %v, doc: %v", nodeID, actualDocIndex), + fmt.Sprintf("node: %v, path: %s", nodeID, stack), ) case []map[string]any: actualValueMap := ConvertToArrayOfMaps(s.t, actualValue) @@ -2007,7 +2020,7 @@ func assertRequestResultDocs( nodeID, expectedValue, actualValueMap, - anyOfByField, + stack, ) default: @@ -2016,10 +2029,12 @@ func assertRequestResultDocs( s.clientType, expectedValue, actualValue, - fmt.Sprintf("node: %v, doc: %v", nodeID, actualDocIndex), + fmt.Sprintf("node: %v, path: %s", nodeID, stack), ) } + stack.pop() } + stack.pop() } return false diff --git a/tools/configs/mockery.yaml b/tools/configs/mockery.yaml index 451ae55771..504dbd1be1 100644 --- a/tools/configs/mockery.yaml +++ b/tools/configs/mockery.yaml @@ -32,6 +32,7 @@ packages: # Packages and their interfaces to generate mocks for. DSReaderWriter: RootStore: Txn: + Blockstore: github.com/sourcenetwork/defradb/client: config: