diff --git a/backend/go.mod b/backend/go.mod index b1fb7d436..6b0923594 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -1,6 +1,6 @@ module github.com/GenerateNU/sac/backend -go 1.21.1 +go 1.21.6 require ( github.com/go-playground/validator/v10 v10.17.0 @@ -13,6 +13,11 @@ require ( gorm.io/gorm v1.25.6 ) +require ( + github.com/awnumar/memcall v0.2.0 // indirect + github.com/awnumar/memguard v0.22.4 // indirect +) + require ( github.com/KyleBanks/depth v1.2.1 // indirect github.com/PuerkitoBio/purell v1.1.1 // indirect @@ -21,6 +26,7 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/garrettladley/mattress v0.2.0 github.com/go-openapi/jsonpointer v0.19.5 // indirect github.com/go-openapi/jsonreference v0.19.6 // indirect github.com/go-openapi/spec v0.20.4 // indirect diff --git a/backend/go.sum b/backend/go.sum index 6441c0fc3..496556c70 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -6,6 +6,10 @@ github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 h1:d+Bc7a5rLufV github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/awnumar/memcall v0.2.0 h1:sRaogqExTOOkkNwO9pzJsL8jrOV29UuUW7teRMfbqtI= +github.com/awnumar/memcall v0.2.0/go.mod h1:S911igBPR9CThzd/hYQQmTc9SWNu3ZHIlCGaWsWsoJo= +github.com/awnumar/memguard v0.22.4 h1:1PLgKcgGPeExPHL8dCOWGVjIbQUBgJv9OL0F/yE1PqQ= +github.com/awnumar/memguard v0.22.4/go.mod h1:+APmZGThMBWjnMlKiSM1X7MVpbIVewen2MTkqWkA/zE= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -17,6 +21,8 @@ github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nos github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/garrettladley/mattress v0.2.0 h1:+XUdsv9NO2s4JL+8exvAFziw0b1kv/0WlQo2Dlxat+w= +github.com/garrettladley/mattress v0.2.0/go.mod h1:OWKIRc9wC3gtD3Ng/nUuNEiR1TJvRYLmn/KZYw9nl5Q= github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= @@ -153,6 +159,7 @@ golang.org/x/net v0.0.0-20180911220305-26e67e76b6c3/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= +golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/backend/src/auth/tokens.go b/backend/src/auth/tokens.go index c065f7340..ca902803f 100644 --- a/backend/src/auth/tokens.go +++ b/backend/src/auth/tokens.go @@ -7,17 +7,18 @@ import ( "github.com/GenerateNU/sac/backend/src/errors" "github.com/GenerateNU/sac/backend/src/types" + m "github.com/garrettladley/mattress" "github.com/gofiber/fiber/v2" "github.com/golang-jwt/jwt" ) func CreateTokenPair(id string, role string, authSettings config.AuthSettings) (*string, *string, *errors.Error) { - accessToken, catErr := CreateAccessToken(id, role, authSettings.AccessTokenExpiry, authSettings.AccessToken) + accessToken, catErr := CreateAccessToken(id, role, authSettings.AccessTokenExpiry, authSettings.AccessKey) if catErr != nil { return nil, nil, catErr } - refreshToken, crtErr := CreateRefreshToken(id, authSettings.RefreshTokenExpiry, authSettings.RefreshToken) + refreshToken, crtErr := CreateRefreshToken(id, authSettings.RefreshTokenExpiry, authSettings.RefreshKey) if crtErr != nil { return nil, nil, crtErr } @@ -26,7 +27,7 @@ func CreateTokenPair(id string, role string, authSettings config.AuthSettings) ( } // CreateAccessToken creates a new access token for the user -func CreateAccessToken(id string, role string, accessExpiresAfter uint, accessTokenSecret string) (*string, *errors.Error) { +func CreateAccessToken(id string, role string, accessExpiresAfter uint, accessToken *m.Secret[string]) (*string, *errors.Error) { if id == "" || role == "" { return nil, &errors.FailedToCreateAccessToken } @@ -40,16 +41,16 @@ func CreateAccessToken(id string, role string, accessExpiresAfter uint, accessTo Role: role, }) - accessToken, err := SignToken(accessTokenClaims, accessTokenSecret) + returnedAccessToken, err := SignToken(accessTokenClaims, accessToken) if err != nil { return nil, err } - return accessToken, nil + return returnedAccessToken, nil } // CreateRefreshToken creates a new refresh token for the user -func CreateRefreshToken(id string, refreshExpiresAfter uint, refreshTokenSecret string) (*string, *errors.Error) { +func CreateRefreshToken(id string, refreshExpiresAfter uint, refreshKey *m.Secret[string]) (*string, *errors.Error) { if id == "" { return nil, &errors.FailedToCreateRefreshToken } @@ -60,20 +61,20 @@ func CreateRefreshToken(id string, refreshExpiresAfter uint, refreshTokenSecret ExpiresAt: time.Now().Add(time.Hour * 24 * time.Duration(refreshExpiresAfter)).Unix(), }) - refreshToken, err := SignToken(refreshTokenClaims, refreshTokenSecret) + returnedRefreshToken, err := SignToken(refreshTokenClaims, refreshKey) if err != nil { return nil, err } - return refreshToken, nil + return returnedRefreshToken, nil } -func SignToken(token *jwt.Token, secret string) (*string, *errors.Error) { - if token == nil || secret == "" { +func SignToken(token *jwt.Token, key *m.Secret[string]) (*string, *errors.Error) { + if token == nil || key.Expose() == "" { return nil, &errors.FailedToSignToken } - tokenString, err := token.SignedString([]byte(secret)) + tokenString, err := token.SignedString([]byte(key.Expose())) if err != nil { return nil, &errors.FailedToSignToken } @@ -101,9 +102,9 @@ func ExpireCookie(name string) *fiber.Cookie { } // RefreshAccessToken refreshes the access token -func RefreshAccessToken(refreshCookie string, role string, refreshTokenSecret string, accessExpiresAfter uint, accessTokenSecret string) (*string, *errors.Error) { +func RefreshAccessToken(refreshCookie string, role string, refreshKey *m.Secret[string], accessExpiresAfter uint, accessKey *m.Secret[string]) (*string, *errors.Error) { // Parse the refresh token - refreshToken, err := ParseRefreshToken(refreshCookie, refreshTokenSecret) + refreshToken, err := ParseRefreshToken(refreshCookie, refreshKey) if err != nil { return nil, &errors.FailedToParseRefreshToken } @@ -115,7 +116,7 @@ func RefreshAccessToken(refreshCookie string, role string, refreshTokenSecret st } // Create a new access token - accessToken, catErr := CreateAccessToken(claims.Issuer, role, accessExpiresAfter, accessTokenSecret) + accessToken, catErr := CreateAccessToken(claims.Issuer, role, accessExpiresAfter, accessKey) if catErr != nil { return nil, &errors.FailedToCreateAccessToken } @@ -124,22 +125,22 @@ func RefreshAccessToken(refreshCookie string, role string, refreshTokenSecret st } // ParseAccessToken parses the access token -func ParseAccessToken(cookie string, accessTokenSecret string) (*jwt.Token, error) { +func ParseAccessToken(cookie string, accessKey *m.Secret[string]) (*jwt.Token, error) { return jwt.ParseWithClaims(cookie, &types.CustomClaims{}, func(token *jwt.Token) (interface{}, error) { - return []byte(accessTokenSecret), nil + return []byte(accessKey.Expose()), nil }) } // ParseRefreshToken parses the refresh token -func ParseRefreshToken(cookie string, refreshTokenSecret string) (*jwt.Token, error) { +func ParseRefreshToken(cookie string, refreshKey *m.Secret[string]) (*jwt.Token, error) { return jwt.ParseWithClaims(cookie, &jwt.StandardClaims{}, func(token *jwt.Token) (interface{}, error) { - return []byte(refreshTokenSecret), nil + return []byte(refreshKey.Expose()), nil }) } // GetRoleFromToken gets the role from the custom claims -func GetRoleFromToken(tokenString string, accessTokenSecret string) (*string, error) { - token, err := ParseAccessToken(tokenString, accessTokenSecret) +func GetRoleFromToken(tokenString string, accessKey *m.Secret[string]) (*string, error) { + token, err := ParseAccessToken(tokenString, accessKey) if err != nil { return nil, err } @@ -153,8 +154,8 @@ func GetRoleFromToken(tokenString string, accessTokenSecret string) (*string, er } // ExtractClaims extracts the claims from the token -func ExtractAccessClaims(tokenString string, accessTokenSecret string) (*types.CustomClaims, *errors.Error) { - token, err := ParseAccessToken(tokenString, accessTokenSecret) +func ExtractAccessClaims(tokenString string, accessKey *m.Secret[string]) (*types.CustomClaims, *errors.Error) { + token, err := ParseAccessToken(tokenString, accessKey) if err != nil { return nil, &errors.FailedToParseAccessToken } @@ -168,8 +169,8 @@ func ExtractAccessClaims(tokenString string, accessTokenSecret string) (*types.C } // ExtractClaims extracts the claims from the token -func ExtractRefreshClaims(tokenString string, refreshTokenSecret string) (*jwt.StandardClaims, *errors.Error) { - token, err := ParseRefreshToken(tokenString, refreshTokenSecret) +func ExtractRefreshClaims(tokenString string, refreshKey *m.Secret[string]) (*jwt.StandardClaims, *errors.Error) { + token, err := ParseRefreshToken(tokenString, refreshKey) if err != nil { return nil, &errors.FailedToParseRefreshToken } diff --git a/backend/src/config/application.go b/backend/src/config/application.go new file mode 100644 index 000000000..dd2591408 --- /dev/null +++ b/backend/src/config/application.go @@ -0,0 +1,7 @@ +package config + +type ApplicationSettings struct { + Port uint16 `yaml:"port"` + Host string `yaml:"host"` + BaseUrl string `yaml:"baseurl"` +} diff --git a/backend/src/config/auth.go b/backend/src/config/auth.go new file mode 100644 index 000000000..babc0cc2e --- /dev/null +++ b/backend/src/config/auth.go @@ -0,0 +1,40 @@ +package config + +import ( + "errors" + + m "github.com/garrettladley/mattress" +) + +type AuthSettings struct { + AccessKey *m.Secret[string] + RefreshKey *m.Secret[string] + AccessTokenExpiry uint + RefreshTokenExpiry uint +} + +type intermediateAuthSettings struct { + AccessKey string `yaml:"accesskey"` + RefreshKey string `yaml:"refreshkey"` + AccessTokenExpiry uint `yaml:"accesstokenexpiry"` + RefreshTokenExpiry uint `yaml:"refreshtokenexpiry"` +} + +func (int *intermediateAuthSettings) into() (*AuthSettings, error) { + accessToken, err := m.NewSecret(int.AccessKey) + if err != nil { + return nil, errors.New("failed to create secret from access key") + } + + refreshToken, err := m.NewSecret(int.RefreshKey) + if err != nil { + return nil, errors.New("failed to create secret from refresh key") + } + + return &AuthSettings{ + AccessKey: accessToken, + RefreshKey: refreshToken, + AccessTokenExpiry: int.AccessTokenExpiry, + RefreshTokenExpiry: int.RefreshTokenExpiry, + }, nil +} diff --git a/backend/src/config/config.go b/backend/src/config/config.go index fdb435248..81bc6e7a9 100644 --- a/backend/src/config/config.go +++ b/backend/src/config/config.go @@ -1,74 +1,47 @@ package config import ( - "fmt" "os" - "strconv" "github.com/spf13/viper" ) type Settings struct { - Application ApplicationSettings `yaml:"application"` - Database DatabaseSettings `yaml:"database"` - SuperUser SuperUserSettings `yaml:"superuser"` - Auth AuthSettings `yaml:"authsecret"` + Application ApplicationSettings + Database DatabaseSettings + SuperUser SuperUserSettings + Auth AuthSettings } -type ProductionSettings struct { - Database ProductionDatabaseSettings `yaml:"database"` - Application ProductionApplicationSettings `yaml:"application"` +type intermediateSettings struct { + Application ApplicationSettings `yaml:"application"` + Database intermediateDatabaseSettings `yaml:"database"` + SuperUser intermediateSuperUserSettings `yaml:"superuser"` + Auth intermediateAuthSettings `yaml:"authsecret"` } -type ApplicationSettings struct { - Port uint16 `yaml:"port"` - Host string `yaml:"host"` - BaseUrl string `yaml:"baseurl"` -} - -type ProductionApplicationSettings struct { - Port uint16 `yaml:"port"` - Host string `yaml:"host"` -} - -type DatabaseSettings struct { - Username string `yaml:"username"` - Password string `yaml:"password"` - Port uint `yaml:"port"` - Host string `yaml:"host"` - DatabaseName string `yaml:"databasename"` - RequireSSL bool `yaml:"requiressl"` -} - -type ProductionDatabaseSettings struct { - RequireSSL bool `yaml:"requiressl"` -} - -func (s *DatabaseSettings) WithoutDb() string { - var sslMode string - if s.RequireSSL { - sslMode = "require" - } else { - sslMode = "disable" +func (int *intermediateSettings) into() (*Settings, error) { + databaseSettings, err := int.Database.into() + if err != nil { + return nil, err } - return fmt.Sprintf("host=%s port=%d user=%s password=%s sslmode=%s", - s.Host, s.Port, s.Username, s.Password, sslMode) -} - -func (s *DatabaseSettings) WithDb() string { - return fmt.Sprintf("%s dbname=%s", s.WithoutDb(), s.DatabaseName) -} + superUserSettings, err := int.SuperUser.into() + if err != nil { + return nil, err + } -type SuperUserSettings struct { - Password string `yaml:"password"` -} + authSettings, err := int.Auth.into() + if err != nil { + return nil, err + } -type AuthSettings struct { - AccessToken string `yaml:"accesstoken"` - RefreshToken string `yaml:"refreshtoken"` - AccessTokenExpiry uint `yaml:"accesstokenexpiry"` - RefreshTokenExpiry uint `yaml:"refreshtokenexpiry"` + return &Settings{ + Application: int.Application, + Database: *databaseSettings, + SuperUser: *superUserSettings, + Auth: *authSettings, + }, nil } type Environment string @@ -78,7 +51,7 @@ const ( EnvironmentProduction Environment = "production" ) -func GetConfiguration(path string) (Settings, error) { +func GetConfiguration(path string) (*Settings, error) { var environment Environment if env := os.Getenv("APP_ENVIRONMENT"); env != "" { environment = Environment(env) @@ -91,80 +64,8 @@ func GetConfiguration(path string) (Settings, error) { v.AddConfigPath(path) if environment == EnvironmentLocal { - var settings Settings - - v.SetConfigName(string(environment)) - - if err := v.ReadInConfig(); err != nil { - return settings, fmt.Errorf("failed to read %s configuration: %w", string(environment), err) - } - - if err := v.Unmarshal(&settings); err != nil { - return settings, fmt.Errorf("failed to unmarshal configuration: %w", err) - } - - return settings, nil + return readLocal(v) } else { - var prodSettings ProductionSettings - - v.SetConfigName(string(environment)) - - if err := v.ReadInConfig(); err != nil { - return Settings{}, fmt.Errorf("failed to read %s configuration: %w", string(environment), err) - } - - if err := v.Unmarshal(&prodSettings); err != nil { - return Settings{}, fmt.Errorf("failed to unmarshal configuration: %w", err) - } - - appPrefix := "APP_" - applicationPrefix := fmt.Sprintf("%sAPPLICATION__", appPrefix) - dbPrefix := fmt.Sprintf("%sDATABASE__", appPrefix) - superUserPrefix := fmt.Sprintf("%sSUPERUSER__", appPrefix) - authSecretPrefix := fmt.Sprintf("%sAUTHSECRET__", appPrefix) - - authAccessExpiry := os.Getenv(fmt.Sprintf("%sACCESS_TOKEN_EXPIRY", authSecretPrefix)) - authRefreshExpiry := os.Getenv(fmt.Sprintf("%sREFRESH_TOKEN_EXPIRY", authSecretPrefix)) - - authAccessExpiryInt, err := strconv.ParseUint(authAccessExpiry, 10, 16) - if err != nil { - return Settings{}, fmt.Errorf("failed to parse access token expiry: %w", err) - } - - authRefreshExpiryInt, err := strconv.ParseUint(authRefreshExpiry, 10, 16) - if err != nil { - return Settings{}, fmt.Errorf("failed to parse refresh token expiry: %w", err) - } - - portStr := os.Getenv(fmt.Sprintf("%sPORT", appPrefix)) - portInt, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - return Settings{}, fmt.Errorf("failed to parse port: %w", err) - } - - return Settings{ - Application: ApplicationSettings{ - Port: uint16(portInt), - Host: prodSettings.Application.Host, - BaseUrl: os.Getenv(fmt.Sprintf("%sBASE_URL", applicationPrefix)), - }, - Database: DatabaseSettings{ - Username: os.Getenv(fmt.Sprintf("%sUSERNAME", dbPrefix)), - Password: os.Getenv(fmt.Sprintf("%sPASSWORD", dbPrefix)), - Host: os.Getenv(fmt.Sprintf("%sHOST", dbPrefix)), - Port: uint(portInt), - DatabaseName: os.Getenv(fmt.Sprintf("%sDATABASE_NAME", dbPrefix)), - RequireSSL: prodSettings.Database.RequireSSL, - }, - SuperUser: SuperUserSettings{ - Password: os.Getenv(fmt.Sprintf("%sPASSWORD", superUserPrefix)), - }, - Auth: AuthSettings{ - AccessToken: os.Getenv(fmt.Sprintf("%sACCESS_TOKEN", authSecretPrefix)), - RefreshToken: os.Getenv(fmt.Sprintf("%sREFRESH_TOKEN", authSecretPrefix)), - AccessTokenExpiry: uint(authAccessExpiryInt), - RefreshTokenExpiry: uint(authRefreshExpiryInt), - }, - }, nil + return readProd(v) } } diff --git a/backend/src/config/database.go b/backend/src/config/database.go new file mode 100644 index 000000000..422d17b69 --- /dev/null +++ b/backend/src/config/database.go @@ -0,0 +1,58 @@ +package config + +import ( + "errors" + "fmt" + + m "github.com/garrettladley/mattress" +) + +type DatabaseSettings struct { + Username string + Password *m.Secret[string] + Port uint + Host string + DatabaseName string + RequireSSL bool +} + +func (int *intermediateDatabaseSettings) into() (*DatabaseSettings, error) { + password, err := m.NewSecret(int.Password) + if err != nil { + return nil, errors.New("failed to create secret from password") + } + + return &DatabaseSettings{ + Username: int.Username, + Password: password, + Port: int.Port, + Host: int.Host, + DatabaseName: int.DatabaseName, + RequireSSL: int.RequireSSL, + }, nil +} + +type intermediateDatabaseSettings struct { + Username string `yaml:"username"` + Password string `yaml:"password"` + Port uint `yaml:"port"` + Host string `yaml:"host"` + DatabaseName string `yaml:"databasename"` + RequireSSL bool `yaml:"requiressl"` +} + +func (s *DatabaseSettings) WithoutDb() string { + var sslMode string + if s.RequireSSL { + sslMode = "require" + } else { + sslMode = "disable" + } + + return fmt.Sprintf("host=%s port=%d user=%s password=%s sslmode=%s", + s.Host, s.Port, s.Username, s.Password.Expose(), sslMode) +} + +func (s *DatabaseSettings) WithDb() string { + return fmt.Sprintf("%s dbname=%s", s.WithoutDb(), s.DatabaseName) +} diff --git a/backend/src/config/local.go b/backend/src/config/local.go new file mode 100644 index 000000000..bac72a155 --- /dev/null +++ b/backend/src/config/local.go @@ -0,0 +1,30 @@ +package config + +import ( + "fmt" + + "github.com/spf13/viper" +) + +func readLocal(v *viper.Viper) (*Settings, error) { + var intermediateSettings intermediateSettings + + env := string(EnvironmentLocal) + + v.SetConfigName(env) + + if err := v.ReadInConfig(); err != nil { + return nil, fmt.Errorf("failed to read %s configuration: %w", env, err) + } + + if err := v.Unmarshal(&intermediateSettings); err != nil { + return nil, fmt.Errorf("failed to unmarshal configuration: %w", err) + } + + settings, err := intermediateSettings.into() + if err != nil { + return nil, fmt.Errorf("failed to convert intermediate settings into final settings: %w", err) + } + + return settings, nil +} diff --git a/backend/src/config/production.go b/backend/src/config/production.go new file mode 100644 index 000000000..469a43574 --- /dev/null +++ b/backend/src/config/production.go @@ -0,0 +1,111 @@ +package config + +import ( + "errors" + "fmt" + "os" + "strconv" + + m "github.com/garrettladley/mattress" + "github.com/spf13/viper" +) + +type ProductionSettings struct { + Database ProductionDatabaseSettings `yaml:"database"` + Application ProductionApplicationSettings `yaml:"application"` +} + +type ProductionDatabaseSettings struct { + RequireSSL bool `yaml:"requiressl"` +} + +type ProductionApplicationSettings struct { + Port uint16 `yaml:"port"` + Host string `yaml:"host"` +} + +func readProd(v *viper.Viper) (*Settings, error) { + var prodSettings ProductionSettings + + env := string(EnvironmentProduction) + + v.SetConfigName(env) + + if err := v.ReadInConfig(); err != nil { + return nil, fmt.Errorf("failed to read %s configuration: %w", env, err) + } + + if err := v.Unmarshal(&prodSettings); err != nil { + return nil, fmt.Errorf("failed to unmarshal configuration: %w", err) + } + + appPrefix := "APP_" + applicationPrefix := fmt.Sprintf("%sAPPLICATION__", appPrefix) + dbPrefix := fmt.Sprintf("%sDATABASE__", appPrefix) + superUserPrefix := fmt.Sprintf("%sSUPERUSER__", appPrefix) + authSecretPrefix := fmt.Sprintf("%sAUTHSECRET__", appPrefix) + + authAccessExpiry := os.Getenv(fmt.Sprintf("%sACCESS_TOKEN_EXPIRY", authSecretPrefix)) + authRefreshExpiry := os.Getenv(fmt.Sprintf("%sREFRESH_TOKEN_EXPIRY", authSecretPrefix)) + + authAccessExpiryInt, err := strconv.ParseUint(authAccessExpiry, 10, 16) + if err != nil { + return nil, fmt.Errorf("failed to parse access token expiry: %w", err) + } + + authRefreshExpiryInt, err := strconv.ParseUint(authRefreshExpiry, 10, 16) + if err != nil { + return nil, fmt.Errorf("failed to parse refresh token expiry: %w", err) + } + + portStr := os.Getenv(fmt.Sprintf("%sPORT", appPrefix)) + portInt, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return nil, fmt.Errorf("failed to parse port: %w", err) + } + + dbPassword, err := m.NewSecret(os.Getenv(fmt.Sprintf("%sUSERNAME", dbPrefix))) + if err != nil { + return nil, errors.New("failed to create secret from database password") + } + + superPassword, err := m.NewSecret(os.Getenv(fmt.Sprintf("%sPASSWORD", superUserPrefix))) + if err != nil { + return nil, errors.New("failed to create secret from super user password") + } + + authAccessKey, err := m.NewSecret(os.Getenv(fmt.Sprintf("%sACCESS_TOKEN", authSecretPrefix))) + if err != nil { + return nil, errors.New("failed to create secret from access token") + } + + authRefreshKey, err := m.NewSecret(os.Getenv(fmt.Sprintf("%sREFRESH_TOKEN", authSecretPrefix))) + if err != nil { + return nil, errors.New("failed to create secret from refresh token") + } + + return &Settings{ + Application: ApplicationSettings{ + Port: uint16(portInt), + Host: prodSettings.Application.Host, + BaseUrl: os.Getenv(fmt.Sprintf("%sBASE_URL", applicationPrefix)), + }, + Database: DatabaseSettings{ + Username: os.Getenv(fmt.Sprintf("%sUSERNAME", dbPrefix)), + Password: dbPassword, + Host: os.Getenv(fmt.Sprintf("%sHOST", dbPrefix)), + Port: uint(portInt), + DatabaseName: os.Getenv(fmt.Sprintf("%sDATABASE_NAME", dbPrefix)), + RequireSSL: prodSettings.Database.RequireSSL, + }, + SuperUser: SuperUserSettings{ + Password: superPassword, + }, + Auth: AuthSettings{ + AccessKey: authAccessKey, + RefreshKey: authRefreshKey, + AccessTokenExpiry: uint(authAccessExpiryInt), + RefreshTokenExpiry: uint(authRefreshExpiryInt), + }, + }, nil +} diff --git a/backend/src/config/super_user.go b/backend/src/config/super_user.go new file mode 100644 index 000000000..2a75c88ef --- /dev/null +++ b/backend/src/config/super_user.go @@ -0,0 +1,26 @@ +package config + +import ( + "errors" + + m "github.com/garrettladley/mattress" +) + +type SuperUserSettings struct { + Password *m.Secret[string] +} + +type intermediateSuperUserSettings struct { + Password string `yaml:"password"` +} + +func (int *intermediateSuperUserSettings) into() (*SuperUserSettings, error) { + password, err := m.NewSecret(int.Password) + if err != nil { + return nil, errors.New("failed to create secret from password") + } + + return &SuperUserSettings{ + Password: password, + }, nil +} diff --git a/backend/src/controllers/auth.go b/backend/src/controllers/auth.go index 546f407e8..8ab0b7bd5 100644 --- a/backend/src/controllers/auth.go +++ b/backend/src/controllers/auth.go @@ -99,7 +99,7 @@ func (a *AuthController) Refresh(c *fiber.Ctx) error { refreshTokenValue := c.Cookies("refresh_token") // Extract id from refresh token - claims, err := auth.ExtractRefreshClaims(refreshTokenValue, a.AuthSettings.RefreshToken) + claims, err := auth.ExtractRefreshClaims(refreshTokenValue, a.AuthSettings.RefreshKey) if err != nil { return err.FiberError(c) } @@ -109,7 +109,7 @@ func (a *AuthController) Refresh(c *fiber.Ctx) error { return err.FiberError(c) } - accessToken, err := auth.RefreshAccessToken(refreshTokenValue, string(*role), a.AuthSettings.RefreshToken, a.AuthSettings.AccessTokenExpiry, a.AuthSettings.AccessToken) + accessToken, err := auth.RefreshAccessToken(refreshTokenValue, string(*role), a.AuthSettings.RefreshKey, a.AuthSettings.AccessTokenExpiry, a.AuthSettings.AccessKey) if err != nil { return err.FiberError(c) } diff --git a/backend/src/database/super.go b/backend/src/database/super.go index 22700693e..93c065b66 100644 --- a/backend/src/database/super.go +++ b/backend/src/database/super.go @@ -11,7 +11,7 @@ import ( var SuperUserUUID uuid.UUID func SuperUser(superUserSettings config.SuperUserSettings) (*models.User, *errors.Error) { - passwordHash, err := auth.ComputePasswordHash(superUserSettings.Password) + passwordHash, err := auth.ComputePasswordHash(superUserSettings.Password.Expose()) if err != nil { return nil, &errors.FailedToComputePasswordHash } diff --git a/backend/src/main.go b/backend/src/main.go index 2c2f73672..ca83de46c 100644 --- a/backend/src/main.go +++ b/backend/src/main.go @@ -23,7 +23,7 @@ func main() { panic(fmt.Sprintf("Error getting configuration: %s", err.Error())) } - db, err := database.ConfigureDB(config) + db, err := database.ConfigureDB(*config) if err != nil { panic(fmt.Sprintf("Error configuring database: %s", err.Error())) } @@ -33,7 +33,7 @@ func main() { panic(err) } - app := server.Init(db, config) + app := server.Init(db, *config) err = app.Listen(fmt.Sprintf("%s:%d", config.Application.Host, config.Application.Port)) if err != nil { diff --git a/backend/src/middleware/auth.go b/backend/src/middleware/auth.go index 3bba7c497..dc49e8f66 100644 --- a/backend/src/middleware/auth.go +++ b/backend/src/middleware/auth.go @@ -38,7 +38,7 @@ func (m *MiddlewareService) Authenticate(c *fiber.Ctx) error { return c.Next() } - token, err := auth.ParseAccessToken(c.Cookies("access_token"), m.AuthSettings.AccessToken) + token, err := auth.ParseAccessToken(c.Cookies("access_token"), m.AuthSettings.AccessKey) if err != nil { return errors.FailedToParseAccessToken.FiberError(c) } @@ -68,7 +68,7 @@ func (m *MiddlewareService) Authorize(requiredPermissions ...types.Permission) f return c.Next() } - role, err := auth.GetRoleFromToken(c.Cookies("access_token"), m.AuthSettings.AccessToken) + role, err := auth.GetRoleFromToken(c.Cookies("access_token"), m.AuthSettings.AccessKey) if err != nil { return errors.FailedToParseAccessToken.FiberError(c) } diff --git a/backend/src/middleware/club.go b/backend/src/middleware/club.go index c64fa7068..dc848b75d 100644 --- a/backend/src/middleware/club.go +++ b/backend/src/middleware/club.go @@ -17,7 +17,7 @@ func (m *MiddlewareService) ClubAuthorizeById(c *fiber.Ctx) error { return errors.FailedToValidateID.FiberError(c) } - token, tokenErr := auth.ParseAccessToken(c.Cookies("access_token"), m.AuthSettings.AccessToken) + token, tokenErr := auth.ParseAccessToken(c.Cookies("access_token"), m.AuthSettings.AccessKey) if tokenErr != nil { return errors.FailedToParseAccessToken.FiberError(c) } diff --git a/backend/src/middleware/user.go b/backend/src/middleware/user.go index 308372a4f..804a1295b 100644 --- a/backend/src/middleware/user.go +++ b/backend/src/middleware/user.go @@ -14,7 +14,7 @@ func (m *MiddlewareService) UserAuthorizeById(c *fiber.Ctx) error { return errors.FailedToValidateID.FiberError(c) } - token, tokenErr := auth.ParseAccessToken(c.Cookies("access_token"), m.AuthSettings.AccessToken) + token, tokenErr := auth.ParseAccessToken(c.Cookies("access_token"), m.AuthSettings.AccessKey) if tokenErr != nil { return err } diff --git a/backend/tests/api/helpers/app.go b/backend/tests/api/helpers/app.go index cb1689e72..68b4acd55 100644 --- a/backend/tests/api/helpers/app.go +++ b/backend/tests/api/helpers/app.go @@ -33,16 +33,16 @@ func spawnApp() (*TestApp, error) { configuration.Database.DatabaseName = generateRandomDBName() - connectionWithDB, err := configureDatabase(configuration) + connectionWithDB, err := configureDatabase(*configuration) if err != nil { return nil, err } return &TestApp{ - App: server.Init(connectionWithDB, configuration), + App: server.Init(connectionWithDB, *configuration), Address: fmt.Sprintf("http://%s", listener.Addr().String()), Conn: connectionWithDB, - Settings: configuration, + Settings: *configuration, }, nil } diff --git a/backend/tests/api/helpers/auth.go b/backend/tests/api/helpers/auth.go index 76fb3033a..eb65640e3 100644 --- a/backend/tests/api/helpers/auth.go +++ b/backend/tests/api/helpers/auth.go @@ -40,7 +40,7 @@ func (app *TestApp) authSuper() { Path: "/api/v1/auth/login", Body: &map[string]interface{}{ "email": email, - "password": password, + "password": password.Expose(), }, }) if err != nil { @@ -65,7 +65,7 @@ func (app *TestApp) authSuper() { app.TestUser = &TestUser{ UUID: database.SuperUserUUID, Email: email, - Password: password, + Password: password.Expose(), AccessToken: accessToken, RefreshToken: refreshToken, } diff --git a/backend/tests/auth_test.go b/backend/tests/auth_test.go index d7d49bb40..3bfc1f530 100644 --- a/backend/tests/auth_test.go +++ b/backend/tests/auth_test.go @@ -6,17 +6,28 @@ import ( "github.com/GenerateNU/sac/backend/src/auth" "github.com/GenerateNU/sac/backend/src/config" + m "github.com/garrettladley/mattress" "github.com/golang-jwt/jwt" "github.com/huandu/go-assert" ) -func AuthSettings() config.AuthSettings { - return config.AuthSettings{ - AccessToken: "g(r|##*?>\\Qp}h37e+,T2", +func AuthSettings() (*config.AuthSettings, error) { + accessKey, err := m.NewSecret("g(r|##*?>\\Qp}h37e+,T2") + if err != nil { + return nil, err + } + + refreshKey, err := m.NewSecret("amk*2!gG}1i\"8D9RwJS$p") + if err != nil { + return nil, err + } + + return &config.AuthSettings{ + AccessKey: accessKey, AccessTokenExpiry: 60, - RefreshToken: "amk*2!gG}1i\"8D9RwJS$p", + RefreshKey: refreshKey, RefreshTokenExpiry: 30, - } + }, nil } func TestCreateTokenPairSuccess(t *testing.T) { @@ -25,9 +36,12 @@ func TestCreateTokenPairSuccess(t *testing.T) { id := "user123" role := "admin" - accessToken, refreshToken, err := auth.CreateTokenPair(id, role, AuthSettings()) + authSettings, err := AuthSettings() + assert.NilError(err) + + accessToken, refreshToken, authErr := auth.CreateTokenPair(id, role, *authSettings) - assert.Assert(err == nil) + assert.Assert(authErr == nil) assert.Assert(accessToken != nil) assert.Assert(refreshToken != nil) @@ -39,9 +53,13 @@ func TestCreateTokenPairFailure(t *testing.T) { id := "user123" role := "" - accessToken, refreshToken, err := auth.CreateTokenPair(id, role, AuthSettings()) + authSettings, err := AuthSettings() + + assert.NilError(err) + + accessToken, refreshToken, authErr := auth.CreateTokenPair(id, role, *authSettings) - assert.Assert(err != nil) + assert.Assert(authErr != nil) assert.Assert(accessToken == nil) assert.Assert(refreshToken == nil) @@ -53,11 +71,13 @@ func TestCreateAccessTokenSuccess(t *testing.T) { id := "user123" role := "admin" - authSettings := AuthSettings() + authSettings, err := AuthSettings() - accessToken, err := auth.CreateAccessToken(id, role, authSettings.AccessTokenExpiry, authSettings.AccessToken) + assert.NilError(err) - assert.Assert(err == nil) + accessToken, authErr := auth.CreateAccessToken(id, role, authSettings.AccessTokenExpiry, authSettings.AccessKey) + + assert.Assert(authErr == nil) assert.Assert(accessToken != nil) } @@ -68,11 +88,13 @@ func TestCreateAccessTokenFailure(t *testing.T) { id := "user123" role := "" - authSettings := AuthSettings() + authSettings, err := AuthSettings() + + assert.NilError(err) - accessToken, err := auth.CreateAccessToken(id, role, authSettings.AccessTokenExpiry, authSettings.AccessToken) + accessToken, authErr := auth.CreateAccessToken(id, role, authSettings.AccessTokenExpiry, authSettings.AccessKey) - assert.Assert(err != nil) + assert.Assert(authErr != nil) assert.Assert(accessToken == nil) } @@ -82,11 +104,13 @@ func TestCreateRefreshTokenSuccess(t *testing.T) { id := "user123" - authSettings := AuthSettings() + authSettings, err := AuthSettings() - refreshToken, err := auth.CreateRefreshToken(id, authSettings.RefreshTokenExpiry, authSettings.RefreshToken) + assert.NilError(err) - assert.Assert(err == nil) + refreshToken, authErr := auth.CreateRefreshToken(id, authSettings.RefreshTokenExpiry, authSettings.RefreshKey) + + assert.Assert(authErr == nil) assert.Assert(refreshToken != nil) } @@ -96,11 +120,13 @@ func TestCreateRefreshTokenFailure(t *testing.T) { id := "" - authSettings := AuthSettings() + authSettings, err := AuthSettings() + + assert.NilError(err) - refreshToken, err := auth.CreateRefreshToken(id, authSettings.RefreshTokenExpiry, authSettings.RefreshToken) + refreshToken, authErr := auth.CreateRefreshToken(id, authSettings.RefreshTokenExpiry, authSettings.RefreshKey) - assert.Assert(err != nil) + assert.Assert(authErr != nil) assert.Assert(refreshToken == nil) } @@ -119,9 +145,13 @@ func TestSignTokenSuccess(t *testing.T) { "iss": "sac", } - signedToken, err := auth.SignToken(token, "secret") + key, err := m.NewSecret("secret") - assert.Assert(err == nil) + assert.NilError(err) + + signedToken, authErr := auth.SignToken(token, key) + + assert.Assert(authErr == nil) assert.Assert(signedToken != nil) } @@ -140,9 +170,13 @@ func TestSignTokenFailure(t *testing.T) { "iss": "sac", } - signedToken, err := auth.SignToken(token, "") + key, err := m.NewSecret("") + + assert.NilError(err) + + signedToken, authErr := auth.SignToken(token, key) - assert.Assert(err != nil) + assert.Assert(authErr != nil) assert.Assert(signedToken == nil) } diff --git a/config/local.yml b/config/local.yml index 57da31f0f..2ad38c2dd 100644 --- a/config/local.yml +++ b/config/local.yml @@ -12,7 +12,7 @@ database: superuser: password: password auth: - accesstoken: g(r|##*?>\Qp}h37e+,T2 + accesskey: g(r|##*?>\Qp}h37e+,T2 accesstokenexpiry: 60 # in minutes - refreshtoken: amk*2!gG}1i"8D9RwJS$p + refreshkey: amk*2!gG}1i"8D9RwJS$p refreshtokenexpiry: 30 # in days diff --git a/go.work b/go.work index 5a962e04b..9bb923f4e 100644 --- a/go.work +++ b/go.work @@ -1,4 +1,4 @@ -go 1.21.1 +go 1.21.6 use ( ./backend diff --git a/go.work.sum b/go.work.sum index a9dfabbc1..52c5208f1 100644 --- a/go.work.sum +++ b/go.work.sum @@ -1,6 +1,8 @@ github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/garrettladley/mattress v0.2.0 h1:+XUdsv9NO2s4JL+8exvAFziw0b1kv/0WlQo2Dlxat+w= +github.com/garrettladley/mattress v0.2.0/go.mod h1:OWKIRc9wC3gtD3Ng/nUuNEiR1TJvRYLmn/KZYw9nl5Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= @@ -11,6 +13,7 @@ github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrk github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/philhofer/fwd v1.1.2/go.mod h1:qkPdfjR2SIEbspLqpe1tO4n5yICnr2DY7mqEx2tUTP0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/tinylib/msgp v1.1.8/go.mod h1:qkpG+2ldGg4xRFmx+jfTvZPxfGFhi64BcnL9vkCm/Tw=