Skip to content

Commit

Permalink
Better CLI completion (#171)
Browse files Browse the repository at this point in the history
  • Loading branch information
nt0xa authored Jul 14, 2024
1 parent 8c54703 commit 5b3afea
Show file tree
Hide file tree
Showing 19 changed files with 265 additions and 148 deletions.
2 changes: 1 addition & 1 deletion cmd/client/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (c Config) ValidateWithContext(ctx context.Context) error {
}

func (c *Config) Server() *Server {
srv, ok := cfg.Servers[cfg.Context.Server]
srv, ok := c.Servers[c.Context.Server]
if !ok {
return nil
}
Expand Down
101 changes: 70 additions & 31 deletions cmd/client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,40 +14,45 @@ import (
"github.com/gookit/color"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"golang.org/x/exp/maps"

"github.com/nt0xa/sonar/internal/actions"
"github.com/nt0xa/sonar/internal/cmd"
"github.com/nt0xa/sonar/internal/modules/api/apiclient"
"github.com/nt0xa/sonar/internal/templates"
)

var (
cfg Config
cfgFile string
jsonOutput bool
)

func init() {
validation.ErrorTag = "err"
cobra.OnInitialize(initConfig)
}

func main() {
var (
cfg Config
cfgFile string
jsonOutput bool
)

c := cmd.New(
nil,
cmd.AllowFileAccess(true),
cmd.PreExec(func(root *cobra.Command) {
addConfigFlag(root)
addJSONFlag(root)
addContextCommand(root)
}),
cmd.InitActions(func() (actions.Actions, error) {
srv := cfg.Server()
if srv == nil {
return nil, errors.New("server must be set")
cmd.PreExec(func(acts *actions.Actions, root *cobra.Command) {
root.PersistentPreRunE = func(cmd *cobra.Command, args []string) error {
if err := initConfig(cfgFile, &cfg); err != nil {
return err
}

if err := initActions(acts, cfg); err != nil {
return err
}

return nil
}
client := apiclient.New(srv.URL, srv.Token, srv.Insecure, srv.Proxy)
return client, nil

// Flags, commands...
root.PersistentFlags().StringVar(&cfgFile, "config", "", "config file")
jsonFlag(root, &jsonOutput)
contextCmd(root, &cfg)
}),
)

Expand Down Expand Up @@ -78,25 +83,23 @@ func main() {
}
}

func addConfigFlag(root *cobra.Command) {
root.PersistentFlags().StringVar(&cfgFile, "config", "", "config file")
}

func addJSONFlag(root *cobra.Command) {
func jsonFlag(root *cobra.Command, jsonOutput *bool) {
for _, cmd := range root.Commands() {
if cmd.HasSubCommands() {
addJSONFlag(cmd)
jsonFlag(cmd, jsonOutput)
}

if cmd.Name() == "help" || cmd.Name() == "completion" {
if cmd.Name() == "help" ||
cmd.Name() == "completion" ||
strings.HasPrefix(cmd.Name(), "_") {
continue
}

cmd.Flags().BoolVar(&jsonOutput, "json", false, "JSON output")
cmd.Flags().BoolVar(jsonOutput, "json", false, "JSON output")
}
}

func addContextCommand(root *cobra.Command) {
func contextCmd(root *cobra.Command, cfg *Config) {
var server string

cmd := &cobra.Command{
Expand Down Expand Up @@ -124,10 +127,18 @@ func addContextCommand(root *cobra.Command) {
cmd.Flags().StringVarP(&server, "server", "s", "", "Server name from list of servers")
viper.BindPFlag("context.server", cmd.Flags().Lookup("server"))

cmd.RegisterFlagCompletionFunc("server", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
var cfg Config
if err := initConfig(findFlagValue("config", os.Args), &cfg); err != nil {
return nil, cobra.ShellCompDirectiveDefault
}
return maps.Keys(cfg.Servers), cobra.ShellCompDirectiveNoFileComp
})

root.AddCommand(cmd)
}

func initConfig() {
func initConfig(cfgFile string, cfg *Config) error {
if cfgFile != "" {
// Use config file from the flag.
viper.SetConfigFile(cfgFile)
Expand All @@ -143,7 +154,35 @@ func initConfig() {
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
viper.AutomaticEnv()

cobra.CheckErr(viper.ReadInConfig())
cobra.CheckErr(viper.Unmarshal(&cfg))
cobra.CheckErr(cfg.ValidateWithContext(context.Background()))
if err := viper.ReadInConfig(); err != nil {
return err
}

if err := viper.Unmarshal(cfg); err != nil {
return err
}

if err := cfg.ValidateWithContext(context.Background()); err != nil {
return err
}

return nil
}

func initActions(acts *actions.Actions, cfg Config) error {
srv := cfg.Server()
if srv == nil {
return errors.New("server must be set")
}
*acts = apiclient.New(srv.URL, srv.Token, srv.Insecure, srv.Proxy)
return nil
}

func findFlagValue(f string, args []string) string {
for i := 1; i < len(args); i++ {
if args[i-1] == "--"+f {
return args[i]
}
}
return ""
}
9 changes: 0 additions & 9 deletions config.toml

This file was deleted.

80 changes: 80 additions & 0 deletions internal/actions/completion.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package actions

import (
"slices"
"strings"

"github.com/spf13/cobra"
)

type completionFunc func(
cmd *cobra.Command,
args []string,
toComplete string,
) ([]string, cobra.ShellCompDirective)

func completePayloadName(acts *Actions) completionFunc {
return func(
cmd *cobra.Command,
_ []string,
toComplete string,
) ([]string, cobra.ShellCompDirective) {
payloads, err := (*acts).PayloadsList(cmd.Context(), PayloadsListParams{
Name: toComplete,
})
if err != nil {
return nil, cobra.ShellCompDirectiveError
}

names := make([]string, len(payloads))

for i, p := range payloads {
names[i] = p.Name
}

return names, cobra.ShellCompDirectiveNoFileComp
}
}

func completeOne(list []string) completionFunc {
return func(
_ *cobra.Command,
_ []string,
_ string,
) ([]string, cobra.ShellCompDirective) {
return list, cobra.ShellCompDirectiveNoFileComp
}
}

func completeMany(completions []string) completionFunc {
return func(
_ *cobra.Command,
_ []string,
toComplete string,
) ([]string, cobra.ShellCompDirective) {
// aaa,bbb,c -> [aaa, bbb, c]
parts := strings.Split(toComplete, ",")

// [aaa, bb, c] -> prefix = "aaa,bbb,"
prefix := strings.Join(parts[:len(parts)-1], ",")
if prefix != "" {
prefix += ","
}

// lastPart = "c"
lastPart := parts[len(parts)-1]

// Filter completions based on the current input
var result []string
for _, comp := range completions {
if slices.Contains(parts, comp) {
continue
}
if strings.HasPrefix(comp, lastPart) {
result = append(result, prefix+comp)
}
}

return result, cobra.ShellCompDirectiveNoSpace
}
}
21 changes: 16 additions & 5 deletions internal/actions/dns_records.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ func (r DNSRecordsCreateResult) ResultID() string {
return DNSRecordsCreateResultID
}

func DNSRecordsCreateCommand(p *DNSRecordsCreateParams, local bool) (*cobra.Command, PrepareCommandFunc) {
func DNSRecordsCreateCommand(acts *Actions, p *DNSRecordsCreateParams, local bool) (*cobra.Command, PrepareCommandFunc) {
cmd := &cobra.Command{
Use: "new VALUES",
Use: "new VALUES...",
Short: "Create new DNS records",
Args: atLeastOneArg("VALUES"),
}
Expand All @@ -85,6 +85,11 @@ func DNSRecordsCreateCommand(p *DNSRecordsCreateParams, local bool) (*cobra.Comm
cmd.Flags().StringVarP(&p.Strategy, "strategy", "s", models.DNSStrategyAll,
fmt.Sprintf("Strategy for multiple records (one of %s)", quoteAndJoin(models.DNSStrategiesAll)))

_ = cmd.MarkFlagRequired("name")
_ = cmd.RegisterFlagCompletionFunc("payload", completePayloadName(acts))
_ = cmd.RegisterFlagCompletionFunc("type", completeOne(models.DNSTypesAll))
_ = cmd.RegisterFlagCompletionFunc("strategy", completeOne(models.DNSStrategiesAll))

return cmd, func(cmd *cobra.Command, args []string) errors.Error {
p.Values = args
return nil
Expand Down Expand Up @@ -115,7 +120,7 @@ func (r DNSRecordsDeleteResult) ResultID() string {
return DNSRecordsDeleteResultID
}

func DNSRecordsDeleteCommand(p *DNSRecordsDeleteParams, local bool) (*cobra.Command, PrepareCommandFunc) {
func DNSRecordsDeleteCommand(acts *Actions, p *DNSRecordsDeleteParams, local bool) (*cobra.Command, PrepareCommandFunc) {
cmd := &cobra.Command{
Use: "del INDEX",
Short: "Delete DNS record",
Expand All @@ -125,6 +130,8 @@ func DNSRecordsDeleteCommand(p *DNSRecordsDeleteParams, local bool) (*cobra.Comm

cmd.Flags().StringVarP(&p.PayloadName, "payload", "p", "", "Payload name")

_ = cmd.RegisterFlagCompletionFunc("payload", completePayloadName(acts))

return cmd, func(cmd *cobra.Command, args []string) errors.Error {
i, err := strconv.ParseInt(args[0], 10, 64)
if err != nil {
Expand Down Expand Up @@ -156,7 +163,7 @@ func (r DNSRecordsClearResult) ResultID() string {
return DNSRecordsClearResultID
}

func DNSRecordsClearCommand(p *DNSRecordsClearParams, local bool) (*cobra.Command, PrepareCommandFunc) {
func DNSRecordsClearCommand(acts *Actions, p *DNSRecordsClearParams, local bool) (*cobra.Command, PrepareCommandFunc) {
cmd := &cobra.Command{
Use: "clr",
Short: "Delete multiple DNS records",
Expand All @@ -166,6 +173,8 @@ func DNSRecordsClearCommand(p *DNSRecordsClearParams, local bool) (*cobra.Comman
cmd.Flags().StringVarP(&p.PayloadName, "payload", "p", "", "Payload name")
cmd.Flags().StringVarP(&p.Name, "name", "n", "", "Subdomain")

_ = cmd.RegisterFlagCompletionFunc("payload", completePayloadName(acts))

return cmd, nil
}

Expand All @@ -189,13 +198,15 @@ func (r DNSRecordsListResult) ResultID() string {
return DNSRecordsListResultID
}

func DNSRecordsListCommand(p *DNSRecordsListParams, local bool) (*cobra.Command, PrepareCommandFunc) {
func DNSRecordsListCommand(acts *Actions, p *DNSRecordsListParams, local bool) (*cobra.Command, PrepareCommandFunc) {
cmd := &cobra.Command{
Use: "list",
Short: "List DNS records",
}

cmd.Flags().StringVarP(&p.PayloadName, "payload", "p", "", "Payload name")

_ = cmd.RegisterFlagCompletionFunc("payload", completePayloadName(acts))

return cmd, nil
}
8 changes: 6 additions & 2 deletions internal/actions/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (r EventsListResult) ResultID() string {
return EventsListResultID
}

func EventsListCommand(p *EventsListParams, local bool) (*cobra.Command, PrepareCommandFunc) {
func EventsListCommand(acts *Actions, p *EventsListParams, local bool) (*cobra.Command, PrepareCommandFunc) {
cmd := &cobra.Command{
Use: "list",
Short: "List payload events",
Expand All @@ -73,6 +73,8 @@ func EventsListCommand(p *EventsListParams, local bool) (*cobra.Command, Prepare
cmd.Flags().Int64VarP(&p.Before, "before", "b", 0, "Before ID")
cmd.Flags().BoolVarP(&p.Reverse, "reverse", "r", false, "List events in reversed order")

_ = cmd.RegisterFlagCompletionFunc("payload", completePayloadName(acts))

return cmd, nil
}

Expand Down Expand Up @@ -100,7 +102,7 @@ func (r EventsGetResult) ResultID() string {
return EventsGetResultID
}

func EventsGetCommand(p *EventsGetParams, local bool) (*cobra.Command, PrepareCommandFunc) {
func EventsGetCommand(acts *Actions, p *EventsGetParams, local bool) (*cobra.Command, PrepareCommandFunc) {
cmd := &cobra.Command{
Use: "get INDEX",
Short: "Get payload event by INDEX",
Expand All @@ -109,6 +111,8 @@ func EventsGetCommand(p *EventsGetParams, local bool) (*cobra.Command, PrepareCo

cmd.Flags().StringVarP(&p.PayloadName, "payload", "p", "", "Payload name")

_ = cmd.RegisterFlagCompletionFunc("payload", completePayloadName(acts))

return cmd, func(cmd *cobra.Command, args []string) errors.Error {
i, err := strconv.ParseInt(args[0], 10, 64)
if err != nil {
Expand Down
Loading

0 comments on commit 5b3afea

Please sign in to comment.