diff --git a/auth/auth.go b/auth/auth.go index 905baf0e9..842961945 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -146,6 +146,27 @@ func PubKeyContextSuperAdmin(next http.Handler) http.Handler { }) } +// ConnectionContext parses token for connection code +func ConnectionCodeContext(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := r.Header.Get("token") + + if token == "" { + fmt.Println("[auth] no token") + http.Error(w, http.StatusText(401), 401) + return + } + + if token != config.Connection_Auth { + fmt.Println("Not a super admin : auth") + http.Error(w, http.StatusText(401), 401) + return + } + ctx := context.WithValue(r.Context(), ContextKey, token) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + func AdminCheck(pubkey string) bool { for _, val := range config.SuperAdmins { if val == pubkey { diff --git a/config/config.go b/config/config.go index 520cfb5b0..5feb3f7d9 100644 --- a/config/config.go +++ b/config/config.go @@ -33,6 +33,7 @@ var S3FolderName string var S3Url string var AdminCheck string var AdminDevFreePass = "FREE_PASS" +var Connection_Auth string var S3Client *s3.Client var PresignClient *s3.PresignClient @@ -51,6 +52,7 @@ func InitConfig() { S3FolderName = os.Getenv("S3_FOLDER_NAME") S3Url = os.Getenv("S3_URL") AdminCheck = os.Getenv("ADMIN_CHECK") + Connection_Auth = os.Getenv("CONNECTION_AUTH") // Add to super admins SuperAdmins = StripSuperAdmins(AdminStrings) diff --git a/db/db.go b/db/db.go index a3f18221b..415bf5775 100644 --- a/db/db.go +++ b/db/db.go @@ -1495,10 +1495,12 @@ func (db database) GetPeopleListShort(count uint32) *[]PersonInShort { return &p } -func (db database) CreateConnectionCode(c ConnectionCodes) (ConnectionCodes, error) { - if c.DateCreated == nil { - now := time.Now() - c.DateCreated = &now +func (db database) CreateConnectionCode(c []ConnectionCodes) ([]ConnectionCodes, error) { + now := time.Now() + for _, code := range c { + if code.DateCreated.IsZero() { + code.DateCreated = &now + } } db.db.Create(&c) return c, nil diff --git a/db/interface.go b/db/interface.go index d0b96b2a1..2709abbb1 100644 --- a/db/interface.go +++ b/db/interface.go @@ -79,7 +79,7 @@ type Database interface { CountBounties() uint64 GetPeopleListShort(count uint32) *[]PersonInShort GetConnectionCode() ConnectionCodesShort - CreateConnectionCode(c ConnectionCodes) (ConnectionCodes, error) + CreateConnectionCode(c []ConnectionCodes) ([]ConnectionCodes, error) GetLnUser(lnKey string) int64 CreateLnUser(lnKey string) (Person, error) GetBountiesLeaderboard() []LeaderData diff --git a/handlers/auth.go b/handlers/auth.go index bd46a4040..858f680a4 100644 --- a/handlers/auth.go +++ b/handlers/auth.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "net/http" - "time" "github.com/form3tech-oss/jwt-go" "github.com/stakwork/sphinx-tribes/auth" @@ -54,16 +53,21 @@ func (ah *authHandler) GetIsAdmin(w http.ResponseWriter, r *http.Request) { } func (ah *authHandler) CreateConnectionCode(w http.ResponseWriter, r *http.Request) { - code := db.ConnectionCodes{} - now := time.Now() + codeArr := []db.ConnectionCodes{} + codeStrArr := []string{} body, err := io.ReadAll(r.Body) r.Body.Close() - err = json.Unmarshal(body, &code) + err = json.Unmarshal(body, &codeStrArr) - code.IsUsed = false - code.DateCreated = &now + for _, code := range codeStrArr { + code := db.ConnectionCodes{ + ConnectionString: code, + IsUsed: false, + } + codeArr = append(codeArr, code) + } if err != nil { fmt.Println(err) @@ -71,13 +75,15 @@ func (ah *authHandler) CreateConnectionCode(w http.ResponseWriter, r *http.Reque return } - _, err = ah.db.CreateConnectionCode(code) + _, err = ah.db.CreateConnectionCode(codeArr) if err != nil { fmt.Println("=> ERR create connection code", err) w.WriteHeader(http.StatusBadRequest) return } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode("Codes created successfully") } func (ah *authHandler) GetConnectionCode(w http.ResponseWriter, _ *http.Request) { diff --git a/handlers/auth_test.go b/handlers/auth_test.go index da98c96ea..461adde96 100644 --- a/handlers/auth_test.go +++ b/handlers/auth_test.go @@ -18,7 +18,6 @@ import ( "github.com/stakwork/sphinx-tribes/db" mocks "github.com/stakwork/sphinx-tribes/mocks" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" ) func TestGetAdminPubkeys(t *testing.T) { @@ -52,16 +51,21 @@ func TestGetAdminPubkeys(t *testing.T) { } func TestCreateConnectionCode(t *testing.T) { - mockDb := mocks.NewDatabase(t) aHandler := NewAuthHandler(mockDb) t.Run("should create connection code successful", func(t *testing.T) { - codeToBeInserted := db.ConnectionCodes{ - ConnectionString: "custom connection string", + codeToBeInserted := []string{"custom connection string", "custom connection string 2"} + + codeArr := []db.ConnectionCodes{} + for _, code := range codeToBeInserted { + code := db.ConnectionCodes{ + ConnectionString: code, + IsUsed: false, + } + codeArr = append(codeArr, code) } - mockDb.On("CreateConnectionCode", mock.MatchedBy(func(code db.ConnectionCodes) bool { - return code.IsUsed == false && code.ConnectionString == codeToBeInserted.ConnectionString - })).Return(codeToBeInserted, nil).Once() + + mockDb.On("CreateConnectionCode", codeArr).Return(codeArr, nil).Once() body, _ := json.Marshal(codeToBeInserted) req, err := http.NewRequest("POST", "/connectioncodes", bytes.NewBuffer(body)) @@ -77,12 +81,18 @@ func TestCreateConnectionCode(t *testing.T) { }) t.Run("should return error if failed to add connection code", func(t *testing.T) { - codeToBeInserted := db.ConnectionCodes{ - ConnectionString: "custom connection string", + codeToBeInserted := []string{"custom connection string", "custom connection string 2"} + + codeArr := []db.ConnectionCodes{} + for _, code := range codeToBeInserted { + code := db.ConnectionCodes{ + ConnectionString: code, + IsUsed: false, + } + codeArr = append(codeArr, code) } - mockDb.On("CreateConnectionCode", mock.MatchedBy(func(code db.ConnectionCodes) bool { - return code.IsUsed == false && code.ConnectionString == codeToBeInserted.ConnectionString - })).Return(codeToBeInserted, errors.New("failed to create connection")).Once() + + mockDb.On("CreateConnectionCode", codeArr).Return(codeArr, errors.New("failed to create connection")).Once() body, _ := json.Marshal(codeToBeInserted) req, err := http.NewRequest("POST", "/connectioncodes", bytes.NewBuffer(body)) diff --git a/mocks/Database.go b/mocks/Database.go index bc189ad35..b44f69f67 100644 --- a/mocks/Database.go +++ b/mocks/Database.go @@ -643,25 +643,27 @@ func (_c *Database_CreateChannel_Call) RunAndReturn(run func(db.Channel) (db.Cha } // CreateConnectionCode provides a mock function with given fields: c -func (_m *Database) CreateConnectionCode(c db.ConnectionCodes) (db.ConnectionCodes, error) { +func (_m *Database) CreateConnectionCode(c []db.ConnectionCodes) ([]db.ConnectionCodes, error) { ret := _m.Called(c) if len(ret) == 0 { panic("no return value specified for CreateConnectionCode") } - var r0 db.ConnectionCodes + var r0 []db.ConnectionCodes var r1 error - if rf, ok := ret.Get(0).(func(db.ConnectionCodes) (db.ConnectionCodes, error)); ok { + if rf, ok := ret.Get(0).(func([]db.ConnectionCodes) ([]db.ConnectionCodes, error)); ok { return rf(c) } - if rf, ok := ret.Get(0).(func(db.ConnectionCodes) db.ConnectionCodes); ok { + if rf, ok := ret.Get(0).(func([]db.ConnectionCodes) []db.ConnectionCodes); ok { r0 = rf(c) } else { - r0 = ret.Get(0).(db.ConnectionCodes) + if ret.Get(0) != nil { + r0 = ret.Get(0).([]db.ConnectionCodes) + } } - if rf, ok := ret.Get(1).(func(db.ConnectionCodes) error); ok { + if rf, ok := ret.Get(1).(func([]db.ConnectionCodes) error); ok { r1 = rf(c) } else { r1 = ret.Error(1) @@ -676,24 +678,24 @@ type Database_CreateConnectionCode_Call struct { } // CreateConnectionCode is a helper method to define mock.On call -// - c db.ConnectionCodes +// - c []db.ConnectionCodes func (_e *Database_Expecter) CreateConnectionCode(c interface{}) *Database_CreateConnectionCode_Call { return &Database_CreateConnectionCode_Call{Call: _e.mock.On("CreateConnectionCode", c)} } -func (_c *Database_CreateConnectionCode_Call) Run(run func(c db.ConnectionCodes)) *Database_CreateConnectionCode_Call { +func (_c *Database_CreateConnectionCode_Call) Run(run func(c []db.ConnectionCodes)) *Database_CreateConnectionCode_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(db.ConnectionCodes)) + run(args[0].([]db.ConnectionCodes)) }) return _c } -func (_c *Database_CreateConnectionCode_Call) Return(_a0 db.ConnectionCodes, _a1 error) *Database_CreateConnectionCode_Call { +func (_c *Database_CreateConnectionCode_Call) Return(_a0 []db.ConnectionCodes, _a1 error) *Database_CreateConnectionCode_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *Database_CreateConnectionCode_Call) RunAndReturn(run func(db.ConnectionCodes) (db.ConnectionCodes, error)) *Database_CreateConnectionCode_Call { +func (_c *Database_CreateConnectionCode_Call) RunAndReturn(run func([]db.ConnectionCodes) ([]db.ConnectionCodes, error)) *Database_CreateConnectionCode_Call { _c.Call.Return(run) return _c } diff --git a/routes/connection_codes.go b/routes/connection_codes.go index 4a5e6bee7..b151ff5e5 100644 --- a/routes/connection_codes.go +++ b/routes/connection_codes.go @@ -2,6 +2,7 @@ package routes import ( "github.com/go-chi/chi" + "github.com/stakwork/sphinx-tribes/auth" "github.com/stakwork/sphinx-tribes/db" "github.com/stakwork/sphinx-tribes/handlers" ) @@ -10,8 +11,12 @@ func ConnectionCodesRoutes() chi.Router { r := chi.NewRouter() authHandler := handlers.NewAuthHandler(db.DB) r.Group(func(r chi.Router) { - r.Post("/", authHandler.CreateConnectionCode) r.Get("/", authHandler.GetConnectionCode) }) + + r.Group(func(r chi.Router) { + r.Use(auth.ConnectionCodeContext) + r.Post("/", authHandler.CreateConnectionCode) + }) return r } diff --git a/utils/utils.go b/utils/utils.go index 95f41417a..af3f0a51d 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -7,10 +7,9 @@ import ( ) func GetPaginationParams(r *http.Request) (int, int, string, string, string) { - // there are cases when the request is not passed in if r == nil { - return 0, -1, "updated", "asc", "" + return 0, 1, "updated", "asc", "" } keys := r.URL.Query()