From cf3b3ee807310d123bbe7904e58a004e2868ab4d Mon Sep 17 00:00:00 2001 From: Tohru <65994850+Tohrusky@users.noreply.github.com> Date: Fri, 12 Jul 2024 21:41:39 +0800 Subject: [PATCH] feat: jwt blacklist (#3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 黑名单 --- go.sum | 2 - internal/middleware/cache/jwt_blacklist.go | 47 ++++++++++++++++++++++ internal/middleware/jwt/auth.go | 12 ------ internal/router/api/v1/api.go | 12 +++++- internal/service/user/logout.go | 20 +++++++++ module/cache/cache.go | 18 ++++++++- module/cache/redis.go | 24 +++++------ 7 files changed, 106 insertions(+), 29 deletions(-) create mode 100644 internal/middleware/cache/jwt_blacklist.go create mode 100644 internal/service/user/logout.go diff --git a/go.sum b/go.sum index 9180ab4..b433df7 100644 --- a/go.sum +++ b/go.sum @@ -161,8 +161,6 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/redis/go-redis/v9 v9.5.3 h1:fOAp1/uJG+ZtcITgZOfYFmTKPE7n4Vclj1wZFgRciUU= -github.com/redis/go-redis/v9 v9.5.3/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/redis/go-redis/v9 v9.5.4 h1:vOFYDKKVgrI5u++QvnMT7DksSMYg7Aw/Np4vLJLKLwY= github.com/redis/go-redis/v9 v9.5.4/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= diff --git a/internal/middleware/cache/jwt_blacklist.go b/internal/middleware/cache/jwt_blacklist.go new file mode 100644 index 0000000..abf82ed --- /dev/null +++ b/internal/middleware/cache/jwt_blacklist.go @@ -0,0 +1,47 @@ +package cache + +import ( + "github.com/TensoRaws/NuxBT-Backend/internal/middleware/jwt" + "github.com/TensoRaws/NuxBT-Backend/module/cache" + "github.com/TensoRaws/NuxBT-Backend/module/log" + "github.com/TensoRaws/NuxBT-Backend/module/util" + "github.com/gin-gonic/gin" +) + +// JWTBlacklist 检查JWT是否在黑名单中 +func JWTBlacklist(redisClient *cache.Client, enableBlacklist bool) gin.HandlerFunc { + return func(c *gin.Context) { + // 从输入的 url 中查询 token 值 + token := c.Query("token") + if len(token) == 0 { + // 从输入的表单中查询 token 值 + token = c.PostForm("token") + } + + if len(token) == 0 { + util.AbortWithMsg(c, "JSON WEB TOKEN IS NULL") + return + } + + log.Logger.Info("Get token successfully") + + // 检查 Token 是否存在于 Redis 黑名单中 + exists := redisClient.Exists(token).Val() + if exists > 0 { + log.Logger.Info("Token has been blacklisted") + util.AbortWithMsg(c, "Token has been blacklisted") + return + } + + // 如果 Token 不在黑名单中,继续处理请求 + c.Next() + + // 如果启用拉黑模式,处理请求拉黑 Token + if enableBlacklist { + err := redisClient.Set(token, "", jwt.GetJWTTokenExpiredDuration()).Err() + if err != nil { + log.Logger.Error("Error adding token to blacklist: " + err.Error()) + } + } + } +} diff --git a/internal/middleware/jwt/auth.go b/internal/middleware/jwt/auth.go index d3a1ca5..1f08894 100644 --- a/internal/middleware/jwt/auth.go +++ b/internal/middleware/jwt/auth.go @@ -1,7 +1,6 @@ package jwt import ( - "github.com/TensoRaws/NuxBT-Backend/module/log" "github.com/TensoRaws/NuxBT-Backend/module/util" "github.com/gin-gonic/gin" ) @@ -12,17 +11,6 @@ func RequireAuth() gin.HandlerFunc { return func(c *gin.Context) { // 从输入的 url 中查询 token 值 token := c.Query("token") - if len(token) == 0 { - // 从输入的表单中查询 token 值 - token = c.PostForm("token") - } - - if len(token) == 0 { - util.AbortWithMsg(c, "JSON WEB TOKEN IS NULL") - return - } - - log.Logger.Info("Get token successfully") // auth = [[header][cliams][signature]] // 解析 token claims, err := ParseToken(token) diff --git a/internal/router/api/v1/api.go b/internal/router/api/v1/api.go index b343b7f..eb2143f 100644 --- a/internal/router/api/v1/api.go +++ b/internal/router/api/v1/api.go @@ -33,8 +33,18 @@ func NewAPI() *gin.Engine { user.POST("register", user_service.Register) // 用户登录 user.POST("login", user_service.Login) + // 用户登出 + user.POST("logout", + middleware_cache.JWTBlacklist(cache.Clients[cache.JWTBlacklist], true), + jwt.RequireAuth(), + user_service.Logout, + ) // 用户信息 - user.GET("profile/me", jwt.RequireAuth(), user_service.ProfileMe) + user.GET("profile/me", + middleware_cache.JWTBlacklist(cache.Clients[cache.JWTBlacklist], false), + jwt.RequireAuth(), + user_service.ProfileMe, + ) } } diff --git a/internal/service/user/logout.go b/internal/service/user/logout.go new file mode 100644 index 0000000..8256c8a --- /dev/null +++ b/internal/service/user/logout.go @@ -0,0 +1,20 @@ +package user + +import ( + "github.com/TensoRaws/NuxBT-Backend/module/log" + "github.com/TensoRaws/NuxBT-Backend/module/util" + "github.com/gin-gonic/gin" +) + +// Logout 用户登出 (POST /logout) +func Logout(c *gin.Context) { + user, err := util.GetUserIDFromGinContext(c) + if err != nil { + util.AbortWithMsg(c, "Please login first") + return + } + + util.OKWithMsg(c, "Logout success") + + log.Logger.Info("Logout success: " + util.StructToString(user)) +} diff --git a/module/cache/cache.go b/module/cache/cache.go index 83cea8f..5fa180e 100644 --- a/module/cache/cache.go +++ b/module/cache/cache.go @@ -2,16 +2,30 @@ package cache import ( "bytes" + "context" "sync" "github.com/gin-gonic/gin" + "github.com/redis/go-redis/v9" ) var once sync.Once +type RDB uint8 + +const ( + IPLimit RDB = iota + JWTBlacklist +) + var Clients = map[RDB]*Client{ - IPLimit: {}, - User: {}, + IPLimit: {}, + JWTBlacklist: {}, +} + +type Client struct { + C *redis.Client + Ctx context.Context } type responseWriter struct { diff --git a/module/cache/redis.go b/module/cache/redis.go index ebe0a5a..28fa4bf 100644 --- a/module/cache/redis.go +++ b/module/cache/redis.go @@ -9,18 +9,6 @@ import ( "github.com/redis/go-redis/v9" ) -type RDB uint8 - -const ( - IPLimit RDB = iota - User -) - -type Client struct { - C *redis.Client - Ctx context.Context -} - func NewRedisClients(clients map[RDB]*Client) { for k := range clients { r := redis.NewClient(&redis.Options{ @@ -112,3 +100,15 @@ func (c Client) ZRange(key string, start, stop int64) *redis.StringSliceCmd { func (c Client) ZAddNX(key string, members ...redis.Z) *redis.IntCmd { return c.C.ZAddNX(c.Ctx, key, members...) } + +func (c Client) SIsMember(key string, member interface{}) *redis.BoolCmd { + return c.C.SIsMember(c.Ctx, key, member) +} + +func (c Client) SAdd(key string, members ...interface{}) *redis.IntCmd { + return c.C.SAdd(c.Ctx, key, members...) +} + +func (c Client) Set(key string, value interface{}, expiration time.Duration) *redis.StatusCmd { + return c.C.Set(c.Ctx, key, value, expiration) +}