Skip to content

Commit

Permalink
Add Flags middleware (#9)
Browse files Browse the repository at this point in the history
This adds a new `middleware` package as well as `middleware.Flags`,
which allows you to supply a callback function receiving a
`flag.FlagSet` to define flags on.

In order to make this work, middleware is now applied before route
matching. Otherwise, middleware can't change the route matching, which
is necessary, because flags are part of the route before parsing.

I had to disable `Router.Scope` because I can't currently make it work
with the middleware changes, and I'm prioritizing the flags feature. See
#8.

Fixes #4
  • Loading branch information
markuswustenberg authored Oct 25, 2024
1 parent bd6db43 commit 39523fc
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 33 deletions.
22 changes: 14 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ go get maragu.dev/clir
package main

import (
"flag"
"fmt"
"log/slog"
"math/rand"
Expand All @@ -38,6 +39,7 @@ import (
"time"

"maragu.dev/clir"
"maragu.dev/clir/middleware"
)

func main() {
Expand All @@ -53,19 +55,20 @@ func main() {
// Add logging middleware to all routes.
r.Use(log(l))

var v *bool
r.Use(middleware.Flags(func(fs *flag.FlagSet) {
v = fs.Bool("v", false, "verbose")
}))

// Add a root route which calls printHello.
r.Route("", printHello())

// Scope some middleware to just the routes within the scope.
r.Scope(func(r *clir.Router) {
r.Use(ping(c))

r.Route("get", get(c))
})
// Add a named route which calls get.
r.Route("get", get(c))

// Branch with subcommands
r.Branch("post", func(r *clir.Router) {
r.Use(ping(c))
r.Use(ping(c, v))

r.Route("stdin", postFromStdin(c))
r.Route("random", postFromRandom(c))
Expand Down Expand Up @@ -142,9 +145,12 @@ func log(l *slog.Logger) clir.Middleware {
}

// ping a URL to check the network.
func ping(c *http.Client) clir.Middleware {
func ping(c *http.Client, v *bool) clir.Middleware {
return func(next clir.Runner) clir.Runner {
return clir.RunnerFunc(func(ctx clir.Context) error {
if *v {
ctx.Println("Pinging!")
}
if _, err := c.Get("https://example.com"); err != nil {
return err
}
Expand Down
22 changes: 14 additions & 8 deletions internal/examples/app/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"flag"
"fmt"
"log/slog"
"math/rand"
Expand All @@ -10,6 +11,7 @@ import (
"time"

"maragu.dev/clir"
"maragu.dev/clir/middleware"
)

func main() {
Expand All @@ -25,19 +27,20 @@ func main() {
// Add logging middleware to all routes.
r.Use(log(l))

var v *bool
r.Use(middleware.Flags(func(fs *flag.FlagSet) {
v = fs.Bool("v", false, "verbose")
}))

// Add a root route which calls printHello.
r.Route("", printHello())

// Scope some middleware to just the routes within the scope.
r.Scope(func(r *clir.Router) {
r.Use(ping(c))

r.Route("get", get(c))
})
// Add a named route which calls get.
r.Route("get", get(c))

// Branch with subcommands
r.Branch("post", func(r *clir.Router) {
r.Use(ping(c))
r.Use(ping(c, v))

r.Route("stdin", postFromStdin(c))
r.Route("random", postFromRandom(c))
Expand Down Expand Up @@ -114,9 +117,12 @@ func log(l *slog.Logger) clir.Middleware {
}

// ping a URL to check the network.
func ping(c *http.Client) clir.Middleware {
func ping(c *http.Client, v *bool) clir.Middleware {
return func(next clir.Runner) clir.Runner {
return clir.RunnerFunc(func(ctx clir.Context) error {
if *v {
ctx.Println("Pinging!")
}
if _, err := c.Get("https://example.com"); err != nil {
return err
}
Expand Down
24 changes: 24 additions & 0 deletions middleware/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Package middleware provides useful middleware for a [clir.Router].
package middleware

import (
"flag"

"maragu.dev/clir"
)

// Flags middleware allows you to set flags on a route.
func Flags(cb func(fs *flag.FlagSet)) clir.Middleware {
fs := flag.NewFlagSet("", flag.ContinueOnError)
cb(fs)

return func(next clir.Runner) clir.Runner {
return clir.RunnerFunc(func(ctx clir.Context) error {
if err := fs.Parse(ctx.Args); err != nil {
return err
}
ctx.Args = fs.Args()
return next.Run(ctx)
})
}
}
92 changes: 92 additions & 0 deletions middleware/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package middleware_test

import (
"flag"
"os"
"testing"

"maragu.dev/is"

"maragu.dev/clir"
"maragu.dev/clir/middleware"
)

func TestFlags(t *testing.T) {
t.Run("can set flags on a root route", func(t *testing.T) {
r := clir.NewRouter()

var v *bool
r.Use(middleware.Flags(func(fs *flag.FlagSet) {
v = fs.Bool("v", false, "")
}))

var called bool
r.RouteFunc("", func(ctx clir.Context) error {
called = true
return nil
})

err := r.Run(clir.Context{
Args: []string{"-v"},
})
is.NotError(t, err)
is.True(t, called)
is.NotNil(t, v)
is.True(t, *v)
})

t.Run("can set flags on the root and subroutes", func(t *testing.T) {
r := clir.NewRouter()

var v *bool
r.Use(middleware.Flags(func(fs *flag.FlagSet) {
v = fs.Bool("v", false, "")
}))

var called bool
var fancy *bool

r.Branch("dance", func(r *clir.Router) {
r.Use(middleware.Flags(func(fs *flag.FlagSet) {
fancy = fs.Bool("fancypants", false, "")
}))

r.RouteFunc("", func(ctx clir.Context) error {
called = true
return nil
})
})

err := r.Run(clir.Context{
Args: []string{"-v", "dance", "-fancypants"},
})
is.NotError(t, err)
is.True(t, called)
is.NotNil(t, v)
is.True(t, *v)
is.NotNil(t, fancy)
is.True(t, *fancy)
})
}

func ExampleFlags() {
r := clir.NewRouter()

var v *bool
r.Use(middleware.Flags(func(fs *flag.FlagSet) {
v = fs.Bool("v", false, "verbose output")
}))

r.RouteFunc("", func(ctx clir.Context) error {
if *v {
ctx.Println("Hello!")
}
return nil
})

_ = r.Run(clir.Context{
Args: []string{"-v"},
Out: os.Stdout,
})
// Output: Hello!
}
44 changes: 27 additions & 17 deletions router.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package clir

import (
"fmt"
"regexp"
"strings"
)
Expand All @@ -9,7 +10,6 @@ import (
type Router struct {
middlewares []Middleware
patterns []*regexp.Regexp
routers []*Router
runners map[string]Runner
}

Expand All @@ -21,28 +21,38 @@ func NewRouter() *Router {

// Run satisfies [Runner].
func (r *Router) Run(ctx Context) error {
// Apply middlewares first, because they can modify the context, including the Context.Args to match against.
var middlewareCtx Context
var runner Runner = RunnerFunc(func(ctx Context) error {
middlewareCtx = ctx
return nil
})
// Apply middlewares in reverse order, so the first middleware is the outermost one, to be called first.
for i := len(r.middlewares) - 1; i >= 0; i-- {
runner = r.middlewares[i](runner)
}
if err := runner.Run(ctx); err != nil {
return fmt.Errorf("error while applying middleware: %w", err)
}
ctx = middlewareCtx

for _, pattern := range r.patterns {
if (len(ctx.Args) == 0 && pattern.String() == "^$") || (len(ctx.Args) > 0 && pattern.MatchString(ctx.Args[0])) {

runner := r.runners[pattern.String()]
runner = r.runners[pattern.String()]
if len(ctx.Args) > 0 {
ctx.Matches = pattern.FindStringSubmatch(ctx.Args[0])
ctx.Args = ctx.Args[1:]
}

for i := len(r.middlewares) - 1; i >= 0; i-- {
runner = r.middlewares[i](runner)
}

return runner.Run(ctx)
}
}

for _, router := range r.routers {
if err := router.Run(ctx); err == nil {
return err
}
}
//for _, router := range r.routers {
// if err := router.Run(ctx); err == nil {
// return err
// }
//}

return ErrorRouteNotFound
}
Expand Down Expand Up @@ -78,13 +88,13 @@ func (r *Router) Branch(pattern string, cb func(r *Router)) {
}

// Scope into a new [Router].
// The middlewares from the parent router are copied to the new router,
// The middlewares from the parent router are used in the new router,
// but new middlewares within the scope are only added to the new router, not the parent router.
func (r *Router) Scope(cb func(r *Router)) {
newR := NewRouter()
newR.middlewares = append(newR.middlewares, r.middlewares...)
cb(newR)
r.routers = append(r.routers, newR)
panic("not implemented")
//newR := NewRouter()
//cb(newR)
//r.routers = append(r.routers, newR)
}

// Middleware for [Router.Use].
Expand Down
35 changes: 35 additions & 0 deletions router_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package clir_test

import (
"flag"
"strings"
"testing"

Expand Down Expand Up @@ -166,9 +167,43 @@ func TestRouter_Use(t *testing.T) {

r.Use(newMiddleware(t, "m1"))
})

t.Run("can use middleware that parses flags", func(t *testing.T) {
r := clir.NewRouter()

r.Use(func(next clir.Runner) clir.Runner {
return clir.RunnerFunc(func(ctx clir.Context) error {
fs := flag.NewFlagSet("test", flag.ContinueOnError)
v := fs.Bool("v", false, "")
err := fs.Parse(ctx.Args)
is.NotError(t, err)
is.True(t, *v)

t.Log(fs.Args())
ctx.Args = fs.Args()

return next.Run(ctx)
})
})

var called bool
r.RouteFunc("", func(ctx clir.Context) error {
called = true
return nil
})

err := r.Run(clir.Context{
Args: []string{"-v"},
})
is.NotError(t, err)
is.True(t, called)
})
}

//nolint:staticcheck
func TestRouter_Scope(t *testing.T) {
t.Skip("not implemented")

t.Run("can scope routes with a new middleware stack", func(t *testing.T) {
r := clir.NewRouter()

Expand Down

0 comments on commit 39523fc

Please sign in to comment.