Skip to content

Commit

Permalink
Support update user and with session
Browse files Browse the repository at this point in the history
  • Loading branch information
slhmy committed May 25, 2024
1 parent c02bc7f commit 3e56ec2
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 79 deletions.
33 changes: 27 additions & 6 deletions cmd/web_server/handler/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"

"github.com/gin-gonic/gin"
user_model "github.com/oj-lab/oj-lab-platform/models/user"
"github.com/oj-lab/oj-lab-platform/modules"
"github.com/oj-lab/oj-lab-platform/modules/middleware"
user_service "github.com/oj-lab/oj-lab-platform/services/user"
Expand All @@ -13,6 +14,7 @@ import (
func SetupUserRouter(baseRoute *gin.RouterGroup) {
g := baseRoute.Group("/user")
{
g.PUT("", updateUser)
g.GET("/health", func(ginCtx *gin.Context) {
ginCtx.String(http.StatusOK, "Hello, this is user service")
})
Expand Down Expand Up @@ -44,12 +46,12 @@ func login(ginCtx *gin.Context) {
return
}

lsId, err := user_service.StartLoginSession(ginCtx, body.Account, body.Password)
ls, err := user_service.StartLoginSession(ginCtx, body.Account, body.Password)
if err != nil {
modules.NewInternalError(fmt.Sprintf("failed to login: %v", err)).AppendToGin(ginCtx)
return
}
middleware.SetLoginSessionCookie(ginCtx, lsId.String())
middleware.SetLoginSessionKeyCookie(ginCtx, ls.Key)

ginCtx.Status(http.StatusOK)
}
Expand All @@ -63,12 +65,12 @@ func login(ginCtx *gin.Context) {
// @Success 200
// @Failure 401
func me(ginCtx *gin.Context) {
ls := middleware.GetLoginSession(ginCtx)
if ls == nil {
modules.NewUnauthorizedError("not logined").AppendToGin(ginCtx)
ls, err := middleware.GetLoginSessionFromGinCtx(ginCtx)
if err != nil {
modules.NewUnauthorizedError("cannot load login session from cookie").AppendToGin(ginCtx)
return
}
user, err := user_service.GetUser(ginCtx, ls.Account)
user, err := user_service.GetUser(ginCtx, ls.Key.Account)
if err != nil {
modules.NewInternalError(fmt.Sprintf("failed to get user: %v", err)).AppendToGin(ginCtx)
return
Expand All @@ -94,3 +96,22 @@ func checkUserExist(ginCtx *gin.Context) {
"exist": exist,
})
}

type updateUserBody struct {
User user_model.User `json:"user"`
}

func updateUser(ginCtx *gin.Context) {
body := &updateUserBody{}
err := ginCtx.BindJSON(body)
if err != nil {
modules.NewInvalidParamError("body", err.Error()).AppendToGin(ginCtx)
return
}

err = user_service.UpdateUser(ginCtx, body.User)
if err != nil {
modules.NewInternalError(fmt.Sprintf("failed to update user: %v", err)).AppendToGin(ginCtx)
return
}
}
10 changes: 5 additions & 5 deletions models/user/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ type Role struct {
Users []*User `gorm:"many2many:user_roles" json:"users,omitempty"`
}

func (user User) GetRolesStringArray() []string {
roles := make([]string, len(user.Roles))
for i, role := range user.Roles {
roles[i] = role.Name
func (user User) GetRolesStringSet() map[string]struct{} {
roleSet := map[string]struct{}{}
for _, role := range user.Roles {
roleSet[role.Name] = struct{}{}
}
return roles
return roleSet
}
53 changes: 37 additions & 16 deletions modules/auth/login_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,41 +7,62 @@ import (
"github.com/google/uuid"
)

type LoginSessionKey struct {
Account string
Id uuid.UUID
}

type LoginSessionData struct {
RoleSet map[string]struct{}
}

type LoginSession struct {
Id uuid.UUID `json:"-"`
Account string `json:"account"`
Roles []string `json:"roles"`
Key LoginSessionKey
Data LoginSessionData
}

func (data LoginSessionData) HasRoles(roles []string) bool {
for _, role := range roles {
if _, ok := data.RoleSet[role]; !ok {
return false
}
}
return true
}

func (ls LoginSession) GetJsonString() (string, error) {
lsBytes, err := json.Marshal(ls)
func (data LoginSessionData) GetJsonString() (string, error) {
bytes, err := json.Marshal(data)
if err != nil {
return "", err
}
lsString := string(lsBytes)
dataString := string(bytes)

return lsString, nil
return dataString, nil
}

func GetLoginSessionFromJsonString(lsString string) (*LoginSession, error) {
ls := &LoginSession{}
err := json.Unmarshal([]byte(lsString), ls)
func getLoginSessionDataFromJsonString(dataString string) (*LoginSessionData, error) {
data := &LoginSessionData{}
err := json.Unmarshal([]byte(dataString), data)
if err != nil {
return nil, err
}
return ls, nil
return data, nil
}

func NewLoginSession(ls LoginSession) *LoginSession {
func NewLoginSession(account string, data LoginSessionData) *LoginSession {
return &LoginSession{
Id: uuid.New(),
Account: ls.Account,
Roles: ls.Roles,
LoginSessionKey{
Account: account,
Id: uuid.New(),
},
LoginSessionData{
RoleSet: data.RoleSet,
},
}
}

func (ls LoginSession) SaveToRedis(ctx context.Context) error {
err := SetLoginSession(ctx, ls)
err := SetLoginSession(ctx, ls.Key, ls.Data)
if err != nil {
return err
}
Expand Down
53 changes: 34 additions & 19 deletions modules/auth/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,56 +5,71 @@ import (
"fmt"
"time"

"github.com/google/uuid"
redisAgent "github.com/oj-lab/oj-lab-platform/modules/agent/redis"
"github.com/oj-lab/oj-lab-platform/modules/log"
"github.com/redis/go-redis/v9"
)

const loginSessionKeyFormat = "LS_%s"
const loginSessionDuration = time.Second * 30
const loginSessionKeyFormat = "LS_%s_%s" // "LS_<account>_<uuid>"
const loginSessionDuration = time.Minute * 15

func SetLoginSession(ctx context.Context, ls LoginSession) error {
func getLoginSessionRedisKey(key LoginSessionKey) string {
return fmt.Sprintf(loginSessionKeyFormat, key.Account, key.Id.String())
}

func SetLoginSession(ctx context.Context, key LoginSessionKey, data LoginSessionData) error {
redisClient := redisAgent.GetDefaultRedisClient()
key := fmt.Sprintf(loginSessionKeyFormat, ls.Id.String())
value, err := ls.GetJsonString()

value, err := data.GetJsonString()
if err != nil {
return err
}

err = redisClient.Set(ctx, key, value, loginSessionDuration).Err()
// TODO: Watch Redis JSON SET usage, currently not support atomic SETEX
err = redisClient.Set(ctx, getLoginSessionRedisKey(key), value, loginSessionDuration).Err()
if err != nil {
return err
}

return nil
}

func GetLoginSession(ctx context.Context, id uuid.UUID) (*LoginSession, error) {
func GetLoginSession(ctx context.Context, key LoginSessionKey) (*LoginSession, error) {
redisClient := redisAgent.GetDefaultRedisClient()
lsIdString := id.String()
key := fmt.Sprintf(loginSessionKeyFormat, lsIdString)

val, err := redisClient.Get(ctx, key).Result()
val, err := redisClient.Get(ctx, getLoginSessionRedisKey(key)).Result()
if err != nil {
return nil, err
}

ls, err := GetLoginSessionFromJsonString(val)
data, err := getLoginSessionDataFromJsonString(val)
if err != nil {
return nil, err
}
ls.Id = id

return ls, nil
return &LoginSession{
Key: key,
Data: *data,
}, nil
}

func UpdateLoginSession(ctx context.Context, idString, sesionString string) error {
func UpdateLoginSessionByAccount(ctx context.Context, account string, data LoginSessionData) error {
redisClient := redisAgent.GetDefaultRedisClient()
key := fmt.Sprintf(loginSessionKeyFormat, idString)

err := redisClient.Set(ctx, key, sesionString, loginSessionDuration).Err()
redisKeys, err := redisClient.Keys(ctx, fmt.Sprintf(loginSessionKeyFormat, account, "*")).Result()
if err != nil {
return err
}

val, err := data.GetJsonString()
if err != nil {
return err
}
for _, redisKey := range redisKeys {
// TODO: KeepTTL only works in redis v6+
err = redisClient.Set(ctx, redisKey, val, redis.KeepTTL).Err()
if err != nil {
log.AppLogger().Errorf("failed to update login session: %v", err)
}
}

return nil
}
2 changes: 1 addition & 1 deletion modules/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ func NewInvalidParamError(param string, hints ...string) *SeviceError {

return &SeviceError{
Code: 400,
Msg: fmt.Sprintf("invalid param: %s", param),
Msg: msg,
}
}
78 changes: 53 additions & 25 deletions modules/middleware/login_session.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"fmt"
"time"

"github.com/gin-gonic/gin"
Expand All @@ -10,46 +11,73 @@ import (
)

const (
loginSessionCookieMaxAge = time.Hour * 24 * 7
loginSessionIdCookieName = "LS_ID"
loginSessionGinCtxKey = "login_session"
loginSessionCookieMaxAge = time.Hour * 24 * 7
loginSessionKeyIdCookieName = "LS_KEY_ID"
loginSessionKeyAccountCookieName = "LS_KEY_ACCOUNT"
loginSessionGinCtxKey = "login_session"
)

func HandleRequireLogin(ginCtx *gin.Context) {
cookieValue, err := ginCtx.Cookie(loginSessionIdCookieName)
if err != nil {
modules.NewUnauthorizedError("login session not found").AppendToGin(ginCtx)
ginCtx.Abort()
return
}
lsId, err := uuid.Parse(cookieValue)
if err != nil {
modules.NewUnauthorizedError("invalid login session id").AppendToGin(ginCtx)
ginCtx.Abort()
return
func BuildHandleRequireLoginWithRoles(roles []string) gin.HandlerFunc {
return func(ginCtx *gin.Context) {
ls, err := GetLoginSessionFromGinCtx(ginCtx)
if err != nil {
modules.NewUnauthorizedError("cannot load login session from cookie").AppendToGin(ginCtx)
ginCtx.Abort()
return
}
ginCtx.Set(loginSessionGinCtxKey, ls)

if !ls.Data.HasRoles(roles) {
modules.NewUnauthorizedError(fmt.Sprintf("require roles: %v", roles)).AppendToGin(ginCtx)
ginCtx.Abort()
return
}

ginCtx.Next()
}
}

ls, err := auth.GetLoginSession(ginCtx, lsId)
func HandleRequireLogin(ginCtx *gin.Context) {
ls, err := GetLoginSessionFromGinCtx(ginCtx)
if err != nil {
modules.NewUnauthorizedError("invalid login session").AppendToGin(ginCtx)
modules.NewUnauthorizedError("cannot load login session from cookie").AppendToGin(ginCtx)
ginCtx.Abort()
return
}

ginCtx.Set(loginSessionGinCtxKey, ls)

ginCtx.Next()
}

func GetLoginSession(ginCtx *gin.Context) *auth.LoginSession {
ls, exist := ginCtx.Get(loginSessionGinCtxKey)
if !exist {
return nil
func GetLoginSessionFromGinCtx(ginCtx *gin.Context) (*auth.LoginSession, error) {
lsAccount, err := ginCtx.Cookie(loginSessionKeyAccountCookieName)
if err != nil {
return nil, err
}
lsIdString, err := ginCtx.Cookie(loginSessionKeyIdCookieName)
if err != nil {
return nil, err
}
lsId, err := uuid.Parse(lsIdString)
if err != nil {
return nil, err
}
key := auth.LoginSessionKey{
Account: lsAccount,
Id: lsId,
}

ls, err := auth.GetLoginSession(ginCtx, key)
if err != nil {
return nil, err
}
return ls.(*auth.LoginSession)

return ls, nil
}

func SetLoginSessionCookie(ginCtx *gin.Context, lsId string) {
ginCtx.SetCookie(loginSessionIdCookieName, lsId,
func SetLoginSessionKeyCookie(ginCtx *gin.Context, key auth.LoginSessionKey) {
ginCtx.SetCookie(loginSessionKeyAccountCookieName, key.Account,
int(loginSessionCookieMaxAge.Seconds()), "/", "", false, true)
ginCtx.SetCookie(loginSessionKeyIdCookieName, key.Id.String(),
int(loginSessionCookieMaxAge.Seconds()), "/", "", false, true)
}
Loading

0 comments on commit 3e56ec2

Please sign in to comment.