diff --git a/cmd/web_server/handler/user.go b/cmd/web_server/handler/user.go index b99793d..dd94996 100644 --- a/cmd/web_server/handler/user.go +++ b/cmd/web_server/handler/user.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/gin-gonic/gin" + user_model "github.com/oj-lab/oj-lab-platform/models/user" "github.com/oj-lab/oj-lab-platform/modules" "github.com/oj-lab/oj-lab-platform/modules/middleware" user_service "github.com/oj-lab/oj-lab-platform/services/user" @@ -13,6 +14,7 @@ import ( func SetupUserRouter(baseRoute *gin.RouterGroup) { g := baseRoute.Group("/user") { + g.PUT("", updateUser) g.GET("/health", func(ginCtx *gin.Context) { ginCtx.String(http.StatusOK, "Hello, this is user service") }) @@ -44,12 +46,12 @@ func login(ginCtx *gin.Context) { return } - lsId, err := user_service.StartLoginSession(ginCtx, body.Account, body.Password) + ls, err := user_service.StartLoginSession(ginCtx, body.Account, body.Password) if err != nil { modules.NewInternalError(fmt.Sprintf("failed to login: %v", err)).AppendToGin(ginCtx) return } - middleware.SetLoginSessionCookie(ginCtx, lsId.String()) + middleware.SetLoginSessionKeyCookie(ginCtx, ls.Key) ginCtx.Status(http.StatusOK) } @@ -63,12 +65,12 @@ func login(ginCtx *gin.Context) { // @Success 200 // @Failure 401 func me(ginCtx *gin.Context) { - ls := middleware.GetLoginSession(ginCtx) - if ls == nil { - modules.NewUnauthorizedError("not logined").AppendToGin(ginCtx) + ls, err := middleware.GetLoginSessionFromGinCtx(ginCtx) + if err != nil { + modules.NewUnauthorizedError("cannot load login session from cookie").AppendToGin(ginCtx) return } - user, err := user_service.GetUser(ginCtx, ls.Account) + user, err := user_service.GetUser(ginCtx, ls.Key.Account) if err != nil { modules.NewInternalError(fmt.Sprintf("failed to get user: %v", err)).AppendToGin(ginCtx) return @@ -94,3 +96,22 @@ func checkUserExist(ginCtx *gin.Context) { "exist": exist, }) } + +type updateUserBody struct { + User user_model.User `json:"user"` +} + +func updateUser(ginCtx *gin.Context) { + body := &updateUserBody{} + err := ginCtx.BindJSON(body) + if err != nil { + modules.NewInvalidParamError("body", err.Error()).AppendToGin(ginCtx) + return + } + + err = user_service.UpdateUser(ginCtx, body.User) + if err != nil { + modules.NewInternalError(fmt.Sprintf("failed to update user: %v", err)).AppendToGin(ginCtx) + return + } +} diff --git a/models/user/user.go b/models/user/user.go index 0abddbf..83544fa 100644 --- a/models/user/user.go +++ b/models/user/user.go @@ -21,10 +21,10 @@ type Role struct { Users []*User `gorm:"many2many:user_roles" json:"users,omitempty"` } -func (user User) GetRolesStringArray() []string { - roles := make([]string, len(user.Roles)) - for i, role := range user.Roles { - roles[i] = role.Name +func (user User) GetRolesStringSet() map[string]struct{} { + roleSet := map[string]struct{}{} + for _, role := range user.Roles { + roleSet[role.Name] = struct{}{} } - return roles + return roleSet } diff --git a/modules/auth/login_session.go b/modules/auth/login_session.go index 6880857..2f56e61 100644 --- a/modules/auth/login_session.go +++ b/modules/auth/login_session.go @@ -7,41 +7,62 @@ import ( "github.com/google/uuid" ) +type LoginSessionKey struct { + Account string + Id uuid.UUID +} + +type LoginSessionData struct { + RoleSet map[string]struct{} +} + type LoginSession struct { - Id uuid.UUID `json:"-"` - Account string `json:"account"` - Roles []string `json:"roles"` + Key LoginSessionKey + Data LoginSessionData +} + +func (data LoginSessionData) HasRoles(roles []string) bool { + for _, role := range roles { + if _, ok := data.RoleSet[role]; !ok { + return false + } + } + return true } -func (ls LoginSession) GetJsonString() (string, error) { - lsBytes, err := json.Marshal(ls) +func (data LoginSessionData) GetJsonString() (string, error) { + bytes, err := json.Marshal(data) if err != nil { return "", err } - lsString := string(lsBytes) + dataString := string(bytes) - return lsString, nil + return dataString, nil } -func GetLoginSessionFromJsonString(lsString string) (*LoginSession, error) { - ls := &LoginSession{} - err := json.Unmarshal([]byte(lsString), ls) +func getLoginSessionDataFromJsonString(dataString string) (*LoginSessionData, error) { + data := &LoginSessionData{} + err := json.Unmarshal([]byte(dataString), data) if err != nil { return nil, err } - return ls, nil + return data, nil } -func NewLoginSession(ls LoginSession) *LoginSession { +func NewLoginSession(account string, data LoginSessionData) *LoginSession { return &LoginSession{ - Id: uuid.New(), - Account: ls.Account, - Roles: ls.Roles, + LoginSessionKey{ + Account: account, + Id: uuid.New(), + }, + LoginSessionData{ + RoleSet: data.RoleSet, + }, } } func (ls LoginSession) SaveToRedis(ctx context.Context) error { - err := SetLoginSession(ctx, ls) + err := SetLoginSession(ctx, ls.Key, ls.Data) if err != nil { return err } diff --git a/modules/auth/redis.go b/modules/auth/redis.go index e2fc704..ede64ba 100644 --- a/modules/auth/redis.go +++ b/modules/auth/redis.go @@ -5,22 +5,27 @@ import ( "fmt" "time" - "github.com/google/uuid" redisAgent "github.com/oj-lab/oj-lab-platform/modules/agent/redis" + "github.com/oj-lab/oj-lab-platform/modules/log" + "github.com/redis/go-redis/v9" ) -const loginSessionKeyFormat = "LS_%s" -const loginSessionDuration = time.Second * 30 +const loginSessionKeyFormat = "LS_%s_%s" // "LS__" +const loginSessionDuration = time.Minute * 15 -func SetLoginSession(ctx context.Context, ls LoginSession) error { +func getLoginSessionRedisKey(key LoginSessionKey) string { + return fmt.Sprintf(loginSessionKeyFormat, key.Account, key.Id.String()) +} + +func SetLoginSession(ctx context.Context, key LoginSessionKey, data LoginSessionData) error { redisClient := redisAgent.GetDefaultRedisClient() - key := fmt.Sprintf(loginSessionKeyFormat, ls.Id.String()) - value, err := ls.GetJsonString() + + value, err := data.GetJsonString() if err != nil { return err } - - err = redisClient.Set(ctx, key, value, loginSessionDuration).Err() + // TODO: Watch Redis JSON SET usage, currently not support atomic SETEX + err = redisClient.Set(ctx, getLoginSessionRedisKey(key), value, loginSessionDuration).Err() if err != nil { return err } @@ -28,33 +33,43 @@ func SetLoginSession(ctx context.Context, ls LoginSession) error { return nil } -func GetLoginSession(ctx context.Context, id uuid.UUID) (*LoginSession, error) { +func GetLoginSession(ctx context.Context, key LoginSessionKey) (*LoginSession, error) { redisClient := redisAgent.GetDefaultRedisClient() - lsIdString := id.String() - key := fmt.Sprintf(loginSessionKeyFormat, lsIdString) - val, err := redisClient.Get(ctx, key).Result() + val, err := redisClient.Get(ctx, getLoginSessionRedisKey(key)).Result() if err != nil { return nil, err } - - ls, err := GetLoginSessionFromJsonString(val) + data, err := getLoginSessionDataFromJsonString(val) if err != nil { return nil, err } - ls.Id = id - return ls, nil + return &LoginSession{ + Key: key, + Data: *data, + }, nil } -func UpdateLoginSession(ctx context.Context, idString, sesionString string) error { +func UpdateLoginSessionByAccount(ctx context.Context, account string, data LoginSessionData) error { redisClient := redisAgent.GetDefaultRedisClient() - key := fmt.Sprintf(loginSessionKeyFormat, idString) - err := redisClient.Set(ctx, key, sesionString, loginSessionDuration).Err() + redisKeys, err := redisClient.Keys(ctx, fmt.Sprintf(loginSessionKeyFormat, account, "*")).Result() if err != nil { return err } + val, err := data.GetJsonString() + if err != nil { + return err + } + for _, redisKey := range redisKeys { + // TODO: KeepTTL only works in redis v6+ + err = redisClient.Set(ctx, redisKey, val, redis.KeepTTL).Err() + if err != nil { + log.AppLogger().Errorf("failed to update login session: %v", err) + } + } + return nil } diff --git a/modules/error.go b/modules/error.go index 305190e..884a003 100644 --- a/modules/error.go +++ b/modules/error.go @@ -59,6 +59,6 @@ func NewInvalidParamError(param string, hints ...string) *SeviceError { return &SeviceError{ Code: 400, - Msg: fmt.Sprintf("invalid param: %s", param), + Msg: msg, } } diff --git a/modules/middleware/login_session.go b/modules/middleware/login_session.go index 64c07b5..ce1a490 100644 --- a/modules/middleware/login_session.go +++ b/modules/middleware/login_session.go @@ -1,6 +1,7 @@ package middleware import ( + "fmt" "time" "github.com/gin-gonic/gin" @@ -10,46 +11,73 @@ import ( ) const ( - loginSessionCookieMaxAge = time.Hour * 24 * 7 - loginSessionIdCookieName = "LS_ID" - loginSessionGinCtxKey = "login_session" + loginSessionCookieMaxAge = time.Hour * 24 * 7 + loginSessionKeyIdCookieName = "LS_KEY_ID" + loginSessionKeyAccountCookieName = "LS_KEY_ACCOUNT" + loginSessionGinCtxKey = "login_session" ) -func HandleRequireLogin(ginCtx *gin.Context) { - cookieValue, err := ginCtx.Cookie(loginSessionIdCookieName) - if err != nil { - modules.NewUnauthorizedError("login session not found").AppendToGin(ginCtx) - ginCtx.Abort() - return - } - lsId, err := uuid.Parse(cookieValue) - if err != nil { - modules.NewUnauthorizedError("invalid login session id").AppendToGin(ginCtx) - ginCtx.Abort() - return +func BuildHandleRequireLoginWithRoles(roles []string) gin.HandlerFunc { + return func(ginCtx *gin.Context) { + ls, err := GetLoginSessionFromGinCtx(ginCtx) + if err != nil { + modules.NewUnauthorizedError("cannot load login session from cookie").AppendToGin(ginCtx) + ginCtx.Abort() + return + } + ginCtx.Set(loginSessionGinCtxKey, ls) + + if !ls.Data.HasRoles(roles) { + modules.NewUnauthorizedError(fmt.Sprintf("require roles: %v", roles)).AppendToGin(ginCtx) + ginCtx.Abort() + return + } + + ginCtx.Next() } +} - ls, err := auth.GetLoginSession(ginCtx, lsId) +func HandleRequireLogin(ginCtx *gin.Context) { + ls, err := GetLoginSessionFromGinCtx(ginCtx) if err != nil { - modules.NewUnauthorizedError("invalid login session").AppendToGin(ginCtx) + modules.NewUnauthorizedError("cannot load login session from cookie").AppendToGin(ginCtx) ginCtx.Abort() return } - ginCtx.Set(loginSessionGinCtxKey, ls) ginCtx.Next() } -func GetLoginSession(ginCtx *gin.Context) *auth.LoginSession { - ls, exist := ginCtx.Get(loginSessionGinCtxKey) - if !exist { - return nil +func GetLoginSessionFromGinCtx(ginCtx *gin.Context) (*auth.LoginSession, error) { + lsAccount, err := ginCtx.Cookie(loginSessionKeyAccountCookieName) + if err != nil { + return nil, err + } + lsIdString, err := ginCtx.Cookie(loginSessionKeyIdCookieName) + if err != nil { + return nil, err + } + lsId, err := uuid.Parse(lsIdString) + if err != nil { + return nil, err + } + key := auth.LoginSessionKey{ + Account: lsAccount, + Id: lsId, + } + + ls, err := auth.GetLoginSession(ginCtx, key) + if err != nil { + return nil, err } - return ls.(*auth.LoginSession) + + return ls, nil } -func SetLoginSessionCookie(ginCtx *gin.Context, lsId string) { - ginCtx.SetCookie(loginSessionIdCookieName, lsId, +func SetLoginSessionKeyCookie(ginCtx *gin.Context, key auth.LoginSessionKey) { + ginCtx.SetCookie(loginSessionKeyAccountCookieName, key.Account, + int(loginSessionCookieMaxAge.Seconds()), "/", "", false, true) + ginCtx.SetCookie(loginSessionKeyIdCookieName, key.Id.String(), int(loginSessionCookieMaxAge.Seconds()), "/", "", false, true) } diff --git a/services/user/user.go b/services/user/user.go index ed47f68..e28b42f 100644 --- a/services/user/user.go +++ b/services/user/user.go @@ -3,7 +3,6 @@ package user import ( "context" - "github.com/google/uuid" user_model "github.com/oj-lab/oj-lab-platform/models/user" gorm_agent "github.com/oj-lab/oj-lab-platform/modules/agent/gorm" "github.com/oj-lab/oj-lab-platform/modules/auth" @@ -20,6 +19,20 @@ func GetUser(ctx context.Context, account string) (*user_model.User, error) { return user, nil } +func UpdateUser(ctx context.Context, user user_model.User) error { + db := gorm_agent.GetDefaultDB() + err := user_model.UpdateUser(db, user) + if err != nil { + return err + } + + return auth.UpdateLoginSessionByAccount(ctx, + user.Account, + auth.LoginSessionData{ + RoleSet: user.GetRolesStringSet(), + }) +} + func CheckUserExist(ctx context.Context, account string) (bool, error) { getOptions := user_model.GetUserOptions{ Account: account, @@ -37,21 +50,20 @@ func CheckUserExist(ctx context.Context, account string) (bool, error) { return count > 0, nil } -func StartLoginSession(ctx context.Context, account, password string) (*uuid.UUID, error) { +func StartLoginSession(ctx context.Context, account, password string) (*auth.LoginSession, error) { db := gorm_agent.GetDefaultDB() user, err := user_model.GetUserByAccountPassword(db, account, password) if err != nil { return nil, err } - loginSession := auth.NewLoginSession(auth.LoginSession{ - Account: account, - Roles: user.GetRolesStringArray(), + ls := auth.NewLoginSession(account, auth.LoginSessionData{ + RoleSet: user.GetRolesStringSet(), }) - err = loginSession.SaveToRedis(ctx) + err = ls.SaveToRedis(ctx) if err != nil { return nil, err } - return &loginSession.Id, nil + return ls, nil }