diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index f1079a4c..c957657d 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -3,7 +3,6 @@ package msdsn import ( "crypto/tls" "crypto/x509" - "encoding/pem" "errors" "fmt" "io/ioutil" @@ -94,19 +93,6 @@ type Config struct { ColumnEncryption bool } -// GetPEMCertificate returns PEM formatted certificate -func GetPEMCertificate(certificate string) ([]byte, error) { - cerData, ok := ioutil.ReadFile(certificate) - if ok != nil { - return nil, fmt.Errorf("cannot read certificate %q: %w", certificate, ok) - } - pemData := pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: cerData, - }) - return pemData, nil -} - // Build a tls.Config object from the supplied certificate. func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate string, minTLSVersion string) (*tls.Config, error) { config := tls.Config{ @@ -124,7 +110,7 @@ func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate str if len(certificate) == 0 { return &config, nil } - pem, err := GetPEMCertificate(certificate) + pem, err := ioutil.ReadFile(certificate) if err != nil { return nil, fmt.Errorf("cannot read certificate %q: %w", certificate, err) } diff --git a/tds.go b/tds.go index 3af87ae8..705aec78 100644 --- a/tds.go +++ b/tds.go @@ -1052,7 +1052,7 @@ func interpretPreloginResponse(p msdsn.Config, fe *featureExtFedAuth, fields map func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger ContextLogger, auth integratedauth.IntegratedAuthenticator, fe *featureExtFedAuth, packetSize uint32) (l *login, err error) { var TDSVersion uint32 - if(p.Encryption == msdsn.EncryptionStrict) { + if p.Encryption == msdsn.EncryptionStrict { TDSVersion = verTDS80 } else { TDSVersion = verTDS74 @@ -1129,6 +1129,27 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont return l, nil } +func getTLSConn(conn *timeoutConn, p msdsn.Config) (tlsConn *tls.Conn, err error) { + var config *tls.Config + if pc := p.TLSConfig; pc != nil { + config = pc + } + if config == nil { + config, err = msdsn.SetupTLS("", false, p.Host, "") + if err != nil { + return nil, err + } + } + //Set ALPN Sequence + config.NextProtos = []string{"tds/8.0"} + tlsConn = tls.Client(conn.c, config) + err = tlsConn.Handshake() + if err != nil { + return nil, fmt.Errorf("TLS Handshake failed: %v", err) + } + return tlsConn, nil +} + func connect(ctx context.Context, c *Connector, logger ContextLogger, p msdsn.Config) (res *tdsSession, err error) { // if instance is specified use instance resolution service @@ -1173,33 +1194,9 @@ initiate_connection: outbuf := newTdsBuffer(packetSize, toconn) if p.Encryption == msdsn.EncryptionStrict { - var config *tls.Config - if pc := p.TLSConfig; pc != nil { - config = pc - if config.DynamicRecordSizingDisabled == false { - config = config.Clone() - - // fix for https://github.com/microsoft/go-mssqldb/issues/166 - // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments, - // while SQL Server seems to expect one TCP segment per encrypted TDS package. - // Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package - config.DynamicRecordSizingDisabled = true - } - } - if config == nil { - config, err = msdsn.SetupTLS("", false, p.Host, "") - if err != nil { - return nil, err - } - } - //Set ALPN Sequence - config.NextProtos = []string{"tds/8.0"} - - tlsConn := tls.Client(toconn.c, config) - err = tlsConn.Handshake() - outbuf.transport = tlsConn + outbuf.transport, err = getTLSConn(toconn, p) if err != nil { - return nil, fmt.Errorf("TLS Handshake failed: %v", err) + return nil, err } } sess := tdsSession{ diff --git a/tds_test.go b/tds_test.go index daabd714..6acdc81e 100644 --- a/tds_test.go +++ b/tds_test.go @@ -662,6 +662,50 @@ func TestSecureConnection(t *testing.T) { } } +func TestTDS8Connection(t *testing.T) { + checkConnStr(t) + tl := testLogger{t: t} + defer tl.StopLogging() + SetLogger(&tl) + + dsn := makeConnStr(t) + if !strings.HasSuffix(strings.Split(dsn.Host, ":")[0], ".database.windows.net") { + t.Skip() + } + dsnParams := dsn.Query() + dsnParams.Set("encrypt", "strict") + dsnParams.Set("TrustServerCertificate", "false") + dsnParams.Set("tlsmin", "1.2") + dsn.RawQuery = dsnParams.Encode() + + conn, err := sql.Open("mssql", dsn.String()) + if err != nil { + t.Fatal("Open connection failed:", err.Error()) + } + defer conn.Close() + stmt, err := conn.Prepare("SELECT protocol_type, CONVERT(varbinary(9),protocol_version),client_net_address from sys.dm_exec_connections where session_id=@@SPID") + if err != nil { + t.Fatal("Prepare failed:", err.Error()) + } + defer stmt.Close() + row := stmt.QueryRow() + var protocolName string + var tdsver []byte + var clientAddress string + err = row.Scan(&protocolName, &tdsver, &clientAddress) + if err != nil { + t.Fatal("Scan failed:", err.Error()) + } + assertEqual(t, "TSQL", protocolName) + assertEqual(t, "0x08000000", hex.EncodeToString(tdsver)) +} + +func assertEqual(t *testing.T, expected interface{}, actual interface{}) { + if expected != actual { + t.Fatalf("Expected %v, got %v", expected, actual) + } +} + func TestBadCredentials(t *testing.T) { params := testConnParams(t) params.Password = "padpwd"