Skip to content

Commit

Permalink
feat: redirect /accounts/me to the correct /accounts/{id} (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
kayra1 authored Jul 18, 2024
1 parent d16dd50 commit 38959c3
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 37 deletions.
33 changes: 28 additions & 5 deletions internal/api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ func NewGoCertRouter(env *Environment) http.Handler {
apiV1Router.HandleFunc("POST /certificate_requests/{id}/certificate/reject", RejectCertificate(env))
apiV1Router.HandleFunc("DELETE /certificate_requests/{id}/certificate", DeleteCertificate(env))

apiV1Router.HandleFunc("GET /accounts/{id}", GetUserAccount(env))
apiV1Router.HandleFunc("GET /accounts", GetUserAccounts(env))
apiV1Router.HandleFunc("POST /accounts", PostUserAccount(env))
apiV1Router.HandleFunc("GET /accounts/{id}", GetUserAccount(env))
apiV1Router.HandleFunc("DELETE /accounts/{id}", DeleteUserAccount(env))
apiV1Router.HandleFunc("POST /accounts/{id}/change_password", ChangeUserAccountPassword(env))

Expand Down Expand Up @@ -300,7 +300,17 @@ func GetUserAccounts(env *Environment) http.HandlerFunc {
func GetUserAccount(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
userAccount, err := env.DB.RetrieveUser(id)
var userAccount certdb.User
var err error
if id == "me" {
claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret)
if headerErr != nil {
logErrorAndWriteResponse(headerErr.Error(), http.StatusUnauthorized, w)
}
userAccount, err = env.DB.RetrieveUserByUsername(claims.Username)
} else {
userAccount, err = env.DB.RetrieveUser(id)
}
if err != nil {
if errors.Is(err, certdb.ErrIdNotFound) {
logErrorAndWriteResponse(err.Error(), http.StatusNotFound, w)
Expand Down Expand Up @@ -409,6 +419,17 @@ func DeleteUserAccount(env *Environment) http.HandlerFunc {
func ChangeUserAccountPassword(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
if id == "me" {
claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret)
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusUnauthorized, w)
}
userAccount, err := env.DB.RetrieveUserByUsername(claims.Username)
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusUnauthorized, w)
}
id = strconv.Itoa(userAccount.ID)
}
var user certdb.User
if err := json.NewDecoder(r.Body).Decode(&user); err != nil {
logErrorAndWriteResponse("Invalid JSON format", http.StatusBadRequest, w)
Expand Down Expand Up @@ -471,7 +492,7 @@ func Login(env *Environment) http.HandlerFunc {
logErrorAndWriteResponse("The username or password is incorrect. Try again.", http.StatusUnauthorized, w)
return
}
jwt, err := generateJWT(userRequest.Username, env.JWTSecret, userAccount.Permissions)
jwt, err := generateJWT(userAccount.ID, userAccount.Username, env.JWTSecret, userAccount.Permissions)
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
Expand All @@ -489,7 +510,7 @@ func logErrorAndWriteResponse(msg string, status int, w http.ResponseWriter) {
log.Println(errMsg)
w.WriteHeader(status)
if _, err := w.Write([]byte(errMsg)); err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
log.Printf("error writing response: %s", err.Error())
}
}

Expand Down Expand Up @@ -554,8 +575,9 @@ func validatePassword(password string) bool {
}

// Helper function to generate a JWT
func generateJWT(username string, jwtSecret []byte, permissions int) (string, error) {
func generateJWT(id int, username string, jwtSecret []byte, permissions int) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwtGocertClaims{
ID: id,
Username: username,
Permissions: permissions,
StandardClaims: jwt.StandardClaims{
Expand All @@ -571,6 +593,7 @@ func generateJWT(username string, jwtSecret []byte, permissions int) (string, er
}

type jwtGocertClaims struct {
ID int `json:"id"`
Username string `json:"username"`
Permissions int `json:"permissions"`
jwt.StandardClaims
Expand Down
18 changes: 18 additions & 0 deletions internal/api/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,24 @@ func TestAuthorization(t *testing.T) {
response: "",
status: http.StatusForbidden,
},
{
desc: "user can change self password with /me",
method: "POST",
path: "/api/v1/accounts/me/change_password",
data: `{"password":"BetterPW1!"}`,
auth: nonAdminToken,
response: "",
status: http.StatusOK,
},
{
desc: "user can login with new password",
method: "POST",
path: "/login",
data: `{"username":"testuser","password":"BetterPW1!"}`,
auth: nonAdminToken,
response: "",
status: http.StatusOK,
},
{
desc: "admin can't delete itself",
method: "DELETE",
Expand Down
116 changes: 84 additions & 32 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
)

const (
USER_ACCOUNT = 0
ADMIN_ACCOUNT = 1
)

type middleware func(http.Handler) http.Handler

// The middlewareContext type helps middleware receive and pass along information through the middleware chain.
Expand Down Expand Up @@ -94,14 +99,6 @@ func loggingMiddleware(ctx *middlewareContext) middleware {
// authMiddleware intercepts requests that need authorization to check if the user's token exists and is
// permitted to use the endpoint
func authMiddleware(ctx *middlewareContext) middleware {
AdminOnlyPaths := []struct{ method, path string }{
{"POST", `accounts`},
{"GET", `accounts`},
{"GET", `accounts\/\d+$`},
{"DELETE", `accounts\/\d+$`},
{"POST", `accounts\/\d+\/change_password$`},
}

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.Path, "/api/v1/") {
Expand All @@ -115,35 +112,23 @@ func authMiddleware(ctx *middlewareContext) middleware {
}
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
logErrorAndWriteResponse("authorization header not found", http.StatusUnauthorized, w)
return
}
bearerToken := strings.Split(authHeader, " ")
if len(bearerToken) != 2 || bearerToken[0] != "Bearer" {
logErrorAndWriteResponse("authorization header couldn't be processed. The expected format is 'Bearer <token>'", http.StatusUnauthorized, w)
return
}
claims, err := getClaimsFromJWT(bearerToken[1], ctx.jwtSecret)
claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), ctx.jwtSecret)
if err != nil {
logErrorAndWriteResponse(fmt.Sprintf("token is not valid: %s", err.Error()), http.StatusUnauthorized, w)
logErrorAndWriteResponse(fmt.Sprintf("auth failed: %s", err.Error()), http.StatusUnauthorized, w)
return
}
if claims.Permissions == 0 {
for _, v := range AdminOnlyPaths {
matched, err := regexp.Match(v.path, []byte(r.URL.Path))
if err != nil {
logErrorAndWriteResponse(fmt.Sprintf("ran into issue parsing path: %s", err.Error()), http.StatusInternalServerError, w)
return
}
if r.Method == v.method && matched {
logErrorAndWriteResponse("forbidden", http.StatusForbidden, w)
return
}
if claims.Permissions == USER_ACCOUNT {
requestAllowed, err := AllowRequest(claims, r.Method, r.URL.Path)
if err != nil {
logErrorAndWriteResponse(fmt.Sprintf("error processing path: %s", err.Error()), http.StatusInternalServerError, w)
return
}
if !requestAllowed {
logErrorAndWriteResponse("forbidden", http.StatusForbidden, w)
return
}
}
if claims.Permissions == 1 && r.Method == "DELETE" && strings.HasSuffix(r.URL.Path, "accounts/1") {
if r.Method == "DELETE" && strings.HasSuffix(r.URL.Path, "accounts/1") {
logErrorAndWriteResponse("can't delete admin account", http.StatusConflict, w)
return
}
Expand All @@ -152,6 +137,73 @@ func authMiddleware(ctx *middlewareContext) middleware {
}
}

func getClaimsFromAuthorizationHeader(header string, jwtSecret []byte) (*jwtGocertClaims, error) {
if header == "" {
return nil, fmt.Errorf("authorization header not found")
}
bearerToken := strings.Split(header, " ")
if len(bearerToken) != 2 || bearerToken[0] != "Bearer" {
return nil, fmt.Errorf("authorization header couldn't be processed. The expected format is 'Bearer <token>'")
}
claims, err := getClaimsFromJWT(bearerToken[1], jwtSecret)
if err != nil {
return nil, fmt.Errorf("token is not valid: %s", err)
}
return claims, nil
}

// AllowRequest looks at the user data to determine the following things:
// The first question is "Is this user trying to access a path that's restricted?"
//
// There are two types of restricted paths: admin only paths that only admins can access, and self authorized paths,
// which users are allowed to use only if they are taking an action on their own user ID. The second question is
// "If the path requires an ID, is the user attempting to access their own ID?"
//
// For all endpoints and permission permutations, there are only 2 cases when users are allowed to use endpoints:
// If the URL path is not restricted to admins
// If the URL path is restricted to self authorized endpoints, and the user is taking action with their own ID
// This function validates that the user the with the given claims is allowed to use the endpoints by passing the above checks.
func AllowRequest(claims *jwtGocertClaims, method, path string) (bool, error) {
restrictedPaths := []struct {
method, pathRegex string
SelfAuthorizedAllowed bool
}{
{"POST", `accounts$`, false},
{"GET", `accounts$`, false},
{"DELETE", `accounts\/(\d+)$`, false},
{"GET", `accounts\/(\d+)$`, true},
{"POST", `accounts\/(\d+)\/change_password$`, true},
}
for _, pr := range restrictedPaths {
regexChallenge, err := regexp.Compile(pr.pathRegex)
if err != nil {
return false, fmt.Errorf("regex couldn't compile: %s", err)
}
matches := regexChallenge.FindStringSubmatch(path)
restrictedPathMatchedToRequestedPath := len(matches) > 0 && method == pr.method
if !restrictedPathMatchedToRequestedPath {
continue
}
if !pr.SelfAuthorizedAllowed {
return false, nil
}
matchedID, err := strconv.Atoi(matches[1])
if err != nil {
return true, fmt.Errorf("error converting url id to string: %s", err)
}
var requestedIDMatchesTheClaimant bool
if matchedID == claims.ID {
requestedIDMatchesTheClaimant = true
}
IDRequiredForPath := len(matches) > 1
if IDRequiredForPath && !requestedIDMatchesTheClaimant {
return false, nil
}
return true, nil
}
return true, nil
}

func getClaimsFromJWT(bearerToken string, jwtSecret []byte) (*jwtGocertClaims, error) {
claims := jwtGocertClaims{}
token, err := jwt.ParseWithClaims(bearerToken, &claims, func(token *jwt.Token) (interface{}, error) {
Expand Down

0 comments on commit 38959c3

Please sign in to comment.