Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support update user and with session #78

Merged
merged 1 commit into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions cmd/web_server/handler/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"

"github.com/gin-gonic/gin"
user_model "github.com/oj-lab/oj-lab-platform/models/user"
"github.com/oj-lab/oj-lab-platform/modules"
"github.com/oj-lab/oj-lab-platform/modules/middleware"
user_service "github.com/oj-lab/oj-lab-platform/services/user"
Expand All @@ -13,6 +14,7 @@ import (
func SetupUserRouter(baseRoute *gin.RouterGroup) {
g := baseRoute.Group("/user")
{
g.PUT("", updateUser)
g.GET("/health", func(ginCtx *gin.Context) {
ginCtx.String(http.StatusOK, "Hello, this is user service")
})
Expand Down Expand Up @@ -44,12 +46,12 @@ func login(ginCtx *gin.Context) {
return
}

lsId, err := user_service.StartLoginSession(ginCtx, body.Account, body.Password)
ls, err := user_service.StartLoginSession(ginCtx, body.Account, body.Password)
if err != nil {
modules.NewInternalError(fmt.Sprintf("failed to login: %v", err)).AppendToGin(ginCtx)
return
}
middleware.SetLoginSessionCookie(ginCtx, lsId.String())
middleware.SetLoginSessionKeyCookie(ginCtx, ls.Key)

ginCtx.Status(http.StatusOK)
}
Expand All @@ -63,12 +65,12 @@ func login(ginCtx *gin.Context) {
// @Success 200
// @Failure 401
func me(ginCtx *gin.Context) {
ls := middleware.GetLoginSession(ginCtx)
if ls == nil {
modules.NewUnauthorizedError("not logined").AppendToGin(ginCtx)
ls, err := middleware.GetLoginSessionFromGinCtx(ginCtx)
if err != nil {
modules.NewUnauthorizedError("cannot load login session from cookie").AppendToGin(ginCtx)
return
}
user, err := user_service.GetUser(ginCtx, ls.Account)
user, err := user_service.GetUser(ginCtx, ls.Key.Account)
if err != nil {
modules.NewInternalError(fmt.Sprintf("failed to get user: %v", err)).AppendToGin(ginCtx)
return
Expand All @@ -94,3 +96,22 @@ func checkUserExist(ginCtx *gin.Context) {
"exist": exist,
})
}

type updateUserBody struct {
User user_model.User `json:"user"`
}

func updateUser(ginCtx *gin.Context) {
body := &updateUserBody{}
err := ginCtx.BindJSON(body)
if err != nil {
modules.NewInvalidParamError("body", err.Error()).AppendToGin(ginCtx)
return
}

err = user_service.UpdateUser(ginCtx, body.User)
if err != nil {
modules.NewInternalError(fmt.Sprintf("failed to update user: %v", err)).AppendToGin(ginCtx)
return
}
}
10 changes: 5 additions & 5 deletions models/user/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ type Role struct {
Users []*User `gorm:"many2many:user_roles" json:"users,omitempty"`
}

func (user User) GetRolesStringArray() []string {
roles := make([]string, len(user.Roles))
for i, role := range user.Roles {
roles[i] = role.Name
func (user User) GetRolesStringSet() map[string]struct{} {
roleSet := map[string]struct{}{}
for _, role := range user.Roles {
roleSet[role.Name] = struct{}{}
}
return roles
return roleSet
}
53 changes: 37 additions & 16 deletions modules/auth/login_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,41 +7,62 @@ import (
"github.com/google/uuid"
)

type LoginSessionKey struct {
Account string
Id uuid.UUID
}

type LoginSessionData struct {
RoleSet map[string]struct{}
}

type LoginSession struct {
Id uuid.UUID `json:"-"`
Account string `json:"account"`
Roles []string `json:"roles"`
Key LoginSessionKey
Data LoginSessionData
}

func (data LoginSessionData) HasRoles(roles []string) bool {
for _, role := range roles {
if _, ok := data.RoleSet[role]; !ok {
return false
}
}
return true
}

func (ls LoginSession) GetJsonString() (string, error) {
lsBytes, err := json.Marshal(ls)
func (data LoginSessionData) GetJsonString() (string, error) {
bytes, err := json.Marshal(data)
if err != nil {
return "", err
}
lsString := string(lsBytes)
dataString := string(bytes)

return lsString, nil
return dataString, nil
}

func GetLoginSessionFromJsonString(lsString string) (*LoginSession, error) {
ls := &LoginSession{}
err := json.Unmarshal([]byte(lsString), ls)
func getLoginSessionDataFromJsonString(dataString string) (*LoginSessionData, error) {
data := &LoginSessionData{}
err := json.Unmarshal([]byte(dataString), data)
if err != nil {
return nil, err
}
return ls, nil
return data, nil
}

func NewLoginSession(ls LoginSession) *LoginSession {
func NewLoginSession(account string, data LoginSessionData) *LoginSession {
return &LoginSession{
Id: uuid.New(),
Account: ls.Account,
Roles: ls.Roles,
LoginSessionKey{
Account: account,
Id: uuid.New(),
},
LoginSessionData{
RoleSet: data.RoleSet,
},
}
}

func (ls LoginSession) SaveToRedis(ctx context.Context) error {
err := SetLoginSession(ctx, ls)
err := SetLoginSession(ctx, ls.Key, ls.Data)
if err != nil {
return err
}
Expand Down
53 changes: 34 additions & 19 deletions modules/auth/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,56 +5,71 @@ import (
"fmt"
"time"

"github.com/google/uuid"
redisAgent "github.com/oj-lab/oj-lab-platform/modules/agent/redis"
"github.com/oj-lab/oj-lab-platform/modules/log"
"github.com/redis/go-redis/v9"
)

const loginSessionKeyFormat = "LS_%s"
const loginSessionDuration = time.Second * 30
const loginSessionKeyFormat = "LS_%s_%s" // "LS_<account>_<uuid>"
const loginSessionDuration = time.Minute * 15

func SetLoginSession(ctx context.Context, ls LoginSession) error {
func getLoginSessionRedisKey(key LoginSessionKey) string {
return fmt.Sprintf(loginSessionKeyFormat, key.Account, key.Id.String())
}

func SetLoginSession(ctx context.Context, key LoginSessionKey, data LoginSessionData) error {
redisClient := redisAgent.GetDefaultRedisClient()
key := fmt.Sprintf(loginSessionKeyFormat, ls.Id.String())
value, err := ls.GetJsonString()

value, err := data.GetJsonString()
if err != nil {
return err
}

err = redisClient.Set(ctx, key, value, loginSessionDuration).Err()
// TODO: Watch Redis JSON SET usage, currently not support atomic SETEX
err = redisClient.Set(ctx, getLoginSessionRedisKey(key), value, loginSessionDuration).Err()
if err != nil {
return err
}

return nil
}

func GetLoginSession(ctx context.Context, id uuid.UUID) (*LoginSession, error) {
func GetLoginSession(ctx context.Context, key LoginSessionKey) (*LoginSession, error) {
redisClient := redisAgent.GetDefaultRedisClient()
lsIdString := id.String()
key := fmt.Sprintf(loginSessionKeyFormat, lsIdString)

val, err := redisClient.Get(ctx, key).Result()
val, err := redisClient.Get(ctx, getLoginSessionRedisKey(key)).Result()
if err != nil {
return nil, err
}

ls, err := GetLoginSessionFromJsonString(val)
data, err := getLoginSessionDataFromJsonString(val)
if err != nil {
return nil, err
}
ls.Id = id

return ls, nil
return &LoginSession{
Key: key,
Data: *data,
}, nil
}

func UpdateLoginSession(ctx context.Context, idString, sesionString string) error {
func UpdateLoginSessionByAccount(ctx context.Context, account string, data LoginSessionData) error {
redisClient := redisAgent.GetDefaultRedisClient()
key := fmt.Sprintf(loginSessionKeyFormat, idString)

err := redisClient.Set(ctx, key, sesionString, loginSessionDuration).Err()
redisKeys, err := redisClient.Keys(ctx, fmt.Sprintf(loginSessionKeyFormat, account, "*")).Result()
if err != nil {
return err
}

val, err := data.GetJsonString()
if err != nil {
return err
}
for _, redisKey := range redisKeys {
// TODO: KeepTTL only works in redis v6+
err = redisClient.Set(ctx, redisKey, val, redis.KeepTTL).Err()
if err != nil {
log.AppLogger().Errorf("failed to update login session: %v", err)
}
}

return nil
}
2 changes: 1 addition & 1 deletion modules/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ func NewInvalidParamError(param string, hints ...string) *SeviceError {

return &SeviceError{
Code: 400,
Msg: fmt.Sprintf("invalid param: %s", param),
Msg: msg,
}
}
78 changes: 53 additions & 25 deletions modules/middleware/login_session.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"fmt"
"time"

"github.com/gin-gonic/gin"
Expand All @@ -10,46 +11,73 @@ import (
)

const (
loginSessionCookieMaxAge = time.Hour * 24 * 7
loginSessionIdCookieName = "LS_ID"
loginSessionGinCtxKey = "login_session"
loginSessionCookieMaxAge = time.Hour * 24 * 7
loginSessionKeyIdCookieName = "LS_KEY_ID"
loginSessionKeyAccountCookieName = "LS_KEY_ACCOUNT"
loginSessionGinCtxKey = "login_session"
)

func HandleRequireLogin(ginCtx *gin.Context) {
cookieValue, err := ginCtx.Cookie(loginSessionIdCookieName)
if err != nil {
modules.NewUnauthorizedError("login session not found").AppendToGin(ginCtx)
ginCtx.Abort()
return
}
lsId, err := uuid.Parse(cookieValue)
if err != nil {
modules.NewUnauthorizedError("invalid login session id").AppendToGin(ginCtx)
ginCtx.Abort()
return
func BuildHandleRequireLoginWithRoles(roles []string) gin.HandlerFunc {
return func(ginCtx *gin.Context) {
ls, err := GetLoginSessionFromGinCtx(ginCtx)
if err != nil {
modules.NewUnauthorizedError("cannot load login session from cookie").AppendToGin(ginCtx)
ginCtx.Abort()
return
}
ginCtx.Set(loginSessionGinCtxKey, ls)

if !ls.Data.HasRoles(roles) {
modules.NewUnauthorizedError(fmt.Sprintf("require roles: %v", roles)).AppendToGin(ginCtx)
ginCtx.Abort()
return
}

ginCtx.Next()
}
}

ls, err := auth.GetLoginSession(ginCtx, lsId)
func HandleRequireLogin(ginCtx *gin.Context) {
ls, err := GetLoginSessionFromGinCtx(ginCtx)
if err != nil {
modules.NewUnauthorizedError("invalid login session").AppendToGin(ginCtx)
modules.NewUnauthorizedError("cannot load login session from cookie").AppendToGin(ginCtx)
ginCtx.Abort()
return
}

ginCtx.Set(loginSessionGinCtxKey, ls)

ginCtx.Next()
}

func GetLoginSession(ginCtx *gin.Context) *auth.LoginSession {
ls, exist := ginCtx.Get(loginSessionGinCtxKey)
if !exist {
return nil
func GetLoginSessionFromGinCtx(ginCtx *gin.Context) (*auth.LoginSession, error) {
lsAccount, err := ginCtx.Cookie(loginSessionKeyAccountCookieName)
if err != nil {
return nil, err
}
lsIdString, err := ginCtx.Cookie(loginSessionKeyIdCookieName)
if err != nil {
return nil, err
}
lsId, err := uuid.Parse(lsIdString)
if err != nil {
return nil, err
}
key := auth.LoginSessionKey{
Account: lsAccount,
Id: lsId,
}

ls, err := auth.GetLoginSession(ginCtx, key)
if err != nil {
return nil, err
}
return ls.(*auth.LoginSession)

return ls, nil
}

func SetLoginSessionCookie(ginCtx *gin.Context, lsId string) {
ginCtx.SetCookie(loginSessionIdCookieName, lsId,
func SetLoginSessionKeyCookie(ginCtx *gin.Context, key auth.LoginSessionKey) {
ginCtx.SetCookie(loginSessionKeyAccountCookieName, key.Account,
int(loginSessionCookieMaxAge.Seconds()), "/", "", false, true)
ginCtx.SetCookie(loginSessionKeyIdCookieName, key.Id.String(),
int(loginSessionCookieMaxAge.Seconds()), "/", "", false, true)
}
Loading
Loading