Skip to content

Commit

Permalink
Fix: handle sql.NullTime parameters (#195)
Browse files Browse the repository at this point in the history
* handle sql.NullTime parameters

* Match SQL sizes for sql.Nullxxx integer types

* handle custom nullable Valuer implementations
  • Loading branch information
shueybubbles authored May 30, 2024
1 parent 3ed002a commit a1c982b
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 5 deletions.
30 changes: 25 additions & 5 deletions alwaysencrypted_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ type aeColumnInfo struct {
sampleValue interface{}
}

type customValuer struct {
}

func (n customValuer) Value() (driver.Value, error) {
return nil, nil
}

func TestAlwaysEncryptedE2E(t *testing.T) {
params := testConnParams(t)
if !params.ColumnEncryption {
Expand All @@ -53,7 +60,11 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
{"int", "INT", ColumnEncryptionDeterministic, int32(1)},
{"nchar(10) COLLATE Latin1_General_BIN2", "NCHAR", ColumnEncryptionDeterministic, NChar("ncharval")},
{"tinyint", "TINYINT", ColumnEncryptionRandomized, byte(2)},
{"tinyint", "TINYINT", ColumnEncryptionDeterministic, sql.NullByte{Valid: false}},
{"tinyint", "TINYINT", ColumnEncryptionDeterministic, sql.NullByte{Valid: true, Byte: 1}},
{"smallint", "SMALLINT", ColumnEncryptionDeterministic, int16(-3)},
{"smallint", "SMALLINT", ColumnEncryptionRandomized, sql.NullInt16{Valid: false}},
{"smallint", "SMALLINT", ColumnEncryptionDeterministic, sql.NullInt16{Valid: true, Int16: 32000}},
{"bigint", "BIGINT", ColumnEncryptionRandomized, int64(4)},
// We can't use fractional float/real values due to rounding errors in the round trip
{"real", "REAL", ColumnEncryptionDeterministic, float32(5)},
Expand All @@ -67,9 +78,13 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
{"datetime2(7)", "DATETIME2", ColumnEncryptionDeterministic, civil.DateTimeOf(dt)},
{"nvarchar(max)", "NVARCHAR", ColumnEncryptionRandomized, NVarCharMax("nvarcharmaxval")},
{"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}},
{"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: true, Int32: -75000}},
{"bigint", "BIGINT", ColumnEncryptionDeterministic, sql.NullInt64{Int64: 128, Valid: true}},
{"bigint", "BIGINT", ColumnEncryptionRandomized, sql.NullInt64{Valid: false}},
{"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}},
{"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, NullUniqueIdentifier{Valid: false}},
{"datetimeoffset(7)", "DATETIMEOFFSET", ColumnEncryptionDeterministic, sql.NullTime{Valid: false}},
{"datetimeoffset(7)", "DATETIMEOFFSET", ColumnEncryptionDeterministic, sql.NullTime{Valid: true, Time: time.Now()}},
}
for _, test := range providerTests {
// turn off key caching
Expand Down Expand Up @@ -108,7 +123,7 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
_, _ = query.WriteString(fmt.Sprintf("CREATE TABLE [%s] (", tableName))
_, _ = insert.WriteString(fmt.Sprintf("INSERT INTO [%s] VALUES (", tableName))
_, _ = sel.WriteString("select top(1) ")
insertArgs := make([]interface{}, len(encryptableColumns)+1)
insertArgs := make([]interface{}, len(encryptableColumns)+2)
for i, ec := range encryptableColumns {
encType := "RANDOMIZED"
null := ""
Expand All @@ -128,11 +143,13 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
insert.WriteString(fmt.Sprintf("@p%d,", i+1))
sel.WriteString(fmt.Sprintf("col%d,", i))
}
_, _ = query.WriteString("unencryptedcolumn nvarchar(100)")
_, _ = query.WriteString("unencryptedcolumn nvarchar(100),")
_, _ = query.WriteString("nullableCustomValuer int NULL")
_, _ = query.WriteString(")")
insertArgs[len(encryptableColumns)] = "unencryptedvalue"
insert.WriteString(fmt.Sprintf("@p%d)", len(encryptableColumns)+1))
sel.WriteString(fmt.Sprintf("unencryptedcolumn from [%s]", tableName))
insertArgs[len(encryptableColumns)+1] = customValuer{}
insert.WriteString(fmt.Sprintf("@p%d,@p%d)", len(encryptableColumns)+1, len(encryptableColumns)+2))
sel.WriteString(fmt.Sprintf("unencryptedcolumn, nullableCustomValuer from [%s]", tableName))
_, err = conn.Exec(query.String())
assert.NoError(t, err, "Failed to create encrypted table")
defer func() { _, _ = conn.Exec("DROP TABLE IF EXISTS " + tableName) }()
Expand All @@ -152,13 +169,15 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
}

var unencryptedColumnValue string
scanValues := make([]interface{}, len(encryptableColumns)+1)
var nullint sql.NullInt32
scanValues := make([]interface{}, len(encryptableColumns)+2)
for v := range scanValues {
if v < len(encryptableColumns) {
scanValues[v] = new(interface{})
}
}
scanValues[len(encryptableColumns)] = &unencryptedColumnValue
scanValues[len(encryptableColumns)+1] = &nullint
err = rows.Scan(scanValues...)
defer rows.Close()
if err != nil {
Expand All @@ -182,6 +201,7 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
assert.Equalf(t, expectedStrVal, strVal, "Incorrect value for col%d. ", i)
}
assert.Equalf(t, "unencryptedvalue", unencryptedColumnValue, "Got wrong value for unencrypted column")
assert.False(t, nullint.Valid, "custom valuer should have null value")
_ = rows.Next()
err = rows.Err()
assert.NoError(t, err, "rows.Err() has non-nil values")
Expand Down
36 changes: 36 additions & 0 deletions mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,19 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
return
}
switch valuer := val.(type) {
// sql.Nullxxx integer types return an int64. We want the original type, to match the SQL type size.
case sql.NullByte:
if valuer.Valid {
return s.makeParam(valuer.Byte)
}
case sql.NullInt16:
if valuer.Valid {
return s.makeParam(valuer.Int16)
}
case sql.NullInt32:
if valuer.Valid {
return s.makeParam(valuer.Int32)
}
case UniqueIdentifier:
case NullUniqueIdentifier:
default:
Expand Down Expand Up @@ -1052,9 +1065,20 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
res.ti.Size = 8
res.buffer = []byte{}
case sql.NullInt32:
// only null values should be getting here
res.ti.TypeId = typeIntN
res.ti.Size = 4
res.buffer = []byte{}
case sql.NullInt16:
// only null values should be getting here
res.buffer = []byte{}
res.ti.Size = 2
res.ti.TypeId = typeIntN
case sql.NullByte:
// only null values should be getting here
res.buffer = []byte{}
res.ti.Size = 1
res.ti.TypeId = typeIntN
case byte:
res.ti.TypeId = typeIntN
res.buffer = []byte{val}
Expand Down Expand Up @@ -1110,6 +1134,18 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
res.buffer = encodeDateTime(val)
res.ti.Size = len(res.buffer)
}
case sql.NullTime: // only null values reach here
res.buffer = []byte{}
res.ti.Size = 8
if s.c.sess.loginAck.TDSVersion >= verTDS73 {
res.ti.TypeId = typeDateTimeOffsetN
res.ti.Scale = 7
} else {
res.ti.TypeId = typeDateTimeN
}
case driver.Valuer:
// We have a custom Valuer implementation with a nil value
return s.makeParam(nil)
default:
return s.makeParamExtra(val)
}
Expand Down

0 comments on commit a1c982b

Please sign in to comment.