From 8eb33798b37c5cb3c98bb1ba12f6501b0cdac9ed Mon Sep 17 00:00:00 2001 From: Dmytro Shteflyuk Date: Wed, 25 Sep 2024 14:03:13 -0400 Subject: [PATCH] Ensure all IO errors are handled in the StaticCertManager test --- internal/server/cert_test.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/internal/server/cert_test.go b/internal/server/cert_test.go index d6c02ff..96ed045 100644 --- a/internal/server/cert_test.go +++ b/internal/server/cert_test.go @@ -3,6 +3,7 @@ package server import ( "crypto/tls" "os" + "path" "testing" "github.com/stretchr/testify/require" @@ -42,7 +43,8 @@ func TestCertificateLoadingRaceCondition(t *testing.T) { manager := NewStaticCertManager(certPath, keyPath) go func() { - manager.GetCertificate(&tls.ClientHelloInfo{}) + _, err2 := manager.GetCertificate(&tls.ClientHelloInfo{}) + require.NoError(t, err2) }() cert, err := manager.GetCertificate(&tls.ClientHelloInfo{}) require.NoError(t, err) @@ -58,8 +60,8 @@ func TestCachesLoadedCertificate(t *testing.T) { require.NoError(t, err) require.NotNil(t, cert1) - os.Remove(certPath) - os.Remove(keyPath) + require.Nil(t, os.Remove(certPath)) + require.Nil(t, os.Remove(keyPath)) cert2, err := manager.GetCertificate(&tls.ClientHelloInfo{}) require.Equal(t, cert1, cert2) @@ -85,21 +87,19 @@ func TestErrorWhenKeyFormatIsInvalid(t *testing.T) { func prepareTestCertificateFiles(t *testing.T) (string, string, error) { t.Helper() - certFile, err := os.CreateTemp("", "example-cert-*.pem") + dir := t.TempDir() + certFile := path.Join(dir, "example-cert.pem") + keyFile := path.Join(dir, "example-key.pem") + + err := os.WriteFile(certFile, []byte(certPem), 0644) if err != nil { return "", "", err } - defer certFile.Close() - certFile.Write([]byte(certPem)) - t.Cleanup(func() { os.Remove(certFile.Name()) }) - keyFile, err := os.CreateTemp("", "example-key-*.pem") + err = os.WriteFile(keyFile, []byte(keyPem), 0644) if err != nil { return "", "", err } - defer keyFile.Close() - keyFile.Write([]byte(keyPem)) - t.Cleanup(func() { os.Remove(keyFile.Name()) }) - return certFile.Name(), keyFile.Name(), nil + return certFile, keyFile, nil }