Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add options for skipping access checks and register #11

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
60 changes: 51 additions & 9 deletions throttle.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,19 @@ type Options struct {
// If the throttle is disabled or not
// defaults to false
Disabled bool

// If this function returns true, the request will not be counted towards the access count.
// You can set it to provide your own conditions for a request to be counted based on the request or the response,
// for example to exclude success responses from the count.
SkipRegisterFunction func(resp http.ResponseWriter, req *http.Request) bool

// If this function returns true, the request will not be checked for access, the policy will be ignored.
// You can set it to provide your own conditions for a request or a response to be allowed, for example to skip
// throttling on an IP allowlist.
// Note: You can't delay processing here with something like c.Next() until after the request, because that will
// make the access check to happen after executing the controller handler. Because of this, be aware that resp might
// not contain what you want yet.
SkipAccessCheckFunction func(resp http.ResponseWriter, req *http.Request) bool
}

// KeyValueStorer is the required interface for the Store Option
Expand Down Expand Up @@ -224,6 +237,14 @@ func (o *Options) Identify(req *http.Request) string {
return o.IdentificationFunction(req)
}

func (o *Options) SkipRegister(resp http.ResponseWriter, req *http.Request) bool {
return o.SkipRegisterFunction(resp, req)
}

func (o *Options) SkipAccessCheck(resp http.ResponseWriter, req *http.Request) bool {
return o.SkipAccessCheckFunction(resp, req)
}

// A throttling Policy
// Takes two arguments, one required:
// First is a Quota (A Limit with an associated time). When the given Limit
Expand All @@ -242,17 +263,28 @@ func Policy(quota *Quota, options ...*Options) func(resp http.ResponseWriter, re
return func(resp http.ResponseWriter, req *http.Request) {
id := makeKey(o.KeyPrefix, quota.KeyId(), o.Identify(req))

// Already set rate limit headers in case the SkipRegister method calls some delay method like c.Next() and we
// might not be able to set the headers again in that case, because the response has already been written.
setRateLimitHeaders(resp, controller, id)

if o.SkipAccessCheck(resp, req) {
return
}

if controller.DeniesAccess(id) {
msg := newAccessMessage(o.StatusCode, o.Message)
setRateLimitHeaders(resp, controller, id)
resp.WriteHeader(msg.StatusCode)
resp.Write([]byte(msg.Message))
return
} else {
}

if !o.SkipRegister(resp, req) {
controller.RegisterAccess(id)

// Set the headers again because the rate limit values have been changed at this point due to calling
// RegisterAccess.
setRateLimitHeaders(resp, controller, id)
}

}
}

Expand All @@ -279,6 +311,14 @@ func defaultIdentify(req *http.Request) string {
return ip
}

func defaultSkipRegister(http.ResponseWriter, *http.Request) bool {
return false
}

func defaultSkipAccess(http.ResponseWriter, *http.Request) bool {
return false
}

// Make a key from various parts for use in the key value store
func makeKey(parts ...string) string {
return strings.Join(parts, "_")
Expand All @@ -287,12 +327,14 @@ func makeKey(parts ...string) string {
// Creates new default options and assigns any given options
func newOptions(options []*Options) *Options {
o := Options{
StatusCode: defaultStatusCode,
Message: defaultMessage,
IdentificationFunction: defaultIdentify,
KeyPrefix: defaultKeyPrefix,
Store: nil,
Disabled: defaultDisabled,
StatusCode: defaultStatusCode,
Message: defaultMessage,
IdentificationFunction: defaultIdentify,
KeyPrefix: defaultKeyPrefix,
Store: nil,
Disabled: defaultDisabled,
SkipRegisterFunction: defaultSkipRegister,
SkipAccessCheckFunction: defaultSkipAccess,
}

// when all defaults, return it
Expand Down