Skip to content

Commit

Permalink
Replace panic with returning errors from key decryption providers (#155)
Browse files Browse the repository at this point in the history
* Add context parameter to key provider interface

* update error handling for AE key providers
  • Loading branch information
shueybubbles authored Oct 18, 2023
1 parent e51fa15 commit 670fd58
Show file tree
Hide file tree
Showing 16 changed files with 396 additions and 172 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
# Changelog
## 1.7.0

### Changed

* Changed always encrypted key provider error handling not to panic on failure

### Features

* Support DER certificates for server authentication (#152)

### Bug fixes

* Improved speed of CharsetToUTF8 (#154)

## 1.6.0

Expand Down
165 changes: 100 additions & 65 deletions aecmk/akv/keyprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,101 +63,120 @@ func init() {

// DecryptColumnEncryptionKey decrypts the specified encrypted value of a column encryption key.
// The encrypted value is expected to be encrypted using the column master key with the specified key path and using the specified algorithm.
func (p *Provider) DecryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) (decryptedKey []byte) {
func (p *Provider) DecryptColumnEncryptionKey(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) (decryptedKey []byte, err error) {
decryptedKey = nil
keyData := p.getKeyData(masterKeyPath)
if keyData == nil {
keyData, err := p.getKeyData(ctx, masterKeyPath, aecmk.Decryption)
if err != nil {
return
}
keySize := keyData.publicKey.Size()
cekv := ae.LoadCEKV(encryptedCek)
if cekv.Version != 1 {
panic(fmt.Errorf("Invalid version byte in encrypted key"))
return nil, aecmk.NewError(aecmk.Decryption, "Invalid version byte in encrypted key", nil)
}
if keySize != len(cekv.Ciphertext) {
panic(fmt.Errorf("Encrypted key has wrong ciphertext length"))
return nil, aecmk.NewError(aecmk.Decryption, "Encrypted key has wrong ciphertext length", nil)
}
if keySize != len(cekv.SignedHash) {
panic(fmt.Errorf("Encrypted key signature length mismatch"))
return nil, aecmk.NewError(aecmk.Decryption, "Encrypted key signature length mismatch", nil)
}
if !cekv.VerifySignature(keyData.publicKey) {
panic(fmt.Errorf("Invalid signature hash"))
return nil, aecmk.NewError(aecmk.Decryption, "Invalid signature hash", nil)
}

client := p.getAKVClient(keyData.endpoint)
algorithm := getAlgorithm(encryptionAlgorithm)
client, err := p.getAKVClient(aecmk.Decryption, keyData.endpoint)
if err != nil {
return
}
algorithm, err := getAlgorithm(aecmk.Decryption, encryptionAlgorithm)
if err != nil {
return
}
parameters := azkeys.KeyOperationParameters{
Algorithm: &algorithm,
Value: cekv.Ciphertext,
}
r, err := client.UnwrapKey(context.Background(), keyData.name, keyData.version, parameters, nil)
if err != nil {
panic(fmt.Errorf("Unable to decrypt key %s: %w", masterKeyPath, err))
r, e := client.UnwrapKey(ctx, keyData.name, keyData.version, parameters, nil)
if e != nil {
err = aecmk.NewError(aecmk.Decryption, fmt.Sprintf("Unable to decrypt key %s", masterKeyPath), e)
} else {
decryptedKey = r.Result
}
decryptedKey = r.Result
return
}

// EncryptColumnEncryptionKey encrypts a column encryption key using the column master key with the specified key path and using the specified algorithm.
func (p *Provider) EncryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, cek []byte) []byte {
keyData := p.getKeyData(masterKeyPath)
// just validate the algorith
_ = getAlgorithm(encryptionAlgorithm)
func (p *Provider) EncryptColumnEncryptionKey(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, cek []byte) (buf []byte, err error) {
keyData, err := p.getKeyData(ctx, masterKeyPath, aecmk.Encryption)
if err != nil {
return
}
_, err = getAlgorithm(aecmk.Encryption, encryptionAlgorithm)
if err != nil {
return
}
keySize := keyData.publicKey.Size()
enc := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewEncoder()
// Start with version byte == 1
buf := []byte{byte(1)}
tmp := []byte{byte(1)}
// EncryptedColumnEncryptionKey = version + keyPathLength + ciphertextLength + keyPath + ciphertext + signature
// version
keyPathBytes, err := enc.Bytes([]byte(strings.ToLower(masterKeyPath)))
if err != nil {
panic(fmt.Errorf("Unable to serialize key path %w", err))
err = aecmk.NewError(aecmk.Encryption, "Unable to serialize key path", err)
return
}
k := uint16(len(keyPathBytes))
// keyPathLength
buf = append(buf, byte(k), byte(k>>8))
tmp = append(tmp, byte(k), byte(k>>8))

cipherText, err := rsa.EncryptOAEP(sha1.New(), rand.Reader, keyData.publicKey, cek, []byte{})
if err != nil {
panic(fmt.Errorf("Unable to encrypt data %w", err))
err = aecmk.NewError(aecmk.Encryption, "Unable to encrypt data", err)
return
}
l := uint16(len(cipherText))
// ciphertextLength
buf = append(buf, byte(l), byte(l>>8))
tmp = append(tmp, byte(l), byte(l>>8))
// keypath
buf = append(buf, keyPathBytes...)
tmp = append(tmp, keyPathBytes...)
// ciphertext
buf = append(buf, cipherText...)
hash := sha256.Sum256(buf)
client := p.getAKVClient(keyData.endpoint)
tmp = append(tmp, cipherText...)
hash := sha256.Sum256(tmp)
client, err := p.getAKVClient(aecmk.Encryption, keyData.endpoint)
if err != nil {
return
}
signAlgorithm := azkeys.SignatureAlgorithmRS256
parameters := azkeys.SignParameters{
Algorithm: &signAlgorithm,
Value: hash[:],
}
r, err := client.Sign(context.Background(), keyData.name, keyData.version, parameters, nil)
r, err := client.Sign(ctx, keyData.name, keyData.version, parameters, nil)
if err != nil {
panic(err)
err = aecmk.NewError(aecmk.Encryption, "AKV failed to sign data", err)
return
}
if len(r.Result) != keySize {
panic("Signature length doesn't match certificate key size")
err = aecmk.NewError(aecmk.Encryption, "Signature length doesn't match certificate key size", nil)
} else {
// signature
buf = append(tmp, r.Result...)
}
// signature
buf = append(buf, r.Result...)
return buf
return
}

// SignColumnMasterKeyMetadata digitally signs the column master key metadata with the column master key
// referenced by the masterKeyPath parameter. The input values used to generate the signature should be the
// specified values of the masterKeyPath and allowEnclaveComputations parameters. May return an empty slice if not supported.
func (p *Provider) SignColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) []byte {
return nil
func (p *Provider) SignColumnMasterKeyMetadata(ctx context.Context, masterKeyPath string, allowEnclaveComputations bool) ([]byte, error) {
return nil, nil
}

// VerifyColumnMasterKeyMetadata verifies the specified signature is valid for the column master key
// with the specified key path and the specified enclave behavior. Return nil if not supported.
func (p *Provider) VerifyColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) *bool {
return nil
func (p *Provider) VerifyColumnMasterKeyMetadata(ctx context.Context, masterKeyPath string, allowEnclaveComputations bool) (*bool, error) {
return nil, nil
}

// KeyLifetime is an optional Duration. Keys fetched by this provider will be discarded after their lifetime expires.
Expand All @@ -167,51 +186,60 @@ func (p *Provider) KeyLifetime() *time.Duration {
return nil
}

func getAlgorithm(encryptionAlgorithm string) (algorithm azkeys.EncryptionAlgorithm) {
func getAlgorithm(op aecmk.Operation, encryptionAlgorithm string) (algorithm azkeys.EncryptionAlgorithm, err error) {
// support both RSA_OAEP and RSA-OAEP
if strings.EqualFold(encryptionAlgorithm, aecmk.KeyEncryptionAlgorithm) {
encryptionAlgorithm = string(azkeys.EncryptionAlgorithmRSAOAEP)
}
if !strings.EqualFold(encryptionAlgorithm, string(azkeys.EncryptionAlgorithmRSAOAEP)) {
panic(fmt.Errorf("Unsupported encryption algorithm %s", encryptionAlgorithm))
err = aecmk.NewError(op, fmt.Sprintf("Unsupported encryption algorithm %s", encryptionAlgorithm), nil)
} else {
algorithm = azkeys.EncryptionAlgorithmRSAOAEP
}
return azkeys.EncryptionAlgorithmRSAOAEP
return
}

// masterKeyPath is a full URL. The AKV client requires it broken down into endpoint, name, and version
// The URL has format '{endpoint}/{host}/keys/{name}/[{version}/]'
func (p *Provider) getKeyData(masterKeyPath string) *keyData {
func (p *Provider) getKeyData(ctx context.Context, masterKeyPath string, op aecmk.Operation) (k *keyData, err error) {
endpoint, keypath, allowed := p.allowedPathAndEndpoint(masterKeyPath)
if !(allowed) {
return nil
err = aecmk.KeyPathNotAllowed(masterKeyPath, op)
return
}
k := &keyData{
k = &keyData{
endpoint: endpoint,
name: keypath[0],
}
if len(keypath) > 1 {
k.version = keypath[1]
}
client := p.getAKVClient(endpoint)
r, err := client.GetKey(context.Background(), k.name, k.version, nil)
client, err := p.getAKVClient(op, endpoint)
if err != nil {
return
}
r, err := client.GetKey(ctx, k.name, k.version, nil)
if err != nil {
panic(fmt.Errorf("Unable to get key from AKV %w", err))
err = aecmk.NewError(op, "Unable to get key from AKV. Name:"+masterKeyPath, err)
}
if r.Key.Kty == nil || (*r.Key.Kty != azkeys.KeyTypeRSA && *r.Key.Kty != azkeys.KeyTypeRSAHSM) {
panic(fmt.Errorf("Key type not supported for Always Encrypted"))
err = aecmk.NewError(op, "Key type not supported for Always Encrypted", nil)
}
k.publicKey = &rsa.PublicKey{
N: new(big.Int).SetBytes(r.Key.N),
E: int(new(big.Int).SetBytes(r.Key.E).Int64()),
if err == nil {
k.publicKey = &rsa.PublicKey{
N: new(big.Int).SetBytes(r.Key.N),
E: int(new(big.Int).SetBytes(r.Key.E).Int64()),
}
}
return k
return
}

func (p *Provider) allowedPathAndEndpoint(masterKeyPath string) (endpoint string, keypath []string, allowed bool) {
allowed = len(p.AllowedLocations) == 0
url, err := url.Parse(masterKeyPath)
if err != nil {
panic(fmt.Errorf("Invalid URL for master key path %s: %w", masterKeyPath, err))
allowed = false
return
}
if !allowed {

Expand All @@ -226,7 +254,8 @@ func (p *Provider) allowedPathAndEndpoint(masterKeyPath string) (endpoint string
if allowed {
pathParts := strings.Split(strings.TrimLeft(url.Path, "/"), "/")
if len(pathParts) < 2 || len(pathParts) > 3 || pathParts[0] != "keys" {
panic(fmt.Errorf("Invalid URL for master key path %s", masterKeyPath))
allowed = false
return
}
keypath = pathParts[1:]
url.Path = ""
Expand All @@ -237,28 +266,34 @@ func (p *Provider) allowedPathAndEndpoint(masterKeyPath string) (endpoint string
return
}

func (p *Provider) getAKVClient(endpoint string) (client *azkeys.Client) {
client, err := azkeys.NewClient(endpoint, p.getCredential(endpoint), nil)
func (p *Provider) getAKVClient(op aecmk.Operation, endpoint string) (client *azkeys.Client, err error) {
credential, err := p.getCredential(op, endpoint)
if err == nil {
client, err = azkeys.NewClient(endpoint, credential, nil)
}
if err != nil {
panic(fmt.Errorf("Unable to create AKV client %w", err))
err = aecmk.NewError(op, "Unable to create AKV client", err)
}
return
}

func (p *Provider) getCredential(endpoint string) azcore.TokenCredential {
func (p *Provider) getCredential(op aecmk.Operation, endpoint string) (credential azcore.TokenCredential, err error) {
if len(p.credentials) == 0 {
credential, err := azidentity.NewDefaultAzureCredential(nil)
credential, err = azidentity.NewDefaultAzureCredential(nil)
if err != nil {
panic(fmt.Errorf("Unable to create a default credential: %w", err))
err = aecmk.NewError(op, "Unable to create a default credential", err)
} else {
p.credentials[wildcard] = credential
}
p.credentials[wildcard] = credential
return credential
return
}
if credential, ok := p.credentials[endpoint]; ok {
return credential
var ok bool
if credential, ok = p.credentials[endpoint]; ok {
return
}
if credential, ok := p.credentials[wildcard]; ok {
return credential
if credential, ok = p.credentials[wildcard]; ok {
return
}
panic(fmt.Errorf("No credential available for AKV path %s", endpoint))
err = aecmk.NewError(op, fmt.Sprintf("No credential available for AKV path %s", endpoint), nil)
return
}
15 changes: 10 additions & 5 deletions aecmk/akv/keyprovider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package akv

import (
"context"
"crypto/rand"
"net/url"
"testing"
Expand All @@ -26,9 +27,13 @@ func TestEncryptDecryptRoundTrip(t *testing.T) {
plainKey := make([]byte, 32)
_, _ = rand.Read(plainKey)
t.Log("Plainkey:", plainKey)
encryptedKey := p.EncryptColumnEncryptionKey(keyPath, aecmk.KeyEncryptionAlgorithm, plainKey)
t.Log("Encryptedkey:", encryptedKey)
assert.NotEqualValues(t, plainKey, encryptedKey, "encryptedKey is the same as plainKey")
decryptedKey := p.DecryptColumnEncryptionKey(keyPath, aecmk.KeyEncryptionAlgorithm, encryptedKey)
assert.Equalf(t, plainKey, decryptedKey, "decryptedkey doesn't match plainKey. %v : %v", decryptedKey, plainKey)
encryptedKey, err := p.EncryptColumnEncryptionKey(context.Background(), keyPath, aecmk.KeyEncryptionAlgorithm, plainKey)
if assert.NoError(t, err, "EncryptColumnEncryptionKey") {
t.Log("Encryptedkey:", encryptedKey)
assert.NotEqualValues(t, plainKey, encryptedKey, "encryptedKey is the same as plainKey")
decryptedKey, err := p.DecryptColumnEncryptionKey(context.Background(), keyPath, aecmk.KeyEncryptionAlgorithm, encryptedKey)
if assert.NoError(t, err, "DecryptColumnEncryptionKey") {
assert.Equalf(t, plainKey, decryptedKey, "decryptedkey doesn't match plainKey. %v : %v", decryptedKey, plainKey)
}
}
}
39 changes: 39 additions & 0 deletions aecmk/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package aecmk

import "fmt"

// Operation specifies the action that returned an error
type Operation int

const (
Decryption Operation = iota
Encryption
Validation
)

// Error is the type of all errors returned by key encryption providers
type Error struct {
Operation Operation
err error
msg string
}

func (e *Error) Error() string {
return e.msg
}

func (e *Error) Unwrap() error {
return e.err
}

func NewError(operation Operation, msg string, err error) error {
return &Error{
Operation: operation,
msg: msg,
err: err,
}
}

func KeyPathNotAllowed(path string, operation Operation) error {
return NewError(operation, fmt.Sprintf("Key path not allowed: %s", path), nil)
}
Loading

0 comments on commit 670fd58

Please sign in to comment.