-
Notifications
You must be signed in to change notification settings - Fork 0
/
jwt.go
116 lines (99 loc) · 2.59 KB
/
jwt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
package jwt
import (
"errors"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/dgrijalva/jwt-go"
"github.com/lestrrat/go-jwx/jwk"
log "github.com/sirupsen/logrus"
"io/ioutil"
"os"
"strings"
"time"
)
var jwksSet *jwk.Set = nil
var stop = false
func NewAuth(jwksFetcher func(string) (*jwk.Set, error), path string, sleepDuration time.Duration) {
if jwksSet != nil {
for {
if stop == false {
loadConfiguration(jwksFetcher, path, sleepDuration)
time.Sleep(sleepDuration)
} else {
break
}
}
} else {
loadConfiguration(jwksFetcher, path, sleepDuration)
go func() {
time.Sleep(sleepDuration)
NewAuth(jwksFetcher, path, sleepDuration)
}()
}
}
func StopReloadingJWKS() {
stop = true
jwksSet = nil
}
func loadConfiguration(jwksFetcher func(string) (*jwk.Set, error), path string, sleepDuration time.Duration) {
newSet, err := jwksFetcher(path)
if err != nil {
log.Error(err)
} else {
log.Info("Reinitialized jwt-auth")
jwksSet = newSet
}
}
func FetchJwksConfigurationFromSSM(ssmPath string) (*jwk.Set, error) {
sess, err := session.NewSessionWithOptions(session.Options{
Config: aws.Config{Region: aws.String("eu-central-1")},
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
return nil, err
}
ssmsvc := ssm.New(sess)
withDecryption := true
param, err := ssmsvc.GetParameter(&ssm.GetParameterInput{
Name: &ssmPath,
WithDecryption: &withDecryption,
})
if err != nil {
return nil, err
}
value := *param.Parameter.Value
return jwk.Parse([]byte(value))
}
func FetchJwksConfigurationFromFS(jwksURL string) (*jwk.Set, error) {
jsonFile, err := os.Open(jwksURL)
if err != nil {
log.Error(err)
}
defer jsonFile.Close()
byteValue, _ := ioutil.ReadAll(jsonFile)
return jwk.Parse(byteValue)
}
func getKey(token *jwt.Token) (interface{}, error) {
keyID, ok := token.Header["kid"].(string)
if !ok {
return nil, errors.New("expecting JWT header to have string kid")
}
if key := jwksSet.LookupKeyID(keyID); len(key) == 1 {
return key[0].Materialize()
}
return nil, errors.New("unable to find key")
}
func DecodeToken(bearerToken string, claims jwt.Claims) (*jwt.Token, error) {
extractedToken := strings.Split(bearerToken, "Bearer ")
if len(extractedToken) != 2 {
return nil, errors.New("error getting token from authorization header")
}
tokenString := extractedToken[1]
token, err := jwt.ParseWithClaims(tokenString, claims, getKey)
if err != nil {
log.Error("Error decoding token: ", err)
return nil, err
}
return token, nil
}