forked from microsoft/go-mssqldb
-
Notifications
You must be signed in to change notification settings - Fork 0
/
alwaysencrypted_test.go
331 lines (307 loc) · 13.3 KB
/
alwaysencrypted_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
package mssql
import (
"context"
"crypto/rand"
"database/sql"
"database/sql/driver"
"fmt"
"math/big"
"strings"
"testing"
"time"
"github.com/golang-sql/civil"
"github.com/parMaster/go-mssqldb/aecmk"
"github.com/stretchr/testify/assert"
)
type providerTest interface {
// ProvisionMasterKey creates a master key in the key storage and returns the path of the key
ProvisionMasterKey(t *testing.T) string
// DeleteMasterKey deletes the master key
DeleteMasterKey(t *testing.T)
// GetProvider returns the appropriate ColumnEncryptionKeyProvider instance
GetProvider(t *testing.T) aecmk.ColumnEncryptionKeyProvider
// Name is the name of the key provider
Name() string
}
var providerTests []providerTest = make([]providerTest, 0, 2)
func addProviderTest(p providerTest) {
providerTests = append(providerTests, p)
}
// Define phrases for create table for each enryptable data type along with sample data for insertion and validation
type aeColumnInfo struct {
queryPhrase string
sqlDataType string
encType ColumnEncryptionType
sampleValue interface{}
}
type customValuer struct {
}
func (n customValuer) Value() (driver.Value, error) {
return nil, nil
}
func TestAlwaysEncryptedE2E(t *testing.T) {
params := testConnParams(t)
if !params.ColumnEncryption {
t.Skip("Test is not running with column encryption enabled")
}
// civil.DateTime has 9 digit precision while SQL only has 7, so we can't use time.Now
dt, err := time.Parse("2006-01-02T15:04:05.9999999", "2023-08-21T18:33:36.5315137")
assert.NoError(t, err, "time.Parse")
encryptableColumns := []aeColumnInfo{
{"int", "INT", ColumnEncryptionDeterministic, int32(1)},
{"nchar(10) COLLATE Latin1_General_BIN2", "NCHAR", ColumnEncryptionDeterministic, NChar("ncharval")},
{"tinyint", "TINYINT", ColumnEncryptionRandomized, byte(2)},
{"tinyint", "TINYINT", ColumnEncryptionDeterministic, sql.NullByte{Valid: false}},
{"tinyint", "TINYINT", ColumnEncryptionDeterministic, sql.NullByte{Valid: true, Byte: 1}},
{"smallint", "SMALLINT", ColumnEncryptionDeterministic, int16(-3)},
{"smallint", "SMALLINT", ColumnEncryptionRandomized, sql.NullInt16{Valid: false}},
{"smallint", "SMALLINT", ColumnEncryptionDeterministic, sql.NullInt16{Valid: true, Int16: 32000}},
{"bigint", "BIGINT", ColumnEncryptionRandomized, int64(4)},
// We can't use fractional float/real values due to rounding errors in the round trip
{"real", "REAL", ColumnEncryptionDeterministic, float32(5)},
{"float", "FLOAT", ColumnEncryptionRandomized, float64(6)},
{"varbinary(10)", "VARBINARY", ColumnEncryptionDeterministic, []byte{1, 2, 3, 4}},
// TODO: Varchar support requires proper selection of a collation and conversion
// {"varchar(10) COLLATE Latin1_General_BIN2", "VARCHAR", ColumnEncryptionRandomized, VarChar("varcharval")},
{"nvarchar(30)", "NVARCHAR", ColumnEncryptionRandomized, "nvarcharval"},
{"bit", "BIT", ColumnEncryptionDeterministic, true},
{"datetimeoffset(7)", "DATETIMEOFFSET", ColumnEncryptionRandomized, dt},
{"datetime2(7)", "DATETIME2", ColumnEncryptionDeterministic, civil.DateTimeOf(dt)},
{"nvarchar(max)", "NVARCHAR", ColumnEncryptionRandomized, NVarCharMax("nvarcharmaxval")},
{"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}},
{"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: true, Int32: -75000}},
{"bigint", "BIGINT", ColumnEncryptionDeterministic, sql.NullInt64{Int64: 128, Valid: true}},
{"bigint", "BIGINT", ColumnEncryptionRandomized, sql.NullInt64{Valid: false}},
{"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}},
{"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, NullUniqueIdentifier{Valid: false}},
{"datetimeoffset(7)", "DATETIMEOFFSET", ColumnEncryptionDeterministic, sql.NullTime{Valid: false}},
{"datetimeoffset(7)", "DATETIMEOFFSET", ColumnEncryptionDeterministic, sql.NullTime{Valid: true, Time: time.Now()}},
}
for _, test := range providerTests {
// turn off key caching
aecmk.ColumnEncryptionKeyLifetime = 0
t.Run(test.Name(), func(t *testing.T) {
conn, _ := open(t)
defer conn.Close()
certPath := test.ProvisionMasterKey(t)
defer test.DeleteMasterKey(t)
s := fmt.Sprintf(createColumnMasterKey, certPath, test.Name(), certPath)
if _, err := conn.Exec(s); err != nil {
t.Fatalf("Unable to create CMK: %s", err.Error())
}
defer func() {
_, err := conn.Exec(fmt.Sprintf(dropColumnMasterKey, certPath))
assert.NoError(t, err, "dropColumnMasterKey")
}()
r, _ := rand.Int(rand.Reader, big.NewInt(1000))
cekName := fmt.Sprintf("mssqlCek%d", r.Int64())
tableName := fmt.Sprintf("mssqlAe%d", r.Int64())
keyBytes := make([]byte, 32)
_, _ = rand.Read(keyBytes)
encryptedCek, err := test.GetProvider(t).EncryptColumnEncryptionKey(context.Background(), certPath, KeyEncryptionAlgorithm, keyBytes)
assert.NoError(t, err, "Encrypt")
createCek := fmt.Sprintf(createColumnEncryptionKey, cekName, certPath, encryptedCek)
_, err = conn.Exec(createCek)
assert.NoError(t, err, "Unable to create CEK")
defer func() {
_, err := conn.Exec(fmt.Sprintf(dropColumnEncryptionKey, cekName))
assert.NoError(t, err, "dropColumnEncryptionKey")
}()
_, _ = conn.Exec("DROP TABLE IF EXISTS " + tableName)
query := new(strings.Builder)
insert := new(strings.Builder)
sel := new(strings.Builder)
_, _ = query.WriteString(fmt.Sprintf("CREATE TABLE [%s] (", tableName))
_, _ = insert.WriteString(fmt.Sprintf("INSERT INTO [%s] VALUES (", tableName))
_, _ = sel.WriteString("select top(1) ")
insertArgs := make([]interface{}, len(encryptableColumns)+2)
for i, ec := range encryptableColumns {
encType := "RANDOMIZED"
null := ""
_, ok := ec.sampleValue.(sql.NullInt32)
if ok {
null = "NULL"
}
if ec.encType == ColumnEncryptionDeterministic {
encType = "DETERMINISTIC"
}
_, _ = query.WriteString(fmt.Sprintf(`col%d %s ENCRYPTED WITH (ENCRYPTION_TYPE = %s,
ALGORITHM = 'AEAD_AES_256_CBC_HMAC_SHA_256',
COLUMN_ENCRYPTION_KEY = [%s]) %s,
`, i, ec.queryPhrase, encType, cekName, null))
insertArgs[i] = ec.sampleValue
insert.WriteString(fmt.Sprintf("@p%d,", i+1))
sel.WriteString(fmt.Sprintf("col%d,", i))
}
_, _ = query.WriteString("unencryptedcolumn nvarchar(100),")
_, _ = query.WriteString("nullableCustomValuer int NULL")
_, _ = query.WriteString(")")
insertArgs[len(encryptableColumns)] = "unencryptedvalue"
insertArgs[len(encryptableColumns)+1] = customValuer{}
insert.WriteString(fmt.Sprintf("@p%d,@p%d)", len(encryptableColumns)+1, len(encryptableColumns)+2))
sel.WriteString(fmt.Sprintf("unencryptedcolumn, nullableCustomValuer from [%s]", tableName))
_, err = conn.Exec(query.String())
assert.NoError(t, err, "Failed to create encrypted table")
defer func() { _, _ = conn.Exec("DROP TABLE IF EXISTS " + tableName) }()
_, err = conn.Exec(insert.String(), insertArgs...)
assert.NoError(t, err, "Failed to insert row in encrypted table")
rows, err := conn.Query(sel.String())
assert.NoErrorf(t, err, "Unable to query encrypted columns")
if !rows.Next() {
rows.Close()
assert.FailNow(t, "rows.Next returned false")
}
cols, err := rows.ColumnTypes()
assert.NoError(t, err, "rows.ColumnTypes failed")
for i := range encryptableColumns {
assert.Equalf(t, encryptableColumns[i].sqlDataType, cols[i].DatabaseTypeName(),
"Got wrong type name for col%d.", i)
}
var unencryptedColumnValue string
var nullint sql.NullInt32
scanValues := make([]interface{}, len(encryptableColumns)+2)
for v := range scanValues {
if v < len(encryptableColumns) {
scanValues[v] = new(interface{})
}
}
scanValues[len(encryptableColumns)] = &unencryptedColumnValue
scanValues[len(encryptableColumns)+1] = &nullint
err = rows.Scan(scanValues...)
defer rows.Close()
if err != nil {
assert.FailNow(t, "Scan failed ", err)
}
for i := range encryptableColumns {
var strVal string
var expectedStrVal string
if encryptableColumns[i].sampleValue == nil {
expectedStrVal = "NULL"
} else {
expectedStrVal = comparisonValueFromObject(encryptableColumns[i].sampleValue)
}
rawVal := scanValues[i].(*interface{})
if rawVal == nil {
strVal = "NULL"
} else {
strVal = comparisonValueFromObject(*rawVal)
}
assert.Equalf(t, expectedStrVal, strVal, "Incorrect value for col%d. ", i)
}
assert.Equalf(t, "unencryptedvalue", unencryptedColumnValue, "Got wrong value for unencrypted column")
assert.False(t, nullint.Valid, "custom valuer should have null value")
_ = rows.Next()
err = rows.Err()
assert.NoError(t, err, "rows.Err() has non-nil values")
testProviderErrorHandling(t, test.Name(), test.GetProvider(t), sel.String(), insert.String(), insertArgs)
})
}
}
func testProviderErrorHandling(t *testing.T, name string, provider aecmk.ColumnEncryptionKeyProvider, sel string, insert string, insertArgs []interface{}) {
t.Helper()
testProvider := &testKeyProvider{fallback: provider}
connector, _ := getTestConnector(t)
connector.RegisterCekProvider(name, testProvider)
conn := sql.OpenDB(connector)
defer conn.Close()
testProvider.decrypt = func(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, cek []byte) ([]byte, error) {
return nil, context.DeadlineExceeded
}
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(1*time.Hour))
defer cancel()
rows, err := conn.QueryContext(ctx, sel)
defer rows.Close()
if assert.NoError(t, err, "Exec should return no error") {
if rows.Next() {
assert.Fail(t, "rows.Next should have failed")
}
assert.ErrorIs(t, rows.Err(), context.DeadlineExceeded)
}
var notAllowed error
testProvider.decrypt = func(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, cek []byte) ([]byte, error) {
notAllowed = aecmk.KeyPathNotAllowed(masterKeyPath, aecmk.Decryption)
return nil, notAllowed
}
_, err = conn.Exec(insert, insertArgs...)
assert.ErrorIs(t, err, notAllowed, "Insert should fail with key path not allowed")
}
func comparisonValueFromObject(object interface{}) string {
switch v := object.(type) {
case []byte:
{
return string(v)
}
case string:
return v
case time.Time:
return civil.DateTimeOf(v).String()
//return v.Format(time.RFC3339)
case bool:
if v == true {
return "1"
}
return "0"
case driver.Valuer:
val, _ := v.Value()
if val == nil {
return "<nil>"
}
return comparisonValueFromObject(val)
case fmt.Stringer:
return v.String()
default:
return fmt.Sprintf("%v", v)
}
}
const (
createColumnMasterKey = `CREATE COLUMN MASTER KEY [%s] WITH (KEY_STORE_PROVIDER_NAME= '%s', KEY_PATH='%s')`
dropColumnMasterKey = `DROP COLUMN MASTER KEY [%s]`
createColumnEncryptionKey = `CREATE COLUMN ENCRYPTION KEY [%s] WITH VALUES (COLUMN_MASTER_KEY = [%s], ALGORITHM = 'RSA_OAEP', ENCRYPTED_VALUE = 0x%x )`
dropColumnEncryptionKey = `DROP COLUMN ENCRYPTION KEY [%s]`
createEncryptedTable = `CREATE TABLE %s
(col1 int
ENCRYPTED WITH (ENCRYPTION_TYPE = DETERMINISTIC,
ALGORITHM = 'AEAD_AES_256_CBC_HMAC_SHA_256',
COLUMN_ENCRYPTION_KEY = [%s]),
col2 nchar(10) COLLATE Latin1_General_BIN2
ENCRYPTED WITH (ENCRYPTION_TYPE = DETERMINISTIC,
ALGORITHM = 'AEAD_AES_256_CBC_HMAC_SHA_256',
COLUMN_ENCRYPTION_KEY = [%s])
)`
)
// Parameterized implementation of a key provider
type testKeyProvider struct {
encrypt func(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, cek []byte) ([]byte, error)
decrypt func(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) ([]byte, error)
lifetime *time.Duration
fallback aecmk.ColumnEncryptionKeyProvider
}
func (p *testKeyProvider) DecryptColumnEncryptionKey(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) (decryptedKey []byte, err error) {
if p.decrypt != nil {
return p.decrypt(ctx, masterKeyPath, encryptionAlgorithm, encryptedCek)
}
return p.fallback.DecryptColumnEncryptionKey(ctx, masterKeyPath, encryptionAlgorithm, encryptedCek)
}
func (p *testKeyProvider) EncryptColumnEncryptionKey(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, cek []byte) ([]byte, error) {
if p.encrypt != nil {
return p.encrypt(ctx, masterKeyPath, encryptionAlgorithm, cek)
}
return p.fallback.EncryptColumnEncryptionKey(ctx, masterKeyPath, encryptionAlgorithm, cek)
}
func (p *testKeyProvider) SignColumnMasterKeyMetadata(ctx context.Context, masterKeyPath string, allowEnclaveComputations bool) ([]byte, error) {
return nil, nil
}
// VerifyColumnMasterKeyMetadata verifies the specified signature is valid for the column master key
// with the specified key path and the specified enclave behavior. Return nil if not supported.
func (p *testKeyProvider) VerifyColumnMasterKeyMetadata(ctx context.Context, masterKeyPath string, allowEnclaveComputations bool) (*bool, error) {
return nil, nil
}
// KeyLifetime is an optional Duration. Keys fetched by this provider will be discarded after their lifetime expires.
// If it returns nil, the keys will expire based on the value of ColumnEncryptionKeyLifetime.
// If it returns zero, the keys will not be cached.
func (p *testKeyProvider) KeyLifetime() *time.Duration {
if p.lifetime != nil {
return p.lifetime
}
return p.fallback.KeyLifetime()
}