diff --git a/handlers/auth.go b/handlers/auth.go index 21bb3af4d..0070d2001 100644 --- a/handlers/auth.go +++ b/handlers/auth.go @@ -64,26 +64,28 @@ func (ah *authHandler) CreateConnectionCode(w http.ResponseWriter, r *http.Reque body, err := io.ReadAll(r.Body) if err != nil { fmt.Println("ReadAll Error", err) + w.WriteHeader(http.StatusBadRequest) + return } r.Body.Close() err = json.Unmarshal(body, &codeBody) if err != nil { - fmt.Println("Could not umarshal connection code body") + fmt.Println("Could not unmarshal connection code body") w.WriteHeader(http.StatusNotAcceptable) return } if codeBody.Pubkey != "" && codeBody.RouteHint == "" { - fmt.Println("route hint missing") - w.WriteHeader(http.StatusNotAcceptable) + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode("Route hint is required when pubkey is provided") return } if codeBody.RouteHint != "" && codeBody.Pubkey == "" { - fmt.Println("pubkey missing missing") w.WriteHeader(http.StatusNotAcceptable) + json.NewEncoder(w).Encode("pubkey is required when Route hint is provided") return } diff --git a/handlers/auth_test.go b/handlers/auth_test.go index f4a720302..c917135f5 100644 --- a/handlers/auth_test.go +++ b/handlers/auth_test.go @@ -93,7 +93,7 @@ func TestCreateConnectionCode(t *testing.T) { body, _ := json.Marshal(data) - req, err := http.NewRequest("POST", "/connectioncodes", bytes.NewBuffer(body)) + req, err := http.NewRequest(http.MethodPost, "/connectioncodes", bytes.NewBuffer(body)) if err != nil { t.Fatal(err) } @@ -106,7 +106,7 @@ func TestCreateConnectionCode(t *testing.T) { t.Run("should return error for malformed request body", func(t *testing.T) { body := []byte(`{"number": "0"}`) - req, err := http.NewRequest("POST", "/connectioncodes", bytes.NewBuffer(body)) + req, err := http.NewRequest(http.MethodPost, "/connectioncodes", bytes.NewBuffer(body)) if err != nil { t.Fatal(err) } @@ -120,7 +120,47 @@ func TestCreateConnectionCode(t *testing.T) { t.Run("should return error for invalid json", func(t *testing.T) { body := []byte(`{"nonumber":0`) - req, err := http.NewRequest("POST", "/connectioncodes", bytes.NewBuffer(body)) + req, err := http.NewRequest(http.MethodPost, "/connectioncodes", bytes.NewBuffer(body)) + if err != nil { + t.Fatal(err) + } + rr := httptest.NewRecorder() + handler := http.HandlerFunc(aHandler.CreateConnectionCode) + + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusNotAcceptable, rr.Code) + }) + + t.Run("should return error if pubkey is provided without route hint", func(t *testing.T) { + data := db.InviteBody{ + Number: 1, + Pubkey: "Test_pubkey", + RouteHint: "", + } + + body, _ := json.Marshal(data) + + req, err := http.NewRequest(http.MethodPost, "/connectioncodes", bytes.NewBuffer(body)) + if err != nil { + t.Fatal(err) + } + rr := httptest.NewRecorder() + handler := http.HandlerFunc(aHandler.CreateConnectionCode) + + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + }) + + t.Run("should return error if route hint is provided without pubkey", func(t *testing.T) { + data := db.InviteBody{ + Number: 1, + Pubkey: "", + RouteHint: "Test_Route_hint", + } + + body, _ := json.Marshal(data) + + req, err := http.NewRequest(http.MethodPost, "/connectioncodes", bytes.NewBuffer(body)) if err != nil { t.Fatal(err) } @@ -167,7 +207,7 @@ func TestGetConnectionCode(t *testing.T) { db.TestDB.CreateConnectionCode(codeArr) - req, err := http.NewRequest("GET", "/connectioncodes", nil) + req, err := http.NewRequest(http.MethodGet, "/connectioncodes", nil) if err != nil { t.Fatal(err) } @@ -186,7 +226,28 @@ func TestGetConnectionCode(t *testing.T) { assert.True(t, timeDifference <= tolerance, "Expected DateCreated to be within tolerance") }) + t.Run("should return empty fields if no connection codes exist", func(t *testing.T) { + rr := httptest.NewRecorder() + handler := http.HandlerFunc(aHandler.GetConnectionCode) + req, err := http.NewRequest(http.MethodGet, "/connectioncodes", nil) + if err != nil { + t.Fatal(err) + } + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + + var response db.ConnectionCodesShort + err = json.Unmarshal(rr.Body.Bytes(), &response) + if err != nil { + t.Fatal("Failed to unmarshal response:", err) + } + + assert.Empty(t, response.ConnectionString) + assert.Nil(t, response.DateCreated) + }) } func TestGetIsAdmin(t *testing.T) {