Skip to content

Commit

Permalink
Merge pull request #1583 from stakwork/feat/array_connection_codes
Browse files Browse the repository at this point in the history
Changed connection codes to accept array
  • Loading branch information
elraphty authored Mar 7, 2024
2 parents 24d1e1f + 3ff85d6 commit 8635b75
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 38 deletions.
21 changes: 21 additions & 0 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,27 @@ func PubKeyContextSuperAdmin(next http.Handler) http.Handler {
})
}

// ConnectionContext parses token for connection code
func ConnectionCodeContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("token")

if token == "" {
fmt.Println("[auth] no token")
http.Error(w, http.StatusText(401), 401)
return
}

if token != config.Connection_Auth {
fmt.Println("Not a super admin : auth")
http.Error(w, http.StatusText(401), 401)
return
}
ctx := context.WithValue(r.Context(), ContextKey, token)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

func AdminCheck(pubkey string) bool {
for _, val := range config.SuperAdmins {
if val == pubkey {
Expand Down
2 changes: 2 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ var S3FolderName string
var S3Url string
var AdminCheck string
var AdminDevFreePass = "FREE_PASS"
var Connection_Auth string

var S3Client *s3.Client
var PresignClient *s3.PresignClient
Expand All @@ -51,6 +52,7 @@ func InitConfig() {
S3FolderName = os.Getenv("S3_FOLDER_NAME")
S3Url = os.Getenv("S3_URL")
AdminCheck = os.Getenv("ADMIN_CHECK")
Connection_Auth = os.Getenv("CONNECTION_AUTH")

// Add to super admins
SuperAdmins = StripSuperAdmins(AdminStrings)
Expand Down
10 changes: 6 additions & 4 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -1495,10 +1495,12 @@ func (db database) GetPeopleListShort(count uint32) *[]PersonInShort {
return &p
}

func (db database) CreateConnectionCode(c ConnectionCodes) (ConnectionCodes, error) {
if c.DateCreated == nil {
now := time.Now()
c.DateCreated = &now
func (db database) CreateConnectionCode(c []ConnectionCodes) ([]ConnectionCodes, error) {
now := time.Now()
for _, code := range c {
if code.DateCreated.IsZero() {
code.DateCreated = &now
}
}
db.db.Create(&c)
return c, nil
Expand Down
2 changes: 1 addition & 1 deletion db/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ type Database interface {
CountBounties() uint64
GetPeopleListShort(count uint32) *[]PersonInShort
GetConnectionCode() ConnectionCodesShort
CreateConnectionCode(c ConnectionCodes) (ConnectionCodes, error)
CreateConnectionCode(c []ConnectionCodes) ([]ConnectionCodes, error)
GetLnUser(lnKey string) int64
CreateLnUser(lnKey string) (Person, error)
GetBountiesLeaderboard() []LeaderData
Expand Down
20 changes: 13 additions & 7 deletions handlers/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"io"
"net/http"
"time"

"github.com/form3tech-oss/jwt-go"
"github.com/stakwork/sphinx-tribes/auth"
Expand Down Expand Up @@ -54,30 +53,37 @@ func (ah *authHandler) GetIsAdmin(w http.ResponseWriter, r *http.Request) {
}

func (ah *authHandler) CreateConnectionCode(w http.ResponseWriter, r *http.Request) {
code := db.ConnectionCodes{}
now := time.Now()
codeArr := []db.ConnectionCodes{}
codeStrArr := []string{}

body, err := io.ReadAll(r.Body)
r.Body.Close()

err = json.Unmarshal(body, &code)
err = json.Unmarshal(body, &codeStrArr)

code.IsUsed = false
code.DateCreated = &now
for _, code := range codeStrArr {
code := db.ConnectionCodes{
ConnectionString: code,
IsUsed: false,
}
codeArr = append(codeArr, code)
}

if err != nil {
fmt.Println(err)
w.WriteHeader(http.StatusNotAcceptable)
return
}

_, err = ah.db.CreateConnectionCode(code)
_, err = ah.db.CreateConnectionCode(codeArr)

if err != nil {
fmt.Println("=> ERR create connection code", err)
w.WriteHeader(http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode("Codes created successfully")
}

func (ah *authHandler) GetConnectionCode(w http.ResponseWriter, _ *http.Request) {
Expand Down
34 changes: 22 additions & 12 deletions handlers/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"github.com/stakwork/sphinx-tribes/db"
mocks "github.com/stakwork/sphinx-tribes/mocks"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

func TestGetAdminPubkeys(t *testing.T) {
Expand Down Expand Up @@ -52,16 +51,21 @@ func TestGetAdminPubkeys(t *testing.T) {
}

func TestCreateConnectionCode(t *testing.T) {

mockDb := mocks.NewDatabase(t)
aHandler := NewAuthHandler(mockDb)
t.Run("should create connection code successful", func(t *testing.T) {
codeToBeInserted := db.ConnectionCodes{
ConnectionString: "custom connection string",
codeToBeInserted := []string{"custom connection string", "custom connection string 2"}

codeArr := []db.ConnectionCodes{}
for _, code := range codeToBeInserted {
code := db.ConnectionCodes{
ConnectionString: code,
IsUsed: false,
}
codeArr = append(codeArr, code)
}
mockDb.On("CreateConnectionCode", mock.MatchedBy(func(code db.ConnectionCodes) bool {
return code.IsUsed == false && code.ConnectionString == codeToBeInserted.ConnectionString
})).Return(codeToBeInserted, nil).Once()

mockDb.On("CreateConnectionCode", codeArr).Return(codeArr, nil).Once()

body, _ := json.Marshal(codeToBeInserted)
req, err := http.NewRequest("POST", "/connectioncodes", bytes.NewBuffer(body))
Expand All @@ -77,12 +81,18 @@ func TestCreateConnectionCode(t *testing.T) {
})

t.Run("should return error if failed to add connection code", func(t *testing.T) {
codeToBeInserted := db.ConnectionCodes{
ConnectionString: "custom connection string",
codeToBeInserted := []string{"custom connection string", "custom connection string 2"}

codeArr := []db.ConnectionCodes{}
for _, code := range codeToBeInserted {
code := db.ConnectionCodes{
ConnectionString: code,
IsUsed: false,
}
codeArr = append(codeArr, code)
}
mockDb.On("CreateConnectionCode", mock.MatchedBy(func(code db.ConnectionCodes) bool {
return code.IsUsed == false && code.ConnectionString == codeToBeInserted.ConnectionString
})).Return(codeToBeInserted, errors.New("failed to create connection")).Once()

mockDb.On("CreateConnectionCode", codeArr).Return(codeArr, errors.New("failed to create connection")).Once()

body, _ := json.Marshal(codeToBeInserted)
req, err := http.NewRequest("POST", "/connectioncodes", bytes.NewBuffer(body))
Expand Down
24 changes: 13 additions & 11 deletions mocks/Database.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion routes/connection_codes.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package routes

import (
"github.com/go-chi/chi"
"github.com/stakwork/sphinx-tribes/auth"
"github.com/stakwork/sphinx-tribes/db"
"github.com/stakwork/sphinx-tribes/handlers"
)
Expand All @@ -10,8 +11,12 @@ func ConnectionCodesRoutes() chi.Router {
r := chi.NewRouter()
authHandler := handlers.NewAuthHandler(db.DB)
r.Group(func(r chi.Router) {
r.Post("/", authHandler.CreateConnectionCode)
r.Get("/", authHandler.GetConnectionCode)
})

r.Group(func(r chi.Router) {
r.Use(auth.ConnectionCodeContext)
r.Post("/", authHandler.CreateConnectionCode)
})
return r
}
3 changes: 1 addition & 2 deletions utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@ import (
)

func GetPaginationParams(r *http.Request) (int, int, string, string, string) {

// there are cases when the request is not passed in
if r == nil {
return 0, -1, "updated", "asc", ""
return 0, 1, "updated", "asc", ""
}

keys := r.URL.Query()
Expand Down

0 comments on commit 8635b75

Please sign in to comment.