Skip to content

Commit

Permalink
黑名单
Browse files Browse the repository at this point in the history
  • Loading branch information
Tohrusky committed Jul 12, 2024
1 parent 94b719c commit cb62d5b
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 27 deletions.
47 changes: 47 additions & 0 deletions internal/middleware/cache/jwt_blacklist.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package cache

import (
"github.com/TensoRaws/NuxBT-Backend/internal/middleware/jwt"
"github.com/TensoRaws/NuxBT-Backend/module/cache"
"github.com/TensoRaws/NuxBT-Backend/module/log"
"github.com/TensoRaws/NuxBT-Backend/module/util"
"github.com/gin-gonic/gin"
)

// JWTBlacklist 检查JWT是否在黑名单中
func JWTBlacklist(redisClient *cache.Client, enableBlacklist bool) gin.HandlerFunc {
return func(c *gin.Context) {
// 从输入的 url 中查询 token 值
token := c.Query("token")
if len(token) == 0 {
// 从输入的表单中查询 token 值
token = c.PostForm("token")
}

if len(token) == 0 {
util.AbortWithMsg(c, "JSON WEB TOKEN IS NULL")
return
}

log.Logger.Info("Get token successfully")

// 检查 Token 是否存在于 Redis 黑名单中
exists := redisClient.Exists(token).Val()
if exists > 0 {
log.Logger.Info("Token has been blacklisted")
util.AbortWithMsg(c, "Token has been blacklisted")
return
}

// 如果 Token 不在黑名单中,继续处理请求
c.Next()

// 如果启用拉黑模式,处理请求拉黑 Token
if enableBlacklist {
err := redisClient.Set(token, "", jwt.GetJWTTokenExpiredDuration()).Err()
if err != nil {
log.Logger.Error("Error adding token to blacklist: " + err.Error())
}
}
}
}
12 changes: 0 additions & 12 deletions internal/middleware/jwt/auth.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package jwt

import (
"github.com/TensoRaws/NuxBT-Backend/module/log"
"github.com/TensoRaws/NuxBT-Backend/module/util"
"github.com/gin-gonic/gin"
)
Expand All @@ -12,17 +11,6 @@ func RequireAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// 从输入的 url 中查询 token 值
token := c.Query("token")
if len(token) == 0 {
// 从输入的表单中查询 token 值
token = c.PostForm("token")
}

if len(token) == 0 {
util.AbortWithMsg(c, "JSON WEB TOKEN IS NULL")
return
}

log.Logger.Info("Get token successfully")
// auth = [[header][cliams][signature]]
// 解析 token
claims, err := ParseToken(token)
Expand Down
12 changes: 11 additions & 1 deletion internal/router/api/v1/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,18 @@ func NewAPI() *gin.Engine {
user.POST("register", user_service.Register)
// 用户登录
user.POST("login", user_service.Login)
// 用户登出
user.POST("logout",
middleware_cache.JWTBlacklist(cache.Clients[cache.JWTBlacklist], true),
jwt.RequireAuth(),
user_service.Logout,
)
// 用户信息
user.GET("profile/me", jwt.RequireAuth(), user_service.ProfileMe)
user.GET("profile/me",
middleware_cache.JWTBlacklist(cache.Clients[cache.JWTBlacklist], false),
jwt.RequireAuth(),
user_service.ProfileMe,
)
}
}

Expand Down
20 changes: 20 additions & 0 deletions internal/service/user/logout.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package user

import (
"github.com/TensoRaws/NuxBT-Backend/module/log"
"github.com/TensoRaws/NuxBT-Backend/module/util"
"github.com/gin-gonic/gin"
)

// Logout 用户登出 (POST /logout)
func Logout(c *gin.Context) {
user, err := util.GetUserIDFromGinContext(c)
if err != nil {
util.AbortWithMsg(c, "Please login first")
return
}

util.OKWithMsg(c, "Logout success")

log.Logger.Info("Logout success: " + util.StructToString(user))
}
18 changes: 16 additions & 2 deletions module/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,30 @@ package cache

import (
"bytes"
"context"
"sync"

"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)

var once sync.Once

type RDB uint8

const (
IPLimit RDB = iota
JWTBlacklist
)

var Clients = map[RDB]*Client{
IPLimit: {},
User: {},
IPLimit: {},
JWTBlacklist: {},
}

type Client struct {
C *redis.Client
Ctx context.Context
}

type responseWriter struct {
Expand Down
24 changes: 12 additions & 12 deletions module/cache/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,6 @@ import (
"github.com/redis/go-redis/v9"
)

type RDB uint8

const (
IPLimit RDB = iota
User
)

type Client struct {
C *redis.Client
Ctx context.Context
}

func NewRedisClients(clients map[RDB]*Client) {
for k := range clients {
r := redis.NewClient(&redis.Options{
Expand Down Expand Up @@ -112,3 +100,15 @@ func (c Client) ZRange(key string, start, stop int64) *redis.StringSliceCmd {
func (c Client) ZAddNX(key string, members ...redis.Z) *redis.IntCmd {
return c.C.ZAddNX(c.Ctx, key, members...)
}

func (c Client) SIsMember(key string, member interface{}) *redis.BoolCmd {
return c.C.SIsMember(c.Ctx, key, member)
}

func (c Client) SAdd(key string, members ...interface{}) *redis.IntCmd {
return c.C.SAdd(c.Ctx, key, members...)
}

func (c Client) Set(key string, value interface{}, expiration time.Duration) *redis.StatusCmd {
return c.C.Set(c.Ctx, key, value, expiration)
}

0 comments on commit cb62d5b

Please sign in to comment.