From 362ed969f02ebd8f99ff10d60886973b4e010945 Mon Sep 17 00:00:00 2001 From: Eraxyso <130852025+Eraxyso@users.noreply.github.com> Date: Thu, 19 Dec 2024 07:46:49 +0000 Subject: [PATCH] fix: update the wire configuration and move middleware from the handler to the controller --- {handler => controller}/middleware.go | 88 +++++++++++++++------------ handler/handler.go | 3 + handler/questionnaire.go | 12 ++-- handler/response.go | 6 +- main.go | 21 +++---- wire.go | 39 +----------- wire_gen.go | 10 +-- 7 files changed, 77 insertions(+), 102 deletions(-) rename {handler => controller}/middleware.go (82%) diff --git a/handler/middleware.go b/controller/middleware.go similarity index 82% rename from handler/middleware.go rename to controller/middleware.go index b7e0167f..7f10d008 100644 --- a/handler/middleware.go +++ b/controller/middleware.go @@ -1,4 +1,4 @@ -package handler +package controller import ( "errors" @@ -21,8 +21,18 @@ type Middleware struct { } // NewMiddleware Middlewareのコンストラクタ -func NewMiddleware() *Middleware { - return &Middleware{} +func NewMiddleware( + administrator model.IAdministrator, + respondent model.IRespondent, + question model.IQuestion, + questionnaire model.IQuestionnaire, +) *Middleware { + return &Middleware{ + IAdministrator: administrator, + IRespondent: respondent, + IQuestion: question, + IQuestionnaire: questionnaire, + } } const ( @@ -41,7 +51,7 @@ const ( var adminUserIDs = []string{"ryoha", "xxarupakaxx", "kaitoyama", "cp20", "itzmeowww"} // SetUserIDMiddleware X-Showcase-UserからユーザーIDを取得しセットする -func (*Middleware) SetUserIDMiddleware(next echo.HandlerFunc) echo.HandlerFunc { +func (m Middleware) SetUserIDMiddleware(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { userID := c.Request().Header.Get("X-Showcase-User") if userID == "" { @@ -55,9 +65,9 @@ func (*Middleware) SetUserIDMiddleware(next echo.HandlerFunc) echo.HandlerFunc { } // TraPMemberAuthenticate traP部員かの認証 -func (*Middleware) TraPMemberAuthenticate(next echo.HandlerFunc) echo.HandlerFunc { +func (m Middleware) TraPMemberAuthenticate(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - userID, err := getUserID(c) + userID, err := m.GetUserID(c) if err != nil { c.Logger().Errorf("failed to get userID: %+v", err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err)) @@ -74,11 +84,11 @@ func (*Middleware) TraPMemberAuthenticate(next echo.HandlerFunc) echo.HandlerFun } // TrapRateLimitMiddlewareFunc traP IDベースのリクエスト制限 -func (*Middleware) TrapRateLimitMiddlewareFunc() echo.MiddlewareFunc { +func (m Middleware) TrapRateLimitMiddlewareFunc() echo.MiddlewareFunc { config := middleware.RateLimiterConfig{ Store: middleware.NewRateLimiterMemoryStore(5), IdentifierExtractor: func(c echo.Context) (string, error) { - userID, err := getUserID(c) + userID, err := m.GetUserID(c) if err != nil { c.Logger().Errorf("failed to get userID: %+v", err) return "", echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err)) @@ -92,10 +102,10 @@ func (*Middleware) TrapRateLimitMiddlewareFunc() echo.MiddlewareFunc { } // QuestionnaireReadAuthenticate アンケートの閲覧権限があるかの認証 -func (m *Middleware) QuestionnaireReadAuthenticate(next echo.HandlerFunc) echo.HandlerFunc { +func (m Middleware) QuestionnaireReadAuthenticate(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - userID, err := getUserID(c) + userID, err := m.GetUserID(c) if err != nil { c.Logger().Errorf("failed to get userID: %+v", err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err)) @@ -116,7 +126,7 @@ func (m *Middleware) QuestionnaireReadAuthenticate(next echo.HandlerFunc) echo.H return next(c) } } - isAdmin, err := m.CheckQuestionnaireAdmin(c.Request().Context(), userID, questionnaireID) + isAdmin, err := m.IAdministrator.CheckQuestionnaireAdmin(c.Request().Context(), userID, questionnaireID) if err != nil { c.Logger().Errorf("failed to check questionnaire admin: %+v", err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to check if you are administrator: %w", err)) @@ -127,7 +137,7 @@ func (m *Middleware) QuestionnaireReadAuthenticate(next echo.HandlerFunc) echo.H } // 公開されたらOK - questionnaire, _, _, _, _, _, err := m.GetQuestionnaireInfo(c.Request().Context(), questionnaireID) + questionnaire, _, _, _, _, _, err := m.IQuestionnaire.GetQuestionnaireInfo(c.Request().Context(), questionnaireID) if errors.Is(err, model.ErrRecordNotFound) { c.Logger().Infof("questionnaire not found: %+v", err) return echo.NewHTTPError(http.StatusNotFound, fmt.Errorf("questionnaire not found:%d", questionnaireID)) @@ -147,10 +157,10 @@ func (m *Middleware) QuestionnaireReadAuthenticate(next echo.HandlerFunc) echo.H } // QuestionnaireAdministratorAuthenticate アンケートの管理者かどうかの認証 -func (m *Middleware) QuestionnaireAdministratorAuthenticate(next echo.HandlerFunc) echo.HandlerFunc { +func (m Middleware) QuestionnaireAdministratorAuthenticate(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - userID, err := getUserID(c) + userID, err := m.GetUserID(c) if err != nil { c.Logger().Errorf("failed to get userID: %+v", err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err)) @@ -170,7 +180,7 @@ func (m *Middleware) QuestionnaireAdministratorAuthenticate(next echo.HandlerFun return next(c) } } - isAdmin, err := m.CheckQuestionnaireAdmin(c.Request().Context(), userID, questionnaireID) + isAdmin, err := m.IAdministrator.CheckQuestionnaireAdmin(c.Request().Context(), userID, questionnaireID) if err != nil { c.Logger().Errorf("failed to check questionnaire admin: %+v", err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to check if you are administrator: %w", err)) @@ -186,10 +196,10 @@ func (m *Middleware) QuestionnaireAdministratorAuthenticate(next echo.HandlerFun } // ResponseReadAuthenticate 回答閲覧権限があるかの認証 -func (m *Middleware) ResponseReadAuthenticate(next echo.HandlerFunc) echo.HandlerFunc { +func (m Middleware) ResponseReadAuthenticate(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - userID, err := getUserID(c) + userID, err := m.GetUserID(c) if err != nil { c.Logger().Errorf("failed to get userID: %+v", err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err)) @@ -203,7 +213,7 @@ func (m *Middleware) ResponseReadAuthenticate(next echo.HandlerFunc) echo.Handle } // 回答者ならOK - respondent, err := m.GetRespondent(c.Request().Context(), responseID) + respondent, err := m.IRespondent.GetRespondent(c.Request().Context(), responseID) if errors.Is(err, model.ErrRecordNotFound) { c.Logger().Infof("response not found: %+v", err) return echo.NewHTTPError(http.StatusNotFound, fmt.Errorf("response not found:%d", responseID)) @@ -229,7 +239,7 @@ func (m *Middleware) ResponseReadAuthenticate(next echo.HandlerFunc) echo.Handle } // アンケートごとの回答閲覧権限チェック - responseReadPrivilegeInfo, err := m.GetResponseReadPrivilegeInfoByResponseID(c.Request().Context(), userID, responseID) + responseReadPrivilegeInfo, err := m.IQuestionnaire.GetResponseReadPrivilegeInfoByResponseID(c.Request().Context(), userID, responseID) if errors.Is(err, model.ErrRecordNotFound) { c.Logger().Infof("response not found: %+v", err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid responseID: %d", responseID)) @@ -252,10 +262,10 @@ func (m *Middleware) ResponseReadAuthenticate(next echo.HandlerFunc) echo.Handle } // RespondentAuthenticate 回答者かどうかの認証 -func (m *Middleware) RespondentAuthenticate(next echo.HandlerFunc) echo.HandlerFunc { +func (m Middleware) RespondentAuthenticate(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - userID, err := getUserID(c) + userID, err := m.GetUserID(c) if err != nil { c.Logger().Errorf("failed to get userID: %+v", err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err)) @@ -268,7 +278,7 @@ func (m *Middleware) RespondentAuthenticate(next echo.HandlerFunc) echo.HandlerF return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("invalid responseID:%s(error: %w)", strResponseID, err)) } - respondent, err := m.GetRespondent(c.Request().Context(), responseID) + respondent, err := m.IRespondent.GetRespondent(c.Request().Context(), responseID) if errors.Is(err, model.ErrRecordNotFound) { c.Logger().Infof("response not found: %+v", err) return echo.NewHTTPError(http.StatusNotFound, fmt.Errorf("response not found:%d", responseID)) @@ -291,21 +301,8 @@ func (m *Middleware) RespondentAuthenticate(next echo.HandlerFunc) echo.HandlerF } } -func checkResponseReadPrivilege(responseReadPrivilegeInfo *model.ResponseReadPrivilegeInfo) (bool, error) { - switch responseReadPrivilegeInfo.ResSharedTo { - case "administrators": - return responseReadPrivilegeInfo.IsAdministrator, nil - case "respondents": - return responseReadPrivilegeInfo.IsAdministrator || responseReadPrivilegeInfo.IsRespondent, nil - case "public": - return true, nil - } - - return false, errors.New("invalid resSharedTo") -} - -// getValidator Validatorを設定する -func getValidator(c echo.Context) (*validator.Validate, error) { +// GetValidator Validatorを設定する +func (m Middleware) GetValidator(c echo.Context) (*validator.Validate, error) { rowValidate := c.Get(validatorKey) validate, ok := rowValidate.(*validator.Validate) if !ok { @@ -315,8 +312,8 @@ func getValidator(c echo.Context) (*validator.Validate, error) { return validate, nil } -// getUserID ユーザーIDを取得する -func getUserID(c echo.Context) (string, error) { +// GetUserID ユーザーIDを取得する +func (m Middleware) GetUserID(c echo.Context) (string, error) { rowUserID := c.Get(userIDKey) userID, ok := rowUserID.(string) if !ok { @@ -325,3 +322,16 @@ func getUserID(c echo.Context) (string, error) { return userID, nil } + +func checkResponseReadPrivilege(responseReadPrivilegeInfo *model.ResponseReadPrivilegeInfo) (bool, error) { + switch responseReadPrivilegeInfo.ResSharedTo { + case "administrators": + return responseReadPrivilegeInfo.IsAdministrator, nil + case "respondents": + return responseReadPrivilegeInfo.IsAdministrator || responseReadPrivilegeInfo.IsRespondent, nil + case "public": + return true, nil + } + + return false, errors.New("invalid resSharedTo") +} diff --git a/handler/handler.go b/handler/handler.go index aa6ddd82..d8fa23dc 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -5,13 +5,16 @@ import "github.com/traPtitech/anke-to/controller" type Handler struct { Questionnaire *controller.Questionnaire Response *controller.Response + Middleware *controller.Middleware } func NewHandler(questionnaire *controller.Questionnaire, response *controller.Response, + middleware *controller.Middleware, ) *Handler { return &Handler{ Questionnaire: questionnaire, Response: response, + Middleware: middleware, } } diff --git a/handler/questionnaire.go b/handler/questionnaire.go index 1719d207..c1fcf241 100644 --- a/handler/questionnaire.go +++ b/handler/questionnaire.go @@ -13,7 +13,7 @@ import ( // (GET /questionnaires) func (h Handler) GetQuestionnaires(ctx echo.Context, params openapi.GetQuestionnairesParams) error { res := openapi.QuestionnaireList{} - userID, err := getUserID(ctx) + userID, err := h.Middleware.GetUserID(ctx) if err != nil { ctx.Logger().Errorf("failed to get userID: %+v", err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err)) @@ -35,7 +35,7 @@ func (h Handler) PostQuestionnaire(ctx echo.Context) error { ctx.Logger().Errorf("failed to bind request body: %+v", err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("failed to bind request body: %w", err)) } - validate, err := getValidator(ctx) + validate, err := h.Middleware.GetValidator(ctx) if err != nil { ctx.Logger().Errorf("failed to get validator: %+v", err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get validator: %w", err)) @@ -48,7 +48,7 @@ func (h Handler) PostQuestionnaire(ctx echo.Context) error { } res := openapi.QuestionnaireDetail{} - userID, err := getUserID(ctx) + userID, err := h.Middleware.GetUserID(ctx) if err != nil { ctx.Logger().Errorf("failed to get userID: %+v", err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err)) @@ -136,7 +136,7 @@ func (h Handler) EditQuestionnaireMyRemindStatus(ctx echo.Context, questionnaire // (GET /questionnaires/{questionnaireID}/responses) func (h Handler) GetQuestionnaireResponses(ctx echo.Context, questionnaireID openapi.QuestionnaireIDInPath, params openapi.GetQuestionnaireResponsesParams) error { - userID, err := getUserID(ctx) + userID, err := h.Middleware.GetUserID(ctx) if err != nil { ctx.Logger().Errorf("failed to get userID: %+v", err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err)) @@ -153,7 +153,7 @@ func (h Handler) GetQuestionnaireResponses(ctx echo.Context, questionnaireID ope // (POST /questionnaires/{questionnaireID}/responses) func (h Handler) PostQuestionnaireResponse(ctx echo.Context, questionnaireID openapi.QuestionnaireIDInPath) error { res := openapi.Response{} - userID, err := getUserID(ctx) + userID, err := h.Middleware.GetUserID(ctx) if err != nil { ctx.Logger().Errorf("failed to get userID: %+v", err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err)) @@ -164,7 +164,7 @@ func (h Handler) PostQuestionnaireResponse(ctx echo.Context, questionnaireID ope ctx.Logger().Errorf("failed to bind request body: %+v", err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("failed to bind request body: %w", err)) } - validate, err := getValidator(ctx) + validate, err := h.Middleware.GetValidator(ctx) if err != nil { ctx.Logger().Errorf("failed to get validator: %+v", err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get validator: %w", err)) diff --git a/handler/response.go b/handler/response.go index 5bf97c57..46862837 100644 --- a/handler/response.go +++ b/handler/response.go @@ -11,7 +11,7 @@ import ( // (GET /responses/myResponses) func (h Handler) GetMyResponses(ctx echo.Context, params openapi.GetMyResponsesParams) error { res := openapi.ResponsesWithQuestionnaireInfo{} - userID, err := getUserID(ctx) + userID, err := h.Middleware.GetUserID(ctx) if err != nil { ctx.Logger().Errorf("failed to get userID: %+v", err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err)) @@ -27,7 +27,7 @@ func (h Handler) GetMyResponses(ctx echo.Context, params openapi.GetMyResponsesP // (DELETE /responses/{responseID}) func (h Handler) DeleteResponse(ctx echo.Context, responseID openapi.ResponseIDInPath) error { - userID, err := getUserID(ctx) + userID, err := h.Middleware.GetUserID(ctx) if err != nil { ctx.Logger().Errorf("failed to get userID: %+v", err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err)) @@ -62,7 +62,7 @@ func (h Handler) EditResponse(ctx echo.Context, responseID openapi.ResponseIDInP return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("failed to bind Responses: %w", err)) } - validate, err := getValidator(ctx) + validate, err := h.Middleware.GetValidator(ctx) if err != nil { ctx.Logger().Errorf("failed to get validator: %+v", err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get validator: %w", err)) diff --git a/main.go b/main.go index 2aad8ef1..6d9d7935 100644 --- a/main.go +++ b/main.go @@ -66,24 +66,23 @@ func main() { } api := InjectAPIServer() e.Use(oapiMiddleware.OapiRequestValidator(swagger)) - e.Use(api.SetUserIDMiddleware) + e.Use(api.Middleware.SetUserIDMiddleware) e.Use(middleware.Logger()) e.Use(middleware.Recover()) mws := NewMiddlewareSwitcher() - mws.AddGroupConfig("", api.TraPMemberAuthenticate) + mws.AddGroupConfig("", api.Middleware.TraPMemberAuthenticate) - mws.AddRouteConfig("/questionnaires", http.MethodGet, api.TrapRateLimitMiddlewareFunc()) - mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodGet, api.QuestionnaireReadAuthenticate) - mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodPatch, api.QuestionnaireAdministratorAuthenticate) - mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodDelete, api.QuestionnaireAdministratorAuthenticate) + mws.AddRouteConfig("/questionnaires", http.MethodGet, api.Middleware.TrapRateLimitMiddlewareFunc()) + mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodGet, api.Middleware.QuestionnaireReadAuthenticate) + mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodPatch, api.Middleware.QuestionnaireAdministratorAuthenticate) + mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodDelete, api.Middleware.QuestionnaireAdministratorAuthenticate) - mws.AddRouteConfig("/responses/:responseID", http.MethodGet, api.ResponseReadAuthenticate) - mws.AddRouteConfig("/responses/:responseID", http.MethodPatch, api.RespondentAuthenticate) - mws.AddRouteConfig("/responses/:responseID", http.MethodDelete, api.RespondentAuthenticate) + mws.AddRouteConfig("/responses/:responseID", http.MethodGet, api.Middleware.ResponseReadAuthenticate) + mws.AddRouteConfig("/responses/:responseID", http.MethodPatch, api.Middleware.RespondentAuthenticate) + mws.AddRouteConfig("/responses/:responseID", http.MethodDelete, api.Middleware.RespondentAuthenticate) - handlerApi := InjectHandler() - openapi.RegisterHandlers(e, handlerApi) + openapi.RegisterHandlers(e, api) e.Use(mws.ApplyMiddlewares) e.Logger.Fatal(e.Start(port)) diff --git a/wire.go b/wire.go index c3e060c6..6ffd397b 100644 --- a/wire.go +++ b/wire.go @@ -26,11 +26,12 @@ var ( webhookBind = wire.Bind(new(traq.IWebhook), new(*traq.Webhook)) ) -func InjectHandler() *handler.Handler { +func InjectAPIServer() *handler.Handler { wire.Build( handler.NewHandler, controller.NewResponse, controller.NewQuestionnaire, + controller.NewMiddleware, model.NewAdministrator, model.NewOption, model.NewQuestionnaire, @@ -54,39 +55,5 @@ func InjectHandler() *handler.Handler { transactionBind, webhookBind, ) - - return nil -} - -func InjectAPIServer() *handler.Middleware { - wire.Build( - // handler.NewHandler, - handler.NewMiddleware, - // controller.NewResponse, - // controller.NewQuestionnaire, - // model.NewAdministrator, - // model.NewOption, - // model.NewQuestionnaire, - // model.NewQuestion, - // model.NewRespondent, - // model.NewResponse, - // model.NewScaleLabel, - // model.NewTarget, - // model.NewValidation, - // model.NewTransaction, - // traq.NewWebhook, - // administratorBind, - // optionBind, - // questionnaireBind, - // questionBind, - // respondentBind, - // responseBind, - // scaleLabelBind, - // targetBind, - // validationBind, - // transactionBind, - // webhookBind, - ) - - return nil + return &handler.Handler{} } diff --git a/wire_gen.go b/wire_gen.go index b25ae41a..7a311489 100644 --- a/wire_gen.go +++ b/wire_gen.go @@ -20,7 +20,7 @@ import ( // Injectors from wire.go: -func InjectHandler() *handler.Handler { +func InjectAPIServer() *handler.Handler { questionnaire := model.NewQuestionnaire() target := model.NewTarget() administrator := model.NewAdministrator() @@ -34,15 +34,11 @@ func InjectHandler() *handler.Handler { respondent := model.NewRespondent() response := model.NewResponse() controllerResponse := controller.NewResponse(questionnaire, respondent, response, target, question, validation, scaleLabel) - handlerHandler := handler.NewHandler(controllerQuestionnaire, controllerResponse) + middleware := controller.NewMiddleware(administrator, respondent, question, questionnaire) + handlerHandler := handler.NewHandler(controllerQuestionnaire, controllerResponse, middleware) return handlerHandler } -func InjectAPIServer() *handler.Middleware { - middleware := handler.NewMiddleware() - return middleware -} - // wire.go: var (