diff --git a/db/db.go b/db/db.go index 1eb8689ef..d3d2a3edc 100644 --- a/db/db.go +++ b/db/db.go @@ -1659,9 +1659,11 @@ func (db database) GetConnectionCode() ConnectionCodesShort { db.db.Raw(`SELECT connection_string, date_created FROM connectioncodes WHERE is_used =? ORDER BY id DESC LIMIT 1`, false).Find(&c) - db.db.Model(&ConnectionCodes{}).Where("connection_string = ?", c.ConnectionString).Updates(map[string]interface{}{ - "is_used": true, - }) + if c.ConnectionString != "" { + db.db.Model(&ConnectionCodes{}).Where("connection_string = ?", c.ConnectionString).Updates(map[string]interface{}{ + "is_used": true, + }) + } return c } diff --git a/db/structs.go b/db/structs.go index bad3aa492..3b73425c2 100644 --- a/db/structs.go +++ b/db/structs.go @@ -226,6 +226,9 @@ type ConnectionCodes struct { ConnectionString string `json:"connection_string"` IsUsed bool `json:"is_used"` DateCreated *time.Time `json:"date_created"` + Pubkey string `json:"pubkey"` + RouteHint string `json:"route_hint"` + SatsAmount int64 `json:"sats_amount"` } type ConnectionCodesShort struct { diff --git a/db/structsv2.go b/db/structsv2.go index 2e68ca46f..6d29aa466 100644 --- a/db/structsv2.go +++ b/db/structsv2.go @@ -80,5 +80,8 @@ type InviteReponse struct { } type InviteBody struct { - Number uint `json:"number"` + Number uint `json:"number"` + Pubkey string `json:"pubkey"` + RouteHint string `json:"route_hint"` + SatsAmount uint64 `json:"sats_amount"` } diff --git a/handlers/auth.go b/handlers/auth.go index 867143e9f..ce0d35ba2 100644 --- a/handlers/auth.go +++ b/handlers/auth.go @@ -17,7 +17,7 @@ import ( type authHandler struct { db db.Database - makeConnectionCodeRequest func() string + makeConnectionCodeRequest func(amt_msat uint64, alias string, pubkey string, route_hint string) string decodeJwt func(token string) (jwt.MapClaims, error) encodeJwt func(pubkey string) (string, error) } @@ -64,46 +64,70 @@ func (ah *authHandler) CreateConnectionCode(w http.ResponseWriter, r *http.Reque body, err := io.ReadAll(r.Body) if err != nil { fmt.Println("ReadAll Error", err) + w.WriteHeader(http.StatusBadRequest) + return } r.Body.Close() err = json.Unmarshal(body, &codeBody) - if err != nil { - fmt.Println("Could not umarshal connection code body") + fmt.Println("Could not unmarshal connection code body") w.WriteHeader(http.StatusNotAcceptable) return } + if codeBody.RouteHint != "" && codeBody.Pubkey == "" { + w.WriteHeader(http.StatusNotAcceptable) + json.NewEncoder(w).Encode("pubkey is required when Route hint is provided") + return + } + + if codeBody.Pubkey != "" && codeBody.RouteHint == "" { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode("Route hint is required when pubkey is provided") + return + } + + if codeBody.SatsAmount == 0 { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode("Sats amount must be greater than 0") + return + } + for i := 0; i < int(codeBody.Number); i++ { - code := ah.makeConnectionCodeRequest() + + amtMsat := codeBody.SatsAmount * 1000 + code := ah.makeConnectionCodeRequest(amtMsat, "new_user", codeBody.Pubkey, codeBody.RouteHint) if code != "" { newCode := db.ConnectionCodes{ ConnectionString: code, IsUsed: false, + Pubkey: codeBody.Pubkey, + RouteHint: codeBody.RouteHint, + SatsAmount: int64(codeBody.SatsAmount), } codeArr = append(codeArr, newCode) } } _, err = ah.db.CreateConnectionCode(codeArr) - if err != nil { fmt.Println("[auth] => ERR create connection code", err) w.WriteHeader(http.StatusBadRequest) return } + w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode("Codes created successfully") } -func MakeConnectionCodeRequest() string { +func MakeConnectionCodeRequest(amt_msat uint64, alias string, pubkey string, route_hint string) string { url := fmt.Sprintf("%s/invite", config.V2BotUrl) client := http.Client{} // Build v2 keysend payment data - bodyData := utils.BuildV2ConnectionCodes(100, "new_user") + bodyData := utils.BuildV2ConnectionCodes(amt_msat, alias, pubkey, route_hint) jsonBody := []byte(bodyData) req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody)) @@ -111,7 +135,6 @@ func MakeConnectionCodeRequest() string { req.Header.Set("Content-Type", "application/json") res, err := client.Do(req) - if err != nil { log.Printf("[Invite] Request Failed: %s", err) return "" @@ -120,20 +143,19 @@ func MakeConnectionCodeRequest() string { defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { log.Printf("Could not read invite body: %s", err) + return "" } - inviteReponse := db.InviteReponse{} - err = json.Unmarshal(body, &inviteReponse) - + inviteResponse := db.InviteReponse{} + err = json.Unmarshal(body, &inviteResponse) if err != nil { fmt.Println("Could not get connection code") return "" } - return inviteReponse.Invite + return inviteResponse.Invite } func (ah *authHandler) GetConnectionCode(w http.ResponseWriter, _ *http.Request) { diff --git a/handlers/auth_test.go b/handlers/auth_test.go index 46b73bd05..98b388539 100644 --- a/handlers/auth_test.go +++ b/handlers/auth_test.go @@ -62,10 +62,17 @@ func TestCreateConnectionCode(t *testing.T) { rr := httptest.NewRecorder() handler := http.HandlerFunc(aHandler.CreateConnectionCode) data := db.InviteBody{ - Number: 2, + Number: 2, + SatsAmount: 100, + Pubkey: "test_pubkey", + RouteHint: "test_route_hint", } - aHandler.makeConnectionCodeRequest = func() string { + aHandler.makeConnectionCodeRequest = func(amt_msat uint64, alias string, pubkey string, route_hint string) string { + assert.Equal(t, uint64(100000), amt_msat) + assert.Equal(t, "new_user", alias) + assert.Equal(t, "test_pubkey", pubkey) + assert.Equal(t, "test_route_hint", route_hint) return "22222222222222222" } @@ -83,9 +90,38 @@ func TestCreateConnectionCode(t *testing.T) { assert.NotEmpty(t, codes) }) - t.Run("should return error if failed to add connection code", func(t *testing.T) { + t.Run("should return error if sats amount is zero", func(t *testing.T) { data := db.InviteBody{ - Number: 0, + Number: 2, + SatsAmount: 0, + Pubkey: "test_pubkey", + RouteHint: "test_route_hint", + } + + body, _ := json.Marshal(data) + + req, err := http.NewRequest(http.MethodPost, "/connectioncodes", bytes.NewBuffer(body)) + if err != nil { + t.Fatal(err) + } + rr := httptest.NewRecorder() + handler := http.HandlerFunc(aHandler.CreateConnectionCode) + + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + + var response string + err = json.NewDecoder(rr.Body).Decode(&response) + assert.NoError(t, err) + assert.Equal(t, "Sats amount must be greater than 0", response) + }) + + t.Run("should return error if pubkey provided without route hint", func(t *testing.T) { + data := db.InviteBody{ + Number: 2, + SatsAmount: 100, + Pubkey: "test_pubkey", + RouteHint: "", } body, _ := json.Marshal(data) @@ -99,6 +135,11 @@ func TestCreateConnectionCode(t *testing.T) { handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) + + var response string + err = json.NewDecoder(rr.Body).Decode(&response) + assert.NoError(t, err) + assert.Equal(t, "Route hint is required when pubkey is provided", response) }) t.Run("should return error for malformed request body", func(t *testing.T) { @@ -136,12 +177,10 @@ func TestGetConnectionCode(t *testing.T) { aHandler := NewAuthHandler(db.TestDB) t.Run("should return connection code from db", func(t *testing.T) { - rr := httptest.NewRecorder() handler := http.HandlerFunc(aHandler.GetConnectionCode) codeStrArr := []string{"sampleCode1"} - codeArr := []db.ConnectionCodes{} now := time.Now() @@ -151,8 +190,10 @@ func TestGetConnectionCode(t *testing.T) { ConnectionString: code, IsUsed: false, DateCreated: &now, + Pubkey: "test_pubkey", + RouteHint: "test_route_hint", + SatsAmount: 100, } - codeArr = append(codeArr, code) } @@ -181,9 +222,7 @@ func TestGetConnectionCode(t *testing.T) { timeDifference = -timeDifference } assert.True(t, timeDifference <= tolerance, "Expected DateCreated to be within tolerance") - }) - } func TestGetIsAdmin(t *testing.T) { diff --git a/tribes.sql b/tribes.sql index ffe65a710..084d7bccd 100644 --- a/tribes.sql +++ b/tribes.sql @@ -137,9 +137,12 @@ VALUES ALTER TABLE IF EXISTS tribes ADD COLUMN IF NOT EXISTS preview VARCHAR NULL; -CREATE TABLE connectioncodes { +CREATE TABLE connectioncodes ( id SERIAL PRIMARY KEY, connection_string TEXT, is_used boolean, - date_created timestamptz -} \ No newline at end of file + date_created timestamptz, + pubkey TEXT, + route_hint TEXT, + sats_amount bigint +) \ No newline at end of file diff --git a/utils/utils.go b/utils/utils.go index 820d3dc0f..c093ab9b2 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -91,7 +91,12 @@ func BuildV2KeysendBodyData(amount uint, receiver_pubkey string, route_hint stri return bodyData } -func BuildV2ConnectionCodes(amt_msat uint, alias string) string { - bodyData := fmt.Sprintf(`{"amt_msat": %d, "alias": "%s"}`, amt_msat, alias) +func BuildV2ConnectionCodes(amt_msat uint64, alias string, pubkey string, route_hint string) string { + bodyData := fmt.Sprintf(`{ + "amt_msat": %d, + "alias": "%s", + "pubkey": "%s", + "route_hint": "%s" + }`, amt_msat, alias, pubkey, route_hint) return bodyData }