From 01310f9217c858e42d628e46e4e110cfe740ec58 Mon Sep 17 00:00:00 2001 From: yoyofx Date: Wed, 19 Aug 2020 17:43:48 +0800 Subject: [PATCH 1/2] jwt --- Test/jwt_test.go | 25 ++++++ Utils/jwt/claims.go | 134 ++++++++++++++++++++++++++++++++ Utils/jwt/ecdsa.go | 148 ++++++++++++++++++++++++++++++++++++ Utils/jwt/errors.go | 59 ++++++++++++++ Utils/jwt/hmac.go | 95 +++++++++++++++++++++++ Utils/jwt/jwt.go | 51 +++++++++++++ Utils/jwt/map_claims.go | 94 +++++++++++++++++++++++ Utils/jwt/parser.go | 148 ++++++++++++++++++++++++++++++++++++ Utils/jwt/rsa.go | 101 ++++++++++++++++++++++++ Utils/jwt/signing_method.go | 35 +++++++++ Utils/jwt/token.go | 108 ++++++++++++++++++++++++++ go.sum | 2 - 12 files changed, 998 insertions(+), 2 deletions(-) create mode 100644 Test/jwt_test.go create mode 100644 Utils/jwt/claims.go create mode 100644 Utils/jwt/ecdsa.go create mode 100644 Utils/jwt/errors.go create mode 100644 Utils/jwt/hmac.go create mode 100644 Utils/jwt/jwt.go create mode 100644 Utils/jwt/map_claims.go create mode 100644 Utils/jwt/parser.go create mode 100644 Utils/jwt/rsa.go create mode 100644 Utils/jwt/signing_method.go create mode 100644 Utils/jwt/token.go diff --git a/Test/jwt_test.go b/Test/jwt_test.go new file mode 100644 index 00000000..be776d37 --- /dev/null +++ b/Test/jwt_test.go @@ -0,0 +1,25 @@ +package Test + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "github.com/yoyofx/yoyogo/Utils/jwt" + "testing" +) + +func TestCreateToken(t *testing.T) { + SecretKey := []byte("AllYourBase") + token, _ := jwt.CreateToken(SecretKey, "YDQ", 2222) + fmt.Println(token) + + claims, err := jwt.ParseToken(token, SecretKey) + if nil != err { + fmt.Println(" err :", err) + } + fmt.Println("claims:", claims) + fmt.Println("claims uid:", claims.(jwt.MapClaims)["uid"]) + + assert.Equal(t, err, nil) + assert.Equal(t, int(claims.(jwt.MapClaims)["uid"].(float64)), 2222) + assert.Equal(t, claims.(jwt.MapClaims)["iss"], "YDQ") +} diff --git a/Utils/jwt/claims.go b/Utils/jwt/claims.go new file mode 100644 index 00000000..f0228f02 --- /dev/null +++ b/Utils/jwt/claims.go @@ -0,0 +1,134 @@ +package jwt + +import ( + "crypto/subtle" + "fmt" + "time" +) + +// For a type to be a Claims object, it must just have a Valid method that determines +// if the token is invalid for any supported reason +type Claims interface { + Valid() error +} + +// Structured version of Claims Section, as referenced at +// https://tools.ietf.org/html/rfc7519#section-4.1 +// See examples for how to use this with your own claim types +type StandardClaims struct { + Audience string `json:"aud,omitempty"` + ExpiresAt int64 `json:"exp,omitempty"` + Id string `json:"jti,omitempty"` + IssuedAt int64 `json:"iat,omitempty"` + Issuer string `json:"iss,omitempty"` + NotBefore int64 `json:"nbf,omitempty"` + Subject string `json:"sub,omitempty"` +} + +// Validates time based claims "exp, iat, nbf". +// There is no accounting for clock skew. +// As well, if any of the above claims are not in the token, it will still +// be considered a valid claim. +func (c StandardClaims) Valid() error { + vErr := new(ValidationError) + now := TimeFunc().Unix() + + // The claims below are optional, by default, so if they are set to the + // default value in Go, let's not fail the verification for them. + if c.VerifyExpiresAt(now, false) == false { + delta := time.Unix(now, 0).Sub(time.Unix(c.ExpiresAt, 0)) + vErr.Inner = fmt.Errorf("token is expired by %v", delta) + vErr.Errors |= ValidationErrorExpired + } + + if c.VerifyIssuedAt(now, false) == false { + vErr.Inner = fmt.Errorf("Token used before issued") + vErr.Errors |= ValidationErrorIssuedAt + } + + if c.VerifyNotBefore(now, false) == false { + vErr.Inner = fmt.Errorf("token is not valid yet") + vErr.Errors |= ValidationErrorNotValidYet + } + + if vErr.valid() { + return nil + } + + return vErr +} + +// Compares the aud claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (c *StandardClaims) VerifyAudience(cmp string, req bool) bool { + return verifyAud(c.Audience, cmp, req) +} + +// Compares the exp claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (c *StandardClaims) VerifyExpiresAt(cmp int64, req bool) bool { + return verifyExp(c.ExpiresAt, cmp, req) +} + +// Compares the iat claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (c *StandardClaims) VerifyIssuedAt(cmp int64, req bool) bool { + return verifyIat(c.IssuedAt, cmp, req) +} + +// Compares the iss claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (c *StandardClaims) VerifyIssuer(cmp string, req bool) bool { + return verifyIss(c.Issuer, cmp, req) +} + +// Compares the nbf claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (c *StandardClaims) VerifyNotBefore(cmp int64, req bool) bool { + return verifyNbf(c.NotBefore, cmp, req) +} + +// ----- helpers + +func verifyAud(aud string, cmp string, required bool) bool { + if aud == "" { + return !required + } + if subtle.ConstantTimeCompare([]byte(aud), []byte(cmp)) != 0 { + return true + } else { + return false + } +} + +func verifyExp(exp int64, now int64, required bool) bool { + if exp == 0 { + return !required + } + return now <= exp +} + +func verifyIat(iat int64, now int64, required bool) bool { + if iat == 0 { + return !required + } + return now >= iat +} + +func verifyIss(iss string, cmp string, required bool) bool { + if iss == "" { + return !required + } + if subtle.ConstantTimeCompare([]byte(iss), []byte(cmp)) != 0 { + return true + } else { + return false + } +} + +func verifyNbf(nbf int64, now int64, required bool) bool { + if nbf == 0 { + return !required + } + return now >= nbf +} diff --git a/Utils/jwt/ecdsa.go b/Utils/jwt/ecdsa.go new file mode 100644 index 00000000..f9773812 --- /dev/null +++ b/Utils/jwt/ecdsa.go @@ -0,0 +1,148 @@ +package jwt + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rand" + "errors" + "math/big" +) + +var ( + // Sadly this is missing from crypto/ecdsa compared to crypto/rsa + ErrECDSAVerification = errors.New("crypto/ecdsa: verification error") +) + +// Implements the ECDSA family of signing methods signing methods +// Expects *ecdsa.PrivateKey for signing and *ecdsa.PublicKey for verification +type SigningMethodECDSA struct { + Name string + Hash crypto.Hash + KeySize int + CurveBits int +} + +// Specific instances for EC256 and company +var ( + SigningMethodES256 *SigningMethodECDSA + SigningMethodES384 *SigningMethodECDSA + SigningMethodES512 *SigningMethodECDSA +) + +func init() { + // ES256 + SigningMethodES256 = &SigningMethodECDSA{"ES256", crypto.SHA256, 32, 256} + RegisterSigningMethod(SigningMethodES256.Alg(), func() SigningMethod { + return SigningMethodES256 + }) + + // ES384 + SigningMethodES384 = &SigningMethodECDSA{"ES384", crypto.SHA384, 48, 384} + RegisterSigningMethod(SigningMethodES384.Alg(), func() SigningMethod { + return SigningMethodES384 + }) + + // ES512 + SigningMethodES512 = &SigningMethodECDSA{"ES512", crypto.SHA512, 66, 521} + RegisterSigningMethod(SigningMethodES512.Alg(), func() SigningMethod { + return SigningMethodES512 + }) +} + +func (m *SigningMethodECDSA) Alg() string { + return m.Name +} + +// Implements the Verify method from SigningMethod +// For this verify method, key must be an ecdsa.PublicKey struct +func (m *SigningMethodECDSA) Verify(signingString, signature string, key interface{}) error { + var err error + + // Decode the signature + var sig []byte + if sig, err = DecodeSegment(signature); err != nil { + return err + } + + // Get the key + var ecdsaKey *ecdsa.PublicKey + switch k := key.(type) { + case *ecdsa.PublicKey: + ecdsaKey = k + default: + return ErrInvalidKeyType + } + + if len(sig) != 2*m.KeySize { + return ErrECDSAVerification + } + + r := big.NewInt(0).SetBytes(sig[:m.KeySize]) + s := big.NewInt(0).SetBytes(sig[m.KeySize:]) + + // Create hasher + if !m.Hash.Available() { + return ErrHashUnavailable + } + hasher := m.Hash.New() + hasher.Write([]byte(signingString)) + + // Verify the signature + if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus == true { + return nil + } else { + return ErrECDSAVerification + } +} + +// Implements the Sign method from SigningMethod +// For this signing method, key must be an ecdsa.PrivateKey struct +func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string, error) { + // Get the key + var ecdsaKey *ecdsa.PrivateKey + switch k := key.(type) { + case *ecdsa.PrivateKey: + ecdsaKey = k + default: + return "", ErrInvalidKeyType + } + + // Create the hasher + if !m.Hash.Available() { + return "", ErrHashUnavailable + } + + hasher := m.Hash.New() + hasher.Write([]byte(signingString)) + + // Sign the string and return r, s + if r, s, err := ecdsa.Sign(rand.Reader, ecdsaKey, hasher.Sum(nil)); err == nil { + curveBits := ecdsaKey.Curve.Params().BitSize + + if m.CurveBits != curveBits { + return "", ErrInvalidKey + } + + keyBytes := curveBits / 8 + if curveBits%8 > 0 { + keyBytes += 1 + } + + // We serialize the outpus (r and s) into big-endian byte arrays and pad + // them with zeros on the left to make sure the sizes work out. Both arrays + // must be keyBytes long, and the output must be 2*keyBytes long. + rBytes := r.Bytes() + rBytesPadded := make([]byte, keyBytes) + copy(rBytesPadded[keyBytes-len(rBytes):], rBytes) + + sBytes := s.Bytes() + sBytesPadded := make([]byte, keyBytes) + copy(sBytesPadded[keyBytes-len(sBytes):], sBytes) + + out := append(rBytesPadded, sBytesPadded...) + + return EncodeSegment(out), nil + } else { + return "", err + } +} diff --git a/Utils/jwt/errors.go b/Utils/jwt/errors.go new file mode 100644 index 00000000..1c93024a --- /dev/null +++ b/Utils/jwt/errors.go @@ -0,0 +1,59 @@ +package jwt + +import ( + "errors" +) + +// Error constants +var ( + ErrInvalidKey = errors.New("key is invalid") + ErrInvalidKeyType = errors.New("key is of invalid type") + ErrHashUnavailable = errors.New("the requested hash function is unavailable") +) + +// The errors that might occur when parsing and validating a token +const ( + ValidationErrorMalformed uint32 = 1 << iota // Token is malformed + ValidationErrorUnverifiable // Token could not be verified because of signing problems + ValidationErrorSignatureInvalid // Signature validation failed + + // Standard Claim validation errors + ValidationErrorAudience // AUD validation failed + ValidationErrorExpired // EXP validation failed + ValidationErrorIssuedAt // IAT validation failed + ValidationErrorIssuer // ISS validation failed + ValidationErrorNotValidYet // NBF validation failed + ValidationErrorId // JTI validation failed + ValidationErrorClaimsInvalid // Generic claims validation error +) + +// Helper for constructing a ValidationError with a string error message +func NewValidationError(errorText string, errorFlags uint32) *ValidationError { + return &ValidationError{ + text: errorText, + Errors: errorFlags, + } +} + +// The error from Parse if token is not valid +type ValidationError struct { + Inner error // stores the error returned by external dependencies, i.e.: KeyFunc + Errors uint32 // bitfield. see ValidationError... constants + text string // errors that do not have a valid error just have text +} + +// Validation error is an error type +func (e ValidationError) Error() string { + if e.Inner != nil { + return e.Inner.Error() + } else if e.text != "" { + return e.text + } else { + return "token is invalid" + } +} + +// No errors +func (e *ValidationError) valid() bool { + return e.Errors == 0 +} diff --git a/Utils/jwt/hmac.go b/Utils/jwt/hmac.go new file mode 100644 index 00000000..addbe5d4 --- /dev/null +++ b/Utils/jwt/hmac.go @@ -0,0 +1,95 @@ +package jwt + +import ( + "crypto" + "crypto/hmac" + "errors" +) + +// Implements the HMAC-SHA family of signing methods signing methods +// Expects key type of []byte for both signing and validation +type SigningMethodHMAC struct { + Name string + Hash crypto.Hash +} + +// Specific instances for HS256 and company +var ( + SigningMethodHS256 *SigningMethodHMAC + SigningMethodHS384 *SigningMethodHMAC + SigningMethodHS512 *SigningMethodHMAC + ErrSignatureInvalid = errors.New("signature is invalid") +) + +func init() { + // HS256 + SigningMethodHS256 = &SigningMethodHMAC{"HS256", crypto.SHA256} + RegisterSigningMethod(SigningMethodHS256.Alg(), func() SigningMethod { + return SigningMethodHS256 + }) + + // HS384 + SigningMethodHS384 = &SigningMethodHMAC{"HS384", crypto.SHA384} + RegisterSigningMethod(SigningMethodHS384.Alg(), func() SigningMethod { + return SigningMethodHS384 + }) + + // HS512 + SigningMethodHS512 = &SigningMethodHMAC{"HS512", crypto.SHA512} + RegisterSigningMethod(SigningMethodHS512.Alg(), func() SigningMethod { + return SigningMethodHS512 + }) +} + +func (m *SigningMethodHMAC) Alg() string { + return m.Name +} + +// Verify the signature of HSXXX tokens. Returns nil if the signature is valid. +func (m *SigningMethodHMAC) Verify(signingString, signature string, key interface{}) error { + // Verify the key is the right type + keyBytes, ok := key.([]byte) + if !ok { + return ErrInvalidKeyType + } + + // Decode signature, for comparison + sig, err := DecodeSegment(signature) + if err != nil { + return err + } + + // Can we use the specified hashing method? + if !m.Hash.Available() { + return ErrHashUnavailable + } + + // This signing method is symmetric, so we validate the signature + // by reproducing the signature from the signing string and key, then + // comparing that against the provided signature. + hasher := hmac.New(m.Hash.New, keyBytes) + hasher.Write([]byte(signingString)) + if !hmac.Equal(sig, hasher.Sum(nil)) { + return ErrSignatureInvalid + } + + // No validation errors. Signature is good. + return nil +} + +// Implements the Sign method from SigningMethod for this signing method. +// Key must be []byte +func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) (string, error) { + if keyBytes, ok := key.([]byte); ok { + if !m.Hash.Available() { + return "", ErrHashUnavailable + } + + hasher := hmac.New(m.Hash.New, keyBytes) + hasher.Write([]byte(signingString)) + + return EncodeSegment(hasher.Sum(nil)), nil + } + + return "", ErrInvalidKeyType +} diff --git a/Utils/jwt/jwt.go b/Utils/jwt/jwt.go new file mode 100644 index 00000000..30f2365c --- /dev/null +++ b/Utils/jwt/jwt.go @@ -0,0 +1,51 @@ +package jwt + +import ( + "time" +) + +type jwtCustomClaims struct { + StandardClaims + + // addition + Uid uint `json:"uid"` + Admin bool `json:"admin"` +} + +/** + * 生成 token + * SecretKey 是一个 const 常量 + */ +func CreateToken(SecretKey []byte, userName string, Uid uint) (tokenString string, err error) { + claims := &jwtCustomClaims{ + StandardClaims{ + ExpiresAt: int64(time.Now().Add(time.Hour * 72).Unix()), + Issuer: userName, + }, + Uid, + false, + } + token := NewWithClaims(SigningMethodHS256, claims) + tokenString, err = token.SignedString(SecretKey) + return +} + +/** + * 生成自定义Claims token + * SecretKey []byte("Your Secret Key") + * customClaims + */ +func CreateCustomToken(SecretKey []byte, customClaims Claims) (tokenString string, err error) { + token := NewWithClaims(SigningMethodHS256, customClaims) + tokenString, err = token.SignedString(SecretKey) + return +} + +func ParseToken(tokenSrt string, SecretKey []byte) (claims Claims, err error) { + var token *Token + token, err = Parse(tokenSrt, func(*Token) (interface{}, error) { + return SecretKey, nil + }) + claims = token.Claims + return +} diff --git a/Utils/jwt/map_claims.go b/Utils/jwt/map_claims.go new file mode 100644 index 00000000..291213c4 --- /dev/null +++ b/Utils/jwt/map_claims.go @@ -0,0 +1,94 @@ +package jwt + +import ( + "encoding/json" + "errors" + // "fmt" +) + +// Claims type that uses the map[string]interface{} for JSON decoding +// This is the default claims type if you don't supply one +type MapClaims map[string]interface{} + +// Compares the aud claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (m MapClaims) VerifyAudience(cmp string, req bool) bool { + aud, _ := m["aud"].(string) + return verifyAud(aud, cmp, req) +} + +// Compares the exp claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool { + switch exp := m["exp"].(type) { + case float64: + return verifyExp(int64(exp), cmp, req) + case json.Number: + v, _ := exp.Int64() + return verifyExp(v, cmp, req) + } + return req == false +} + +// Compares the iat claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool { + switch iat := m["iat"].(type) { + case float64: + return verifyIat(int64(iat), cmp, req) + case json.Number: + v, _ := iat.Int64() + return verifyIat(v, cmp, req) + } + return req == false +} + +// Compares the iss claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (m MapClaims) VerifyIssuer(cmp string, req bool) bool { + iss, _ := m["iss"].(string) + return verifyIss(iss, cmp, req) +} + +// Compares the nbf claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool { + switch nbf := m["nbf"].(type) { + case float64: + return verifyNbf(int64(nbf), cmp, req) + case json.Number: + v, _ := nbf.Int64() + return verifyNbf(v, cmp, req) + } + return req == false +} + +// Validates time based claims "exp, iat, nbf". +// There is no accounting for clock skew. +// As well, if any of the above claims are not in the token, it will still +// be considered a valid claim. +func (m MapClaims) Valid() error { + vErr := new(ValidationError) + now := TimeFunc().Unix() + + if m.VerifyExpiresAt(now, false) == false { + vErr.Inner = errors.New("Token is expired") + vErr.Errors |= ValidationErrorExpired + } + + if m.VerifyIssuedAt(now, false) == false { + vErr.Inner = errors.New("Token used before issued") + vErr.Errors |= ValidationErrorIssuedAt + } + + if m.VerifyNotBefore(now, false) == false { + vErr.Inner = errors.New("Token is not valid yet") + vErr.Errors |= ValidationErrorNotValidYet + } + + if vErr.valid() { + return nil + } + + return vErr +} diff --git a/Utils/jwt/parser.go b/Utils/jwt/parser.go new file mode 100644 index 00000000..d6901d9a --- /dev/null +++ b/Utils/jwt/parser.go @@ -0,0 +1,148 @@ +package jwt + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" +) + +type Parser struct { + ValidMethods []string // If populated, only these methods will be considered valid + UseJSONNumber bool // Use JSON Number format in JSON decoder + SkipClaimsValidation bool // Skip claims validation during token parsing +} + +// Parse, validate, and return a token. +// keyFunc will receive the parsed token and should return the key for validating. +// If everything is kosher, err will be nil +func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { + return p.ParseWithClaims(tokenString, MapClaims{}, keyFunc) +} + +func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) { + token, parts, err := p.ParseUnverified(tokenString, claims) + if err != nil { + return token, err + } + + // Verify signing method is in the required set + if p.ValidMethods != nil { + var signingMethodValid = false + var alg = token.Method.Alg() + for _, m := range p.ValidMethods { + if m == alg { + signingMethodValid = true + break + } + } + if !signingMethodValid { + // signing method is not in the listed set + return token, NewValidationError(fmt.Sprintf("signing method %v is invalid", alg), ValidationErrorSignatureInvalid) + } + } + + // Lookup key + var key interface{} + if keyFunc == nil { + // keyFunc was not provided. short circuiting validation + return token, NewValidationError("no Keyfunc was provided.", ValidationErrorUnverifiable) + } + if key, err = keyFunc(token); err != nil { + // keyFunc returned an error + if ve, ok := err.(*ValidationError); ok { + return token, ve + } + return token, &ValidationError{Inner: err, Errors: ValidationErrorUnverifiable} + } + + vErr := &ValidationError{} + + // Validate Claims + if !p.SkipClaimsValidation { + if err := token.Claims.Valid(); err != nil { + + // If the Claims Valid returned an error, check if it is a validation error, + // If it was another error type, create a ValidationError with a generic ClaimsInvalid flag set + if e, ok := err.(*ValidationError); !ok { + vErr = &ValidationError{Inner: err, Errors: ValidationErrorClaimsInvalid} + } else { + vErr = e + } + } + } + + // Perform validation + token.Signature = parts[2] + if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil { + vErr.Inner = err + vErr.Errors |= ValidationErrorSignatureInvalid + } + + if vErr.valid() { + token.Valid = true + return token, nil + } + + return token, vErr +} + +// WARNING: Don't use this method unless you know what you're doing +// +// This method parses the token but doesn't validate the signature. It's only +// ever useful in cases where you know the signature is valid (because it has +// been checked previously in the stack) and you want to extract values from +// it. +func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) { + parts = strings.Split(tokenString, ".") + if len(parts) != 3 { + return nil, parts, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed) + } + + token = &Token{Raw: tokenString} + + // parse Header + var headerBytes []byte + if headerBytes, err = DecodeSegment(parts[0]); err != nil { + if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") { + return token, parts, NewValidationError("tokenstring should not contain 'bearer '", ValidationErrorMalformed) + } + return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} + } + if err = json.Unmarshal(headerBytes, &token.Header); err != nil { + return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} + } + + // parse Claims + var claimBytes []byte + token.Claims = claims + + if claimBytes, err = DecodeSegment(parts[1]); err != nil { + return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} + } + dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) + if p.UseJSONNumber { + dec.UseNumber() + } + // JSON Decode. Special case for map type to avoid weird pointer behavior + if c, ok := token.Claims.(MapClaims); ok { + err = dec.Decode(&c) + } else { + err = dec.Decode(&claims) + } + // Handle decode error + if err != nil { + return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} + } + + // Lookup signature method + if method, ok := token.Header["alg"].(string); ok { + if token.Method = GetSigningMethod(method); token.Method == nil { + return token, parts, NewValidationError("signing method (alg) is unavailable.", ValidationErrorUnverifiable) + } + } else { + return token, parts, NewValidationError("signing method (alg) is unspecified.", ValidationErrorUnverifiable) + } + + return token, parts, nil +} diff --git a/Utils/jwt/rsa.go b/Utils/jwt/rsa.go new file mode 100644 index 00000000..e4caf1ca --- /dev/null +++ b/Utils/jwt/rsa.go @@ -0,0 +1,101 @@ +package jwt + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" +) + +// Implements the RSA family of signing methods signing methods +// Expects *rsa.PrivateKey for signing and *rsa.PublicKey for validation +type SigningMethodRSA struct { + Name string + Hash crypto.Hash +} + +// Specific instances for RS256 and company +var ( + SigningMethodRS256 *SigningMethodRSA + SigningMethodRS384 *SigningMethodRSA + SigningMethodRS512 *SigningMethodRSA +) + +func init() { + // RS256 + SigningMethodRS256 = &SigningMethodRSA{"RS256", crypto.SHA256} + RegisterSigningMethod(SigningMethodRS256.Alg(), func() SigningMethod { + return SigningMethodRS256 + }) + + // RS384 + SigningMethodRS384 = &SigningMethodRSA{"RS384", crypto.SHA384} + RegisterSigningMethod(SigningMethodRS384.Alg(), func() SigningMethod { + return SigningMethodRS384 + }) + + // RS512 + SigningMethodRS512 = &SigningMethodRSA{"RS512", crypto.SHA512} + RegisterSigningMethod(SigningMethodRS512.Alg(), func() SigningMethod { + return SigningMethodRS512 + }) +} + +func (m *SigningMethodRSA) Alg() string { + return m.Name +} + +// Implements the Verify method from SigningMethod +// For this signing method, must be an *rsa.PublicKey structure. +func (m *SigningMethodRSA) Verify(signingString, signature string, key interface{}) error { + var err error + + // Decode the signature + var sig []byte + if sig, err = DecodeSegment(signature); err != nil { + return err + } + + var rsaKey *rsa.PublicKey + var ok bool + + if rsaKey, ok = key.(*rsa.PublicKey); !ok { + return ErrInvalidKeyType + } + + // Create hasher + if !m.Hash.Available() { + return ErrHashUnavailable + } + hasher := m.Hash.New() + hasher.Write([]byte(signingString)) + + // Verify the signature + return rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig) +} + +// Implements the Sign method from SigningMethod +// For this signing method, must be an *rsa.PrivateKey structure. +func (m *SigningMethodRSA) Sign(signingString string, key interface{}) (string, error) { + var rsaKey *rsa.PrivateKey + var ok bool + + // Validate type of key + if rsaKey, ok = key.(*rsa.PrivateKey); !ok { + return "", ErrInvalidKey + } + + // Create the hasher + if !m.Hash.Available() { + return "", ErrHashUnavailable + } + + hasher := m.Hash.New() + hasher.Write([]byte(signingString)) + + // Sign the string and return the encoded bytes + if sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, m.Hash, hasher.Sum(nil)); err == nil { + return EncodeSegment(sigBytes), nil + } else { + return "", err + } +} diff --git a/Utils/jwt/signing_method.go b/Utils/jwt/signing_method.go new file mode 100644 index 00000000..ed1f212b --- /dev/null +++ b/Utils/jwt/signing_method.go @@ -0,0 +1,35 @@ +package jwt + +import ( + "sync" +) + +var signingMethods = map[string]func() SigningMethod{} +var signingMethodLock = new(sync.RWMutex) + +// Implement SigningMethod to add new methods for signing or verifying tokens. +type SigningMethod interface { + Verify(signingString, signature string, key interface{}) error // Returns nil if signature is valid + Sign(signingString string, key interface{}) (string, error) // Returns encoded signature or error + Alg() string // returns the alg identifier for this method (example: 'HS256') +} + +// Register the "alg" name and a factory function for signing method. +// This is typically done during init() in the method's implementation +func RegisterSigningMethod(alg string, f func() SigningMethod) { + signingMethodLock.Lock() + defer signingMethodLock.Unlock() + + signingMethods[alg] = f +} + +// Get a signing method from an "alg" string +func GetSigningMethod(alg string) (method SigningMethod) { + signingMethodLock.RLock() + defer signingMethodLock.RUnlock() + + if methodF, ok := signingMethods[alg]; ok { + method = methodF() + } + return +} diff --git a/Utils/jwt/token.go b/Utils/jwt/token.go new file mode 100644 index 00000000..d637e086 --- /dev/null +++ b/Utils/jwt/token.go @@ -0,0 +1,108 @@ +package jwt + +import ( + "encoding/base64" + "encoding/json" + "strings" + "time" +) + +// TimeFunc provides the current time when parsing token to validate "exp" claim (expiration time). +// You can override it to use another time value. This is useful for testing or if your +// server uses a different time zone than your tokens. +var TimeFunc = time.Now + +// Parse methods use this callback function to supply +// the key for verification. The function receives the parsed, +// but unverified Token. This allows you to use properties in the +// Header of the token (such as `kid`) to identify which key to use. +type Keyfunc func(*Token) (interface{}, error) + +// A JWT Token. Different fields will be used depending on whether you're +// creating or parsing/verifying a token. +type Token struct { + Raw string // The raw token. Populated when you Parse a token + Method SigningMethod // The signing method used or to be used + Header map[string]interface{} // The first segment of the token + Claims Claims // The second segment of the token + Signature string // The third segment of the token. Populated when you Parse a token + Valid bool // Is the token valid? Populated when you Parse/Verify a token +} + +// Create a new Token. Takes a signing method +func New(method SigningMethod) *Token { + return NewWithClaims(method, MapClaims{}) +} + +func NewWithClaims(method SigningMethod, claims Claims) *Token { + return &Token{ + Header: map[string]interface{}{ + "typ": "JWT", + "alg": method.Alg(), + }, + Claims: claims, + Method: method, + } +} + +// Get the complete, signed token +func (t *Token) SignedString(key interface{}) (string, error) { + var sig, sstr string + var err error + if sstr, err = t.SigningString(); err != nil { + return "", err + } + if sig, err = t.Method.Sign(sstr, key); err != nil { + return "", err + } + return strings.Join([]string{sstr, sig}, "."), nil +} + +// Generate the signing string. This is the +// most expensive part of the whole deal. Unless you +// need this for something special, just go straight for +// the SignedString. +func (t *Token) SigningString() (string, error) { + var err error + parts := make([]string, 2) + for i, _ := range parts { + var jsonValue []byte + if i == 0 { + if jsonValue, err = json.Marshal(t.Header); err != nil { + return "", err + } + } else { + if jsonValue, err = json.Marshal(t.Claims); err != nil { + return "", err + } + } + + parts[i] = EncodeSegment(jsonValue) + } + return strings.Join(parts, "."), nil +} + +// Parse, validate, and return a token. +// keyFunc will receive the parsed token and should return the key for validating. +// If everything is kosher, err will be nil +func Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { + return new(Parser).Parse(tokenString, keyFunc) +} + +func ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) { + return new(Parser).ParseWithClaims(tokenString, claims, keyFunc) +} + +// Encode JWT specific base64url encoding with padding stripped +func EncodeSegment(seg []byte) string { + return strings.TrimRight(base64.URLEncoding.EncodeToString(seg), "=") +} + +// Decode JWT specific base64url encoding with padding stripped +func DecodeSegment(seg string) ([]byte, error) { + if l := len(seg) % 4; l > 0 { + seg += strings.Repeat("=", 4-l) + } + + return base64.URLEncoding.DecodeString(seg) +} diff --git a/go.sum b/go.sum index ea145e64..1389bb2c 100644 --- a/go.sum +++ b/go.sum @@ -188,8 +188,6 @@ github.com/spf13/jwalterweatherman v1.0.0 h1:XHEdyB+EcvlqZamSM4ZOMGlc93t6AcsBEu9 github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= github.com/spf13/pflag v1.0.3 h1:zPAT6CGy6wXeQ7NtTnaTerfKOsV6V6F8agHXFiazDkg= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= -github.com/spf13/viper v1.7.0 h1:xVKxvI7ouOI5I+U9s2eeiUfMaWBVoXA3AWskkrqK0VM= -github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg= github.com/spf13/viper v1.7.1 h1:pM5oEahlgWv/WnHXpgbKz7iLIxRf65tye2Ci+XFK5sk= github.com/spf13/viper v1.7.1/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= From a06cd918d3b34eb423ea498617064ae95be70aff Mon Sep 17 00:00:00 2001 From: yoyofx Date: Thu, 20 Aug 2020 17:56:42 +0800 Subject: [PATCH 2/2] =?UTF-8?q?1.=20=E4=BF=AE=E5=A4=8Dpprof=20endpoint?= =?UTF-8?q?=E4=B8=8D=E8=83=BD=E8=AE=BF=E9=97=AE=E7=9A=84=E9=97=AE=E9=A2=98?= =?UTF-8?q?=202.=20=E6=B7=BB=E5=8A=A0RequestId=E4=B8=AD=E9=97=B4=E4=BB=B6,?= =?UTF-8?q?=E5=9C=A8Header=E4=B8=AD=E5=8A=A0=E5=85=A5x-request-id=E5=AD=97?= =?UTF-8?q?=E6=AE=B5.=203.=20=E4=B8=BA=E4=B8=AD=E9=97=B4=E4=BB=B6=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0SetConfiguration=E5=87=BD=E6=95=B0=E7=94=A8=E4=BA=8E?= =?UTF-8?q?=E8=AF=BB=E5=8F=96=E9=85=8D=E7=BD=AE=E4=BF=A1=E6=81=AF.=204.=20?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0JWT=E9=AA=8C=E8=AF=81=E4=B8=AD=E9=97=B4?= =?UTF-8?q?=E4=BB=B6=E5=92=8C/auth/token=E8=8E=B7=E5=8F=96jwt=20string?= =?UTF-8?q?=E7=9A=84endpoint.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Abstractions/HostBuilder.go | 2 +- Abstractions/xlog/XLogger.go | 12 ++- Examples/SimpleWeb/config_dev.yml | 11 +++ Examples/SimpleWeb/config_prod.yml | 6 ++ .../SimpleWeb/contollers/usercontroller.go | 7 +- Examples/SimpleWeb/main.go | 13 +-- Test/jwt_test.go | 5 +- Utils/StringHelper.go | 22 ++++- Utils/jwt/jwt.go | 13 ++- Version.go | 2 +- WebFramework/Context/HttpContext.go | 13 ++- WebFramework/Context/httpcontext_output.go | 7 +- WebFramework/Context/responseWriter.go | 46 +++++----- WebFramework/Endpoints/health_endpoint.go | 2 +- WebFramework/Endpoints/jwt_endpoint.go | 64 ++++++++++++++ WebFramework/Endpoints/pprof_endpoint.go | 38 ++++---- WebFramework/Middleware.go | 12 +-- WebFramework/Middleware/BaseMiddleware.go | 16 ++++ WebFramework/Middleware/JWTMiddleware.go | 86 +++++++++++++++++++ WebFramework/Middleware/RecoveryMiddleware.go | 2 +- .../Middleware/RequestIDMiddleware.go | 24 ++++++ WebFramework/Mvc/MvcRouterHandler.go | 2 +- WebFramework/Router/DefaultRouterBuilder.go | 10 +++ WebFramework/Router/IRouterBuilder.go | 5 ++ WebFramework/WebApplicationBuilder.go | 17 ++-- go.mod | 1 + go.sum | 2 + 27 files changed, 351 insertions(+), 89 deletions(-) create mode 100644 WebFramework/Endpoints/jwt_endpoint.go create mode 100644 WebFramework/Middleware/BaseMiddleware.go create mode 100644 WebFramework/Middleware/JWTMiddleware.go create mode 100644 WebFramework/Middleware/RequestIDMiddleware.go diff --git a/Abstractions/HostBuilder.go b/Abstractions/HostBuilder.go index 46d8cd3f..62d13741 100644 --- a/Abstractions/HostBuilder.go +++ b/Abstractions/HostBuilder.go @@ -133,7 +133,7 @@ func (host *HostBuilder) Build() IServiceHost { host.Context.ApplicationServicesDef = services applicationBuilder.SetHostBuildContext(host.Context) host.Context.HostServices = services.Build() //serviceProvider - host.Context.RequestDelegate = applicationBuilder.Build() // ServeHTTP(w http.ResponseWriter, r *http.Request) + host.Context.RequestDelegate = applicationBuilder.Build() // ServeHTTP(w http.IResponseWriter, r *http.Request) host.Context.ApplicationServices = services.Build() //serviceProvider if host.lifeConfigure != nil { diff --git a/Abstractions/xlog/XLogger.go b/Abstractions/xlog/XLogger.go index 194cc794..1a9b2916 100644 --- a/Abstractions/xlog/XLogger.go +++ b/Abstractions/xlog/XLogger.go @@ -69,9 +69,7 @@ func GetXLogger(class string) ILogger { func (log *XLogger) log(level LogLevel, format string, a ...interface{}) { hostName, _ := os.Hostname() message := format - if len(a[0].([]interface{})) > 0 { - message = fmt.Sprintf(format, a...) - } + message = fmt.Sprintf(format, a...) start := time.Now() info := LogInfo{ @@ -85,17 +83,17 @@ func (log *XLogger) log(level LogLevel, format string, a ...interface{}) { } func (log *XLogger) Debug(format string, a ...interface{}) { - log.log(DEBUG, format, a) + log.log(DEBUG, format, a...) } func (log *XLogger) Info(format string, a ...interface{}) { - log.log(INFO, format, a) + log.log(INFO, format, a...) } func (log *XLogger) Warning(format string, a ...interface{}) { - log.log(WARNING, format, a) + log.log(WARNING, format, a...) } func (log *XLogger) Error(format string, a ...interface{}) { - log.log(ERROR, format, a) + log.log(ERROR, format, a...) } diff --git a/Examples/SimpleWeb/config_dev.yml b/Examples/SimpleWeb/config_dev.yml index 268acf0b..e2d661ca 100644 --- a/Examples/SimpleWeb/config_dev.yml +++ b/Examples/SimpleWeb/config_dev.yml @@ -8,3 +8,14 @@ application: static: patten: "/" webroot: "./Static" + jwt: + header: "Authorization" + secret: "12391JdeOW^%$#@" + prefix: "Bearer" + expires: 3 + enable: true + skip_path: [ + "/info", + "/v1/user/GetInfo" + ] + diff --git a/Examples/SimpleWeb/config_prod.yml b/Examples/SimpleWeb/config_prod.yml index 0185ecb0..98eff047 100644 --- a/Examples/SimpleWeb/config_prod.yml +++ b/Examples/SimpleWeb/config_prod.yml @@ -8,3 +8,9 @@ application: static: patten: "/" webroot: "./Static" + jwt: + header: "Authorization" + secret: "12391JdeOW^%$#@" + prefix: "Bearer" + expires: 3 + enable: true \ No newline at end of file diff --git a/Examples/SimpleWeb/contollers/usercontroller.go b/Examples/SimpleWeb/contollers/usercontroller.go index 01773175..50f08bc2 100644 --- a/Examples/SimpleWeb/contollers/usercontroller.go +++ b/Examples/SimpleWeb/contollers/usercontroller.go @@ -34,9 +34,12 @@ func (controller UserController) GetUserName(ctx *Context.HttpContext, request * return ActionResult.Json{Data: result} } -func (controller UserController) PostUserInfo(request *RegisterRequest) ActionResult.IActionResult { +func (controller UserController) PostUserInfo(ctx *Context.HttpContext, request *RegisterRequest) ActionResult.IActionResult { - return ActionResult.Json{Data: Mvc.ApiResult{Success: true, Message: "ok", Data: request}} + return ActionResult.Json{Data: Mvc.ApiResult{Success: true, Message: "ok", Data: Context.H{ + "user": ctx.GetUser(), + "request": request, + }}} } func (controller UserController) GetHtmlHello() ActionResult.IActionResult { diff --git a/Examples/SimpleWeb/main.go b/Examples/SimpleWeb/main.go index 4e796257..7363e82e 100644 --- a/Examples/SimpleWeb/main.go +++ b/Examples/SimpleWeb/main.go @@ -10,6 +10,7 @@ import ( "github.com/yoyofx/yoyogo/WebFramework" "github.com/yoyofx/yoyogo/WebFramework/Context" "github.com/yoyofx/yoyogo/WebFramework/Endpoints" + "github.com/yoyofx/yoyogo/WebFramework/Middleware" "github.com/yoyofx/yoyogo/WebFramework/Mvc" "github.com/yoyofx/yoyogo/WebFramework/Router" ) @@ -19,7 +20,7 @@ func SimpleDemo() { Endpoints.UsePrometheus(router) router.GET("/info", func(ctx *Context.HttpContext) { - ctx.JSON(200, Context.M{"info": "ok"}) + ctx.JSON(200, Context.H{"info": "ok"}) }) }).Build().Run() } @@ -37,6 +38,7 @@ func CreateCustomBuilder() *Abstractions.HostBuilder { return YoyoGo.NewWebHostBuilder(). UseConfiguration(configuration). Configure(func(app *YoyoGo.WebApplicationBuilder) { + app.UseMiddleware(Middleware.NewRequestID()) app.UseStaticAssets() app.UseEndpoints(registerEndpointRouterConfig) app.UseMvc(func(builder *Mvc.ControllerBuilder) { @@ -58,7 +60,8 @@ func registerEndpointRouterConfig(router Router.IRouterBuilder) { Endpoints.UseHealth(router) Endpoints.UseViz(router) Endpoints.UsePrometheus(router) - //Endpoints.UsePprof(router) + Endpoints.UsePprof(router) + Endpoints.UseJwt(router) router.GET("/error", func(ctx *Context.HttpContext) { panic("http get error") @@ -87,13 +90,13 @@ type UserInfo struct { //HttpGet request: /info or /v1/api/info //bind UserInfo for id,q1,username func GetInfo(ctx *Context.HttpContext) { - ctx.JSON(200, Context.M{"info": "ok"}) + ctx.JSON(200, Context.H{"info": "ok"}) } func GetInfoByIOC(ctx *Context.HttpContext) { var userAction models.IUserAction _ = ctx.RequiredServices.GetService(&userAction) - ctx.JSON(200, Context.M{"info": "ok " + userAction.Login("zhang")}) + ctx.JSON(200, Context.H{"info": "ok " + userAction.Login("zhang")}) } //HttpPost request: /info/:id ?q1=abc&username=123 @@ -106,7 +109,7 @@ func PostInfo(ctx *Context.HttpContext) { strResult := fmt.Sprintf("Name:%s , Q1:%s , bind: %s , routeData id:%s", pd_name, qs_q1, userInfo, id) - ctx.JSON(200, Context.M{"info": "hello world", "result": strResult}) + ctx.JSON(200, Context.H{"info": "hello world", "result": strResult}) } func getApplicationLifeEvent(life *Abstractions.ApplicationLife) { diff --git a/Test/jwt_test.go b/Test/jwt_test.go index be776d37..53ba592d 100644 --- a/Test/jwt_test.go +++ b/Test/jwt_test.go @@ -5,11 +5,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/yoyofx/yoyogo/Utils/jwt" "testing" + "time" ) func TestCreateToken(t *testing.T) { - SecretKey := []byte("AllYourBase") - token, _ := jwt.CreateToken(SecretKey, "YDQ", 2222) + SecretKey := []byte("12391JdeOW^%$#@") + token, _ := jwt.CreateToken(SecretKey, "YDQ", 2222, int64(time.Now().Add(time.Hour*72).Unix())) fmt.Println(token) claims, err := jwt.ParseToken(token, SecretKey) diff --git a/Utils/StringHelper.go b/Utils/StringHelper.go index a4676ad3..4935a0d5 100644 --- a/Utils/StringHelper.go +++ b/Utils/StringHelper.go @@ -1,6 +1,9 @@ package Utils -import "unicode" +import ( + "reflect" + "unicode" +) func PadLeft(s string, pad string, plength int) string { for i := len(s); i < plength; i++ { @@ -22,3 +25,20 @@ func LowercaseFirst(str string) string { } return "" } + +func Contains(obj interface{}, target interface{}) bool { + targetValue := reflect.ValueOf(target) + switch reflect.TypeOf(target).Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < targetValue.Len(); i++ { + if targetValue.Index(i).Interface() == obj { + return true + } + } + case reflect.Map: + if targetValue.MapIndex(reflect.ValueOf(obj)).IsValid() { + return true + } + } + return false +} diff --git a/Utils/jwt/jwt.go b/Utils/jwt/jwt.go index 30f2365c..7f70efb1 100644 --- a/Utils/jwt/jwt.go +++ b/Utils/jwt/jwt.go @@ -1,9 +1,5 @@ package jwt -import ( - "time" -) - type jwtCustomClaims struct { StandardClaims @@ -16,18 +12,19 @@ type jwtCustomClaims struct { * 生成 token * SecretKey 是一个 const 常量 */ -func CreateToken(SecretKey []byte, userName string, Uid uint) (tokenString string, err error) { +func CreateToken(SecretKey []byte, userName string, Uid uint, expiresAt int64) (string, int64) { claims := &jwtCustomClaims{ StandardClaims{ - ExpiresAt: int64(time.Now().Add(time.Hour * 72).Unix()), + ExpiresAt: expiresAt, Issuer: userName, }, Uid, false, } + token := NewWithClaims(SigningMethodHS256, claims) - tokenString, err = token.SignedString(SecretKey) - return + tokenString, _ := token.SignedString(SecretKey) + return tokenString, claims.ExpiresAt } /** diff --git a/Version.go b/Version.go index 1a910c40..914d80e3 100644 --- a/Version.go +++ b/Version.go @@ -2,7 +2,7 @@ package YoyoGo const ( //Application Version, such as v1.x.x pre-release - Version = "v1.5.2.release" + Version = "v1.5.1.2.release" //Application logo Logo = "IF8gICAgIF8gICAgICAgICAgICAgICAgICAgIF9fXyAgICAgICAgICAKKCApICAgKCApICAgICAgICAgICAgICAgICAgKCAgX2BcICAgICAgICAKYFxgXF8vJy8nXyAgICBfICAgXyAgICBfICAgfCAoIChfKSAgIF8gICAKICBgXCAvJy8nX2BcICggKSAoICkgLydfYFwgfCB8X19fICAvJ19gXCAKICAgfCB8KCAoXykgKXwgKF8pIHwoIChfKSApfCAoXywgKSggKF8pICkKICAgKF8pYFxfX18vJ2BcX18sIHxgXF9fXy8nKF9fX18vJ2BcX19fLycKICAgICAgICAgICAgICggKV98IHwgICAgICAgICAgICAgICAgICAgICAKICAgICAgICAgICAgIGBcX19fLycgICAgICAgICAgICBMaWdodCBhbmQgZmFzdC4gIA==" ) diff --git a/WebFramework/Context/HttpContext.go b/WebFramework/Context/HttpContext.go index 88fa2107..d2a19ba6 100644 --- a/WebFramework/Context/HttpContext.go +++ b/WebFramework/Context/HttpContext.go @@ -18,7 +18,7 @@ const ( ) -type M = map[string]string +type H = map[string]interface{} type HttpContext struct { Input Input @@ -38,7 +38,7 @@ func NewContext(w http.ResponseWriter, r *http.Request, sp DependencyInjection.I func (ctx *HttpContext) init(w http.ResponseWriter, r *http.Request, sp DependencyInjection.IServiceProvider) { ctx.storeMutex = new(sync.RWMutex) ctx.Input = NewInput(r, 20<<32) - ctx.Output = Output{Response: &responseWriter{w, 0, 0, nil}} + ctx.Output = Output{Response: &CResponseWriter{w, 0, 0, nil}} ctx.RequiredServices = sp ctx.storeMutex.Lock() ctx.store = nil @@ -63,6 +63,15 @@ func (ctx *HttpContext) GetItem(key string) interface{} { return v } +// Get JWT UserInfo +func (ctx *HttpContext) GetUser() map[string]interface{} { + v := ctx.GetItem("userinfo") + if v != nil { + return v.(map[string]interface{}) + } + return nil +} + func (ctx *HttpContext) Bind(i interface{}) (err error) { req := ctx.Input.Request contentType := req.Header.Get(HeaderContentType) diff --git a/WebFramework/Context/httpcontext_output.go b/WebFramework/Context/httpcontext_output.go index 5d43365f..3e3af400 100644 --- a/WebFramework/Context/httpcontext_output.go +++ b/WebFramework/Context/httpcontext_output.go @@ -3,7 +3,7 @@ package Context import "net/http" type Output struct { - Response *responseWriter + Response IResponseWriter } //Set Cookie value @@ -22,7 +22,7 @@ func (output Output) Status() int { return output.Response.Status() } -func (output Output) GetWriter() *responseWriter { +func (output Output) GetWriter() IResponseWriter { return output.Response } @@ -31,11 +31,10 @@ func (output Output) SetStatus(status int) { } func (output Output) SetStatusCode(status int) { - output.Response.SetStatus(status) + output.Response.WriteHeader(status) } func (output Output) SetStatusCodeNow() { - output.Response.WriteHeaderNow() } // Write Byte[] Response. diff --git a/WebFramework/Context/responseWriter.go b/WebFramework/Context/responseWriter.go index df977664..2b345403 100644 --- a/WebFramework/Context/responseWriter.go +++ b/WebFramework/Context/responseWriter.go @@ -7,29 +7,29 @@ import ( "net/http" ) -// ResponseWriter is a wrapper around http.ResponseWriter that provides extra information about +// IResponseWriter is a wrapper around http.ResponseWriter that provides extra information about // the response. It is recommended that middleware handlers use this construct to wrap a responsewriter // if the functionality calls for it. -type ResponseWriter interface { +type IResponseWriter interface { http.ResponseWriter http.Flusher // Status returns the status code of the response or 0 if the response has // not been written Status() int - // Written returns whether or not the ResponseWriter has been written. + // Written returns whether or not the IResponseWriter has been written. Written() bool // Size returns the size of the response body. Size() int - // Before allows for a function to be called before the ResponseWriter has been written to. This is + // Before allows for a function to be called before the IResponseWriter has been written to. This is // useful for setting headers or any other operations that must happen before a response has been written. - Before(func(ResponseWriter)) + Before(func(IResponseWriter)) } -type beforeFunc func(ResponseWriter) +type beforeFunc func(IResponseWriter) -// NewResponseWriter creates a ResponseWriter that wraps an http.ResponseWriter -func NewResponseWriter(rw http.ResponseWriter) ResponseWriter { - nrw := &responseWriter{ +// NewResponseWriter creates a IResponseWriter that wraps an http.ResponseWriter +func NewResponseWriter(rw http.ResponseWriter) IResponseWriter { + nrw := &CResponseWriter{ ResponseWriter: rw, } @@ -40,31 +40,31 @@ func NewResponseWriter(rw http.ResponseWriter) ResponseWriter { return nrw } -type responseWriter struct { +type CResponseWriter struct { http.ResponseWriter status int size int beforeFuncs []beforeFunc } -func (rw *responseWriter) SetStatus(code int) { +func (rw *CResponseWriter) SetStatus(code int) { rw.status = code } -func (rw *responseWriter) WriteHeader(s int) { +func (rw *CResponseWriter) WriteHeader(s int) { rw.status = s rw.callBefore() rw.ResponseWriter.WriteHeader(s) } -func (w *responseWriter) WriteHeaderNow() { +func (w *CResponseWriter) WriteHeaderNow() { if !w.Written() { w.size = 0 w.ResponseWriter.WriteHeader(w.status) } } -func (rw *responseWriter) Write(b []byte) (int, error) { +func (rw *CResponseWriter) Write(b []byte) (int, error) { //if !rw.Written() { // // The status will be StatusOK if WriteHeader has not been called yet // rw.WriteHeader(http.StatusOK) @@ -74,37 +74,37 @@ func (rw *responseWriter) Write(b []byte) (int, error) { return size, err } -func (rw *responseWriter) Status() int { +func (rw *CResponseWriter) Status() int { return rw.status } -func (rw *responseWriter) Size() int { +func (rw *CResponseWriter) Size() int { return rw.size } -func (rw *responseWriter) Written() bool { +func (rw *CResponseWriter) Written() bool { return rw.status != 0 } -func (rw *responseWriter) Before(before func(ResponseWriter)) { +func (rw *CResponseWriter) Before(before func(IResponseWriter)) { rw.beforeFuncs = append(rw.beforeFuncs, before) } -func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { +func (rw *CResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { hijacker, ok := rw.ResponseWriter.(http.Hijacker) if !ok { - return nil, nil, errors.New("the ResponseWriter doesn't support the Hijacker interface") + return nil, nil, errors.New("the IResponseWriter doesn't support the Hijacker interface") } return hijacker.Hijack() } -func (rw *responseWriter) callBefore() { +func (rw *CResponseWriter) callBefore() { for i := len(rw.beforeFuncs) - 1; i >= 0; i-- { rw.beforeFuncs[i](rw) } } -func (rw *responseWriter) Flush() { +func (rw *CResponseWriter) Flush() { flusher, ok := rw.ResponseWriter.(http.Flusher) if ok { if !rw.Written() { @@ -116,7 +116,7 @@ func (rw *responseWriter) Flush() { } type responseWriterCloseNotifer struct { - *responseWriter + *CResponseWriter } func (rw *responseWriterCloseNotifer) CloseNotify() <-chan bool { diff --git a/WebFramework/Endpoints/health_endpoint.go b/WebFramework/Endpoints/health_endpoint.go index cf3d872b..cded2fd6 100644 --- a/WebFramework/Endpoints/health_endpoint.go +++ b/WebFramework/Endpoints/health_endpoint.go @@ -10,7 +10,7 @@ func UseHealth(router Router.IRouterBuilder) { xlog.GetXLogger("Endpoint").Debug("loaded health endpoint.") router.GET("/actuator/health", func(ctx *Context.HttpContext) { - ctx.JSON(200, Context.M{ + ctx.JSON(200, Context.H{ "status": "UP", }) }) diff --git a/WebFramework/Endpoints/jwt_endpoint.go b/WebFramework/Endpoints/jwt_endpoint.go new file mode 100644 index 00000000..269b4a5e --- /dev/null +++ b/WebFramework/Endpoints/jwt_endpoint.go @@ -0,0 +1,64 @@ +package Endpoints + +import ( + "github.com/yoyofx/yoyogo/Abstractions/xlog" + "github.com/yoyofx/yoyogo/Utils/jwt" + "github.com/yoyofx/yoyogo/WebFramework/Context" + "github.com/yoyofx/yoyogo/WebFramework/Middleware" + "github.com/yoyofx/yoyogo/WebFramework/Router" + "strconv" + "time" +) + +func UseJwt(router Router.IRouterBuilder) { + xlog.GetXLogger("Endpoint").Debug("loaded jwt endpoint.") + config := router.GetConfiguration() + var secretKey string + var expires int64 + var hasSecret, hasExpires bool + if config != nil { + secretKey, hasSecret = config.Get("application.server.jwt.secret").(string) + expires, hasExpires = config.Get("application.server.jwt.expires").(int64) + } + if !hasSecret { + secretKey = "12391JdeOW^%$#@" + } + if !hasExpires { + expires = 3 + } + if config != nil { + router.POST("/auth/token", func(ctx *Context.HttpContext) { + name := ctx.Input.Param("name") + id := ctx.Input.Param("id") + if name == "" || id == "" { + request := &Middleware.JwtRequest{} + err := ctx.Bind(request) + if err == nil { + id = request.Id + name = request.Name + } + } + if name == "" || id == "" { + xlog.GetXLogger("Jwt Endpoint").Debug("Create Token: name: %s , id: %v , token: %s") + ctx.JSON(200, Context.H{ + "token": "", + "expires": 0, + "success": false, + }) + return + } + + uid, _ := strconv.Atoi(id) + token, expires := jwt.CreateToken([]byte(secretKey), name, uint(uid), int64(time.Now().Add(time.Hour*time.Duration(expires)).Unix())) + xlog.GetXLogger("Jwt Endpoint").Debug("Create Token: ( name: %s , id: %s , token: %s )", name, id, token) + ctx.JSON(200, Context.H{ + "token": token, + "expires": expires, + "success": true, + }) + }) + } else { + xlog.GetXLogger("Jwt Endpoint").Error("config load error.") + } + +} diff --git a/WebFramework/Endpoints/pprof_endpoint.go b/WebFramework/Endpoints/pprof_endpoint.go index 2c4a53e0..442b6505 100644 --- a/WebFramework/Endpoints/pprof_endpoint.go +++ b/WebFramework/Endpoints/pprof_endpoint.go @@ -9,33 +9,33 @@ import ( "net/http/pprof" ) -type pprofHandler struct { - Path string - HandlerFunc YoyoGo.HandlerFunc -} - -var debupApi = []pprofHandler{ - {"/debug/pprof/", WarpHandlerFunc(pprof.Index)}, - {"/debug/pprof/cmdline", WarpHandlerFunc(pprof.Cmdline)}, - {"/debug/pprof/profile", WarpHandlerFunc(pprof.Profile)}, - {"/debug/pprof/symbol", WarpHandlerFunc(pprof.Symbol)}, - {"/debug/pprof/trace", WarpHandlerFunc(pprof.Trace)}, -} - -func WarpHandlerFunc(h func(w http.ResponseWriter, r *http.Request)) YoyoGo.HandlerFunc { +func pprofHandler(h http.HandlerFunc) YoyoGo.HandlerFunc { + handler := http.HandlerFunc(h) return func(c *Context.HttpContext) { + c.Output.SetStatus(200) if c.Input.Path() == "/debug/pprof/" { c.Output.Header(Context.HeaderContentType, Context.MIMETextHTML) } - c.Output.SetStatus(200) - h(c.Output.GetWriter(), c.Input.GetReader()) + handler.ServeHTTP(c.Output.GetWriter(), c.Input.GetReader()) } } func UsePprof(router Router.IRouterBuilder) { xlog.GetXLogger("Endpoint").Debug("loaded pprof endpoint.") - for _, item := range debupApi { - router.GET(item.Path, item.HandlerFunc) - } + router.Group("/debug/pprof", func(prefixRouter *Router.RouterGroup) { + prefixRouter.GET("/", pprofHandler(pprof.Index)) + prefixRouter.GET("/cmdline", pprofHandler(pprof.Cmdline)) + prefixRouter.GET("/profile", pprofHandler(pprof.Profile)) + prefixRouter.POST("/symbol", pprofHandler(pprof.Symbol)) + prefixRouter.GET("/symbol", pprofHandler(pprof.Symbol)) + prefixRouter.GET("/trace", pprofHandler(pprof.Trace)) + prefixRouter.GET("/allocs", pprofHandler(pprof.Handler("allocs").ServeHTTP)) + prefixRouter.GET("/block", pprofHandler(pprof.Handler("block").ServeHTTP)) + prefixRouter.GET("/goroutine", pprofHandler(pprof.Handler("goroutine").ServeHTTP)) + prefixRouter.GET("/heap", pprofHandler(pprof.Handler("heap").ServeHTTP)) + prefixRouter.GET("/mutex", pprofHandler(pprof.Handler("mutex").ServeHTTP)) + prefixRouter.GET("/threadcreate", pprofHandler(pprof.Handler("threadcreate").ServeHTTP)) + }) + } diff --git a/WebFramework/Middleware.go b/WebFramework/Middleware.go index e725cf18..da29ec5b 100644 --- a/WebFramework/Middleware.go +++ b/WebFramework/Middleware.go @@ -5,7 +5,7 @@ import ( "net/http" ) -type Handler interface { +type MiddlewareHandler interface { Inovke(ctx *Context.HttpContext, next func(ctx *Context.HttpContext)) } @@ -18,13 +18,13 @@ func (h MiddlewareHandlerFunc) Inovke(ctx *Context.HttpContext, next func(ctx *C } type middleware struct { - handler Handler + handler MiddlewareHandler // nextfn stores the next.ServeHTTP to reduce memory allocate nextfn func(ctx *Context.HttpContext) } -func newMiddleware(handler Handler, next *middleware) middleware { +func newMiddleware(handler MiddlewareHandler, next *middleware) middleware { return middleware{ handler: handler, nextfn: next.Invoke, @@ -35,14 +35,14 @@ func (m middleware) Invoke(ctx *Context.HttpContext) { m.handler.Inovke(ctx, m.nextfn) } -func wrap(handler http.Handler) Handler { +func wrap(handler http.Handler) MiddlewareHandler { return MiddlewareHandlerFunc(func(ctx *Context.HttpContext, next func(ctx *Context.HttpContext)) { handler.ServeHTTP(ctx.Output.GetWriter(), ctx.Input.GetReader()) next(ctx) }) } -func wrapFunc(handlerFunc http.HandlerFunc) Handler { +func wrapFunc(handlerFunc http.HandlerFunc) MiddlewareHandler { return MiddlewareHandlerFunc(func(ctx *Context.HttpContext, next func(ctx *Context.HttpContext)) { handlerFunc(ctx.Output.GetWriter(), ctx.Input.GetReader()) next(ctx) @@ -60,7 +60,7 @@ func voidMiddleware() middleware { ) } -func build(handlers []Handler) middleware { +func build(handlers []MiddlewareHandler) middleware { var next middleware switch { diff --git a/WebFramework/Middleware/BaseMiddleware.go b/WebFramework/Middleware/BaseMiddleware.go new file mode 100644 index 00000000..4df2de73 --- /dev/null +++ b/WebFramework/Middleware/BaseMiddleware.go @@ -0,0 +1,16 @@ +package Middleware + +import "github.com/yoyofx/yoyogo/Abstractions" + +type IConfigurationMiddleware interface { + SetConfiguration(config Abstractions.IConfiguration) +} + +type BaseMiddleware struct { + // Configuration + config Abstractions.IConfiguration +} + +func (mdw *BaseMiddleware) SetConfiguration(config Abstractions.IConfiguration) { + mdw.config = config +} diff --git a/WebFramework/Middleware/JWTMiddleware.go b/WebFramework/Middleware/JWTMiddleware.go new file mode 100644 index 00000000..7e4983f2 --- /dev/null +++ b/WebFramework/Middleware/JWTMiddleware.go @@ -0,0 +1,86 @@ +package Middleware + +import ( + "github.com/yoyofx/yoyogo/Abstractions" + "github.com/yoyofx/yoyogo/Utils" + "github.com/yoyofx/yoyogo/Utils/jwt" + "github.com/yoyofx/yoyogo/WebFramework/Context" + "net/http" +) + +type JwtMiddleware struct { + *BaseMiddleware + + Enable bool + SecretKey string + Prefix string + Header string + SkipPath []interface{} +} + +type JwtRequest struct { + Id string `json:"id"` + Name string `json:"name"` +} + +func NewJwt() *JwtMiddleware { + return &JwtMiddleware{BaseMiddleware: &BaseMiddleware{}} +} + +func (jwtmdw *JwtMiddleware) SetConfiguration(config Abstractions.IConfiguration) { + var hasEnable, hasSecret, hasPrefix, hasHeader bool + if config != nil { + jwtmdw.Enable, hasEnable = config.Get("application.server.jwt.enable").(bool) + jwtmdw.SecretKey, hasSecret = config.Get("application.server.jwt.secret").(string) + jwtmdw.Prefix, hasPrefix = config.Get("application.server.jwt.prefix").(string) + jwtmdw.Header, hasHeader = config.Get("application.server.jwt.header").(string) + jwtmdw.SkipPath, _ = config.Get("application.server.jwt.skip_path").([]interface{}) + } + + if !hasEnable { + jwtmdw.Enable = false + } + + if !hasSecret { + jwtmdw.SecretKey = "12391JdeOW^%$#@" + } + if !hasPrefix { + jwtmdw.Prefix = "Bearer" + } + if !hasHeader { + jwtmdw.Header = "Authorization" + } + + jwtmdw.SkipPath = append(jwtmdw.SkipPath, "/auth/token") + +} + +func (jwtmdw *JwtMiddleware) Inovke(ctx *Context.HttpContext, next func(ctx *Context.HttpContext)) { + + if !jwtmdw.Enable || Utils.Contains(ctx.Input.Path(), jwtmdw.SkipPath) { + next(ctx) + return + } + auth := ctx.Input.Header(jwtmdw.Header) + if auth == "" { + ctx.Output.SetStatus(http.StatusUnauthorized) + return + } + token := auth[len(jwtmdw.Prefix)+1:] + info, err := jwt.ParseToken(token, []byte(jwtmdw.SecretKey)) + + if err != nil { + ctx.Output.SetStatus(http.StatusUnauthorized) + ctx.Output.Error(http.StatusUnauthorized, "Unauthorized") + return + } else { + mapClaims := info.(jwt.MapClaims) + userInfo := make(map[string]interface{}) + for k, v := range mapClaims { + userInfo[k] = v + } + ctx.SetItem("userinfo", userInfo) + next(ctx) + } + +} diff --git a/WebFramework/Middleware/RecoveryMiddleware.go b/WebFramework/Middleware/RecoveryMiddleware.go index e22575d2..a0e2e9c5 100644 --- a/WebFramework/Middleware/RecoveryMiddleware.go +++ b/WebFramework/Middleware/RecoveryMiddleware.go @@ -176,7 +176,7 @@ func (rec *Recovery) Inovke(ctx *Context.HttpContext, next func(ctx *Context.Htt } infos := &PanicInformation{RecoveredPanic: err, Request: ctx.Input.Request} - // PrintStack will write stack trace info to the ResponseWriter if set to true! + // PrintStack will write stack trace info to the IResponseWriter if set to true! if rec.LogStack { infos.Stack = stack var msg string diff --git a/WebFramework/Middleware/RequestIDMiddleware.go b/WebFramework/Middleware/RequestIDMiddleware.go new file mode 100644 index 00000000..502deef2 --- /dev/null +++ b/WebFramework/Middleware/RequestIDMiddleware.go @@ -0,0 +1,24 @@ +package Middleware + +import ( + "github.com/google/uuid" + "github.com/yoyofx/yoyogo/WebFramework/Context" +) + +const headerXRequestID = "X-Request-ID" + +type RequestIDMiddleware struct { +} + +func NewRequestID() *RequestIDMiddleware { + return &RequestIDMiddleware{} +} + +func (router *RequestIDMiddleware) Inovke(ctx *Context.HttpContext, next func(ctx *Context.HttpContext)) { + requestId := ctx.Input.Header(headerXRequestID) + if requestId == "" { + requestId = uuid.New().String() + } + ctx.Output.Header(headerXRequestID, requestId) + next(ctx) +} diff --git a/WebFramework/Mvc/MvcRouterHandler.go b/WebFramework/Mvc/MvcRouterHandler.go index a927c6bc..c2a3f934 100644 --- a/WebFramework/Mvc/MvcRouterHandler.go +++ b/WebFramework/Mvc/MvcRouterHandler.go @@ -89,7 +89,7 @@ func (handler *RouterHandler) Invoke(ctx *Context.HttpContext, pathComponents [] filter.OnActionExecuted(actionFilterContext) } } else { - ctx.JSON(http.StatusUnauthorized, Context.M{"Message": "Unauthorized"}) + ctx.JSON(http.StatusUnauthorized, Context.H{"Message": "Unauthorized"}) } response := &RouterHandlerResponse{Result: actionResult} diff --git a/WebFramework/Router/DefaultRouterBuilder.go b/WebFramework/Router/DefaultRouterBuilder.go index 5c280a75..6099f9cf 100644 --- a/WebFramework/Router/DefaultRouterBuilder.go +++ b/WebFramework/Router/DefaultRouterBuilder.go @@ -1,6 +1,7 @@ package Router import ( + "github.com/yoyofx/yoyogo/Abstractions" "github.com/yoyofx/yoyogo/WebFramework/Context" "github.com/yoyofx/yoyogo/WebFramework/Mvc" "net/url" @@ -10,6 +11,7 @@ import ( type DefaultRouterBuilder struct { mvcControllerBuilder *Mvc.ControllerBuilder endPointRouterHandler *EndPointRouterHandler + configuration Abstractions.IConfiguration } func NewRouterBuilder() IRouterBuilder { @@ -23,6 +25,14 @@ func NewRouterBuilder() IRouterBuilder { return defaultRouterHandler } +func (router *DefaultRouterBuilder) SetConfiguration(config Abstractions.IConfiguration) { + router.configuration = config +} + +func (router *DefaultRouterBuilder) GetConfiguration() Abstractions.IConfiguration { + return router.configuration +} + func (router *DefaultRouterBuilder) UseMvc(used bool) { if used { router.mvcControllerBuilder = Mvc.NewControllerBuilder() diff --git a/WebFramework/Router/IRouterBuilder.go b/WebFramework/Router/IRouterBuilder.go index eddcf329..c33a9120 100644 --- a/WebFramework/Router/IRouterBuilder.go +++ b/WebFramework/Router/IRouterBuilder.go @@ -1,6 +1,7 @@ package Router import ( + "github.com/yoyofx/yoyogo/Abstractions" "github.com/yoyofx/yoyogo/WebFramework/Context" "github.com/yoyofx/yoyogo/WebFramework/Mvc" "net/url" @@ -48,4 +49,8 @@ type IRouterBuilder interface { Any(path string, handler func(ctx *Context.HttpContext)) Group(name string, routerBuilderFunc func(router *RouterGroup)) + + SetConfiguration(config Abstractions.IConfiguration) + + GetConfiguration() Abstractions.IConfiguration } diff --git a/WebFramework/WebApplicationBuilder.go b/WebFramework/WebApplicationBuilder.go index 0b9de868..36959885 100644 --- a/WebFramework/WebApplicationBuilder.go +++ b/WebFramework/WebApplicationBuilder.go @@ -17,7 +17,7 @@ type WebApplicationBuilder struct { hostContext *Abstractions.HostBuildContext // host build 's context routerBuilder Router.IRouterBuilder // route builder of interface middleware middleware - handlers []Handler + handlers []MiddlewareHandler routeConfigures []func(Router.IRouterBuilder) // endpoints router configure functions mvcConfigures []func(builder *Mvc.ControllerBuilder) // mvc router configure functions } @@ -42,7 +42,7 @@ func CreateBlankWebBuilder() *WebHostBuilder { } // create application builder when combo all handlers to middleware -func New(handlers ...Handler) *WebApplicationBuilder { +func New(handlers ...MiddlewareHandler) *WebApplicationBuilder { return &WebApplicationBuilder{ handlers: handlers, //middleware: build(handlers), @@ -55,7 +55,8 @@ func NewWebApplicationBuilder() *WebApplicationBuilder { recovery := Middleware.NewRecovery() logger := Middleware.NewLogger() router := Middleware.NewRouter(routerBuilder) - self := New(logger, recovery, router) + jwt := Middleware.NewJwt() + self := New(logger, recovery, jwt, router) self.routerBuilder = routerBuilder return self } @@ -75,6 +76,7 @@ func (self *WebApplicationBuilder) UseEndpoints(configure func(Router.IRouterBui } func (this *WebApplicationBuilder) buildEndPoints() { + this.routerBuilder.SetConfiguration(this.hostContext.Configuration) for _, configure := range this.routeConfigures { configure(this.routerBuilder) } @@ -101,6 +103,11 @@ func (this *WebApplicationBuilder) Build() interface{} { panic("hostContext is nil! please set.") } //this.hostContext.HostingEnvironment + for _, handler := range this.handlers { + if configurationMdw, ok := handler.(Middleware.IConfigurationMiddleware); ok { + configurationMdw.SetConfiguration(this.hostContext.Configuration) + } + } this.middleware = build(this.handlers) this.buildEndPoints() this.buildMvc(this.hostContext.ApplicationServicesDef) @@ -112,7 +119,7 @@ func (this *WebApplicationBuilder) SetHostBuildContext(context *Abstractions.Hos } // apply middleware in builder -func (app *WebApplicationBuilder) UseMiddleware(handler Handler) { +func (app *WebApplicationBuilder) UseMiddleware(handler MiddlewareHandler) { if handler == nil { panic("handler cannot be nil") } @@ -144,7 +151,7 @@ func (app *WebApplicationBuilder) UseFunc(handlerFunc MiddlewareHandlerFunc) { } /* -Middleware of Server Handler , request port. +Middleware of Server MiddlewareHandler , request port. */ func (app *WebApplicationBuilder) ServeHTTP(w http.ResponseWriter, r *http.Request) { app.middleware.Invoke(Context.NewContext(w, r, app.hostContext.ApplicationServices)) diff --git a/go.mod b/go.mod index e623d40e..e46fd05b 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.14 require ( github.com/defval/inject/v2 v2.2.2 github.com/golang/protobuf v1.4.2 + github.com/google/uuid v1.1.1 github.com/magiconair/properties v1.8.1 github.com/pkg/errors v0.8.1 github.com/prometheus/client_golang v0.9.3 diff --git a/go.sum b/go.sum index 1389bb2c..b360f1f5 100644 --- a/go.sum +++ b/go.sum @@ -82,6 +82,8 @@ github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXi github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8=