Skip to content

Commit

Permalink
Add TDS8 testcase
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvdeshmukh committed Aug 22, 2023
1 parent ad2cf66 commit 1e53bf7
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 42 deletions.
16 changes: 1 addition & 15 deletions msdsn/conn_str.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package msdsn
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -94,19 +93,6 @@ type Config struct {
ColumnEncryption bool
}

// GetPEMCertificate returns PEM formatted certificate
func GetPEMCertificate(certificate string) ([]byte, error) {
cerData, ok := ioutil.ReadFile(certificate)
if ok != nil {
return nil, fmt.Errorf("cannot read certificate %q: %w", certificate, ok)
}
pemData := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: cerData,
})
return pemData, nil
}

// Build a tls.Config object from the supplied certificate.
func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate string, minTLSVersion string) (*tls.Config, error) {
config := tls.Config{
Expand All @@ -124,7 +110,7 @@ func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate str
if len(certificate) == 0 {
return &config, nil
}
pem, err := GetPEMCertificate(certificate)
pem, err := ioutil.ReadFile(certificate)
if err != nil {
return nil, fmt.Errorf("cannot read certificate %q: %w", certificate, err)
}
Expand Down
51 changes: 24 additions & 27 deletions tds.go
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,7 @@ func interpretPreloginResponse(p msdsn.Config, fe *featureExtFedAuth, fields map

func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger ContextLogger, auth integratedauth.IntegratedAuthenticator, fe *featureExtFedAuth, packetSize uint32) (l *login, err error) {
var TDSVersion uint32
if(p.Encryption == msdsn.EncryptionStrict) {
if p.Encryption == msdsn.EncryptionStrict {
TDSVersion = verTDS80
} else {
TDSVersion = verTDS74
Expand Down Expand Up @@ -1129,6 +1129,27 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont
return l, nil
}

func getTLSConn(conn *timeoutConn, p msdsn.Config) (tlsConn *tls.Conn, err error) {
var config *tls.Config
if pc := p.TLSConfig; pc != nil {
config = pc
}
if config == nil {
config, err = msdsn.SetupTLS("", false, p.Host, "")
if err != nil {
return nil, err
}
}
//Set ALPN Sequence
config.NextProtos = []string{"tds/8.0"}
tlsConn = tls.Client(conn.c, config)
err = tlsConn.Handshake()
if err != nil {
return nil, fmt.Errorf("TLS Handshake failed: %v", err)
}
return tlsConn, nil
}

func connect(ctx context.Context, c *Connector, logger ContextLogger, p msdsn.Config) (res *tdsSession, err error) {

// if instance is specified use instance resolution service
Expand Down Expand Up @@ -1173,33 +1194,9 @@ initiate_connection:
outbuf := newTdsBuffer(packetSize, toconn)

if p.Encryption == msdsn.EncryptionStrict {
var config *tls.Config
if pc := p.TLSConfig; pc != nil {
config = pc
if config.DynamicRecordSizingDisabled == false {
config = config.Clone()

// fix for https://github.com/microsoft/go-mssqldb/issues/166
// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
// while SQL Server seems to expect one TCP segment per encrypted TDS package.
// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
config.DynamicRecordSizingDisabled = true
}
}
if config == nil {
config, err = msdsn.SetupTLS("", false, p.Host, "")
if err != nil {
return nil, err
}
}
//Set ALPN Sequence
config.NextProtos = []string{"tds/8.0"}

tlsConn := tls.Client(toconn.c, config)
err = tlsConn.Handshake()
outbuf.transport = tlsConn
outbuf.transport, err = getTLSConn(toconn, p)
if err != nil {
return nil, fmt.Errorf("TLS Handshake failed: %v", err)
return nil, err
}
}
sess := tdsSession{
Expand Down
44 changes: 44 additions & 0 deletions tds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,50 @@ func TestSecureConnection(t *testing.T) {
}
}

func TestTDS8Connection(t *testing.T) {
checkConnStr(t)
tl := testLogger{t: t}
defer tl.StopLogging()
SetLogger(&tl)

dsn := makeConnStr(t)
if !strings.HasSuffix(strings.Split(dsn.Host, ":")[0], ".database.windows.net") {
t.Skip()
}
dsnParams := dsn.Query()
dsnParams.Set("encrypt", "strict")
dsnParams.Set("TrustServerCertificate", "false")
dsnParams.Set("tlsmin", "1.2")
dsn.RawQuery = dsnParams.Encode()

conn, err := sql.Open("mssql", dsn.String())
if err != nil {
t.Fatal("Open connection failed:", err.Error())
}
defer conn.Close()
stmt, err := conn.Prepare("SELECT protocol_type, CONVERT(varbinary(9),protocol_version),client_net_address from sys.dm_exec_connections where session_id=@@SPID")
if err != nil {
t.Fatal("Prepare failed:", err.Error())
}
defer stmt.Close()
row := stmt.QueryRow()
var protocolName string
var tdsver []byte
var clientAddress string
err = row.Scan(&protocolName, &tdsver, &clientAddress)
if err != nil {
t.Fatal("Scan failed:", err.Error())
}
assertEqual(t, "TSQL", protocolName)
assertEqual(t, "0x08000000", hex.EncodeToString(tdsver))
}

func assertEqual(t *testing.T, expected interface{}, actual interface{}) {
if expected != actual {
t.Fatalf("Expected %v, got %v", expected, actual)
}
}

func TestBadCredentials(t *testing.T) {
params := testConnParams(t)
params.Password = "padpwd"
Expand Down

0 comments on commit 1e53bf7

Please sign in to comment.