Skip to content

Commit

Permalink
support uniqueidentifier in AE
Browse files Browse the repository at this point in the history
  • Loading branch information
shueybubbles committed Mar 1, 2024
1 parent be0e255 commit fd990b8
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
8 changes: 5 additions & 3 deletions alwaysencrypted_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
{"datetime2(7)", "DATETIME2", ColumnEncryptionDeterministic, civil.DateTimeOf(dt)},
{"nvarchar(max)", "NVARCHAR", ColumnEncryptionRandomized, NVarCharMax("nvarcharmaxval")},
{"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}},
{"int", "INT", ColumnEncryptionDeterministic, sql.NullInt64{Int64: 128, Valid: true}},
{"bigint", "BIGINT", ColumnEncryptionDeterministic, sql.NullInt64{Int64: 128, Valid: true}},
{"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}},
{"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, NullUniqueIdentifier{Valid: false}},
}
for _, test := range providerTests {
// turn off key caching
Expand Down Expand Up @@ -231,8 +233,6 @@ func comparisonValueFromObject(object interface{}) string {
case time.Time:
return civil.DateTimeOf(v).String()
//return v.Format(time.RFC3339)
case fmt.Stringer:
return v.String()
case bool:
if v == true {
return "1"
Expand All @@ -244,6 +244,8 @@ func comparisonValueFromObject(object interface{}) string {
return "<nil>"
}
return comparisonValueFromObject(val)
case fmt.Stringer:
return v.String()
default:
return fmt.Sprintf("%v", v)
}
Expand Down
23 changes: 21 additions & 2 deletions mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -982,8 +982,13 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
res.ti.Size = 0
return
}
// If the value has a non-nil value, call MakeParam on its Value
if valuer, ok := val.(driver.Valuer); ok {
switch valuer := val.(type) {
case UniqueIdentifier:
case NullUniqueIdentifier:
default:
break
case driver.Valuer:
// If the value has a non-nil value, call MakeParam on its Value
val, e := driver.DefaultParameterConverter.ConvertValue(valuer)
if e != nil {
err = e
Expand All @@ -994,6 +999,20 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
}
}
switch val := val.(type) {
case UniqueIdentifier:
res.ti.TypeId = typeGuid
res.ti.Size = 16
guid, _ := val.Value()
res.buffer = guid.([]byte)
case NullUniqueIdentifier:
res.ti.TypeId = typeGuid
res.ti.Size = 16
if val.Valid {
guid, _ := val.Value()
res.buffer = guid.([]byte)
} else {
res.buffer = []byte{}
}
case int:
res.ti.TypeId = typeIntN
// Rather than guess if the caller intends to pass a 32bit int from a 64bit app based on the
Expand Down
2 changes: 2 additions & 0 deletions mssql_go19.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ func convertInputParameter(val interface{}) (interface{}, error) {
// return nil
case float32:
return val, nil
case driver.Valuer:
return val, nil
default:
return driver.DefaultParameterConverter.ConvertValue(v)
}
Expand Down

0 comments on commit fd990b8

Please sign in to comment.