Skip to content

Commit

Permalink
Refactor file writing logic (#778)
Browse files Browse the repository at this point in the history
  • Loading branch information
HomayoonAlimohammadi authored Nov 8, 2024
1 parent 297bed5 commit 528d459
Show file tree
Hide file tree
Showing 14 changed files with 143 additions and 17 deletions.
3 changes: 1 addition & 2 deletions src/k8s/cmd/k8s/k8s_bootstrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package k8s
import (
"bytes"
_ "embed"
"os"
"path/filepath"
"testing"

Expand Down Expand Up @@ -109,7 +108,7 @@ var testCases = []testCase{
func mustAddConfigToTestDir(t *testing.T, configPath string, data string) {
t.Helper()
// Create the cluster bootstrap config file
err := os.WriteFile(configPath, []byte(data), 0o644)
err := utils.WriteFile(configPath, []byte(data), 0o644)
if err != nil {
t.Fatal(err)
}
Expand Down
5 changes: 2 additions & 3 deletions src/k8s/cmd/k8s/k8s_x_capi.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
package k8s

import (
"os"

apiv1 "github.com/canonical/k8s-snap-api/api/v1"
cmdutil "github.com/canonical/k8s/cmd/util"
"github.com/canonical/k8s/pkg/utils"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -48,7 +47,7 @@ func newXCAPICmd(env cmdutil.ExecutionEnvironment) *cobra.Command {
return
}

if err := os.WriteFile(env.Snap.NodeTokenFile(), []byte(token), 0o600); err != nil {
if err := utils.WriteFile(env.Snap.NodeTokenFile(), []byte(token), 0o600); err != nil {
cmd.PrintErrf("Error: Failed to write the node token to file.\n\nThe error was: %v\n", err)
env.Exit(1)
return
Expand Down
2 changes: 1 addition & 1 deletion src/k8s/cmd/k8sd/k8sd_cluster_recover.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ func yamlEditorGuide(
newContent = removeEmptyLines(newContent)

if applyChanges {
err = os.WriteFile(path, newContent, os.FileMode(0o644))
err = utils.WriteFile(path, newContent, os.FileMode(0o644))
if err != nil {
return nil, fmt.Errorf("could not write file: %s, error: %w", path, err)
}
Expand Down
4 changes: 3 additions & 1 deletion src/k8s/pkg/docgen/json_struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"os"
"reflect"
"strings"

"github.com/canonical/k8s/pkg/utils"
)

type JsonTag struct {
Expand Down Expand Up @@ -55,7 +57,7 @@ func MarkdownFromJsonStructToFile(i any, outFilePath string, projectDir string)
return err
}

err = os.WriteFile(outFilePath, []byte(content), 0o644)
err = utils.WriteFile(outFilePath, []byte(content), 0o644)
if err != nil {
return fmt.Errorf("failed to write markdown documentation to %s: %w", outFilePath, err)
}
Expand Down
3 changes: 2 additions & 1 deletion src/k8s/pkg/k8sd/setup/certificates.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/canonical/k8s/pkg/k8sd/pki"
"github.com/canonical/k8s/pkg/snap"
"github.com/canonical/k8s/pkg/utils"
)

// ensureFile creates fname with the specified contents, mode and owner bits.
Expand Down Expand Up @@ -39,7 +40,7 @@ func ensureFile(fname string, contents string, uid, gid int, mode fs.FileMode) (
var contentChanged bool

if contents != string(origContent) {
if err := os.WriteFile(fname, []byte(contents), mode); err != nil {
if err := utils.WriteFile(fname, []byte(contents), mode); err != nil {
return false, fmt.Errorf("failed to write: %w", err)
}
contentChanged = true
Expand Down
2 changes: 1 addition & 1 deletion src/k8s/pkg/k8sd/setup/containerd.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func Containerd(snap snap.Snap, extraContainerdConfig map[string]any, extraArgs
return fmt.Errorf("failed to render containerd config.toml: %w", err)
}

if err := os.WriteFile(filepath.Join(snap.ContainerdConfigDir(), "config.toml"), b, 0o600); err != nil {
if err := utils.WriteFile(filepath.Join(snap.ContainerdConfigDir(), "config.toml"), b, 0o600); err != nil {
return fmt.Errorf("failed to write config.toml: %w", err)
}

Expand Down
2 changes: 1 addition & 1 deletion src/k8s/pkg/k8sd/setup/containerd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestContainerd(t *testing.T) {

dir := t.TempDir()

g.Expect(os.WriteFile(filepath.Join(dir, "mockcni"), []byte("echo hi"), 0o600)).To(Succeed())
g.Expect(utils.WriteFile(filepath.Join(dir, "mockcni"), []byte("echo hi"), 0o600)).To(Succeed())

s := &mock.Snap{
Mock: mock.Mock{
Expand Down
2 changes: 1 addition & 1 deletion src/k8s/pkg/k8sd/setup/k8s_dqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func K8sDqlite(snap snap.Snap, address string, cluster []string, extraArgs map[s
return fmt.Errorf("failed to create init.yaml file for address=%s cluster=%v: %w", address, cluster, err)
}

if err := os.WriteFile(filepath.Join(snap.K8sDqliteStateDir(), "init.yaml"), b, 0o600); err != nil {
if err := utils.WriteFile(filepath.Join(snap.K8sDqliteStateDir(), "init.yaml"), b, 0o600); err != nil {
return fmt.Errorf("failed to write init.yaml: %w", err)
}

Expand Down
3 changes: 2 additions & 1 deletion src/k8s/pkg/k8sd/setup/util_extra_files.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"

"github.com/canonical/k8s/pkg/snap"
"github.com/canonical/k8s/pkg/utils"
)

// ExtraNodeConfigFiles writes the file contents to the specified filenames in the snap.ExtraFilesDir directory.
Expand All @@ -20,7 +21,7 @@ func ExtraNodeConfigFiles(snap snap.Snap, files map[string]string) error {

filePath := filepath.Join(snap.ServiceExtraConfigDir(), filename)
// Write the content to the file
if err := os.WriteFile(filePath, []byte(content), 0o400); err != nil {
if err := utils.WriteFile(filePath, []byte(content), 0o400); err != nil {
return fmt.Errorf("failed to write to file %s: %w", filePath, err)
}

Expand Down
4 changes: 3 additions & 1 deletion src/k8s/pkg/proxy/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"fmt"
"os"
"sort"

"github.com/canonical/k8s/pkg/utils"
)

// Configuration is the format of the apiserver proxy endpoints config file.
Expand Down Expand Up @@ -33,7 +35,7 @@ func WriteEndpointsConfig(endpoints []string, file string) error {
return fmt.Errorf("failed to marshal configuration: %w", err)
}

if err := os.WriteFile(file, b, 0o600); err != nil {
if err := utils.WriteFile(file, b, 0o600); err != nil {
return fmt.Errorf("failed to write configuration file %s: %w", file, err)
}
return nil
Expand Down
2 changes: 1 addition & 1 deletion src/k8s/pkg/snap/util/arguments.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func UpdateServiceArguments(snap snap.Snap, serviceName string, updateMap map[st
// sort arguments so that output is consistent
sort.Strings(newArguments)

if err := os.WriteFile(argumentsFile, []byte(strings.Join(newArguments, "\n")+"\n"), 0o600); err != nil {
if err := utils.WriteFile(argumentsFile, []byte(strings.Join(newArguments, "\n")+"\n"), 0o600); err != nil {
return false, fmt.Errorf("failed to write arguments for service %s: %w", serviceName, err)
}
return changed, nil
Expand Down
4 changes: 2 additions & 2 deletions src/k8s/pkg/snap/util/arguments_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ package snaputil_test

import (
"fmt"
"os"
"path/filepath"
"testing"

"github.com/canonical/k8s/pkg/snap/mock"
snaputil "github.com/canonical/k8s/pkg/snap/util"
"github.com/canonical/k8s/pkg/utils"
. "github.com/onsi/gomega"
)

Expand All @@ -32,7 +32,7 @@ func TestGetServiceArgument(t *testing.T) {
--key=value-of-service-two
`,
} {
g.Expect(os.WriteFile(filepath.Join(dir, svc), []byte(args), 0o600)).To(Succeed())
g.Expect(utils.WriteFile(filepath.Join(dir, svc), []byte(args), 0o600)).To(Succeed())
}

for _, tc := range []struct {
Expand Down
32 changes: 32 additions & 0 deletions src/k8s/pkg/utils/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,35 @@ func CreateTarball(tarballPath string, rootDir string, walkDir string, excludeFi

return nil
}

// WriteFile writes data to a file with the given name and permissions.
// The file is written to a temporary file in the same directory as the target file
// and then renamed to the target file to avoid partial writes in case of a crash.
func WriteFile(name string, data []byte, perm fs.FileMode) error {
dir := filepath.Dir(name)
tmpFile, err := os.CreateTemp(dir, "tmp-*")
if err != nil {
return fmt.Errorf("failed to create temp file: %w", err)
}
defer os.Remove(tmpFile.Name())

if _, err := tmpFile.Write(data); err != nil {
tmpFile.Close()
return fmt.Errorf("failed to write to temp file: %w", err)
}

if err := tmpFile.Chmod(perm); err != nil {
tmpFile.Close()
return fmt.Errorf("failed to set permissions on temp file: %w", err)
}

if err := tmpFile.Close(); err != nil {
return fmt.Errorf("failed to close temp file: %w", err)
}

if err := os.Rename(tmpFile.Name(), name); err != nil {
return fmt.Errorf("failed to rename temp file to target file: %w", err)
}

return nil
}
92 changes: 91 additions & 1 deletion src/k8s/pkg/utils/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"os"
"path/filepath"
"sync"
"testing"

"github.com/canonical/k8s/pkg/utils"
Expand Down Expand Up @@ -88,7 +89,7 @@ func TestParseArgumentFile(t *testing.T) {
g := NewWithT(t)

filePath := filepath.Join(t.TempDir(), tc.name)
err := os.WriteFile(filePath, []byte(tc.content), 0o755)
err := utils.WriteFile(filePath, []byte(tc.content), 0o755)
if err != nil {
t.Fatalf("Failed to setup testfile: %v", err)
}
Expand Down Expand Up @@ -182,3 +183,92 @@ func TestGetMountPropagationType(t *testing.T) {
g.Expect(err).ToNot(HaveOccurred())
g.Expect(mountType).To(Equal(utils.MountPropagationShared))
}

func TestWriteFile(t *testing.T) {
t.Run("PartialWrites", func(t *testing.T) {
g := NewWithT(t)

name := filepath.Join(t.TempDir(), "testfile")

const (
numWriters = 200
numIterations = 200
)

var wg sync.WaitGroup
wg.Add(numWriters)

expContent := "key: value"
expPerm := os.FileMode(0o644)

for i := 0; i < numWriters; i++ {
go func(writerID int) {
defer wg.Done()

for j := 0; j < numIterations; j++ {
g.Expect(utils.WriteFile(name, []byte(expContent), expPerm)).To(Succeed())

content, err := os.ReadFile(name)
g.Expect(err).ToNot(HaveOccurred())
g.Expect(string(content)).To(Equal(expContent))

fileInfo, err := os.Stat(name)
g.Expect(err).ToNot(HaveOccurred())
g.Expect(fileInfo.Mode().Perm()).To(Equal(expPerm))
}
}(i)
}

wg.Wait()
})

tcs := []struct {
name string
expContent []byte
expPerm os.FileMode
}{
{
name: "test1",
expContent: []byte("key: value"),
expPerm: os.FileMode(0o644),
},
{
name: "test2",
expContent: []byte(""),
expPerm: os.FileMode(0o600),
},
{
name: "test3",
expContent: []byte("key: value"),
expPerm: os.FileMode(0o755),
},
{
name: "test4",
expContent: []byte("key: value"),
expPerm: os.FileMode(0o777),
},
{
name: "test5",
expContent: []byte("key: value"),
expPerm: os.FileMode(0o400),
},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
g := NewWithT(t)

name := filepath.Join(t.TempDir(), tc.name)

g.Expect(utils.WriteFile(name, tc.expContent, tc.expPerm)).To(Succeed())

content, err := os.ReadFile(name)
g.Expect(err).ToNot(HaveOccurred())
g.Expect(string(content)).To(Equal(string(tc.expContent)))

fileInfo, err := os.Stat(name)
g.Expect(err).ToNot(HaveOccurred())
g.Expect(fileInfo.Mode().Perm()).To(Equal(tc.expPerm))
})
}
}

0 comments on commit 528d459

Please sign in to comment.