diff --git a/handlers/tribes.go b/handlers/tribes.go index 11a319a16..61f90f495 100644 --- a/handlers/tribes.go +++ b/handlers/tribes.go @@ -31,20 +31,20 @@ func NewTribeHandler(db db.Database) *tribeHandler { } } -func GetAllTribes(w http.ResponseWriter, r *http.Request) { - tribes := db.DB.GetAllTribes() +func (th *tribeHandler) GetAllTribes(w http.ResponseWriter, r *http.Request) { + tribes := th.db.GetAllTribes() w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(tribes) } -func GetTotalribes(w http.ResponseWriter, r *http.Request) { - tribesTotal := db.DB.GetTribesTotal() +func (th *tribeHandler) GetTotalribes(w http.ResponseWriter, r *http.Request) { + tribesTotal := th.db.GetTribesTotal() w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(tribesTotal) } -func GetListedTribes(w http.ResponseWriter, r *http.Request) { - tribes := db.DB.GetListedTribes(r) +func (th *tribeHandler) GetListedTribes(w http.ResponseWriter, r *http.Request) { + tribes := th.db.GetListedTribes(r) w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(tribes) } diff --git a/handlers/tribes_test.go b/handlers/tribes_test.go index 4602ecc11..f9654e721 100644 --- a/handlers/tribes_test.go +++ b/handlers/tribes_test.go @@ -6,9 +6,11 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "strings" "testing" "github.com/go-chi/chi" + "github.com/lib/pq" "github.com/stakwork/sphinx-tribes/auth" "github.com/stakwork/sphinx-tribes/db" mocks "github.com/stakwork/sphinx-tribes/mocks" @@ -471,3 +473,134 @@ func TestGetTribeByUniqueName(t *testing.T) { assert.Equal(t, mockUniqueName, responseData["unique_name"]) }) } + +func TestGetAllTribes(t *testing.T) { + mockDb := mocks.NewDatabase(t) + tHandler := NewTribeHandler(mockDb) + t.Run("should return all tribes", func(t *testing.T) { + rr := httptest.NewRecorder() + handler := http.HandlerFunc(tHandler.GetAllTribes) + + expectedTribes := []db.Tribe{ + {UUID: "uuid", Name: "Tribe1"}, + {UUID: "uuid", Name: "Tribe2"}, + {UUID: "uuid", Name: "Tribe3"}, + } + + rctx := chi.NewRouteContext() + req, err := http.NewRequestWithContext(context.WithValue(context.Background(), chi.RouteCtxKey, rctx), http.MethodGet, "/", nil) + assert.NoError(t, err) + + mockDb.On("GetAllTribes", mock.Anything).Return(expectedTribes) + handler.ServeHTTP(rr, req) + var returnedTribes []db.Tribe + err = json.Unmarshal(rr.Body.Bytes(), &returnedTribes) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rr.Code) + assert.EqualValues(t, expectedTribes, returnedTribes) + mockDb.AssertExpectations(t) + + }) +} + +func TestGetTotalTribes(t *testing.T) { + mockDb := mocks.NewDatabase(t) + tHandler := NewTribeHandler(mockDb) + t.Run("should return the total number of tribes", func(t *testing.T) { + rr := httptest.NewRecorder() + handler := http.HandlerFunc(tHandler.GetTotalribes) + + expectedTribes := []db.Tribe{ + {UUID: "uuid", Name: "Tribe1"}, + {UUID: "uuid", Name: "Tribe2"}, + {UUID: "uuid", Name: "Tribe3"}, + } + + expectedTribesCount := int64(len(expectedTribes)) + + rctx := chi.NewRouteContext() + req, err := http.NewRequestWithContext(context.WithValue(context.Background(), chi.RouteCtxKey, rctx), http.MethodGet, "/total", nil) + assert.NoError(t, err) + + mockDb.On("GetTribesTotal", mock.Anything).Return(expectedTribesCount) + + handler.ServeHTTP(rr, req) + var returnedTribesCount int64 + err = json.Unmarshal(rr.Body.Bytes(), &returnedTribesCount) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rr.Code) + assert.EqualValues(t, expectedTribesCount, returnedTribesCount) + mockDb.AssertExpectations(t) + + }) +} + +func TestGetListedTribes(t *testing.T) { + mockDb := mocks.NewDatabase(t) + tHandler := NewTribeHandler(mockDb) + + t.Run("should only return tribes associated with a passed tag query", func(t *testing.T) { + rr := httptest.NewRecorder() + handler := http.HandlerFunc(tHandler.GetListedTribes) + expectedTribes := []db.Tribe{ + {UUID: "1", Name: "Tribe 1", Tags: pq.StringArray{"tag1", "tag2", "tag3"}}, + {UUID: "2", Name: "Tribe 2", Tags: pq.StringArray{"tag4", "tag5"}}, + {UUID: "3", Name: "Tribe 3", Tags: pq.StringArray{"tag6", "tag7", "tag8"}}, + } + req, err := http.NewRequest("GET", "/tribes", nil) + if err != nil { + t.Fatal(err) + } + query := req.URL.Query() + tagVals := pq.StringArray{"tag1", "tag4", "tag7"} + tags := strings.Join(tagVals, ",") + query.Set("tags", tags) + req.URL.RawQuery = query.Encode() + if err != nil { + t.Fatal(err) + } + + mockDb.On("GetListedTribes", req).Return(expectedTribes) + handler.ServeHTTP(rr, req) + var returnedTribes []db.Tribe + err = json.Unmarshal(rr.Body.Bytes(), &returnedTribes) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rr.Code) + assert.EqualValues(t, expectedTribes, returnedTribes) + + }) + + t.Run("should return all tribes when no tag queries are passed", func(t *testing.T) { + rr := httptest.NewRecorder() + handler := http.HandlerFunc(tHandler.GetListedTribes) + expectedTribes := []db.Tribe{ + {UUID: "1", Name: "Tribe 1", Tags: pq.StringArray{"tag1", "tag2", "tag3"}}, + {UUID: "2", Name: "Tribe 2", Tags: pq.StringArray{"tag4", "tag5"}}, + {UUID: "3", Name: "Tribe 3", Tags: pq.StringArray{"tag6", "tag7", "tag8"}}, + } + + req, err := http.NewRequest("GET", "/tribes", nil) + if err != nil { + t.Fatal(err) + } + query := req.URL.Query() + tagVals := pq.StringArray{"tag1", "tag4", "tag7"} + tags := strings.Join(tagVals, ",") + query.Set("tags", tags) + req.URL.RawQuery = query.Encode() + if err != nil { + t.Fatal(err) + } + + mockDb.On("GetListedTribes", req).Return(expectedTribes) + handler.ServeHTTP(rr, req) + + var returnedTribes []db.Tribe + err = json.Unmarshal(rr.Body.Bytes(), &returnedTribes) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rr.Code) + assert.EqualValues(t, expectedTribes, returnedTribes) + + }) + +} diff --git a/routes/tribes.go b/routes/tribes.go index 7e8e68439..bec12b75f 100644 --- a/routes/tribes.go +++ b/routes/tribes.go @@ -10,11 +10,11 @@ func TribeRoutes() chi.Router { r := chi.NewRouter() tribeHandlers := handlers.NewTribeHandler(db.DB) r.Group(func(r chi.Router) { - r.Get("/", handlers.GetListedTribes) + r.Get("/", tribeHandlers.GetListedTribes) r.Get("/app_url/{app_url}", tribeHandlers.GetTribesByAppUrl) r.Get("/app_urls/{app_urls}", handlers.GetTribesByAppUrls) r.Get("/{uuid}", tribeHandlers.GetTribe) - r.Get("/total", handlers.GetTotalribes) + r.Get("/total", tribeHandlers.GetTotalribes) r.Post("/", tribeHandlers.CreateOrEditTribe) }) return r