Skip to content

Commit

Permalink
Implemented UTs
Browse files Browse the repository at this point in the history
  • Loading branch information
FaisalIqbal211 committed Feb 22, 2024
1 parent 54731b8 commit dd215b5
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 29 deletions.
23 changes: 15 additions & 8 deletions handlers/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,24 @@ import (
"net/http"
"time"

"github.com/form3tech-oss/jwt-go"
"github.com/stakwork/sphinx-tribes/auth"
"github.com/stakwork/sphinx-tribes/config"
"github.com/stakwork/sphinx-tribes/db"
)

type authHandler struct {
db db.Database
db db.Database
decodeJwt func(token string) (jwt.MapClaims, error)
encodeJwt func(pubkey string) (string, error)
}

func NewAuthHandler(db db.Database) *authHandler {
return &authHandler{db: db}
return &authHandler{
db: db,
decodeJwt: auth.DecodeJwt,
encodeJwt: auth.EncodeJwt,
}
}

func GetAdminPubkeys(w http.ResponseWriter, r *http.Request) {
Expand All @@ -31,7 +38,7 @@ func GetAdminPubkeys(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}

func GetIsAdmin(w http.ResponseWriter, r *http.Request) {
func (ah *authHandler) GetIsAdmin(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
pubKeyFromAuth, _ := ctx.Value(auth.ContextKey).(string)
isAdmin := auth.AdminCheck(pubKeyFromAuth)
Expand Down Expand Up @@ -165,11 +172,11 @@ func ReceiveLnAuthData(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(responseMsg)
}

func RefreshToken(w http.ResponseWriter, r *http.Request) {
func (ah *authHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("x-jwt")

responseData := make(map[string]interface{})
claims, err := auth.DecodeJwt(token)
claims, err := ah.decodeJwt(token)

if err != nil {
fmt.Println("Failed to parse JWT")
Expand All @@ -180,11 +187,11 @@ func RefreshToken(w http.ResponseWriter, r *http.Request) {

pubkey := fmt.Sprint(claims["pubkey"])

userCount := db.DB.GetLnUser(pubkey)
userCount := ah.db.GetLnUser(pubkey)

if userCount > 0 {
// Generate a new token
tokenString, err := auth.EncodeJwt(pubkey)
tokenString, err := ah.encodeJwt(pubkey)

if err != nil {
fmt.Println("error creating refresh JWT")
Expand All @@ -193,7 +200,7 @@ func RefreshToken(w http.ResponseWriter, r *http.Request) {
return
}

person := db.DB.GetPersonByPubkey(pubkey)
person := ah.db.GetPersonByPubkey(pubkey)
user := returnUserMap(person)

responseData["k1"] = ""
Expand Down
138 changes: 119 additions & 19 deletions handlers/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package handlers

import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
Expand All @@ -11,6 +12,8 @@ import (
"testing"
"time"

"github.com/form3tech-oss/jwt-go"
"github.com/stakwork/sphinx-tribes/auth"
"github.com/stakwork/sphinx-tribes/config"
"github.com/stakwork/sphinx-tribes/db"
mocks "github.com/stakwork/sphinx-tribes/mocks"
Expand All @@ -19,29 +22,33 @@ import (
)

func TestGetAdminPubkeys(t *testing.T) {
// set the admins and init the config to update superadmins
os.Setenv("ADMINS", "test")
config.InitConfig()

req, err := http.NewRequest("GET", "/admin_pubkeys", nil)
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
handler := http.HandlerFunc(GetAdminPubkeys)
t.Run("Should test that all admin pubkeys is returned", func(t *testing.T) {
// set the admins and init the config to update superadmins
os.Setenv("ADMINS", "test")
os.Setenv("RELAY_URL", "RelayUrl")
os.Setenv("RELAY_AUTH_KEY", "RelayAuthKey")
config.InitConfig()

req, err := http.NewRequest("GET", "/admin_pubkeys", nil)
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
handler := http.HandlerFunc(GetAdminPubkeys)

handler.ServeHTTP(rr, req)
handler.ServeHTTP(rr, req)

if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}

expected := `{"pubkeys":["test"]}`
if strings.TrimRight(rr.Body.String(), "\n") != expected {
expected := `{"pubkeys":["test"]}`
if strings.TrimRight(rr.Body.String(), "\n") != expected {

t.Errorf("handler returned unexpected body: expected %s pubkeys %s is there a space after?", expected, rr.Body.String())
}
t.Errorf("handler returned unexpected body: expected %s pubkeys %s is there a space after?", expected, rr.Body.String())
}
})
}

func TestCreateConnectionCode(t *testing.T) {
Expand Down Expand Up @@ -145,3 +152,96 @@ func TestGetConnectionCode(t *testing.T) {
})

}

func TestGetIsAdmin(t *testing.T) {
mockDb := mocks.NewDatabase(t)
aHandler := NewAuthHandler(mockDb)

t.Run("Should test that GetIsAdmin returns a 401 error if the user is not an admin", func(t *testing.T) {
req, err := http.NewRequest("GET", "/admin/auth", nil)
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
handler := http.HandlerFunc(aHandler.GetIsAdmin)

pubKey := "non_admin_pubkey"
ctx := context.WithValue(req.Context(), auth.ContextKey, pubKey)
req = req.WithContext(ctx)

handler.ServeHTTP(rr, req)

assert.Equal(t, http.StatusUnauthorized, rr.Code)
})

t.Run("Should test that a 200 status code is returned if the user is an admin", func(t *testing.T) {
req, err := http.NewRequest("GET", "/admin/auth", nil)
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
handler := http.HandlerFunc(aHandler.GetIsAdmin)

adminPubKey := config.SuperAdmins[0]
ctx := context.WithValue(req.Context(), auth.ContextKey, adminPubKey)
req = req.WithContext(ctx)

handler.ServeHTTP(rr, req)

assert.Equal(t, http.StatusOK, rr.Code)
})
}

func TestRefreshToken(t *testing.T) {
mockDb := mocks.NewDatabase(t)
aHandler := NewAuthHandler(mockDb)

t.Run("Should test that a user token can be refreshed", func(t *testing.T) {
mockToken := "mock_token"
mockUserPubkey := "mock_pubkey"
mockPerson := db.Person{
ID: 1,
OwnerPubKey: mockUserPubkey,
}
mockDb.On("GetLnUser", mockUserPubkey).Return(int64(1)).Once()
mockDb.On("GetPersonByPubkey", mockUserPubkey).Return(mockPerson).Once()

// Mock JWT decoding
mockClaims := jwt.MapClaims{
"pubkey": mockUserPubkey,
}
mockDecodeJwt := func(token string) (jwt.MapClaims, error) {
return mockClaims, nil
}
aHandler.decodeJwt = mockDecodeJwt

// Mock JWT encoding
mockEncodedToken := "encoded_mock_token"
mockEncodeJwt := func(pubkey string) (string, error) {
return mockEncodedToken, nil
}
aHandler.encodeJwt = mockEncodeJwt

// Create request with mock token in header
req, err := http.NewRequest("GET", "/refresh_jwt", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("x-jwt", mockToken)

// Serve request
rr := httptest.NewRecorder()
handler := http.HandlerFunc(aHandler.RefreshToken)
handler.ServeHTTP(rr, req)

// Verify response
assert.Equal(t, http.StatusOK, rr.Code)
var responseData map[string]interface{}
err = json.Unmarshal(rr.Body.Bytes(), &responseData)
if err != nil {
t.Fatalf("Error decoding JSON response: %s", err)
}
assert.Equal(t, true, responseData["status"])
assert.Equal(t, mockEncodedToken, responseData["jwt"])
})
}
5 changes: 3 additions & 2 deletions routes/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
func NewRouter() *http.Server {
r := initChi()
tribeHandlers := handlers.NewTribeHandler(db.DB)
authHandler := handlers.NewAuthHandler(db.DB)

r.Mount("/tribes", TribeRoutes())
r.Mount("/bots", BotsRoutes())
Expand Down Expand Up @@ -74,13 +75,13 @@ func NewRouter() *http.Server {
r.Delete("/ticket/{pubKey}/{created}", handlers.DeleteTicketByAdmin)
r.Get("/poll/invoice/{paymentRequest}", handlers.PollInvoice)
r.Post("/meme_upload", handlers.MemeImageUpload)
r.Get("/admin/auth", handlers.GetIsAdmin)
r.Get("/admin/auth", authHandler.GetIsAdmin)
})

r.Group(func(r chi.Router) {
r.Get("/lnauth_login", handlers.ReceiveLnAuthData)
r.Get("/lnauth", handlers.GetLnurlAuth)
r.Get("/refresh_jwt", handlers.RefreshToken)
r.Get("/refresh_jwt", authHandler.RefreshToken)
r.Post("/invoices", handlers.GenerateInvoice)
r.Post("/budgetinvoices", handlers.GenerateBudgetInvoice)
})
Expand Down

0 comments on commit dd215b5

Please sign in to comment.