Skip to content

Commit

Permalink
Wrap map into map mutex lock for read/write operation
Browse files Browse the repository at this point in the history
  • Loading branch information
vkuznet committed Aug 22, 2024
1 parent 1de226e commit fe88e01
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 1 deletion.
4 changes: 4 additions & 0 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"log"
"strings"
"sync"
)

// OAuthProviders contains maps of all participated providers
Expand Down Expand Up @@ -47,6 +48,7 @@ func (t *TokenInfo) String() string {
// Init initializes map of OAuth providers
func Init(providers []string, verbose int) {
OAuthProviders = make(map[string]Provider)
mapMutex := sync.RWMutex{}
for _, purl := range providers {
if verbose > 0 {
log.Println("initialize provider ", purl)
Expand All @@ -56,7 +58,9 @@ func Init(providers []string, verbose int) {
if err != nil {
log.Fatalf("fail to initialize %s error %v", p.URL, err)
}
mapMutex.Lock()
OAuthProviders[purl] = p
mapMutex.Unlock()
}
}

Expand Down
4 changes: 4 additions & 0 deletions auth/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"math/big"
"net/http"
"strings"
"sync"
"time"

"github.com/pascaldekloe/jwt"
Expand Down Expand Up @@ -213,6 +214,7 @@ func getPublicKey(exp, mod string) (*rsa.PublicKey, error) {
// github.com/pascaldekloe/jwt go package
func tokenClaims(provider Provider, token string) (map[string]interface{}, error) {
out := make(map[string]interface{})
mapMutex := sync.RWMutex{}
// First parse without checking signature, to get the Kid
claims, err := jwt.ParseWithoutCheck([]byte(token))
log.Println("ParseWithoutCheck returns %v", err)
Expand All @@ -238,6 +240,7 @@ func tokenClaims(provider Provider, token string) (map[string]interface{}, error
msg := "The token is not valid"
return out, errors.New(msg)
}
mapMutex.Lock()
for k, v := range claims.Set {
out[k] = v
}
Expand All @@ -246,5 +249,6 @@ func tokenClaims(provider Provider, token string) (map[string]interface{}, error
out["sub"] = claims.Subject
out["iss"] = claims.Issuer
out["aud"] = claims.Audiences
mapMutex.Unlock()
return out, nil
}
5 changes: 5 additions & 0 deletions oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ func oauthRequestHandler(w http.ResponseWriter, r *http.Request) {

status := http.StatusOK
userData := make(map[string]interface{})
mapMutex := sync.RWMutex{}
tstamp := int64(start.UnixNano() / 1000000) // use milliseconds for MONIT
sess := globalSessions.SessionStart(w, r)
oauthState := uuid.New().String()
Expand Down Expand Up @@ -559,12 +560,16 @@ func oauthRequestHandler(w http.ResponseWriter, r *http.Request) {
}
} else {
// in case of existing token CERN SSO or IAM we use token attributes as user data
mapMutex.Lock()
userData["email"] = attrs.Email
userData["name"] = attrs.UserName
userData["exp"] = attrs.Expiration
mapMutex.Unlock()
}
// set id in user data based on token ClientID. The id will be used by SetCMSHeadersXXX calls
mapMutex.Lock()
userData["id"] = attrs.ClientID
mapMutex.Unlock()

// set CMS headers
if Config.CMSHeaders {
Expand Down
10 changes: 10 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"regexp"
"runtime"
"strings"
"sync"
"time"

"github.com/vkuznet/auth-proxy-server/cric"
Expand Down Expand Up @@ -351,6 +352,7 @@ func findCN(subject string) (string, error) {
// helper function to get user data from TLS request
func getUserData(r *http.Request) map[string]interface{} {
userData := make(map[string]interface{})
mapMutex := sync.RWMutex{}
if r.TLS == nil {
if Config.Verbose > 2 {
log.Printf("HTTP request does not support TLS, %+v", r)
Expand Down Expand Up @@ -393,6 +395,7 @@ func getUserData(r *http.Request) map[string]interface{} {
log.Printf("found user %+v error=%v elapsed time %v\n", rec, err, time.Since(start))
}
if err == nil {
mapMutex.Lock()
userData["issuer"] = strings.Split(cert.Issuer.String(), ",")[0]
userData["Subject"] = strings.Split(cert.Subject.String(), ",")[0]
userData["name"] = rec.Name
Expand All @@ -409,6 +412,7 @@ func getUserData(r *http.Request) map[string]interface{} {
userData["dn"] = dn
}
}
mapMutex.Unlock()
break
} else {
log.Println(err)
Expand Down Expand Up @@ -493,10 +497,13 @@ func PathMatched(rurl, path string, strict bool) bool {
// RedirectRules provides redirect rules map by reading Config.Ingress items
func RedirectRules(ingressRules []Ingress) (map[string]Ingress, []string) {
rmap := make(map[string]Ingress)
mapMutex := sync.RWMutex{}
var rules []string
for _, rec := range ingressRules {
rules = append(rules, rec.Path)
mapMutex.Lock()
rmap[rec.Path] = rec
mapMutex.Unlock()
}
// we should not sort rules, otherwise we break order of the rules which is important, e.g.
// /wmstats should point to /wmstats/index.html, while /wmstats/.* should go further
Expand All @@ -507,6 +514,7 @@ func RedirectRules(ingressRules []Ingress) (map[string]Ingress, []string) {
// RedirectRulesFromFiles provides redirect rules map by reading Config.IngressFiles
func RedirectRulesFromFiles(ingressFiles []string) (map[string]Ingress, []string) {
rmap := make(map[string]Ingress)
mapMutex := sync.RWMutex{}
var rules []string
for _, fname := range ingressFiles {
file, err := os.Open(fname)
Expand All @@ -525,7 +533,9 @@ func RedirectRulesFromFiles(ingressFiles []string) (map[string]Ingress, []string
}
for _, rec := range ingressRules {
rules = append(rules, rec.Path)
mapMutex.Lock()
rmap[rec.Path] = rec
mapMutex.Unlock()
}
}
// we should not sort rules, otherwise we break order of the rules which is important, e.g.
Expand Down
9 changes: 8 additions & 1 deletion x509.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"
"log"
"net/http"
"sync"
"sync/atomic"
"time"

Expand Down Expand Up @@ -58,6 +59,7 @@ func x509RequestHandler(w http.ResponseWriter, r *http.Request) {
status := http.StatusOK
tstamp := int64(start.UnixNano() / 1000000) // use milliseconds for MONIT
userData := getUserData(r)
mapMutex := sync.RWMutex{}
if Config.Verbose > 0 {
log.Println("userData", userData)
}
Expand All @@ -67,7 +69,9 @@ func x509RequestHandler(w http.ResponseWriter, r *http.Request) {
if Config.Verbose > 3 {
level = true
}
mapMutex.RLock()
CMSAuth.SetCMSHeaders(r, userData, cric.CricRecords, level)
mapMutex.RUnlock()
if Config.Verbose > 1 {
printHTTPRequest(r, "cms headers")
}
Expand All @@ -82,7 +86,10 @@ func x509RequestHandler(w http.ResponseWriter, r *http.Request) {

// add LogRequest after we set cms headers in HTTP request
defer logging.LogRequest(crw, r, start, "x509", &status, tstamp, 0)
if _, ok := userData["name"]; !ok {
mapMutex.RLock()
_, ok := userData["name"]
mapMutex.RUnlock()
if !ok {
log.Println("unauthorized access, user not found in CRIC DB")
status = http.StatusUnauthorized
w.WriteHeader(status)
Expand Down

0 comments on commit fe88e01

Please sign in to comment.