Skip to content

Commit

Permalink
feat: configurable batch size
Browse files Browse the repository at this point in the history
Signed-off-by: Brian McGee <[email protected]>
  • Loading branch information
brianmcgee committed Oct 17, 2024
1 parent b9c0568 commit fd59969
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 28 deletions.
8 changes: 2 additions & 6 deletions cmd/format/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ import (
bolt "go.etcd.io/bbolt"
)

const (
BatchSize = 1024
)

var ErrFailOnChange = errors.New("unexpected changes detected, --fail-on-change is enabled")

func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) error {
Expand Down Expand Up @@ -156,7 +152,7 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string)
}

// create a composite formatter which will handle applying the correct formatters to each file we traverse
formatter, err := format.NewCompositeFormatter(cfg, statz, BatchSize)
formatter, err := format.NewCompositeFormatter(cfg, statz)
if err != nil {
return fmt.Errorf("failed to create composite formatter: %w", err)
}
Expand All @@ -175,7 +171,7 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string)
}

// start traversing
files := make([]*walk.File, BatchSize)
files := make([]*walk.File, cfg.BatchSize)

for {
// read the next batch
Expand Down
35 changes: 35 additions & 0 deletions cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,41 @@ func TestCpuProfile(t *testing.T) {
as.NoError(err)
}

func TestBatchSize(t *testing.T) {
as := require.New(t)
tempDir := test.TempExamples(t)

// capture current cwd, so we can replace it after the test is finished
cwd, err := os.Getwd()
as.NoError(err)

t.Cleanup(func() {
// return to the previous working directory
as.NoError(os.Chdir(cwd))
})

_, _, err = treefmt(t, "-C", tempDir, "--allow-missing-formatter", "--batch-size", "10241")
as.ErrorIs(err, config.ErrInvalidBatchSize)

_, _, err = treefmt(t, "-C", tempDir, "--allow-missing-formatter", "--batch-size", "-10241")
as.ErrorContains(err, "invalid argument")

_, _, err = treefmt(t, "-C", tempDir, "--allow-missing-formatter")
as.NoError(err)

out, _, err := treefmt(t, "-C", tempDir, "--allow-missing-formatter", "--batch-size", "1", "-v")
as.NoError(err)
as.Contains(string(out), fmt.Sprintf("INFO config: batch size = %d", 1))

out, _, err = treefmt(t, "-C", tempDir, "--allow-missing-formatter", "--batch-size", "4", "-v")
as.NoError(err)
as.Contains(string(out), fmt.Sprintf("INFO config: batch size = %d", 4))

out, _, err = treefmt(t, "-C", tempDir, "--allow-missing-formatter", "--batch-size", "128", "-v")
as.NoError(err)
as.Contains(string(out), fmt.Sprintf("INFO config: batch size = %d", 128))
}

func TestAllowMissingFormatter(t *testing.T) {
as := require.New(t)

Expand Down
22 changes: 22 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@ import (
"path/filepath"
"strings"

"github.com/charmbracelet/log"
"github.com/numtide/treefmt/walk"
"github.com/spf13/pflag"
"github.com/spf13/viper"
)

var ErrInvalidBatchSize = fmt.Errorf("batch size must be between 1 and 10,240")

// Config is used to represent the list of configured Formatters.
type Config struct {
AllowMissingFormatter bool `mapstructure:"allow-missing-formatter" toml:"allow-missing-formatter,omitempty"`
BatchSize int `mapstructure:"batch-size" toml:"batch-size,omitempty"`
CI bool `mapstructure:"ci" toml:"ci,omitempty"`
ClearCache bool `mapstructure:"clear-cache" toml:"-"` // not allowed in config
CPUProfile string `mapstructure:"cpu-profile" toml:"cpu-profile,omitempty"`
Expand Down Expand Up @@ -59,6 +63,9 @@ func SetFlags(fs *pflag.FlagSet) {
"allow-missing-formatter", false,
"Do not exit with error if a configured formatter is missing. (env $TREEFMT_ALLOW_MISSING_FORMATTER)",
)
fs.Uint("batch-size", 1024,
"The maximum number of files to pass to a formatter at once. (env $TREEFMT_BATCH_SIZE)",
)
fs.Bool(
"ci", false,
"Runs treefmt in a CI mode, enabling --no-cache, --fail-on-change and adjusting some other settings "+
Expand Down Expand Up @@ -235,6 +242,21 @@ func FromViper(v *viper.Viper) (*Config, error) {
}
}

// validate batch size
// todo what is a reasonable upper limit on this?

// default if it isn't set (e.g. in tests when using Config directly)
if cfg.BatchSize == 0 {
cfg.BatchSize = 1024
}

if !(1 <= cfg.BatchSize && cfg.BatchSize <= 10240) {
return nil, ErrInvalidBatchSize
}

l := log.WithPrefix("config")
l.Infof("batch size = %d", cfg.BatchSize)

return cfg, nil
}

Expand Down
3 changes: 1 addition & 2 deletions format/composite.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,6 @@ func (c *CompositeFormatter) BustCache(db *bolt.DB) error {
func NewCompositeFormatter(
cfg *config.Config,
statz *stats.Stats,
batchSize int,
) (*CompositeFormatter, error) {
// compile global exclude globs
globalExcludes, err := compileGlobs(cfg.Excludes)
Expand Down Expand Up @@ -392,7 +391,7 @@ func NewCompositeFormatter(
return &CompositeFormatter{
cfg: cfg,
stats: statz,
batchSize: batchSize,
batchSize: cfg.BatchSize,
globalExcludes: globalExcludes,

log: log.WithPrefix("composite-formatter"),
Expand Down
35 changes: 15 additions & 20 deletions format/formatter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@ import (
func TestInvalidFormatterName(t *testing.T) {
as := require.New(t)

const batchSize = 1024

cfg := &config.Config{}
cfg.OnUnmatched = "info"

statz := stats.New()

// simple "empty" config
_, err := format.NewCompositeFormatter(cfg, &statz, batchSize)
_, err := format.NewCompositeFormatter(cfg, &statz)
as.NoError(err)

// valid name using all the acceptable characters
Expand All @@ -35,7 +33,7 @@ func TestInvalidFormatterName(t *testing.T) {
},
}

_, err = format.NewCompositeFormatter(cfg, &statz, batchSize)
_, err = format.NewCompositeFormatter(cfg, &statz)
as.NoError(err)

// test with some bad examples
Expand All @@ -48,18 +46,15 @@ func TestInvalidFormatterName(t *testing.T) {
},
}

_, err = format.NewCompositeFormatter(cfg, &statz, batchSize)
_, err = format.NewCompositeFormatter(cfg, &statz)
as.ErrorIs(err, format.ErrInvalidName)
}
}

func TestFormatterHash(t *testing.T) {
as := require.New(t)

const batchSize = 1024

statz := stats.New()

tempDir := t.TempDir()

// symlink some formatters into temp dir, so we can mess with their mod times
Expand Down Expand Up @@ -93,7 +88,7 @@ func TestFormatterHash(t *testing.T) {
},
}

f, err := format.NewCompositeFormatter(cfg, &statz, batchSize)
f, err := format.NewCompositeFormatter(cfg, &statz)
as.NoError(err)

// hash for the first time
Expand Down Expand Up @@ -123,7 +118,7 @@ func TestFormatterHash(t *testing.T) {
})

t.Run("modify formatter options", func(_ *testing.T) {
f, err = format.NewCompositeFormatter(cfg, &statz, batchSize)
f, err = format.NewCompositeFormatter(cfg, &statz)
as.NoError(err)

h3, err := f.Hash()
Expand All @@ -133,7 +128,7 @@ func TestFormatterHash(t *testing.T) {
python := cfg.FormatterConfigs["python"]
python.Includes = []string{"*.py", "*.pyi"}

f, err = format.NewCompositeFormatter(cfg, &statz, batchSize)
f, err = format.NewCompositeFormatter(cfg, &statz)
as.NoError(err)

h4, err := f.Hash()
Expand All @@ -148,7 +143,7 @@ func TestFormatterHash(t *testing.T) {
// adjust python excludes
python.Excludes = []string{"*.pyi"}

f, err = format.NewCompositeFormatter(cfg, &statz, batchSize)
f, err = format.NewCompositeFormatter(cfg, &statz)
as.NoError(err)

h6, err := f.Hash()
Expand All @@ -163,7 +158,7 @@ func TestFormatterHash(t *testing.T) {
// adjust python options
python.Options = []string{"-w", "-s"}

f, err = format.NewCompositeFormatter(cfg, &statz, batchSize)
f, err = format.NewCompositeFormatter(cfg, &statz)
as.NoError(err)

h8, err := f.Hash()
Expand All @@ -178,7 +173,7 @@ func TestFormatterHash(t *testing.T) {
// adjust python priority
python.Priority = 100

f, err = format.NewCompositeFormatter(cfg, &statz, batchSize)
f, err = format.NewCompositeFormatter(cfg, &statz)
as.NoError(err)

h10, err := f.Hash()
Expand All @@ -198,7 +193,7 @@ func TestFormatterHash(t *testing.T) {
Includes: []string{"*.go"},
}

f, err = format.NewCompositeFormatter(cfg, &statz, batchSize)
f, err = format.NewCompositeFormatter(cfg, &statz)
as.NoError(err)

h3, err := f.Hash()
Expand All @@ -208,7 +203,7 @@ func TestFormatterHash(t *testing.T) {
// remove python formatter
delete(cfg.FormatterConfigs, "python")

f, err = format.NewCompositeFormatter(cfg, &statz, batchSize)
f, err = format.NewCompositeFormatter(cfg, &statz)
as.NoError(err)

h4, err := f.Hash()
Expand All @@ -223,7 +218,7 @@ func TestFormatterHash(t *testing.T) {
// remove elm formatter
delete(cfg.FormatterConfigs, "elm")

f, err = format.NewCompositeFormatter(cfg, &statz, batchSize)
f, err = format.NewCompositeFormatter(cfg, &statz)
as.NoError(err)

h6, err := f.Hash()
Expand All @@ -239,15 +234,15 @@ func TestFormatterHash(t *testing.T) {
t.Run("modify global excludes", func(_ *testing.T) {
cfg.Excludes = []string{"*.go"}

f, err = format.NewCompositeFormatter(cfg, &statz, batchSize)
f, err = format.NewCompositeFormatter(cfg, &statz)
as.NoError(err)

h3, err := f.Hash()
as.NoError(err)

cfg.Excludes = []string{"*.go", "*.hs"}

f, err = format.NewCompositeFormatter(cfg, &statz, batchSize)
f, err = format.NewCompositeFormatter(cfg, &statz)
as.NoError(err)

h4, err := f.Hash()
Expand All @@ -263,7 +258,7 @@ func TestFormatterHash(t *testing.T) {
cfg.Excludes = nil
cfg.Global.Excludes = []string{"*.go", "*.hs"}

f, err = format.NewCompositeFormatter(cfg, &statz, batchSize)
f, err = format.NewCompositeFormatter(cfg, &statz)
as.NoError(err)

h6, err := f.Hash()
Expand Down

0 comments on commit fd59969

Please sign in to comment.