Skip to content

Commit

Permalink
preserve type information for Valuer parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
shueybubbles committed Mar 1, 2024
1 parent fe7c3d4 commit be0e255
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 2 deletions.
11 changes: 9 additions & 2 deletions alwaysencrypted_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"database/sql"
"database/sql/driver"
"fmt"
"math/big"
"strings"
Expand Down Expand Up @@ -65,8 +66,8 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
{"datetimeoffset(7)", "DATETIMEOFFSET", ColumnEncryptionRandomized, dt},
{"datetime2(7)", "DATETIME2", ColumnEncryptionDeterministic, civil.DateTimeOf(dt)},
{"nvarchar(max)", "NVARCHAR", ColumnEncryptionRandomized, NVarCharMax("nvarcharmaxval")},
// TODO: The driver throws away type information about Valuer implementations and sends nil as nvarchar(1). Fix that.
// {"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}},
{"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}},
{"int", "INT", ColumnEncryptionDeterministic, sql.NullInt64{Int64: 128, Valid: true}},
}
for _, test := range providerTests {
// turn off key caching
Expand Down Expand Up @@ -237,6 +238,12 @@ func comparisonValueFromObject(object interface{}) string {
return "1"
}
return "0"
case driver.Valuer:
val, _ := v.Value()
if val == nil {
return "<nil>"
}
return comparisonValueFromObject(val)
default:
return fmt.Sprintf("%v", v)
}
Expand Down
15 changes: 15 additions & 0 deletions mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,17 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
res.ti.Size = 0
return
}
// If the value has a non-nil value, call MakeParam on its Value
if valuer, ok := val.(driver.Valuer); ok {
val, e := driver.DefaultParameterConverter.ConvertValue(valuer)
if e != nil {
err = e
return
}
if val != nil {
return s.makeParam(val)
}
}
switch val := val.(type) {
case int:
res.ti.TypeId = typeIntN
Expand Down Expand Up @@ -1021,6 +1032,10 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
res.ti.TypeId = typeIntN
res.ti.Size = 8
res.buffer = []byte{}
case sql.NullInt32:
res.ti.TypeId = typeIntN
res.ti.Size = 4
res.buffer = []byte{}
case byte:
res.ti.TypeId = typeIntN
res.buffer = []byte{val}
Expand Down
31 changes: 31 additions & 0 deletions queries_go19_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"
"reflect"
"regexp"
"strings"
"testing"
"time"

Expand All @@ -31,6 +32,36 @@ func TestOutputParam(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

t.Run("varchar(max) to sql.NullString", func(t *testing.T) {
sqltextcreate := `CREATE PROCEDURE [GetTask]
@strparam varchar(max) = NULL OUTPUT
AS
SELECT @strparam = REPLICATE('a', 8000)
RETURN 0`
sqltextdrop := `drop procedure GetTask`
sqltextrun := `GetTask`
_, _ = db.ExecContext(ctx, sqltextdrop)
_, err = db.ExecContext(ctx, sqltextcreate)
if err != nil {
t.Fatal(err)
}
defer db.ExecContext(ctx, sqltextdrop)
if err != nil {
t.Error(err)
}

nullstr := sql.NullString{}
_, err := db.ExecContext(ctx, sqltextrun,
sql.Named("strparam", sql.Out{Dest: &nullstr}),
)
if err != nil {
t.Error(err)
}
defer db.ExecContext(ctx, sqltextdrop)
if nullstr.String != strings.Repeat("a", 8000) {
t.Error("Got incorrect NullString of length:", len(nullstr.String))
}
})
t.Run("sp with rows", func(t *testing.T) {
sqltextcreate := `
CREATE PROCEDURE spwithrows
Expand Down
13 changes: 13 additions & 0 deletions queries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,19 @@ func TestSelect(t *testing.T) {
}
})
})
t.Run("scan into sql.NullString", func(t *testing.T) {
row := conn.QueryRow("SELECT REPLICATE('a', 8000)")
var out sql.NullString
err := row.Scan(&out)
if err != nil {
t.Error("Scan to NullString failed", err.Error())
return
}

if out.String != strings.Repeat("a", 8000) {
t.Error("got back a string with count:", len(out.String))
}
})
}

func TestSelectDateTimeOffset(t *testing.T) {
Expand Down

0 comments on commit be0e255

Please sign in to comment.