diff --git a/README.md b/README.md index d531dbf..219be21 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,10 @@ func main() { })) // use redis cluster store - // redis.NewRedisClusterStore() + // manager.MapTokenStorage(redis.NewRedisClusterStore(&redis.ClusterOptions{ + // Addrs: []string{"127.0.0.1:6379"}, + // DB: 15, + // })) } ``` diff --git a/redis.go b/redis.go index 14e5984..6ae72bf 100644 --- a/redis.go +++ b/redis.go @@ -5,7 +5,7 @@ import ( "time" "github.com/go-redis/redis" - "github.com/json-iterator/go" + jsoniter "github.com/json-iterator/go" "gopkg.in/oauth2.v3" "gopkg.in/oauth2.v3/models" "gopkg.in/oauth2.v3/utils/uuid" @@ -59,6 +59,7 @@ func NewRedisClusterStoreWithCli(cli *redis.ClusterClient, keyNamespace ...strin type clienter interface { Get(key string) *redis.StringCmd + Exists(key ...string) *redis.IntCmd TxPipeline() redis.Pipeliner Del(keys ...string) *redis.IntCmd Close() error @@ -79,12 +80,104 @@ func (s *TokenStore) wrapperKey(key string) string { return fmt.Sprintf("%s%s", s.ns, key) } +func (s *TokenStore) checkError(result redis.Cmder) (bool, error) { + if err := result.Err(); err != nil { + if err == redis.Nil { + return true, nil + } + return false, err + } + return false, nil +} + +// remove +func (s *TokenStore) remove(key string) error { + result := s.cli.Del(s.wrapperKey(key)) + _, err := s.checkError(result) + return err +} + +func (s *TokenStore) removeToken(tokenString string, isRefresh bool) error { + basicID, err := s.getBasicID(tokenString) + if err != nil { + return err + } else if basicID == "" { + return nil + } + + err = s.remove(tokenString) + if err != nil { + return err + } + + token, err := s.getToken(basicID) + if err != nil { + return err + } else if token == nil { + return nil + } + + checkToken := token.GetRefresh() + if isRefresh { + checkToken = token.GetAccess() + } + iresult := s.cli.Exists(s.wrapperKey(checkToken)) + if err := iresult.Err(); err != nil && err != redis.Nil { + return err + } else if iresult.Val() == 0 { + return s.remove(basicID) + } + + return nil +} + +func (s *TokenStore) parseToken(result *redis.StringCmd) (oauth2.TokenInfo, error) { + if ok, err := s.checkError(result); err != nil { + return nil, err + } else if ok { + return nil, nil + } + + buf, err := result.Bytes() + if err != nil { + if err == redis.Nil { + return nil, nil + } + return nil, err + } + + var token models.Token + if err := jsonUnmarshal(buf, &token); err != nil { + return nil, err + } + return &token, nil +} + +func (s *TokenStore) getToken(key string) (oauth2.TokenInfo, error) { + result := s.cli.Get(s.wrapperKey(key)) + return s.parseToken(result) +} + +func (s *TokenStore) parseBasicID(result *redis.StringCmd) (string, error) { + if ok, err := s.checkError(result); err != nil { + return "", err + } else if ok { + return "", nil + } + return result.Val(), nil +} + +func (s *TokenStore) getBasicID(token string) (string, error) { + result := s.cli.Get(s.wrapperKey(token)) + return s.parseBasicID(result) +} + // Create Create and store the new token information -func (s *TokenStore) Create(info oauth2.TokenInfo) (err error) { +func (s *TokenStore) Create(info oauth2.TokenInfo) error { ct := time.Now() jv, err := jsonMarshal(info) if err != nil { - return + return err } pipe := s.cli.TxPipeline() @@ -107,99 +200,46 @@ func (s *TokenStore) Create(info oauth2.TokenInfo) (err error) { pipe.Set(s.wrapperKey(basicID), jv, rexp) } - if _, verr := pipe.Exec(); verr != nil { - err = verr + if _, err := pipe.Exec(); err != nil { + return err } - return -} - -// remove -func (s *TokenStore) remove(key string) (err error) { - _, verr := s.cli.Del(s.wrapperKey(key)).Result() - if verr != redis.Nil { - err = verr - } - return + return nil } // RemoveByCode Use the authorization code to delete the token information -func (s *TokenStore) RemoveByCode(code string) (err error) { - err = s.remove(code) - return +func (s *TokenStore) RemoveByCode(code string) error { + return s.remove(code) } // RemoveByAccess Use the access token to delete the token information -func (s *TokenStore) RemoveByAccess(access string) (err error) { - err = s.remove(access) - return +func (s *TokenStore) RemoveByAccess(access string) error { + return s.removeToken(access, false) } // RemoveByRefresh Use the refresh token to delete the token information -func (s *TokenStore) RemoveByRefresh(refresh string) (err error) { - err = s.remove(refresh) - return -} - -func (s *TokenStore) getData(key string) (ti oauth2.TokenInfo, err error) { - result := s.cli.Get(s.wrapperKey(key)) - if verr := result.Err(); verr != nil { - if verr == redis.Nil { - return - } - err = verr - return - } - - iv, err := result.Bytes() - if err != nil { - return - } - - var tm models.Token - if verr := jsonUnmarshal(iv, &tm); verr != nil { - err = verr - return - } - - ti = &tm - return -} - -func (s *TokenStore) getBasicID(token string) (basicID string, err error) { - tv, verr := s.cli.Get(s.wrapperKey(token)).Result() - if verr != nil { - if verr == redis.Nil { - return - } - err = verr - return - } - basicID = tv - return +func (s *TokenStore) RemoveByRefresh(refresh string) error { + return s.removeToken(refresh, false) } // GetByCode Use the authorization code for token information data -func (s *TokenStore) GetByCode(code string) (ti oauth2.TokenInfo, err error) { - ti, err = s.getData(code) - return +func (s *TokenStore) GetByCode(code string) (oauth2.TokenInfo, error) { + return s.getToken(code) } // GetByAccess Use the access token for token information data -func (s *TokenStore) GetByAccess(access string) (ti oauth2.TokenInfo, err error) { +func (s *TokenStore) GetByAccess(access string) (oauth2.TokenInfo, error) { basicID, err := s.getBasicID(access) if err != nil || basicID == "" { - return + return nil, err } - ti, err = s.getData(basicID) - return + return s.getToken(basicID) } // GetByRefresh Use the refresh token for token information data -func (s *TokenStore) GetByRefresh(refresh string) (ti oauth2.TokenInfo, err error) { +func (s *TokenStore) GetByRefresh(refresh string) (oauth2.TokenInfo, error) { basicID, err := s.getBasicID(refresh) if err != nil || basicID == "" { - return + return nil, err } - ti, err = s.getData(basicID) - return + return s.getToken(basicID) }