From 6717fbd750ab3b76e91ebaa1f69cfedd628bb718 Mon Sep 17 00:00:00 2001 From: elraphty Date: Mon, 4 Mar 2024 18:05:03 +0100 Subject: [PATCH 1/5] connection auth change --- auth/auth.go | 24 ++++++++++++++++++++++++ config/config.go | 2 ++ db/db.go | 6 +----- db/interface.go | 2 +- handlers/auth.go | 17 ++++++++++++----- mocks/Database.go | 24 +++++++++++++----------- routes/connection_codes.go | 7 ++++++- 7 files changed, 59 insertions(+), 23 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 905baf0e9..fab2bffc4 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -146,6 +146,30 @@ func PubKeyContextSuperAdmin(next http.Handler) http.Handler { }) } +// ConnectionContext parses token for connection code +func ConnectionCodeContext(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := r.URL.Query().Get("token") + if token == "" { + token = r.Header.Get("x-jwt") + } + + if token == "" { + fmt.Println("[auth] no token") + http.Error(w, http.StatusText(401), 401) + return + } + + if token != config.Connection_Auth { + fmt.Println("Not a super admin : auth") + http.Error(w, http.StatusText(401), 401) + return + } + ctx := context.WithValue(r.Context(), ContextKey, token) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + func AdminCheck(pubkey string) bool { for _, val := range config.SuperAdmins { if val == pubkey { diff --git a/config/config.go b/config/config.go index 36059cc8b..f3cee5b81 100644 --- a/config/config.go +++ b/config/config.go @@ -33,6 +33,7 @@ var S3FolderName string var S3Url string var AdminCheck string var AdminDevFreePass = "FREE_PASS" +var Connection_Auth string var S3Client *s3.S3 @@ -50,6 +51,7 @@ func InitConfig() { S3FolderName = os.Getenv("S3_FOLDER_NAME") S3Url = os.Getenv("S3_URL") AdminCheck = os.Getenv("ADMIN_CHECK") + Connection_Auth = os.Getenv("CONNECTION_AUTH") // Add to super admins SuperAdmins = StripSuperAdmins(AdminStrings) diff --git a/db/db.go b/db/db.go index a3f18221b..727b09698 100644 --- a/db/db.go +++ b/db/db.go @@ -1495,11 +1495,7 @@ func (db database) GetPeopleListShort(count uint32) *[]PersonInShort { return &p } -func (db database) CreateConnectionCode(c ConnectionCodes) (ConnectionCodes, error) { - if c.DateCreated == nil { - now := time.Now() - c.DateCreated = &now - } +func (db database) CreateConnectionCode(c []ConnectionCodes) ([]ConnectionCodes, error) { db.db.Create(&c) return c, nil } diff --git a/db/interface.go b/db/interface.go index bdf993e0f..ddcf1d0c1 100644 --- a/db/interface.go +++ b/db/interface.go @@ -79,7 +79,7 @@ type Database interface { CountBounties() uint64 GetPeopleListShort(count uint32) *[]PersonInShort GetConnectionCode() ConnectionCodesShort - CreateConnectionCode(c ConnectionCodes) (ConnectionCodes, error) + CreateConnectionCode(c []ConnectionCodes) ([]ConnectionCodes, error) GetLnUser(lnKey string) int64 CreateLnUser(lnKey string) (Person, error) GetBountiesLeaderboard() []LeaderData diff --git a/handlers/auth.go b/handlers/auth.go index dd60960ef..5d3a5380c 100644 --- a/handlers/auth.go +++ b/handlers/auth.go @@ -54,16 +54,23 @@ func (ah *authHandler) GetIsAdmin(w http.ResponseWriter, r *http.Request) { } func (ah *authHandler) CreateConnectionCode(w http.ResponseWriter, r *http.Request) { - code := db.ConnectionCodes{} + codeArr := []db.ConnectionCodes{} + codeStrArr := []string{} now := time.Now() body, err := io.ReadAll(r.Body) r.Body.Close() - err = json.Unmarshal(body, &code) + err = json.Unmarshal(body, &codeStrArr) - code.IsUsed = false - code.DateCreated = &now + for _, code := range codeStrArr { + code := db.ConnectionCodes{ + ConnectionString: code, + IsUsed: false, + DateCreated: &now, + } + codeArr = append(codeArr, code) + } if err != nil { fmt.Println(err) @@ -71,7 +78,7 @@ func (ah *authHandler) CreateConnectionCode(w http.ResponseWriter, r *http.Reque return } - _, err = ah.db.CreateConnectionCode(code) + _, err = ah.db.CreateConnectionCode(codeArr) if err != nil { fmt.Println("=> ERR create connection code", err) diff --git a/mocks/Database.go b/mocks/Database.go index ab0e6066c..fa9f1dd93 100644 --- a/mocks/Database.go +++ b/mocks/Database.go @@ -643,25 +643,27 @@ func (_c *Database_CreateChannel_Call) RunAndReturn(run func(db.Channel) (db.Cha } // CreateConnectionCode provides a mock function with given fields: c -func (_m *Database) CreateConnectionCode(c db.ConnectionCodes) (db.ConnectionCodes, error) { +func (_m *Database) CreateConnectionCode(c []db.ConnectionCodes) ([]db.ConnectionCodes, error) { ret := _m.Called(c) if len(ret) == 0 { panic("no return value specified for CreateConnectionCode") } - var r0 db.ConnectionCodes + var r0 []db.ConnectionCodes var r1 error - if rf, ok := ret.Get(0).(func(db.ConnectionCodes) (db.ConnectionCodes, error)); ok { + if rf, ok := ret.Get(0).(func([]db.ConnectionCodes) ([]db.ConnectionCodes, error)); ok { return rf(c) } - if rf, ok := ret.Get(0).(func(db.ConnectionCodes) db.ConnectionCodes); ok { + if rf, ok := ret.Get(0).(func([]db.ConnectionCodes) []db.ConnectionCodes); ok { r0 = rf(c) } else { - r0 = ret.Get(0).(db.ConnectionCodes) + if ret.Get(0) != nil { + r0 = ret.Get(0).([]db.ConnectionCodes) + } } - if rf, ok := ret.Get(1).(func(db.ConnectionCodes) error); ok { + if rf, ok := ret.Get(1).(func([]db.ConnectionCodes) error); ok { r1 = rf(c) } else { r1 = ret.Error(1) @@ -676,24 +678,24 @@ type Database_CreateConnectionCode_Call struct { } // CreateConnectionCode is a helper method to define mock.On call -// - c db.ConnectionCodes +// - c []db.ConnectionCodes func (_e *Database_Expecter) CreateConnectionCode(c interface{}) *Database_CreateConnectionCode_Call { return &Database_CreateConnectionCode_Call{Call: _e.mock.On("CreateConnectionCode", c)} } -func (_c *Database_CreateConnectionCode_Call) Run(run func(c db.ConnectionCodes)) *Database_CreateConnectionCode_Call { +func (_c *Database_CreateConnectionCode_Call) Run(run func(c []db.ConnectionCodes)) *Database_CreateConnectionCode_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(db.ConnectionCodes)) + run(args[0].([]db.ConnectionCodes)) }) return _c } -func (_c *Database_CreateConnectionCode_Call) Return(_a0 db.ConnectionCodes, _a1 error) *Database_CreateConnectionCode_Call { +func (_c *Database_CreateConnectionCode_Call) Return(_a0 []db.ConnectionCodes, _a1 error) *Database_CreateConnectionCode_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *Database_CreateConnectionCode_Call) RunAndReturn(run func(db.ConnectionCodes) (db.ConnectionCodes, error)) *Database_CreateConnectionCode_Call { +func (_c *Database_CreateConnectionCode_Call) RunAndReturn(run func([]db.ConnectionCodes) ([]db.ConnectionCodes, error)) *Database_CreateConnectionCode_Call { _c.Call.Return(run) return _c } diff --git a/routes/connection_codes.go b/routes/connection_codes.go index 4a5e6bee7..b151ff5e5 100644 --- a/routes/connection_codes.go +++ b/routes/connection_codes.go @@ -2,6 +2,7 @@ package routes import ( "github.com/go-chi/chi" + "github.com/stakwork/sphinx-tribes/auth" "github.com/stakwork/sphinx-tribes/db" "github.com/stakwork/sphinx-tribes/handlers" ) @@ -10,8 +11,12 @@ func ConnectionCodesRoutes() chi.Router { r := chi.NewRouter() authHandler := handlers.NewAuthHandler(db.DB) r.Group(func(r chi.Router) { - r.Post("/", authHandler.CreateConnectionCode) r.Get("/", authHandler.GetConnectionCode) }) + + r.Group(func(r chi.Router) { + r.Use(auth.ConnectionCodeContext) + r.Post("/", authHandler.CreateConnectionCode) + }) return r } From cc261739d2b1bb6d4840b6ba30d61c475f8bac0d Mon Sep 17 00:00:00 2001 From: elraphty Date: Mon, 4 Mar 2024 20:10:00 +0100 Subject: [PATCH 2/5] changed conection codes to array --- auth/auth.go | 7 ++----- handlers/auth.go | 2 ++ 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index fab2bffc4..aabb73375 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -37,7 +37,7 @@ func PubKeyContext(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { token := r.URL.Query().Get("token") if token == "" { - token = r.Header.Get("x-jwt") + token = r.Header.Get("token") } if token == "" { @@ -149,10 +149,7 @@ func PubKeyContextSuperAdmin(next http.Handler) http.Handler { // ConnectionContext parses token for connection code func ConnectionCodeContext(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - token := r.URL.Query().Get("token") - if token == "" { - token = r.Header.Get("x-jwt") - } + token := r.Header.Get("token") if token == "" { fmt.Println("[auth] no token") diff --git a/handlers/auth.go b/handlers/auth.go index 5d3a5380c..773309f21 100644 --- a/handlers/auth.go +++ b/handlers/auth.go @@ -85,6 +85,8 @@ func (ah *authHandler) CreateConnectionCode(w http.ResponseWriter, r *http.Reque w.WriteHeader(http.StatusBadRequest) return } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode("Codes created successfully") } func (ah *authHandler) GetConnectionCode(w http.ResponseWriter, _ *http.Request) { From b51275139d57e97825ecb93fd7de9c8416e59c40 Mon Sep 17 00:00:00 2001 From: elraphty Date: Mon, 4 Mar 2024 21:02:21 +0100 Subject: [PATCH 3/5] fixed connection errors --- db/db.go | 6 ++++++ handlers/auth.go | 3 --- handlers/auth_test.go | 34 ++++++++++++++++++++++------------ 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/db/db.go b/db/db.go index 727b09698..415bf5775 100644 --- a/db/db.go +++ b/db/db.go @@ -1496,6 +1496,12 @@ func (db database) GetPeopleListShort(count uint32) *[]PersonInShort { } func (db database) CreateConnectionCode(c []ConnectionCodes) ([]ConnectionCodes, error) { + now := time.Now() + for _, code := range c { + if code.DateCreated.IsZero() { + code.DateCreated = &now + } + } db.db.Create(&c) return c, nil } diff --git a/handlers/auth.go b/handlers/auth.go index 773309f21..35f283cec 100644 --- a/handlers/auth.go +++ b/handlers/auth.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "net/http" - "time" "github.com/form3tech-oss/jwt-go" "github.com/stakwork/sphinx-tribes/auth" @@ -56,7 +55,6 @@ func (ah *authHandler) GetIsAdmin(w http.ResponseWriter, r *http.Request) { func (ah *authHandler) CreateConnectionCode(w http.ResponseWriter, r *http.Request) { codeArr := []db.ConnectionCodes{} codeStrArr := []string{} - now := time.Now() body, err := io.ReadAll(r.Body) r.Body.Close() @@ -67,7 +65,6 @@ func (ah *authHandler) CreateConnectionCode(w http.ResponseWriter, r *http.Reque code := db.ConnectionCodes{ ConnectionString: code, IsUsed: false, - DateCreated: &now, } codeArr = append(codeArr, code) } diff --git a/handlers/auth_test.go b/handlers/auth_test.go index da98c96ea..461adde96 100644 --- a/handlers/auth_test.go +++ b/handlers/auth_test.go @@ -18,7 +18,6 @@ import ( "github.com/stakwork/sphinx-tribes/db" mocks "github.com/stakwork/sphinx-tribes/mocks" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" ) func TestGetAdminPubkeys(t *testing.T) { @@ -52,16 +51,21 @@ func TestGetAdminPubkeys(t *testing.T) { } func TestCreateConnectionCode(t *testing.T) { - mockDb := mocks.NewDatabase(t) aHandler := NewAuthHandler(mockDb) t.Run("should create connection code successful", func(t *testing.T) { - codeToBeInserted := db.ConnectionCodes{ - ConnectionString: "custom connection string", + codeToBeInserted := []string{"custom connection string", "custom connection string 2"} + + codeArr := []db.ConnectionCodes{} + for _, code := range codeToBeInserted { + code := db.ConnectionCodes{ + ConnectionString: code, + IsUsed: false, + } + codeArr = append(codeArr, code) } - mockDb.On("CreateConnectionCode", mock.MatchedBy(func(code db.ConnectionCodes) bool { - return code.IsUsed == false && code.ConnectionString == codeToBeInserted.ConnectionString - })).Return(codeToBeInserted, nil).Once() + + mockDb.On("CreateConnectionCode", codeArr).Return(codeArr, nil).Once() body, _ := json.Marshal(codeToBeInserted) req, err := http.NewRequest("POST", "/connectioncodes", bytes.NewBuffer(body)) @@ -77,12 +81,18 @@ func TestCreateConnectionCode(t *testing.T) { }) t.Run("should return error if failed to add connection code", func(t *testing.T) { - codeToBeInserted := db.ConnectionCodes{ - ConnectionString: "custom connection string", + codeToBeInserted := []string{"custom connection string", "custom connection string 2"} + + codeArr := []db.ConnectionCodes{} + for _, code := range codeToBeInserted { + code := db.ConnectionCodes{ + ConnectionString: code, + IsUsed: false, + } + codeArr = append(codeArr, code) } - mockDb.On("CreateConnectionCode", mock.MatchedBy(func(code db.ConnectionCodes) bool { - return code.IsUsed == false && code.ConnectionString == codeToBeInserted.ConnectionString - })).Return(codeToBeInserted, errors.New("failed to create connection")).Once() + + mockDb.On("CreateConnectionCode", codeArr).Return(codeArr, errors.New("failed to create connection")).Once() body, _ := json.Marshal(codeToBeInserted) req, err := http.NewRequest("POST", "/connectioncodes", bytes.NewBuffer(body)) From 4ecf9f54ee07b7b0f1ebfe66fbff9e373309579f Mon Sep 17 00:00:00 2001 From: elraphty Date: Thu, 7 Mar 2024 00:11:49 +0100 Subject: [PATCH 4/5] changed error returned limit to 1 --- utils/utils.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/utils/utils.go b/utils/utils.go index 95f41417a..af3f0a51d 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -7,10 +7,9 @@ import ( ) func GetPaginationParams(r *http.Request) (int, int, string, string, string) { - // there are cases when the request is not passed in if r == nil { - return 0, -1, "updated", "asc", "" + return 0, 1, "updated", "asc", "" } keys := r.URL.Query() From 3ff85d6f961e38e9671d7fc5d6ca3fe7088767f2 Mon Sep 17 00:00:00 2001 From: elraphty Date: Thu, 7 Mar 2024 10:25:02 +0100 Subject: [PATCH 5/5] changed auth header name --- auth/auth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auth/auth.go b/auth/auth.go index aabb73375..842961945 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -37,7 +37,7 @@ func PubKeyContext(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { token := r.URL.Query().Get("token") if token == "" { - token = r.Header.Get("token") + token = r.Header.Get("x-jwt") } if token == "" {