Skip to content

Commit

Permalink
Add Validating Directives
Browse files Browse the repository at this point in the history
Adding support for directives which are evaluated prior to executing
any resolvers. This allows validation to be performed on the request
and prevent it from executing any significant work by rejecting the
request early.

The most obvious case for this is authorization: based on the requested
fields, we can tell whether the request is valid given the current
user, and reject the entire request. If that were applied at resolution
time, the request would have partially resolved, only to return errors
for the specific fields which are not authorized.
  • Loading branch information
dackroyd committed Feb 18, 2023
1 parent 8d0aad8 commit d3994f5
Show file tree
Hide file tree
Showing 9 changed files with 383 additions and 32 deletions.
5 changes: 5 additions & 0 deletions directives/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,8 @@ type Resolver interface {
type ResolverInterceptor interface {
Resolve(ctx context.Context, args interface{}, next Resolver) (output interface{}, err error)
}

// Validator directive which executes before anything is resolved, allowing the request to be rejected.
type Validator interface {
Validate(ctx context.Context, args interface{}) error
}
8 changes: 4 additions & 4 deletions example/directives/authorization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,16 @@ $ curl 'http://localhost:8080/query' \
return "hasRole"
}

func (h *HasRoleDirective) Resolve(ctx context.Context, args interface{}, next directives.Resolver) (output interface{}, err error) {
func (h *HasRoleDirective) Validate(ctx context.Context, _ interface{}) error {
u, ok := user.FromContext(ctx)
if !ok {
return nil, fmt.Errorf("user not provided in context")
return fmt.Errorf("user not provided in context")
}
role := strings.ToLower(h.Role)
if !u.HasRole(role) {
return nil, fmt.Errorf("access denied, %q role required", role)
return fmt.Errorf("access denied, %q role required", role)
}
return next.Resolve(ctx, args)
return nil
}
```

Expand Down
9 changes: 4 additions & 5 deletions example/directives/authorization/authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"strings"

"github.com/graph-gophers/graphql-go/directives"
"github.com/graph-gophers/graphql-go/example/directives/authorization/user"
)

Expand Down Expand Up @@ -36,17 +35,17 @@ func (h *HasRoleDirective) ImplementsDirective() string {
return "hasRole"
}

func (h *HasRoleDirective) Resolve(ctx context.Context, args interface{}, next directives.Resolver) (output interface{}, err error) {
func (h *HasRoleDirective) Validate(ctx context.Context, _ interface{}) error {
u, ok := user.FromContext(ctx)
if !ok {
return nil, fmt.Errorf("user not provided in cotext")
return fmt.Errorf("user not provided in cotext")
}
role := strings.ToLower(h.Role)
if !u.HasRole(role) {
return nil, fmt.Errorf("access denied, %q role required", role)
return fmt.Errorf("access denied, %q role required", role)
}

return next.Resolve(ctx, args)
return nil
}

type Resolver struct{}
Expand Down
42 changes: 35 additions & 7 deletions example_directives_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"os"
"strings"

"github.com/graph-gophers/graphql-go"
"github.com/graph-gophers/graphql-go/directives"
Expand All @@ -22,11 +23,31 @@ func (h *HasRoleDirective) ImplementsDirective() string {
return "hasRole"
}

func (h *HasRoleDirective) Resolve(ctx context.Context, in interface{}, next directives.Resolver) (interface{}, error) {
func (h *HasRoleDirective) Validate(ctx context.Context, _ interface{}) error {
if ctx.Value(RoleKey) != h.Role {
return nil, fmt.Errorf("access deinied, role %q required", h.Role)
return fmt.Errorf("access denied, role %q required", h.Role)
}
return next.Resolve(ctx, in)
return nil
}

type UpperDirective struct{}

func (d *UpperDirective) ImplementsDirective() string {
return "upper"
}

func (d *UpperDirective) Resolve(ctx context.Context, args interface{}, next directives.Resolver) (interface{}, error) {
out, err := next.Resolve(ctx, args)
if err != nil {
return out, err
}

s, ok := out.(string)
if !ok {
return out, nil
}

return strings.ToUpper(s), nil
}

type authResolver struct{}
Expand All @@ -43,13 +64,14 @@ func ExampleDirectives() {
}
directive @hasRole(role: String!) on FIELD_DEFINITION
directive @upper on FIELD_DEFINITION
type Query {
greet(name: String!): String! @hasRole(role: "admin")
greet(name: String!): String! @hasRole(role: "admin") @upper
}
`
opts := []graphql.SchemaOpt{
graphql.Directives(&HasRoleDirective{}),
graphql.Directives(&HasRoleDirective{}, &UpperDirective{}),
// other options go here
}
schema := graphql.MustParseSchema(s, &authResolver{}, opts...)
Expand Down Expand Up @@ -86,7 +108,13 @@ func ExampleDirectives() {
// {
// "errors": [
// {
// "message": "access deinied, role \"admin\" required",
// "message": "access denied, role \"admin\" required",
// "locations": [
// {
// "line": 10,
// "column": 4
// }
// ],
// "path": [
// "greet"
// ]
Expand All @@ -97,7 +125,7 @@ func ExampleDirectives() {
// Admin user result:
// {
// "data": {
// "greet": "Hello, GraphQL!"
// "greet": "HELLO, GRAPHQL!"
// }
// }
}
163 changes: 163 additions & 0 deletions graphql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"strings"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -456,6 +457,168 @@ func TestCustomDirective(t *testing.T) {
})
}

func TestCustomValidatingDirective(t *testing.T) {
t.Parallel()

gqltesting.RunTests(t, []*gqltesting.Test{
{
Schema: graphql.MustParseSchema(`
directive @hasRole(role: String!) on FIELD_DEFINITION
schema {
query: Query
}
type Query {
hello: String! @hasRole(role: "ADMIN")
}`,
&helloResolver{},
graphql.Directives(&HasRoleDirective{}),
),
Context: context.WithValue(context.Background(), RoleKey, "USER"),
Query: `
{
hello
}
`,
ExpectedResult: "null",
ExpectedErrors: []*gqlerrors.QueryError{
{Message: `access denied, role "ADMIN" required`, Locations: []gqlerrors.Location{{Line: 9, Column: 6}}, Path: []interface{}{"hello"}},
},
},
{
Schema: graphql.MustParseSchema(`
directive @hasRole(role: String!) on FIELD_DEFINITION
schema {
query: Query
}
type Query {
hello: String! @hasRole(role: "ADMIN")
}`,
&helloResolver{},
graphql.Directives(&HasRoleDirective{}),
),
Context: context.WithValue(context.Background(), RoleKey, "ADMIN"),
Query: `
{
hello
}
`,
ExpectedResult: `
{
"hello": "Hello world!"
}
`,
},
{
Schema: graphql.MustParseSchema(
`directive @hasRole(role: String!) on FIELD_DEFINITION
`+strings.ReplaceAll(
starwars.Schema,
"hero(episode: Episode = NEWHOPE): Character",
`hero(episode: Episode = NEWHOPE): Character @hasRole(role: "REBELLION")`,
),
&starwars.Resolver{},
graphql.Directives(&HasRoleDirective{}),
),
Context: context.WithValue(context.Background(), RoleKey, "EMPIRE"),
Query: `
query HeroesOfTheRebellion($episode: Episode!) {
hero(episode: $episode) {
id name
... on Human { starships { id name } }
... on Droid { primaryFunction }
}
}
`,
Variables: map[string]interface{}{"episode": "NEWHOPE"},
ExpectedResult: "null",
ExpectedErrors: []*gqlerrors.QueryError{
{Message: `access denied, role "REBELLION" required`, Locations: []gqlerrors.Location{{Line: 10, Column: 3}}, Path: []interface{}{"hero"}},
},
},
{
Schema: graphql.MustParseSchema(
`directive @hasRole(role: String!) on FIELD_DEFINITION
`+strings.ReplaceAll(
starwars.Schema,
"starships: [Starship]",
`starships: [Starship] @hasRole(role: "REBELLION")`,
),
&starwars.Resolver{},
graphql.Directives(&HasRoleDirective{}),
),
Context: context.WithValue(context.Background(), RoleKey, "EMPIRE"),
Query: `
query HeroesOfTheRebellion($episode: Episode!) {
hero(episode: $episode) {
id name
... on Human { starships { id name } }
... on Droid { primaryFunction }
}
}
`,
Variables: map[string]interface{}{"episode": "NEWHOPE"},
ExpectedResult: "null",
ExpectedErrors: []*gqlerrors.QueryError{
{Message: `access denied, role "REBELLION" required`, Locations: []gqlerrors.Location{{Line: 68, Column: 3}}, Path: []interface{}{"hero", "starships"}},
},
},
{
Schema: graphql.MustParseSchema(
`directive @restrictImperialUnits on FIELD_DEFINITION
`+strings.ReplaceAll(
starwars.Schema,
"height(unit: LengthUnit = METER): Float!",
`height(unit: LengthUnit = METER): Float! @restrictImperialUnits`,
),
&starwars.Resolver{},
graphql.Directives(&restrictImperialUnitsDirective{}),
),
Context: context.WithValue(context.Background(), RoleKey, "REBELLION"),
Query: `
query HeroesOfTheRebellion($episode: Episode!) {
hero(episode: $episode) {
id name
... on Human { height(unit: FOOT) }
}
}
`,
Variables: map[string]interface{}{"episode": "NEWHOPE"},
ExpectedResult: "null",
ExpectedErrors: []*gqlerrors.QueryError{
{Message: `rebel scum cannot request imperial units`, Locations: []gqlerrors.Location{{Line: 58, Column: 3}}, Path: []interface{}{"hero", "height"}},
},
},
})
}

type restrictImperialUnitsDirective struct{}

func (d *restrictImperialUnitsDirective) ImplementsDirective() string {
return "restrictImperialUnits"
}

func (d *restrictImperialUnitsDirective) Validate(ctx context.Context, args interface{}) error {
if ctx.Value(RoleKey) == "EMPIRE" {
return nil
}

v, ok := args.(struct {
Unit string
})
if ok && v.Unit == "FOOT" {
return fmt.Errorf("rebel scum cannot request imperial units")
}

return nil
}

func TestCustomDirectiveStructFieldResolver(t *testing.T) {
t.Parallel()

Expand Down
12 changes: 12 additions & 0 deletions internal/exec/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ func (r *Request) Execute(ctx context.Context, s *resolvable.Schema, op *types.O
default:
panic("unknown query operation")
}

if errs := validateSelections(ctx, sels, nil, s); errs != nil {
r.Errs = errs
out.Write([]byte("null"))
return
}

r.execSelections(ctx, sels, nil, s, resolver, &out, op.Type == query.Mutation)
}()

Expand All @@ -64,6 +71,11 @@ func (r *Request) Execute(ctx context.Context, s *resolvable.Schema, op *types.O
return out.Bytes(), r.Errs
}

type fieldToValidate struct {
field *selected.SchemaField
sels []selected.Selection
}

type fieldToExec struct {
field *selected.SchemaField
sels []selected.Selection
Expand Down
Loading

0 comments on commit d3994f5

Please sign in to comment.