diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 459f05c..c00934a 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -15,9 +15,7 @@ "postCreateCommand": "bash scripts/postCreateCommand.sh", "features": { "ghcr.io/devcontainers/features/docker-in-docker:2": {}, - "ghcr.io/devcontainers/features/go:1": { - "version": "1.22" - } + "ghcr.io/devcontainers/features/go:1": {} }, "forwardPorts": [ 3000 @@ -38,4 +36,4 @@ "onAutoForward": "openPreview" } } -} \ No newline at end of file +} diff --git a/.golangci.yml b/.golangci.yml index ba73a58..8556941 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,15 +1,10 @@ run: - deadline: 6m - - skip-files: - - "zz_generated\\..+\\.go$" - - skip-dirs: - - vendor$ + timeout: 6m output: # colored-line-number|line-number|json|tab|checkstyle|code-climate, default is "colored-line-number" - format: colored-line-number + formats: + - format: colored-line-number linters-settings: errcheck: @@ -21,19 +16,15 @@ linters-settings: # default is false: such cases aren't reported by default. check-blank: false - # [deprecated] comma-separated list of pairs of the form pkg:regex - # the regex is used to ignore names within pkg. (default "fmt:.*"). - # see https://github.com/kisielk/errcheck#the-deprecated-method for details - ignore: fmt:.*,io/ioutil:^Read.* + # report about not checking of errors in assignments: `num, err := strconv.Atoi(numStr)`; + exclude-functions: + - fmt:.* + - io/ioutil:^Read.* govet: # report about shadowed variables check-shadowing: false - golint: - # minimal confidence for issues, default is 0.8 - min-confidence: 0.8 - gofmt: # simplify code: gofmt with `-s` option, true by default simplify: true @@ -47,10 +38,6 @@ linters-settings: # minimal code complexity to report, 30 by default (but we recommend 10-20) min-complexity: 10 - maligned: - # print struct with more effective memory layout or not, false by default - suggest-new: true - dupl: # tokens count to trigger issue, 150 by default threshold: 100 @@ -109,26 +96,26 @@ linters-settings: severity: warning confidence: 0.8 - linters: enable: - - megacheck - govet - gocyclo - gocritic + - gosimple + - staticcheck + - unused - goconst - goimports - - gofmt # We enable this as well as goimports for its simplify mode. + - gofmt # We enable this as well as goimports for its simplify mode. - prealloc - revive - unconvert - misspell - nakedret - - exportloopref + - copyloopvar - gosec disable: - - scopelint - errcheck presets: @@ -136,8 +123,12 @@ linters: - unused fast: false - issues: + exclude-files: + - "zz_generated\\..+\\.go$" + exclude-dirs: + - vendor$ + exclude: - "G103: Use of unsafe calls should be audited" @@ -164,31 +155,36 @@ issues: # rather than using a pointer. - text: "(hugeParam|rangeValCopy):" linters: - - gocritic + - gocritic # This "TestMain should call os.Exit to set exit code" warning is not clever # enough to notice that we call a helper method that calls os.Exit. - text: "SA3000:" linters: - - staticcheck + - staticcheck - text: "k8s.io/api/core/v1" linters: - - goimports + - goimports # This is a "potential hardcoded credentials" warning. It's triggered by # any variable with 'secret' in the same, and thus hits a lot of false # positives in Kubernetes land where a Secret is an object type. - text: "G101:" linters: - - gosec - - gas + - gosec + - gas # This is an 'errors unhandled' warning that duplicates errcheck. - text: "G104:" linters: - - gosec - - gas + - gosec + - gas + + - text: "G115:" + linters: + - gosec + - gas # Independently from option `exclude` we use default exclude patterns, # it can be disabled by this option. To list all @@ -208,4 +204,4 @@ issues: max-per-linter: 0 # Maximum count of issues with the same text. Set to 0 to disable. Default is 3. - max-same-issues: 0 \ No newline at end of file + max-same-issues: 0 diff --git a/adapters/adapter.go b/adapters/adapter.go index 3f35f8d..ceb1fef 100644 --- a/adapters/adapter.go +++ b/adapters/adapter.go @@ -14,7 +14,6 @@ func init() { gob.Register(&GothAccount{}) gob.Register(&GothUser{}) gob.Register(&GothSession{}) - gob.Register(&GothTeam{}) gob.Register(&GothVerificationToken{}) gob.Register(&GothCsrfToken{}) } @@ -103,8 +102,6 @@ type GothUser struct { Accounts []GothAccount `json:"accounts" gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"` // Sessions are the sessions of the user. Sessions []GothSession `json:"sessions" gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"` - // Teams are the teams the user is a member of. - Teams *[]GothTeam `json:"teams" gorm:"many2many:goth_team_users"` // CreatedAt is the creation time of the user. CreatedAt time.Time `json:"created_at"` // UpdatedAt is the update time of the user. @@ -113,17 +110,6 @@ type GothUser struct { DeletedAt gorm.DeletedAt `json:"deleted_at"` } -// TeamBySlug is returning the team with the given ID. -func (u GothUser) TeamBySlug(slug string) GothTeam { - for _, team := range *u.Teams { - if team.Slug == slug { - return team - } - } - - return GothTeam{} -} - // GothSession is a session for a user. type GothSession struct { // ID is the unique identifier of the session. @@ -190,48 +176,6 @@ type GothVerificationToken struct { DeletedAt gorm.DeletedAt `json:"deleted_at"` } -// GothTeam is a team in the application. -type GothTeam struct { - // ID is the unique identifier of the team. - ID uuid.UUID `json:"id" gorm:"primaryKey;unique;type:uuid;column:id;default:gen_random_uuid()"` - // Name is the name of the team. - Name string `json:"name" validate:"required,max=255"` - // Slug is the slug of the team. - Slug string `json:"slug" validate:"required,min=3,max=255"` - // Description is the description of the team. - Description string `json:"description" validate:"max=255"` - // Users are the users in the team. - Users []GothUser `json:"users" gorm:"many2many:goth_team_users"` - // Roles are the roles in the team. - Roles []GothRole `json:"roles" gorm:"foreignKey:TeamID;constraint:OnDelete:CASCADE"` - // CreatedAt is the creation time of the team. - CreatedAt time.Time `json:"created_at"` - // UpdatedAt is the update time of the team. - UpdatedAt time.Time `json:"updated_at"` - // DeletedAt is the deletion time of the team. - DeletedAt gorm.DeletedAt `json:"deleted_at"` -} - -// GothRole is a role in the application. -type GothRole struct { - // ID is the unique identifier of the role. - ID uuid.UUID `json:"id" gorm:"primaryKey;unique;type:uuid;column:id;default:gen_random_uuid()"` - // Name is the name of the role. - Name string `json:"name" validate:"required,min=3,max=255"` - // Description is the description of the role. - Description string `json:"description" validate:"max=255"` - // TeamID is the team ID of the role. - TeamID uuid.UUID `json:"team_id"` - // Team is the team of the role. - Team GothTeam `json:"team"` - // CreatedAt is the creation time of the role. - CreatedAt time.Time `json:"created_at"` - // UpdatedAt is the update time of the role. - UpdatedAt time.Time `json:"updated_at"` - // DeletedAt is the deletion time of the role. - DeletedAt gorm.DeletedAt `json:"deleted_at"` -} - // Adapter is an interface that defines the methods for interacting with the underlying data storage. type Adapter interface { // CreateUser creates a new user. @@ -262,6 +206,12 @@ type Adapter interface { CreateVerificationToken(ctx context.Context, verficationToken GothVerificationToken) (GothVerificationToken, error) // UseVerficationToken uses a verification token. UseVerficationToken(ctx context.Context, identifier string, token string) (GothVerificationToken, error) + // CreateCsrfToken creates a new CSRF token. + CreateCsrfToken(ctx context.Context, csrfToken GothCsrfToken) (GothCsrfToken, error) + // GetCsrfToken retrieves a CSRF token by token. + GetCsrfToken(ctx context.Context, token string) (GothCsrfToken, error) + // DeleteCsrfToken deletes a CSRF token by token. + DeleteCsrfToken(ctx context.Context, token string) error } var _ Adapter = (*UnimplementedAdapter)(nil) @@ -343,3 +293,18 @@ func (a *UnimplementedAdapter) CreateVerificationToken(_ context.Context, erfica func (a *UnimplementedAdapter) UseVerficationToken(_ context.Context, identifier string, token string) (GothVerificationToken, error) { return GothVerificationToken{}, ErrUnimplemented } + +// CreateCsrfToken creates a new CSRF token. +func (a *UnimplementedAdapter) CreateCsrfToken(_ context.Context, csrfToken GothCsrfToken) (GothCsrfToken, error) { + return GothCsrfToken{}, ErrUnimplemented +} + +// GetCsrfToken retrieves a CSRF token by token. +func (a *UnimplementedAdapter) GetCsrfToken(_ context.Context, token string) (GothCsrfToken, error) { + return GothCsrfToken{}, ErrUnimplemented +} + +// DeleteCsrfToken deletes a CSRF token by token. +func (a *UnimplementedAdapter) DeleteCsrfToken(_ context.Context, token string) error { + return ErrUnimplemented +} diff --git a/adapters/gorm/gorm.go b/adapters/gorm/gorm.go index 2dcfd62..328868a 100644 --- a/adapters/gorm/gorm.go +++ b/adapters/gorm/gorm.go @@ -19,8 +19,6 @@ func RunMigrations(db *gorm.DB) error { &adapters.GothUser{}, &adapters.GothSession{}, &adapters.GothVerificationToken{}, - &adapters.GothTeam{}, - &adapters.GothRole{}, ) } @@ -28,13 +26,12 @@ var _ adapters.Adapter = (*gormAdapter)(nil) type gormAdapter struct { db *gorm.DB - adapters.UnimplementedAdapter } -// New ... +// New is a helper function to create a new adapter. func New(db *gorm.DB) *gormAdapter { - return &gormAdapter{db, adapters.UnimplementedAdapter{}} + return &gormAdapter{db: db} } // CreateUser is a helper function to create a new user. @@ -50,11 +47,7 @@ func (a *gormAdapter) CreateUser(ctx context.Context, user adapters.GothUser) (a // GetSession is a helper function to retrieve a session by session token. func (a *gormAdapter) GetSession(ctx context.Context, sessionToken string) (adapters.GothSession, error) { var session adapters.GothSession - err := a.db.WithContext(ctx). - Preload(clause.Associations). - Preload("User.Teams"). - Preload("User.Teams.Roles"). - Where("session_token = ?", sessionToken).First(&session).Error + err := a.db.WithContext(ctx).Preload(clause.Associations).Where("session_token = ?", sessionToken).First(&session).Error if err != nil { return adapters.GothSession{}, goth.ErrMissingSession } @@ -65,7 +58,7 @@ func (a *gormAdapter) GetSession(ctx context.Context, sessionToken string) (adap // GetUser is a helper function to retrieve a user by ID. func (a *gormAdapter) GetUser(ctx context.Context, id uuid.UUID) (adapters.GothUser, error) { var user adapters.GothUser - err := a.db.WithContext(ctx).Preload("Accounts").Where("id = ?", id).First(&user).Error + err := a.db.WithContext(ctx).Preload(clause.Associations).Where("id = ?", id).First(&user).Error if err != nil { return adapters.GothUser{}, goth.ErrMissingUser } @@ -80,10 +73,11 @@ func (a *gormAdapter) CreateSession(ctx context.Context, userID uuid.UUID, expir SessionToken: uuid.NewString(), ExpiresAt: expires, CsrfToken: adapters.GothCsrfToken{ - Token: uuid.NewString(), // creates a token that is used to prevent CSRF attacks - ExpiresAt: time.Now().Add(24 * time.Hour), + Token: uuid.NewString(), // creates a token that is used to prevent CSRF attacks + ExpiresAt: time.Now().Add(24 * time.Hour), // expires in 24 hours }, } + err := a.db.Session(&gorm.Session{FullSaveAssociations: true}).WithContext(ctx).Create(&session).Error if err != nil { return adapters.GothSession{}, goth.ErrBadSession diff --git a/csrf/csrf.go b/csrf/csrf.go new file mode 100644 index 0000000..7979a67 --- /dev/null +++ b/csrf/csrf.go @@ -0,0 +1,264 @@ +package csrf + +import ( + "time" + + "github.com/google/uuid" + "github.com/valyala/fasthttp" + "github.com/zeiss/fiber-goth/adapters" + "github.com/zeiss/pkg/utilx" + + "github.com/gofiber/fiber/v2" +) + +var ( + // ErrMissingHeader is returned when the token is missing from the request. + ErrMissingHeader = fiber.NewError(fiber.StatusForbidden, "missing csrf token in header") + // ErrTokenNotFound is returned when the token is not found in the session. + ErrTokenNotFound = fiber.NewError(fiber.StatusForbidden, "csrf token not found in session") +) + +// HeaderName is the default header name used to extract the token. +const HeaderName = "X-Csrf-Token" + +// The contextKey type is unexported to prevent collisions with context keys defined in +// other packages. +type contextKey int + +const ( + csrfTokenKey contextKey = iota +) + +// Config defines the config for csrf middleware. +type Config struct { + // Next defines a function to skip this middleware when returned true. + Next func(c *fiber.Ctx) bool + + // Adapter is the adapter used to store the session. + // Adapter adapters.Adapter + Adapter adapters.Adapter + + // ErrorHandler is executed when an error is returned from fiber.Handler. + // + // Optional. Default: DefaultErrorHandler + ErrorHandler fiber.ErrorHandler + + // Extractor is the function used to extract the token from the request. + Extractor func(c *fiber.Ctx) (string, error) + + // Indicates if CSRF cookie is secure. + // Optional. Default value false. + CookieSecure bool + + // Decides whether cookie should last for only the browser sesison. + // Ignores Expiration if set to true + CookieSessionOnly bool + + // SingleUseToken indicates if the CSRF token be destroyed + // and a new one generated on each use. + // + // Optional. Default: false + SingleUseToken bool + + // CookieName is the name of the cookie used to store the session. + CookieName string + + // CookieSameSite is the SameSite attribute of the cookie. + CookieSameSite fasthttp.CookieSameSite + + // CookiePath is the path of the cookie. + CookiePath string + + // CookieDomain is the domain of the cookie. + CookieDomain string + + // CookieHTTPOnly is the HTTPOnly attribute of the cookie. + CookieHTTPOnly bool + + // TrustedOrigins is a list of origins that are allowed to set the cookie. + TrustedOrigins []string + + // IdleTimeout is the duration of time before the session expires. + IdleTimeout time.Duration + + // TokenGenerator is a function that generates a CSRF token. + TokenGenerator CsrfTokenGenerator +} + +// ConfigDefault is the default config. +var ConfigDefault = Config{ + IdleTimeout: 30 * time.Minute, + CookieName: "csrf_", + CookieSameSite: fasthttp.CookieSameSiteLaxMode, + ErrorHandler: defaultErrorHandler, + Extractor: FromHeader(HeaderName), + TokenGenerator: DefaultCsrfTokenGenerator, +} + +// CsrfTokenGenerator is a function that generates a CSRF token. +type CsrfTokenGenerator func() (string, error) + +// DefaultCsrfTokenGenerator generates a new CSRF token. +func DefaultCsrfTokenGenerator() (string, error) { + token, err := uuid.NewV7() + if err != nil { + return "", err + } + + return token.String(), nil +} + +// default ErrorHandler that process return error from fiber.Handler +func defaultErrorHandler(_ *fiber.Ctx, _ error) error { + return fiber.ErrForbidden +} + +// Helper function to set default values +// nolint:gocyclo +func configDefault(config ...Config) Config { + if len(config) < 1 { + return ConfigDefault + } + + // Override default config + cfg := config[0] + + if cfg.IdleTimeout <= 0 { + cfg.IdleTimeout = ConfigDefault.IdleTimeout + } + + if cfg.CookieName == "" { + cfg.CookieName = ConfigDefault.CookieName + } + + if cfg.CookieSameSite == 0 { + cfg.CookieSameSite = ConfigDefault.CookieSameSite + } + + if cfg.ErrorHandler == nil { + cfg.ErrorHandler = ConfigDefault.ErrorHandler + } + + if cfg.Extractor == nil { + cfg.Extractor = ConfigDefault.Extractor + } + + if cfg.TokenGenerator == nil { + cfg.TokenGenerator = ConfigDefault.TokenGenerator + } + + return cfg +} + +// Handler ... +type Handler struct { + config Config +} + +// New creates a new csrf middleware. +func New(config ...Config) fiber.Handler { + // Set default config + cfg := configDefault(config...) + + // handler := &Handler{ + // config: cfg, + // } + + var token string + + // Return new handler + return func(c *fiber.Ctx) error { + // Skip middleware if Next returns true + if cfg.Next != nil && cfg.Next(c) { + return c.Next() + } + + switch c.Method() { + case fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace: + // cookieToken := c.Cookies(cfg.CookieName) + default: + extractedToken, err := cfg.Extractor(c) + if err != nil { + return cfg.ErrorHandler(c, err) + } + + if utilx.Empty(extractedToken) { + return cfg.ErrorHandler(c, ErrTokenNotFound) + } + + raw := "" + + if utilx.Empty(raw) { + // expire the token + cookieValue := fasthttp.Cookie{} + cookieValue.SetKey(cfg.CookieName) + cookieValue.SetValueBytes([]byte("")) + cookieValue.SetHTTPOnly(cfg.CookieHTTPOnly) + cookieValue.SetSameSite(cfg.CookieSameSite) + cookieValue.SetExpire(time.Now().Add(-time.Hour)) + cookieValue.SetPath(cfg.CookiePath) + cookieValue.SetDomain(cfg.CookieDomain) + cookieValue.SetSecure(cfg.CookieSecure) + + // Set the cookie + c.Response().Header.SetCookie(&cookieValue) + + return cfg.ErrorHandler(c, ErrTokenNotFound) + } + } + + // Generate a new token + if utilx.Empty(token) { + // csrfToken, err := cfg.TokenGenerator() + // if err != nil { + // return cfg.ErrorHandler(c, err) + // } + } + + // Create the cookie + cookieValue := fasthttp.Cookie{} + cookieValue.SetKey(cfg.CookieName) + cookieValue.SetValueBytes([]byte(token)) + cookieValue.SetHTTPOnly(cfg.CookieHTTPOnly) + cookieValue.SetSameSite(cfg.CookieSameSite) + cookieValue.SetExpire(time.Now().Add(cfg.IdleTimeout)) + cookieValue.SetPath(cfg.CookiePath) + cookieValue.SetDomain(cfg.CookieDomain) + cookieValue.SetSecure(cfg.CookieSecure) + + // Set the cookie + c.Response().Header.SetCookie(&cookieValue) + + // Add the token to the context + c.Vary(fiber.HeaderCookie) + + // Add the token to the context + c.Locals(csrfTokenKey, token) + + // Continue stack + return c.Next() + } +} + +// CsrfTokenFromContext returns the csrf token from the context. +func CsrfTokenFromContext(c *fiber.Ctx) string { + token, ok := c.Locals(csrfTokenKey).(string) + if !ok { + return "" + } + + return token +} + +// FromHeader returns a function that extracts token from the request header. +func FromHeader(param string) func(c *fiber.Ctx) (string, error) { + return func(c *fiber.Ctx) (string, error) { + token := c.Get(param) + + if utilx.Empty(token) { + return "", ErrMissingHeader + } + + return token, nil + } +} diff --git a/examples/main.go b/examples/main.go index fc4347d..aeeb458 100644 --- a/examples/main.go +++ b/examples/main.go @@ -163,7 +163,19 @@ var helloTemplate = `
Hello World
` var indexTemplate = `{{range $key,$value:=.Providers}}

Log in with {{index $.ProvidersMap $value}}

-{{end}}` +{{end}} +
+
+ + + + + + + +
+
+` var userTemplate = `

logout

diff --git a/go.mod b/go.mod index e3885ab..6bc5ab9 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module github.com/zeiss/fiber-goth -go 1.23 +go 1.22.1 + +toolchain go1.23.0 require ( github.com/gofiber/fiber/v2 v2.52.5 @@ -9,7 +11,7 @@ require ( github.com/katallaxie/pkg v0.6.6 github.com/spf13/cobra v1.8.1 github.com/valyala/fasthttp v1.56.0 - github.com/zeiss/pkg v0.1.17 + github.com/zeiss/pkg v0.1.13-0.20241019201052-9f5bf9d1a0df golang.org/x/crypto v0.28.0 golang.org/x/oauth2 v0.23.0 gorm.io/driver/postgres v1.5.9 diff --git a/go.sum b/go.sum index ef668d1..556904f 100644 --- a/go.sum +++ b/go.sum @@ -61,8 +61,8 @@ github.com/valyala/fasthttp v1.56.0 h1:bEZdJev/6LCBlpdORfrLu/WOZXXxvrUQSiyniuaoW github.com/valyala/fasthttp v1.56.0/go.mod h1:sReBt3XZVnudxuLOx4J/fMrJVorWRiWY2koQKgABiVI= github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= -github.com/zeiss/pkg v0.1.17 h1:rDvBtaRUSD1ypeu66R3UHMtEphPSBaZ52484BQtPEVI= -github.com/zeiss/pkg v0.1.17/go.mod h1:2k/MCcM0p8KiHJMdUG3Rnx90pE7UfzaGd0GIXm6V7/8= +github.com/zeiss/pkg v0.1.13-0.20241019201052-9f5bf9d1a0df h1:RpHj41NcJ5d4AO8mBz5qYduPk/K6729ioHqlVSn7fxE= +github.com/zeiss/pkg v0.1.13-0.20241019201052-9f5bf9d1a0df/go.mod h1:RAQyzmnyfiXtnHJGb1o8E/Bf1MCiA0FYPBC1RuFpINk= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= diff --git a/providers/credentials/credentials.go b/providers/credentials/credentials.go index 29645ca..8e40cac 100644 --- a/providers/credentials/credentials.go +++ b/providers/credentials/credentials.go @@ -91,7 +91,7 @@ func HashPassword(password string) (string, error) { } // BeginAuth starts the authentication process. -func (e *credentialsProvider) BeginAuth(ctx context.Context, adapter adapters.Adapter, state string, _ providers.AuthParams) (providers.AuthIntent, error) { +func (e *credentialsProvider) BeginAuth(ctx context.Context, adapter adapters.Adapter, state string, params providers.AuthParams) (providers.AuthIntent, error) { return &authIntent{ authURL: "", }, nil