diff --git a/README.md b/README.md index 28f730e..129f18a 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ go get maragu.dev/clir package main import ( + "flag" "fmt" "log/slog" "math/rand" @@ -38,6 +39,7 @@ import ( "time" "maragu.dev/clir" + "maragu.dev/clir/middleware" ) func main() { @@ -53,19 +55,20 @@ func main() { // Add logging middleware to all routes. r.Use(log(l)) + var v *bool + r.Use(middleware.Flags(func(fs *flag.FlagSet) { + v = fs.Bool("v", false, "verbose") + })) + // Add a root route which calls printHello. r.Route("", printHello()) - // Scope some middleware to just the routes within the scope. - r.Scope(func(r *clir.Router) { - r.Use(ping(c)) - - r.Route("get", get(c)) - }) + // Add a named route which calls get. + r.Route("get", get(c)) // Branch with subcommands r.Branch("post", func(r *clir.Router) { - r.Use(ping(c)) + r.Use(ping(c, v)) r.Route("stdin", postFromStdin(c)) r.Route("random", postFromRandom(c)) @@ -142,9 +145,12 @@ func log(l *slog.Logger) clir.Middleware { } // ping a URL to check the network. -func ping(c *http.Client) clir.Middleware { +func ping(c *http.Client, v *bool) clir.Middleware { return func(next clir.Runner) clir.Runner { return clir.RunnerFunc(func(ctx clir.Context) error { + if *v { + ctx.Println("Pinging!") + } if _, err := c.Get("https://example.com"); err != nil { return err } diff --git a/internal/examples/app/main.go b/internal/examples/app/main.go index 1a4f38d..4534171 100644 --- a/internal/examples/app/main.go +++ b/internal/examples/app/main.go @@ -1,6 +1,7 @@ package main import ( + "flag" "fmt" "log/slog" "math/rand" @@ -10,6 +11,7 @@ import ( "time" "maragu.dev/clir" + "maragu.dev/clir/middleware" ) func main() { @@ -25,19 +27,20 @@ func main() { // Add logging middleware to all routes. r.Use(log(l)) + var v *bool + r.Use(middleware.Flags(func(fs *flag.FlagSet) { + v = fs.Bool("v", false, "verbose") + })) + // Add a root route which calls printHello. r.Route("", printHello()) - // Scope some middleware to just the routes within the scope. - r.Scope(func(r *clir.Router) { - r.Use(ping(c)) - - r.Route("get", get(c)) - }) + // Add a named route which calls get. + r.Route("get", get(c)) // Branch with subcommands r.Branch("post", func(r *clir.Router) { - r.Use(ping(c)) + r.Use(ping(c, v)) r.Route("stdin", postFromStdin(c)) r.Route("random", postFromRandom(c)) @@ -114,9 +117,12 @@ func log(l *slog.Logger) clir.Middleware { } // ping a URL to check the network. -func ping(c *http.Client) clir.Middleware { +func ping(c *http.Client, v *bool) clir.Middleware { return func(next clir.Runner) clir.Runner { return clir.RunnerFunc(func(ctx clir.Context) error { + if *v { + ctx.Println("Pinging!") + } if _, err := c.Get("https://example.com"); err != nil { return err } diff --git a/middleware/middleware.go b/middleware/middleware.go new file mode 100644 index 0000000..b6fc971 --- /dev/null +++ b/middleware/middleware.go @@ -0,0 +1,24 @@ +// Package middleware provides useful middleware for a [clir.Router]. +package middleware + +import ( + "flag" + + "maragu.dev/clir" +) + +// Flags middleware allows you to set flags on a route. +func Flags(cb func(fs *flag.FlagSet)) clir.Middleware { + fs := flag.NewFlagSet("", flag.ContinueOnError) + cb(fs) + + return func(next clir.Runner) clir.Runner { + return clir.RunnerFunc(func(ctx clir.Context) error { + if err := fs.Parse(ctx.Args); err != nil { + return err + } + ctx.Args = fs.Args() + return next.Run(ctx) + }) + } +} diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go new file mode 100644 index 0000000..ce8d661 --- /dev/null +++ b/middleware/middleware_test.go @@ -0,0 +1,92 @@ +package middleware_test + +import ( + "flag" + "os" + "testing" + + "maragu.dev/is" + + "maragu.dev/clir" + "maragu.dev/clir/middleware" +) + +func TestFlags(t *testing.T) { + t.Run("can set flags on a root route", func(t *testing.T) { + r := clir.NewRouter() + + var v *bool + r.Use(middleware.Flags(func(fs *flag.FlagSet) { + v = fs.Bool("v", false, "") + })) + + var called bool + r.RouteFunc("", func(ctx clir.Context) error { + called = true + return nil + }) + + err := r.Run(clir.Context{ + Args: []string{"-v"}, + }) + is.NotError(t, err) + is.True(t, called) + is.NotNil(t, v) + is.True(t, *v) + }) + + t.Run("can set flags on the root and subroutes", func(t *testing.T) { + r := clir.NewRouter() + + var v *bool + r.Use(middleware.Flags(func(fs *flag.FlagSet) { + v = fs.Bool("v", false, "") + })) + + var called bool + var fancy *bool + + r.Branch("dance", func(r *clir.Router) { + r.Use(middleware.Flags(func(fs *flag.FlagSet) { + fancy = fs.Bool("fancypants", false, "") + })) + + r.RouteFunc("", func(ctx clir.Context) error { + called = true + return nil + }) + }) + + err := r.Run(clir.Context{ + Args: []string{"-v", "dance", "-fancypants"}, + }) + is.NotError(t, err) + is.True(t, called) + is.NotNil(t, v) + is.True(t, *v) + is.NotNil(t, fancy) + is.True(t, *fancy) + }) +} + +func ExampleFlags() { + r := clir.NewRouter() + + var v *bool + r.Use(middleware.Flags(func(fs *flag.FlagSet) { + v = fs.Bool("v", false, "verbose output") + })) + + r.RouteFunc("", func(ctx clir.Context) error { + if *v { + ctx.Println("Hello!") + } + return nil + }) + + _ = r.Run(clir.Context{ + Args: []string{"-v"}, + Out: os.Stdout, + }) + // Output: Hello! +} diff --git a/router.go b/router.go index 4a99412..99858a2 100644 --- a/router.go +++ b/router.go @@ -1,6 +1,7 @@ package clir import ( + "fmt" "regexp" "strings" ) @@ -9,7 +10,6 @@ import ( type Router struct { middlewares []Middleware patterns []*regexp.Regexp - routers []*Router runners map[string]Runner } @@ -21,28 +21,38 @@ func NewRouter() *Router { // Run satisfies [Runner]. func (r *Router) Run(ctx Context) error { + // Apply middlewares first, because they can modify the context, including the Context.Args to match against. + var middlewareCtx Context + var runner Runner = RunnerFunc(func(ctx Context) error { + middlewareCtx = ctx + return nil + }) + // Apply middlewares in reverse order, so the first middleware is the outermost one, to be called first. + for i := len(r.middlewares) - 1; i >= 0; i-- { + runner = r.middlewares[i](runner) + } + if err := runner.Run(ctx); err != nil { + return fmt.Errorf("error while applying middleware: %w", err) + } + ctx = middlewareCtx + for _, pattern := range r.patterns { if (len(ctx.Args) == 0 && pattern.String() == "^$") || (len(ctx.Args) > 0 && pattern.MatchString(ctx.Args[0])) { - - runner := r.runners[pattern.String()] + runner = r.runners[pattern.String()] if len(ctx.Args) > 0 { ctx.Matches = pattern.FindStringSubmatch(ctx.Args[0]) ctx.Args = ctx.Args[1:] } - for i := len(r.middlewares) - 1; i >= 0; i-- { - runner = r.middlewares[i](runner) - } - return runner.Run(ctx) } } - for _, router := range r.routers { - if err := router.Run(ctx); err == nil { - return err - } - } + //for _, router := range r.routers { + // if err := router.Run(ctx); err == nil { + // return err + // } + //} return ErrorRouteNotFound } @@ -78,13 +88,13 @@ func (r *Router) Branch(pattern string, cb func(r *Router)) { } // Scope into a new [Router]. -// The middlewares from the parent router are copied to the new router, +// The middlewares from the parent router are used in the new router, // but new middlewares within the scope are only added to the new router, not the parent router. func (r *Router) Scope(cb func(r *Router)) { - newR := NewRouter() - newR.middlewares = append(newR.middlewares, r.middlewares...) - cb(newR) - r.routers = append(r.routers, newR) + panic("not implemented") + //newR := NewRouter() + //cb(newR) + //r.routers = append(r.routers, newR) } // Middleware for [Router.Use]. diff --git a/router_test.go b/router_test.go index 330e848..8271ecc 100644 --- a/router_test.go +++ b/router_test.go @@ -1,6 +1,7 @@ package clir_test import ( + "flag" "strings" "testing" @@ -166,9 +167,43 @@ func TestRouter_Use(t *testing.T) { r.Use(newMiddleware(t, "m1")) }) + + t.Run("can use middleware that parses flags", func(t *testing.T) { + r := clir.NewRouter() + + r.Use(func(next clir.Runner) clir.Runner { + return clir.RunnerFunc(func(ctx clir.Context) error { + fs := flag.NewFlagSet("test", flag.ContinueOnError) + v := fs.Bool("v", false, "") + err := fs.Parse(ctx.Args) + is.NotError(t, err) + is.True(t, *v) + + t.Log(fs.Args()) + ctx.Args = fs.Args() + + return next.Run(ctx) + }) + }) + + var called bool + r.RouteFunc("", func(ctx clir.Context) error { + called = true + return nil + }) + + err := r.Run(clir.Context{ + Args: []string{"-v"}, + }) + is.NotError(t, err) + is.True(t, called) + }) } +//nolint:staticcheck func TestRouter_Scope(t *testing.T) { + t.Skip("not implemented") + t.Run("can scope routes with a new middleware stack", func(t *testing.T) { r := clir.NewRouter()