diff --git a/Makefile b/Makefile index 7005a53..6c44b44 100644 --- a/Makefile +++ b/Makefile @@ -18,7 +18,7 @@ gen: ## generate CURD code .PHONY: gen_error_code gen_error_code: ## generate error code - ${GO} generate github.com/TensoRaws/NuxBT-Backend/module/error/gen + ${GO} generate github.com/TensoRaws/NuxBT-Backend/module/code/gen .PHONY: test test: tidy ## go test diff --git a/conf/nuxbt.yml b/conf/nuxbt.yml index b3572b0..7505b1a 100644 --- a/conf/nuxbt.yml +++ b/conf/nuxbt.yml @@ -1,12 +1,17 @@ server: port: 8080 mode: prod - allowRegister: true - useInvitationCode: false requestLimit: 50 # 50 times per minute cros: - https://114514.com +register: + allowRegister: true + useInvitationCode: true + invitationCodeEligibilityTime: 30 # day, only that users who have registered over xx days can gen invitation code + invitationCodeExpirationTime: 7 # day, invitation code expiration time + invitationCodeLimit: 5 # invitation code limit, one user can gen xx invitation code + jwt: timeout: 600 # minute key: nuxbt diff --git a/internal/common/cache/user.go b/internal/common/cache/user.go new file mode 100644 index 0000000..cface5c --- /dev/null +++ b/internal/common/cache/user.go @@ -0,0 +1,153 @@ +package cache + +import ( + "fmt" + "time" + + "github.com/TensoRaws/NuxBT-Backend/module/cache" + "github.com/TensoRaws/NuxBT-Backend/module/config" + "github.com/TensoRaws/NuxBT-Backend/module/util" +) + +type UserInvitationMapValue struct { + CreatedAt int64 `json:"created_at"` + UsedBy int32 `json:"used_by"` + ExpiresAt int64 `json:"expires_at"` +} + +// GenerateInvitationCode 生成邀请码 +func GenerateInvitationCode(userID int32) (string, error) { + c := cache.Clients[cache.InvitationCode] + + expTime := time.Duration(config.RegisterConfig.InvitationCodeExpirationTime) * time.Hour * 24 + code := util.GetRandomString(24) + // 将生成的邀请码存储到 Redis + err := c.Set(code, userID, expTime).Err() + if err != nil { + return "", err + } + + toMapString := util.StructToString(UserInvitationMapValue{ + CreatedAt: time.Now().Unix(), // 存储邀请码的创建时间 + UsedBy: 0, // 初始状态为未使用 + ExpiresAt: time.Now().Add(expTime).Unix(), // 过期时间 + }) + + // 将邀请码信息存储到用户的哈希表中,方便查询 + err = c.HSet(fmt.Sprintf("user:%d:invitations", userID), code, toMapString).Err() + if err != nil { + return "", err + } + + // 更新哈希表键的过期时间,为 10 倍的邀请码过期时间,保证一段时间内可以查询到邀请码状态 + err = c.Expire(fmt.Sprintf("user:%d:invitations", userID), 10*expTime).Err() + if err != nil { + return "", err + } + + return code, nil +} + +type UserInvitation struct { + InvitationCode string `json:"invitation_code"` + UserInvitationMapValue +} + +// GetInvitationCodeListByUserID 获取用户近期的邀请码信息 +func GetInvitationCodeListByUserID(userID int32) ([]UserInvitation, error) { + c := cache.Clients[cache.InvitationCode] + + // 从 Redis 中获取用户的邀请码信息 + invitations, err := c.HGetAll(fmt.Sprintf("user:%d:invitations", userID)).Result() + if err != nil { + return nil, err + } + + var invitationList []UserInvitation + for code, info := range invitations { + var uim UserInvitationMapValue + err := util.StringToStruct(info, &uim) + if err != nil { + return nil, err + } + invitationList = append(invitationList, UserInvitation{ + InvitationCode: code, + UserInvitationMapValue: uim, + }) + } + + return invitationList, nil +} + +// GetValidInvitationCodeCountByUserID 获取用户有效的邀请码数量 +func GetValidInvitationCodeCountByUserID(userID int32) (int, error) { + c := cache.Clients[cache.InvitationCode] + + invitations, err := c.HGetAll(fmt.Sprintf("user:%d:invitations", userID)).Result() + if err != nil { + return 0, err + } + + count := 0 + for _, info := range invitations { + var uim UserInvitationMapValue + err := util.StringToStruct(info, &uim) + if err != nil { + return 0, err + } + + if uim.UsedBy == 0 && uim.ExpiresAt > time.Now().Unix() { + count++ + } + } + + return count, nil +} + +// ConsumeInvitationCode 注册成功后消费邀请码 +func ConsumeInvitationCode(code string, userID int32) error { + c := cache.Clients[cache.InvitationCode] + + inviterID, err := c.Get(code).Int() + if err != nil { + return err + } + + // 从 Redis 中获取邀请码信息,修改邀请码状态 + invitation, err := c.HGet(fmt.Sprintf("user:%d:invitations", inviterID), code).Result() + if err != nil { + return err + } + var uim UserInvitationMapValue + err = util.StringToStruct(invitation, &uim) + if err != nil { + return err + } + uim.UsedBy = userID + + // 更新邀请码状态 + err = c.HSet(fmt.Sprintf("user:%d:invitations", inviterID), code, util.StructToString(uim)).Err() + if err != nil { + return err + } + + // 删除邀请码 + err = c.Del(code).Err() + if err != nil { + return err + } + + return nil +} + +// GetInviterIDByInvitationCode 根据邀请码获取邀请者的 userID +func GetInviterIDByInvitationCode(code string) (int32, error) { + c := cache.Clients[cache.InvitationCode] + + userID, err := c.Get(code).Int() + if err != nil { + return 0, err + } + + return int32(userID), nil +} diff --git a/internal/common/dao/user.go b/internal/common/db/user.go similarity index 93% rename from internal/common/dao/user.go rename to internal/common/db/user.go index 5eefb25..0bdb390 100644 --- a/internal/common/dao/user.go +++ b/internal/common/db/user.go @@ -1,4 +1,4 @@ -package dao +package db import ( "github.com/TensoRaws/NuxBT-Backend/dal/model" @@ -14,8 +14,8 @@ func CreateUser(user *model.User) (err error) { // UpdateUserDataByUserID 根据 map 更新用户信息,map 中的 key 为字段名 func UpdateUserDataByUserID(userID int32, maps map[string]interface{}) (err error) { - u := query.User - _, err = u.Where(u.UserID.Eq(userID)).Updates(maps) + q := query.User + _, err = q.Where(q.UserID.Eq(userID)).Updates(maps) if err != nil { return err } diff --git a/internal/router/api/v1/api.go b/internal/router/api/v1/api.go index a85817c..16a27ec 100644 --- a/internal/router/api/v1/api.go +++ b/internal/router/api/v1/api.go @@ -63,6 +63,15 @@ func NewAPI() *gin.Engine { user.POST("profile/update", jwt.RequireAuth(cache.Clients[cache.JWTBlacklist], false), user_service.ProfileUpdate) + // 用户邀请码生成 + user.POST("invitation/gen", + jwt.RequireAuth(cache.Clients[cache.JWTBlacklist], false), + user_service.InvitationGen) + // 用户邀请码列表 + user.GET("invitation/me", + jwt.RequireAuth(cache.Clients[cache.JWTBlacklist], false), + middleware_cache.Response(cache.Clients[cache.RespCache], 5*time.Second), + user_service.InvitationMe) } } diff --git a/internal/service/user/invitation.go b/internal/service/user/invitation.go new file mode 100644 index 0000000..12dd145 --- /dev/null +++ b/internal/service/user/invitation.go @@ -0,0 +1,61 @@ +package user + +import ( + "fmt" + + "github.com/TensoRaws/NuxBT-Backend/internal/common/cache" + "github.com/TensoRaws/NuxBT-Backend/module/code" + "github.com/TensoRaws/NuxBT-Backend/module/config" + "github.com/TensoRaws/NuxBT-Backend/module/log" + "github.com/TensoRaws/NuxBT-Backend/module/resp" + "github.com/gin-gonic/gin" +) + +type InvitationGenResponse struct { + InvitationCode string `json:"invitation_code"` +} + +type InvitationMeResponse []cache.UserInvitation + +// InvitationGen 生成邀请码 (POST /invitation/gen) +func InvitationGen(c *gin.Context) { + userID, _ := resp.GetUserIDFromGinContext(c) + + count, err := cache.GetValidInvitationCodeCountByUserID(userID) + if err != nil { + return + } + log.Logger.Infof("User %d has %d valid invitation codes!", userID, count) + + if count >= config.RegisterConfig.InvitationCodeLimit { + resp.AbortWithMsg(c, code.UserErrorInvitationCodeHasReachedLimit, + fmt.Sprintf("You have generated %d invitation codes!", count)) + return + } + + codeGen, err := cache.GenerateInvitationCode(userID) + if err != nil { + resp.AbortWithMsg(c, code.UnknownError, err.Error()) + return + } + + resp.OKWithData(c, InvitationGenResponse{InvitationCode: codeGen}) + log.Logger.Infof("User %d generated invitation code_gen %s successfully!", userID, codeGen) +} + +// InvitationMe 获取邀请码列表 (GET /invitation/me) +func InvitationMe(c *gin.Context) { + userID, _ := resp.GetUserIDFromGinContext(c) + + codeList, err := cache.GetInvitationCodeListByUserID(userID) + if err != nil { + return + } + + if len(codeList) == 0 { + resp.OKWithData(c, InvitationMeResponse{}) + } else { + resp.OKWithData(c, InvitationMeResponse(codeList)) + } + log.Logger.Infof("User %d got invitation code list successfully!", userID) +} diff --git a/internal/service/user/login.go b/internal/service/user/login.go index b805a06..4b4ee66 100644 --- a/internal/service/user/login.go +++ b/internal/service/user/login.go @@ -1,7 +1,7 @@ package user import ( - "github.com/TensoRaws/NuxBT-Backend/internal/common/dao" + "github.com/TensoRaws/NuxBT-Backend/internal/common/db" "github.com/TensoRaws/NuxBT-Backend/internal/middleware/jwt" "github.com/TensoRaws/NuxBT-Backend/module/code" "github.com/TensoRaws/NuxBT-Backend/module/resp" @@ -28,7 +28,7 @@ func Login(c *gin.Context) { } // GORM 查询 - user, err := dao.GetUserByEmail(req.Email) + user, err := db.GetUserByEmail(req.Email) if err != nil { resp.AbortWithMsg(c, code.DatabaseErrorRecordNotFound, "User not found") return diff --git a/internal/service/user/profile.go b/internal/service/user/profile.go index 0cf07a3..8453235 100644 --- a/internal/service/user/profile.go +++ b/internal/service/user/profile.go @@ -1,7 +1,7 @@ package user import ( - "github.com/TensoRaws/NuxBT-Backend/internal/common/dao" + "github.com/TensoRaws/NuxBT-Backend/internal/common/db" "github.com/TensoRaws/NuxBT-Backend/module/code" "github.com/TensoRaws/NuxBT-Backend/module/log" "github.com/TensoRaws/NuxBT-Backend/module/resp" @@ -31,13 +31,13 @@ type ProfileOthersRequest struct { func ProfileMe(c *gin.Context) { userID, _ := resp.GetUserIDFromGinContext(c) - user, err := dao.GetUserByID(userID) + user, err := db.GetUserByID(userID) if err != nil { resp.AbortWithMsg(c, code.DatabaseErrorRecordNotFound, "User not found") return } - roles, err := dao.GetUserRolesByID(userID) + roles, err := db.GetUserRolesByID(userID) if err != nil { log.Logger.Info("Failed to get user roles: " + err.Error()) roles = []string{} @@ -73,13 +73,13 @@ func ProfileOthers(c *gin.Context) { userID, _ := resp.GetUserIDFromGinContext(c) // 获取信息 - user, err := dao.GetUserByID(req.UserID) + user, err := db.GetUserByID(req.UserID) if err != nil { resp.AbortWithMsg(c, code.DatabaseErrorRecordNotFound, "User not found") return } - roles, err := dao.GetUserRolesByID(req.UserID) + roles, err := db.GetUserRolesByID(req.UserID) if err != nil { log.Logger.Info("Failed to get user roles: " + err.Error()) roles = []string{} diff --git a/internal/service/user/profile_update.go b/internal/service/user/profile_update.go index 81c092d..003d4ca 100644 --- a/internal/service/user/profile_update.go +++ b/internal/service/user/profile_update.go @@ -1,7 +1,7 @@ package user import ( - "github.com/TensoRaws/NuxBT-Backend/internal/common/dao" + "github.com/TensoRaws/NuxBT-Backend/internal/common/db" "github.com/TensoRaws/NuxBT-Backend/module/code" "github.com/TensoRaws/NuxBT-Backend/module/log" "github.com/TensoRaws/NuxBT-Backend/module/resp" @@ -61,7 +61,7 @@ func ProfileUpdate(c *gin.Context) { updates["background"] = *req.Background } // 执行更新 - err := dao.UpdateUserDataByUserID(userID, updates) + err := db.UpdateUserDataByUserID(userID, updates) if err != nil { resp.AbortWithMsg(c, code.DatabaseErrorRecordUpdateFailed, err.Error()) return diff --git a/internal/service/user/register.go b/internal/service/user/register.go index c9aa631..c60ed47 100644 --- a/internal/service/user/register.go +++ b/internal/service/user/register.go @@ -4,7 +4,8 @@ import ( "time" "github.com/TensoRaws/NuxBT-Backend/dal/model" - "github.com/TensoRaws/NuxBT-Backend/internal/common/dao" + "github.com/TensoRaws/NuxBT-Backend/internal/common/cache" + "github.com/TensoRaws/NuxBT-Backend/internal/common/db" "github.com/TensoRaws/NuxBT-Backend/module/code" "github.com/TensoRaws/NuxBT-Backend/module/config" "github.com/TensoRaws/NuxBT-Backend/module/log" @@ -37,23 +38,38 @@ func Register(c *gin.Context) { return } + // 检查是否允许注册 + if !config.RegisterConfig.AllowRegister { + resp.Abort(c, code.UserErrorRegisterNotAllowed) + return + } + err := util.CheckUsername(req.Username) if err != nil { resp.AbortWithMsg(c, code.UserErrorInvalidUsername, err.Error()) return } + var inviterID int32 = 0 + // 无邀请码注册,检查是否允许无邀请码注册 if req.InvitationCode == nil || *req.InvitationCode == "" { - if config.ServerConfig.UseInvitationCode { + if config.RegisterConfig.UseInvitationCode { resp.AbortWithMsg(c, code.UserErrorInvalidInvitationCode, "invitation code is required") return } } else { - // TODO: 邀请码功能, 有邀请码注册,检查邀请码是否有效 - + // 邀请码功能, 有邀请码注册,检查邀请码是否有效 + inviterID, err = cache.GetInviterIDByInvitationCode(*req.InvitationCode) + if err != nil { + resp.AbortWithMsg(c, code.UserErrorInvalidInvitationCode, "invalid invitation code") + log.Logger.Error("invalid invitation code: " + err.Error()) + return + } log.Logger.Info("invitation code: " + *req.InvitationCode) } + + // 生成密码哈希 password, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) if err != nil { resp.AbortWithMsg(c, code.UnknownError, "failed to hash password") @@ -61,11 +77,12 @@ func Register(c *gin.Context) { return } // 注册 - err = dao.CreateUser(&model.User{ + err = db.CreateUser(&model.User{ Username: req.Username, Email: req.Email, Password: string(password), LastActive: time.Now(), + Inviter: inviterID, }) if err != nil { resp.AbortWithMsg(c, code.DatabaseErrorRecordCreateFailed, "failed to register "+err.Error()) @@ -73,13 +90,23 @@ func Register(c *gin.Context) { return } - user, err := dao.GetUserByEmail(req.Email) + // 获取用户注册的 userID + user, err := db.GetUserByEmail(req.Email) if err != nil { resp.AbortWithMsg(c, code.DatabaseErrorRecordNotFound, "failed to get user by email") log.Logger.Error("failed to get user by email: " + err.Error()) return } + // 消费邀请码 + if req.InvitationCode != nil && *req.InvitationCode != "" { + err = cache.ConsumeInvitationCode(*req.InvitationCode, user.UserID) + if err != nil { + resp.AbortWithMsg(c, code.UnknownError, "failed to consume invitation code") + return + } + } + resp.OKWithData(c, RegisterDataResponse{ Email: user.Email, UserID: user.UserID, diff --git a/internal/service/user/reset.go b/internal/service/user/reset.go index baa233c..68c2a2c 100644 --- a/internal/service/user/reset.go +++ b/internal/service/user/reset.go @@ -1,7 +1,7 @@ package user import ( - "github.com/TensoRaws/NuxBT-Backend/internal/common/dao" + "github.com/TensoRaws/NuxBT-Backend/internal/common/db" "github.com/TensoRaws/NuxBT-Backend/module/code" "github.com/TensoRaws/NuxBT-Backend/module/log" "github.com/TensoRaws/NuxBT-Backend/module/resp" @@ -31,7 +31,7 @@ func ResetPassword(c *gin.Context) { return } // 修改密码 - err = dao.UpdateUserDataByUserID(userID, map[string]interface{}{ + err = db.UpdateUserDataByUserID(userID, map[string]interface{}{ "password": password, }) if err != nil { diff --git a/module/cache/cache.go b/module/cache/cache.go index 35ca038..ea82eac 100644 --- a/module/cache/cache.go +++ b/module/cache/cache.go @@ -17,12 +17,14 @@ const ( IPLimit RDB = iota JWTBlacklist RespCache + InvitationCode ) var Clients = map[RDB]*Client{ - IPLimit: {}, - JWTBlacklist: {}, - RespCache: {}, + IPLimit: {}, + JWTBlacklist: {}, + RespCache: {}, + InvitationCode: {}, } type Client struct { diff --git a/module/cache/redis.go b/module/cache/redis.go index a8cb10a..648a07e 100644 --- a/module/cache/redis.go +++ b/module/cache/redis.go @@ -116,3 +116,23 @@ func (c Client) Set(key string, value interface{}, expiration time.Duration) *re func (c Client) Get(key string) *redis.StringCmd { return c.C.Get(c.Ctx, key) } + +func (c Client) HMSet(key string, fields map[string]interface{}) *redis.BoolCmd { + return c.C.HMSet(c.Ctx, key, fields) +} + +func (c Client) HMGet(key string, fields ...string) *redis.SliceCmd { + return c.C.HMGet(c.Ctx, key, fields...) +} + +func (c Client) HSet(key, field string, value interface{}) *redis.IntCmd { + return c.C.HSet(c.Ctx, key, field, value) +} + +func (c Client) HGet(key, field string) *redis.StringCmd { + return c.C.HGet(c.Ctx, key, field) +} + +func (c Client) HGetAll(key string) *redis.MapStringStringCmd { + return c.C.HGetAll(c.Ctx, key) +} diff --git a/module/code/code.go b/module/code/code.go index 1ee34fe..1660fc2 100644 --- a/module/code/code.go +++ b/module/code/code.go @@ -19,10 +19,12 @@ const ( DatabaseErrorRecordNotFound DatabaseErrorRecordUpdateFailed // UserError 用户侧错误 + UserErrorRegisterNotAllowed UserErrorInvalidUsername UserErrorInvalidPassword UserErrorInvalidEmail UserErrorInvalidInvitationCode + UserErrorInvitationCodeHasReachedLimit // gen code end // DO NOT EDIT ) diff --git a/module/code/code.ts b/module/code/code.ts index 27aaba4..e481589 100644 --- a/module/code/code.ts +++ b/module/code/code.ts @@ -11,8 +11,10 @@ export const enum ErrorCode { DatabaseErrorRecordCreateFailed, DatabaseErrorRecordNotFound, DatabaseErrorRecordUpdateFailed, + UserErrorRegisterNotAllowed, UserErrorInvalidUsername, UserErrorInvalidPassword, UserErrorInvalidEmail, - UserErrorInvalidInvitationCode + UserErrorInvalidInvitationCode, + UserErrorInvitationCodeHasReachedLimit } diff --git a/module/code/code_map.go b/module/code/code_map.go index ced9358..cf86af8 100644 --- a/module/code/code_map.go +++ b/module/code/code_map.go @@ -5,16 +5,18 @@ package code // codeToString use a map to store the string representation of Code var codeToString = map[Code]string{ - InternalError: "Internal error", - UnknownError: "Unknown error", - AuthErrorTokenHasBeenBlacklisted: "Auth error token has been blacklisted", - AuthErrorTokenIsInvalid: "Auth error token is invalid", - RequestErrorInvalidParams: "Request error invalid params", - DatabaseErrorRecordCreateFailed: "Database error record create failed", - DatabaseErrorRecordNotFound: "Database error record not found", - DatabaseErrorRecordUpdateFailed: "Database error record update failed", - UserErrorInvalidUsername: "User error invalid username", - UserErrorInvalidPassword: "User error invalid password", - UserErrorInvalidEmail: "User error invalid email", - UserErrorInvalidInvitationCode: "User error invalid invitation code", + InternalError: "Internal error", + UnknownError: "Unknown error", + AuthErrorTokenHasBeenBlacklisted: "Auth error token has been blacklisted", + AuthErrorTokenIsInvalid: "Auth error token is invalid", + RequestErrorInvalidParams: "Request error invalid params", + DatabaseErrorRecordCreateFailed: "Database error record create failed", + DatabaseErrorRecordNotFound: "Database error record not found", + DatabaseErrorRecordUpdateFailed: "Database error record update failed", + UserErrorRegisterNotAllowed: "User error register not allowed", + UserErrorInvalidUsername: "User error invalid username", + UserErrorInvalidPassword: "User error invalid password", + UserErrorInvalidEmail: "User error invalid email", + UserErrorInvalidInvitationCode: "User error invalid invitation code", + UserErrorInvitationCodeHasReachedLimit: "User error invitation code has reached limit", } diff --git a/module/config/config.go b/module/config/config.go index 8975f86..e22d4e6 100644 --- a/module/config/config.go +++ b/module/config/config.go @@ -48,41 +48,7 @@ func initialize() { if errors.As(err, &configFileNotFoundError) { // 配置文件未找到错误 fmt.Println("config file not found use default config") - config.SetDefault("server", map[string]interface{}{ - "port": 8080, - "mode": "prod", - "allowRegister": true, - "useInvitationCode": false, - "requestLimit": 50, - "cros": []string{}, - }) - - config.SetDefault("jwt", map[string]interface{}{ - "timeout": 60, - "key": "nuxbt", - }) - - config.SetDefault("log", map[string]interface{}{ - "level": "debug", - "mode": []string{"console", "file"}, - }) - - config.SetDefault("db", map[string]interface{}{ - "type": "mysql", - "host": "127.0.0.1", - "port": 5432, - "username": "root", - "password": "123456", - "database": "nuxbt", - "ssl": false, - }) - - config.SetDefault("redis", map[string]interface{}{ - "host": "127.0.0.1", - "port": 6379, - "password": "123456", - "poolSize": 10, - }) + configSetDefault() } } diff --git a/module/config/default.go b/module/config/default.go new file mode 100644 index 0000000..ef5616a --- /dev/null +++ b/module/config/default.go @@ -0,0 +1,45 @@ +package config + +func configSetDefault() { + config.SetDefault("server", map[string]interface{}{ + "port": 8080, + "mode": "prod", + "requestLimit": 50, + "cros": []string{}, + }) + + config.SetDefault("register", map[string]interface{}{ + "allowRegister": true, + "useInvitationCode": true, + "invitationCodeEligibilityTime": 30, + "invitationCodeExpirationTime": 7, + "invitationCodeLimit": 5, + }) + + config.SetDefault("jwt", map[string]interface{}{ + "timeout": 60, + "key": "nuxbt", + }) + + config.SetDefault("log", map[string]interface{}{ + "level": "debug", + "mode": []string{"console", "file"}, + }) + + config.SetDefault("db", map[string]interface{}{ + "type": "mysql", + "host": "127.0.0.1", + "port": 5432, + "username": "root", + "password": "123456", + "database": "nuxbt", + "ssl": false, + }) + + config.SetDefault("redis", map[string]interface{}{ + "host": "127.0.0.1", + "port": 6379, + "password": "123456", + "poolSize": 10, + }) +} diff --git a/module/config/global.go b/module/config/global.go index 7da93e0..c8ed4eb 100644 --- a/module/config/global.go +++ b/module/config/global.go @@ -5,13 +5,14 @@ import ( ) var ( - ServerConfig Server - JwtConfig Jwt - LogConfig Log - DBConfig DB - RedisConfig Redis - OSSConfig OSS - OSS_PREFIX string + ServerConfig Server + RegisterConfig Register + JwtConfig Jwt + LogConfig Log + DBConfig DB + RedisConfig Redis + OSSConfig OSS + OSS_PREFIX string ) func setConfig() { @@ -19,25 +20,36 @@ func setConfig() { if err != nil { log.Fatalf("unable to decode into server struct, %v", err) } + + err = config.UnmarshalKey("register", &RegisterConfig) + if err != nil { + log.Fatalf("unable to decode into register struct, %v", err) + } + err = config.UnmarshalKey("jwt", &JwtConfig) if err != nil { log.Fatalf("unable to decode into jwt struct, %v", err) } + err = config.UnmarshalKey("log", &LogConfig) if err != nil { log.Fatalf("unable to decode into log struct, %v", err) } + err = config.UnmarshalKey("db", &DBConfig) if err != nil { log.Fatalf("unable to decode into db struct, %v", err) } + err = config.UnmarshalKey("redis", &RedisConfig) if err != nil { log.Fatalf("unable to decode into redis struct, %v", err) } + err = config.UnmarshalKey("oss", &OSSConfig) if err != nil { log.Fatalf("unable to decode into oss struct, %v", err) } + OSS_PREFIX = GenerateOSSPrefix() } diff --git a/module/config/type.go b/module/config/type.go index e895a6b..a55054b 100644 --- a/module/config/type.go +++ b/module/config/type.go @@ -1,12 +1,18 @@ package config type Server struct { - Port int `yaml:"port"` - Mode string `yaml:"mode"` - AllowResgister bool `yaml:"allowResgister"` - UseInvitationCode bool `yaml:"useInvitationCode"` - RequestLimit int `yaml:"requestLimit"` - Cros []string `yaml:"cros"` + Port int `yaml:"port"` + Mode string `yaml:"mode"` + RequestLimit int `yaml:"requestLimit"` + Cros []string `yaml:"cros"` +} + +type Register struct { + AllowRegister bool `yaml:"allowRegister"` + UseInvitationCode bool `yaml:"useInvitationCode"` + InvitationCodeEligibilityTime int `yaml:"invitationCodeEligibilityTime"` + InvitationCodeExpirationTime int `yaml:"invitationCodeExpirationTime"` + InvitationCodeLimit int `yaml:"invitationCodeLimit"` } type Jwt struct { diff --git a/module/util/json.go b/module/util/json.go new file mode 100644 index 0000000..9d737d8 --- /dev/null +++ b/module/util/json.go @@ -0,0 +1,34 @@ +package util + +import "github.com/bytedance/sonic" + +// StructToString 结构体转字符串 +func StructToString(s interface{}) string { + // v, _ := json.Marshal(s) + v, _ := sonic.Marshal(s) + return string(v) +} + +// StringToStruct 字符串转结构体 +func StringToStruct(str string, s interface{}) error { + // return json.Unmarshal([]byte(str), s) + return sonic.Unmarshal([]byte(str), s) +} + +// StructToMap 结构体转 map[string]interface{} +func StructToMap(s interface{}) (map[string]interface{}, error) { + // 使用 sonic 将结构体序列化为 JSON + jsonBytes, err := sonic.Marshal(s) + if err != nil { + return nil, err + } + + // 将 JSON 反序列化为 map[string]interface{} + var result map[string]interface{} + err = sonic.Unmarshal(jsonBytes, &result) + if err != nil { + return nil, err + } + + return result, nil +} diff --git a/module/util/print.go b/module/util/print.go index 3d4e212..8b7feb4 100644 --- a/module/util/print.go +++ b/module/util/print.go @@ -1,7 +1,5 @@ package util -import "github.com/bytedance/sonic" - type Color string // 高亮颜色map @@ -35,16 +33,3 @@ func HighlightString(color Color, str string) string { } return colorMap[color] + str + colorMap["reset"] } - -// StructToString 结构体转字符串 -func StructToString(s interface{}) string { - // v, _ := json.Marshal(s) - v, _ := sonic.Marshal(s) - return string(v) -} - -// StringToStruct 字符串转结构体 -func StringToStruct(str string, s interface{}) error { - // return json.Unmarshal([]byte(str), s) - return sonic.Unmarshal([]byte(str), s) -}