diff --git a/handler/middleware.go b/handler/middleware.go index d40ec3af..ad6d94a3 100644 --- a/handler/middleware.go +++ b/handler/middleware.go @@ -3,11 +3,28 @@ package handler import ( "errors" "fmt" + "net/http" + "strconv" "github.com/go-playground/validator/v10" "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" + "github.com/traPtitech/anke-to/model" ) +// Middleware Middlewareの構造体 +type Middleware struct { + model.IAdministrator + model.IRespondent + model.IQuestion + model.IQuestionnaire +} + +// NewMiddleware Middlewareのコンストラクタ +func NewMiddleware() *Middleware { + return &Middleware{} +} + const ( validatorKey = "validator" userIDKey = "userID" @@ -16,6 +33,13 @@ const ( questionIDKey = "questionID" ) +/* + 消せないアンケートの発生を防ぐための管理者 + +暫定的にハードコーディングで対応 +*/ +var adminUserIDs = []string{"ryoha", "xxarupakaxx", "kaitoyama", "cp20", "itzmeowww"} + // SetUserIDMiddleware X-Showcase-UserからユーザーIDを取得しセットする func SetUserIDMiddleware(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -30,6 +54,204 @@ func SetUserIDMiddleware(next echo.HandlerFunc) echo.HandlerFunc { } } +// TraPMemberAuthenticate traP部員かの認証 +func TraPMemberAuthenticate(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + userID, err := getUserID(c) + if err != nil { + c.Logger().Errorf("failed to get userID: %+v", err) + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err)) + } + + // トークンを持たないユーザはアクセスできない + if userID == "-" { + c.Logger().Info("not logged in") + return echo.NewHTTPError(http.StatusUnauthorized, "You are not logged in") + } + + return next(c) + } +} + +// TrapRateLimitMiddlewareFunc traP IDベースのリクエスト制限 +func TrapRateLimitMiddlewareFunc() echo.MiddlewareFunc { + config := middleware.RateLimiterConfig{ + Store: middleware.NewRateLimiterMemoryStore(5), + IdentifierExtractor: func(c echo.Context) (string, error) { + userID, err := getUserID(c) + if err != nil { + c.Logger().Errorf("failed to get userID: %+v", err) + return "", echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err)) + } + + return userID, nil + }, + } + + return middleware.RateLimiterWithConfig(config) +} + +// QuestionnaireAdministratorAuthenticate アンケートの管理者かどうかの認証 +func QuestionnaireAdministratorAuthenticate(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + m := NewMiddleware() + + userID, err := getUserID(c) + if err != nil { + c.Logger().Errorf("failed to get userID: %+v", err) + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err)) + } + + strQuestionnaireID := c.Param("questionnaireID") + questionnaireID, err := strconv.Atoi(strQuestionnaireID) + if err != nil { + c.Logger().Infof("failed to convert questionnaireID to int: %+v", err) + return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("invalid questionnaireID:%s(error: %w)", strQuestionnaireID, err)) + } + + for _, adminID := range adminUserIDs { + if userID == adminID { + c.Set(questionnaireIDKey, questionnaireID) + + return next(c) + } + } + isAdmin, err := m.CheckQuestionnaireAdmin(c.Request().Context(), userID, questionnaireID) + if err != nil { + c.Logger().Errorf("failed to check questionnaire admin: %+v", err) + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to check if you are administrator: %w", err)) + } + if !isAdmin { + return c.String(http.StatusForbidden, "You are not a administrator of this questionnaire.") + } + + c.Set(questionnaireIDKey, questionnaireID) + + return next(c) + } +} + +// ResponseReadAuthenticate 回答閲覧権限があるかの認証 +func ResponseReadAuthenticate(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + m := NewMiddleware() + + userID, err := getUserID(c) + if err != nil { + c.Logger().Errorf("failed to get userID: %+v", err) + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err)) + } + + strResponseID := c.Param("responseID") + responseID, err := strconv.Atoi(strResponseID) + if err != nil { + c.Logger().Info("failed to convert responseID to int: %+v", err) + return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("invalid responseID:%s(error: %w)", strResponseID, err)) + } + + // 回答者ならOK + respondent, err := m.GetRespondent(c.Request().Context(), responseID) + if errors.Is(err, model.ErrRecordNotFound) { + c.Logger().Infof("response not found: %+v", err) + return echo.NewHTTPError(http.StatusNotFound, fmt.Errorf("response not found:%d", responseID)) + } + if err != nil { + c.Logger().Errorf("failed to check if you are a respondent: %+v", err) + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to check if you are a respondent: %w", err)) + } + if respondent == nil { + c.Logger().Error("respondent is nil") + return echo.NewHTTPError(http.StatusInternalServerError) + } + if respondent.UserTraqid == userID { + return next(c) + } + + // 回答者以外は一時保存の回答は閲覧できない + if !respondent.SubmittedAt.Valid { + c.Logger().Info("not submitted") + + // Note: 一時保存の回答の存在もわかってはいけないので、Respondentが見つからない時と全く同じエラーを返す + return echo.NewHTTPError(http.StatusNotFound, fmt.Errorf("response not found:%d", responseID)) + } + + // アンケートごとの回答閲覧権限チェック + responseReadPrivilegeInfo, err := m.GetResponseReadPrivilegeInfoByResponseID(c.Request().Context(), userID, responseID) + if errors.Is(err, model.ErrRecordNotFound) { + c.Logger().Infof("response not found: %+v", err) + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid responseID: %d", responseID)) + } else if err != nil { + c.Logger().Errorf("failed to get response read privilege info: %+v", err) + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get response read privilege info: %w", err)) + } + + haveReadPrivilege, err := checkResponseReadPrivilege(responseReadPrivilegeInfo) + if err != nil { + c.Logger().Errorf("failed to check response read privilege: %+v", err) + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to check response read privilege: %w", err)) + } + if !haveReadPrivilege { + return c.String(http.StatusForbidden, "You do not have permission to view this response.") + } + + return next(c) + } +} + +// RespondentAuthenticate 回答者かどうかの認証 +func RespondentAuthenticate(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + m := NewMiddleware() + + userID, err := getUserID(c) + if err != nil { + c.Logger().Errorf("failed to get userID: %+v", err) + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err)) + } + + strResponseID := c.Param("responseID") + responseID, err := strconv.Atoi(strResponseID) + if err != nil { + c.Logger().Infof("failed to convert responseID to int: %+v", err) + return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("invalid responseID:%s(error: %w)", strResponseID, err)) + } + + respondent, err := m.GetRespondent(c.Request().Context(), responseID) + if errors.Is(err, model.ErrRecordNotFound) { + c.Logger().Infof("response not found: %+v", err) + return echo.NewHTTPError(http.StatusNotFound, fmt.Errorf("response not found:%d", responseID)) + } + if err != nil { + c.Logger().Errorf("failed to check if you are a respondent: %+v", err) + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to check if you are a respondent: %w", err)) + } + if respondent == nil { + c.Logger().Error("respondent is nil") + return echo.NewHTTPError(http.StatusInternalServerError) + } + if respondent.UserTraqid != userID { + return c.String(http.StatusForbidden, "You are not a respondent of this response.") + } + + c.Set(responseIDKey, responseID) + + return next(c) + } +} + +func checkResponseReadPrivilege(responseReadPrivilegeInfo *model.ResponseReadPrivilegeInfo) (bool, error) { + switch responseReadPrivilegeInfo.ResSharedTo { + case "administrators": + return responseReadPrivilegeInfo.IsAdministrator, nil + case "respondents": + return responseReadPrivilegeInfo.IsAdministrator || responseReadPrivilegeInfo.IsRespondent, nil + case "public": + return true, nil + } + + return false, errors.New("invalid resSharedTo") +} + // getValidator Validatorを設定する func getValidator(c echo.Context) (*validator.Validate, error) { rowValidate := c.Get(validatorKey) diff --git a/main.go b/main.go index fe08cd25..4f6c5754 100644 --- a/main.go +++ b/main.go @@ -66,7 +66,21 @@ func main() { e.Use(handler.SetUserIDMiddleware) e.Use(middleware.Logger()) e.Use(middleware.Recover()) + + mws := NewMiddlewareSwitcher() + mws.AddGroupConfig("", handler.TraPMemberAuthenticate) + + mws.AddRouteConfig("/questionnaires", http.MethodGet, handler.TrapRateLimitMiddlewareFunc()) + mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodPatch, handler.QuestionnaireAdministratorAuthenticate) + mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodDelete, handler.QuestionnaireAdministratorAuthenticate) + + mws.AddRouteConfig("/responses/:responseID", http.MethodGet, handler.ResponseReadAuthenticate) + mws.AddRouteConfig("/responses/:responseID", http.MethodPatch, handler.RespondentAuthenticate) + mws.AddRouteConfig("/responses/:responseID", http.MethodDelete, handler.RespondentAuthenticate) + openapi.RegisterHandlers(e, handler.Handler{}) + + e.Use(mws.ApplyMiddlewares) e.Logger.Fatal(e.Start(port)) // SetRouting(port) diff --git a/middleware.go b/middleware.go new file mode 100644 index 00000000..5a7ba3cf --- /dev/null +++ b/middleware.go @@ -0,0 +1,80 @@ +package main + +import ( + "strings" + + "github.com/labstack/echo/v4" +) + +type RouteConfig struct { + path string + method string + middlewares []echo.MiddlewareFunc + isGroup bool +} + +type MiddlewareSwitcher struct { + routeConfigs []RouteConfig +} + +func NewMiddlewareSwitcher() *MiddlewareSwitcher { + return &MiddlewareSwitcher{ + routeConfigs: []RouteConfig{}, + } +} + +func (m *MiddlewareSwitcher) AddGroupConfig(grouppath string, middlewares ...echo.MiddlewareFunc) { + m.routeConfigs = append(m.routeConfigs, RouteConfig{ + path: grouppath, + middlewares: middlewares, + isGroup: true, + }) +} + +func (m *MiddlewareSwitcher) AddRouteConfig(path string, method string, middlewares ...echo.MiddlewareFunc) { + m.routeConfigs = append(m.routeConfigs, RouteConfig{ + path: path, + method: method, + middlewares: middlewares, + isGroup: false, + }) +} + +func (m *MiddlewareSwitcher) IsWithinGroup(groupPath string, path string) bool { + if !strings.HasPrefix(path, groupPath) { + return false + } + return len(groupPath) == len(path) || path[len(groupPath)] == '/' +} + +func (m *MiddlewareSwitcher) FindMiddlewares(path string, method string) []echo.MiddlewareFunc { + var matchedMiddlewares []echo.MiddlewareFunc + + for _, config := range m.routeConfigs { + if config.isGroup && m.IsWithinGroup(config.path, path) { + matchedMiddlewares = append(matchedMiddlewares, config.middlewares...) + } + if !config.isGroup && config.path == path && config.method == method { + matchedMiddlewares = append(matchedMiddlewares, config.middlewares...) + } + } + + return matchedMiddlewares +} + +func (m *MiddlewareSwitcher) ApplyMiddlewares(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + path := c.Path() + method := c.Request().Method + + middlewares := m.FindMiddlewares(path, method) + + for _, mw := range middlewares { + if err := mw(next)(c); err != nil { + return err + } + } + + return next(c) + } +}