diff --git a/alwaysencrypted_test.go b/alwaysencrypted_test.go index 18c77997..9dd08b59 100644 --- a/alwaysencrypted_test.go +++ b/alwaysencrypted_test.go @@ -67,7 +67,9 @@ func TestAlwaysEncryptedE2E(t *testing.T) { {"datetime2(7)", "DATETIME2", ColumnEncryptionDeterministic, civil.DateTimeOf(dt)}, {"nvarchar(max)", "NVARCHAR", ColumnEncryptionRandomized, NVarCharMax("nvarcharmaxval")}, {"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}}, - {"int", "INT", ColumnEncryptionDeterministic, sql.NullInt64{Int64: 128, Valid: true}}, + {"bigint", "BIGINT", ColumnEncryptionDeterministic, sql.NullInt64{Int64: 128, Valid: true}}, + {"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}}, + {"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, NullUniqueIdentifier{Valid: false}}, } for _, test := range providerTests { // turn off key caching @@ -231,8 +233,6 @@ func comparisonValueFromObject(object interface{}) string { case time.Time: return civil.DateTimeOf(v).String() //return v.Format(time.RFC3339) - case fmt.Stringer: - return v.String() case bool: if v == true { return "1" @@ -244,6 +244,8 @@ func comparisonValueFromObject(object interface{}) string { return "" } return comparisonValueFromObject(val) + case fmt.Stringer: + return v.String() default: return fmt.Sprintf("%v", v) } diff --git a/mssql.go b/mssql.go index 9238ff60..f86d5361 100644 --- a/mssql.go +++ b/mssql.go @@ -982,8 +982,13 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) { res.ti.Size = 0 return } - // If the value has a non-nil value, call MakeParam on its Value - if valuer, ok := val.(driver.Valuer); ok { + switch valuer := val.(type) { + case UniqueIdentifier: + case NullUniqueIdentifier: + default: + break + case driver.Valuer: + // If the value has a non-nil value, call MakeParam on its Value val, e := driver.DefaultParameterConverter.ConvertValue(valuer) if e != nil { err = e @@ -994,6 +999,20 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) { } } switch val := val.(type) { + case UniqueIdentifier: + res.ti.TypeId = typeGuid + res.ti.Size = 16 + guid, _ := val.Value() + res.buffer = guid.([]byte) + case NullUniqueIdentifier: + res.ti.TypeId = typeGuid + res.ti.Size = 16 + if val.Valid { + guid, _ := val.Value() + res.buffer = guid.([]byte) + } else { + res.buffer = []byte{} + } case int: res.ti.TypeId = typeIntN // Rather than guess if the caller intends to pass a 32bit int from a 64bit app based on the diff --git a/mssql_go19.go b/mssql_go19.go index b0285eef..6435f67e 100644 --- a/mssql_go19.go +++ b/mssql_go19.go @@ -75,6 +75,8 @@ func convertInputParameter(val interface{}) (interface{}, error) { // return nil case float32: return val, nil + case driver.Valuer: + return val, nil default: return driver.DefaultParameterConverter.ConvertValue(v) }