diff --git a/db.go b/db.go index b8dcdcf..61655ef 100644 --- a/db.go +++ b/db.go @@ -1,6 +1,7 @@ package surrealdb import ( + "context" "encoding/json" "errors" "fmt" @@ -10,8 +11,9 @@ import ( const statusOK = "OK" var ( - InvalidResponse = errors.New("invalid SurrealDB response") - QueryError = errors.New("error occurred processing the SurrealDB query") + ErrInvalidResponse = errors.New("invalid SurrealDB response") + ErrQuery = errors.New("error occurred processing the SurrealDB query") + ErrInvalidLoginResponse = errors.New("invalid login response") ) // DB is a client for the SurrealDB database that holds are websocket connection. @@ -20,8 +22,8 @@ type DB struct { } // New Creates a new DB instance given a WebSocket URL. -func New(url string) (*DB, error) { - ws, err := NewWebsocket(url) +func New(ctx context.Context, url string) (*DB, error) { + ws, err := NewWebsocket(ctx, url) if err != nil { return nil, err } @@ -29,12 +31,15 @@ func New(url string) (*DB, error) { } // Unmarshal loads a SurrealDB response into a struct. -func Unmarshal(data interface{}, v interface{}) error { - var ok bool +func Unmarshal(data any, v any) error { + // TODO: make this function obsolete + // currently, we get JSON bytes from the connection, unmarshall them to a *any, marshall them back into + // JSON and then unmarshall them into the target struct + // This is cumbersome to use and expensive to run - assertedData, ok := data.([]interface{}) + assertedData, ok := data.([]any) if !ok { - return InvalidResponse + return ErrInvalidResponse } sliceFlag := isSlice(v) @@ -59,27 +64,27 @@ func Unmarshal(data interface{}, v interface{}) error { // UnmarshalRaw loads a raw SurrealQL response returned by Query into a struct. Queries that return with results will // return ok = true, and queries that return with no results will return ok = false. -func UnmarshalRaw(rawData interface{}, v interface{}) (ok bool, err error) { - var data []interface{} - if data, ok = rawData.([]interface{}); !ok { - return false, InvalidResponse +func UnmarshalRaw(rawData any, v any) (ok bool, err error) { + data, ok := rawData.([]any) + if !ok { + return false, ErrInvalidResponse } - var responseObj map[string]interface{} - if responseObj, ok = data[0].(map[string]interface{}); !ok { - return false, InvalidResponse + responseObj, ok := data[0].(map[string]any) + if !ok { + return false, ErrInvalidResponse } - var status string - if status, ok = responseObj["status"].(string); !ok { - return false, InvalidResponse + status, ok := responseObj["status"].(string) + if !ok { + return false, ErrInvalidResponse } if status != statusOK { - return false, QueryError + return false, ErrQuery } result := responseObj["result"] - if len(result.([]interface{})) == 0 { + if len(result.([]any)) == 0 { return false, nil } err = Unmarshal(result, v) @@ -95,86 +100,117 @@ func UnmarshalRaw(rawData interface{}, v interface{}) (ok bool, err error) { // -------------------------------------------------- // Close closes the underlying WebSocket connection. -func (self *DB) Close() { - _ = self.ws.Close() +func (db *DB) Close() error { + return db.ws.Close() } // -------------------------------------------------- // Use is a method to select the namespace and table to use. -func (self *DB) Use(ns string, db string) (interface{}, error) { - return self.send("use", ns, db) +func (db *DB) Use(ctx context.Context, ns string, dbname string) (any, error) { + return db.send(ctx, "use", ns, dbname) } -func (self *DB) Info() (interface{}, error) { - return self.send("info") +func (db *DB) Info(ctx context.Context) (any, error) { + return db.send(ctx, "info") } // Signup is a helper method for signing up a new user. -func (self *DB) Signup(vars interface{}) (interface{}, error) { - return self.send("signup", vars) +func (db *DB) Signup(ctx context.Context, vars any) (any, error) { + return db.send(ctx, "signup", vars) +} + +// SignupUser is a helper method for signing in a user and returning a typed response +func (db *DB) SignupUser(ctx context.Context, vars UserInfo) (*AuthenticationResult, error) { + authResult := &AuthenticationResult{Success: false} + result, err := db.send(ctx, "signup", vars) + if err != nil { + return authResult, err + } + + err = authResult.fromQuery(result) + + return authResult, err } // Signin is a helper method for signing in a user. -func (self *DB) Signin(vars interface{}) (interface{}, error) { - return self.send("signin", vars) +func (db *DB) Signin(ctx context.Context, vars UserInfo) (any, error) { + return db.send(ctx, "signin", vars) } -func (self *DB) Invalidate() (interface{}, error) { - return self.send("invalidate") +// SigninUser is a helper method for signing in a user and returning a typed response +// Note: This will probably fail when signing in as a root user, but for +// a regular user(via a scope for example) we get a JWT response +func (db *DB) SigninUser(ctx context.Context, vars UserInfo) (*AuthenticationResult, error) { + authResult := &AuthenticationResult{Success: false} + result, err := db.send(ctx, "signin", vars) + if err != nil { + return authResult, err + } + if err != nil { + return authResult, err + } + + err = authResult.fromQuery(result) + + return authResult, err +} + +func (db *DB) Invalidate(ctx context.Context) (any, error) { + return db.send(ctx, "invalidate") } -func (self *DB) Authenticate(token string) (interface{}, error) { - return self.send("authenticate", token) +func (db *DB) Authenticate(ctx context.Context, token string) (any, error) { + return db.send(ctx, "authenticate", token) } // -------------------------------------------------- -func (self *DB) Live(table string) (interface{}, error) { - return self.send("live", table) +func (db *DB) Live(ctx context.Context, table string) (any, error) { + return db.send(ctx, "live", table) } -func (self *DB) Kill(query string) (interface{}, error) { - return self.send("kill", query) +func (db *DB) Kill(ctx context.Context, query string) (any, error) { + return db.send(ctx, "kill", query) } -func (self *DB) Let(key string, val interface{}) (interface{}, error) { - return self.send("let", key, val) +func (db *DB) Let(ctx context.Context, key string, val any) (any, error) { + return db.send(ctx, "let", key, val) } // Query is a convenient method for sending a query to the database. -func (self *DB) Query(sql string, vars interface{}) (interface{}, error) { - return self.send("query", sql, vars) +func (db *DB) Query(ctx context.Context, sql string, vars any) (any, error) { + return db.send(ctx, "query", sql, vars) } // Select a table or record from the database. -func (self *DB) Select(what string) (interface{}, error) { - return self.send("select", what) +func (db *DB) Select(ctx context.Context, what string) (any, error) { + return db.send(ctx, "select", what) } // Creates a table or record in the database like a POST request. -func (self *DB) Create(thing string, data interface{}) (interface{}, error) { - return self.send("create", thing, data) +func (db *DB) Create(ctx context.Context, thing string, data any) (any, error) { + return db.send(ctx, "create", thing, data) } // Update a table or record in the database like a PUT request. -func (self *DB) Update(what string, data interface{}) (interface{}, error) { - return self.send("update", what, data) +func (db *DB) Update(ctx context.Context, what string, data any) (any, error) { + return db.send(ctx, "update", what, data) } // Change a table or record in the database like a PATCH request. -func (self *DB) Change(what string, data interface{}) (interface{}, error) { - return self.send("change", what, data) +func (db *DB) Change(ctx context.Context, what string, data any) (any, error) { + return db.send(ctx, "change", what, data) } // Modify applies a series of JSONPatches to a table or record. -func (self *DB) Modify(what string, data []Patch) (interface{}, error) { - return self.send("modify", what, data) +func (db *DB) Modify(ctx context.Context, what string, data []Patch) (any, error) { + return db.send(ctx, "modify", what, data) } // Delete a table or a row from the database like a DELETE request. -func (self *DB) Delete(what string) (interface{}, error) { - return self.send("delete", what) +func (db *DB) Delete(ctx context.Context, what string) (any, error) { + return db.send(ctx, "delete", what) } // -------------------------------------------------- @@ -182,44 +218,45 @@ func (self *DB) Delete(what string) (interface{}, error) { // -------------------------------------------------- // send is a helper method for sending a query to the database. -func (self *DB) send(method string, params ...interface{}) (interface{}, error) { +func (db *DB) send(ctx context.Context, method string, params ...any) (any, error) { // generate an id for the action, this is used to distinguish its response - id := xid(16) + id := xid() // chn: the channel where the server response will arrive, err: the channel where errors will come - chn, err := self.ws.Once(id, method) + chn := db.ws.Once(id, method) // here we send the args through our websocket connection - self.ws.Send(id, method, params) + db.ws.Send(id, method, params) + + select { + case <-ctx.Done(): + return nil, ctx.Err() - for { - select { + case r := <-chn: + if r.err != nil { + return nil, r.err + } + switch method { + case "delete": + return nil, nil + case "select": + return db.resp(method, params, r.value) + case "create": + return db.resp(method, params, r.value) + case "update": + return db.resp(method, params, r.value) + case "change": + return db.resp(method, params, r.value) + case "modify": + return db.resp(method, params, r.value) default: - case e := <-err: - return nil, e - case r := <-chn: - switch method { - case "delete": - return nil, nil - case "select": - return self.resp(method, params, r) - case "create": - return self.resp(method, params, r) - case "update": - return self.resp(method, params, r) - case "change": - return self.resp(method, params, r) - case "modify": - return self.resp(method, params, r) - default: - return r, nil - } + return r.value, nil } } } // resp is a helper method for parsing the response from a query. -func (self *DB) resp(_ string, params []interface{}, res interface{}) (interface{}, error) { +func (db *DB) resp(_ string, params []any, res any) (any, error) { arg, ok := params[0].(string) @@ -227,9 +264,10 @@ func (self *DB) resp(_ string, params []interface{}, res interface{}) (interface return res, nil } + // TODO: explian what that condition is for if strings.Contains(arg, ":") { - arr, ok := res.([]interface{}) + arr, ok := res.([]any) if !ok { return nil, PermissionError{what: arg} @@ -248,15 +286,10 @@ func (self *DB) resp(_ string, params []interface{}, res interface{}) (interface } func isSlice(possibleSlice interface{}) bool { - slice := false - - switch v := possibleSlice.(type) { - default: - res := fmt.Sprintf("%s", v) - if res == "[]" || res == "&[]" || res == "*[]" { - slice = true - } + res := fmt.Sprintf("%s", possibleSlice) + if res == "[]" || res == "&[]" || res == "*[]" { + return true } - return slice + return false } diff --git a/db_test.go b/db_test.go index 61f269a..ed9f00b 100644 --- a/db_test.go +++ b/db_test.go @@ -1,9 +1,13 @@ package surrealdb_test import ( + "context" "fmt" - "github.com/surrealdb/surrealdb.go" + "log" "testing" + + "github.com/surrealdb/surrealdb.go" + "github.com/test-go/testify/suite" ) // a simple user struct for testing @@ -15,7 +19,10 @@ type testUser struct { // an example test for creating a new entry in surrealdb func ExampleNew() { - db, err := surrealdb.New("ws://localhost:8000/rpc") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + db, err := surrealdb.New(ctx, "ws://localhost:8000/rpc") if err != nil { panic(err) @@ -27,28 +34,31 @@ func ExampleNew() { } func ExampleDB_Delete() { - db, err := surrealdb.New("ws://localhost:8000/rpc") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + db, err := surrealdb.New(ctx, "ws://localhost:8000/rpc") if err != nil { panic(err) } defer db.Close() - _, err = db.Signin(map[string]interface{}{ - "user": "root", - "pass": "root", + _, err = db.Signin(ctx, surrealdb.UserInfo{ + User: "root", + Password: "root", }) if err != nil { panic(err) } - _, err = db.Use("test", "test") + _, err = db.Use(ctx, "test", "test") if err != nil { panic(err) } - userData, err := db.Create("users", testUser{ + userData, err := db.Create(ctx, "users", testUser{ Username: "johnny", Password: "123", }) @@ -61,7 +71,7 @@ func ExampleDB_Delete() { } // Delete the users... - _, err = db.Delete("users") + _, err = db.Delete(ctx, "users") if err != nil { panic(err) @@ -71,7 +81,10 @@ func ExampleDB_Delete() { } func ExampleDB_Create() { - db, err := surrealdb.New("ws://localhost:8000/rpc") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + db, err := surrealdb.New(ctx, "ws://localhost:8000/rpc") if err != nil { panic(err) @@ -79,22 +92,22 @@ func ExampleDB_Create() { defer db.Close() - signin, err := db.Signin(map[string]interface{}{ - "user": "root", - "pass": "root", + signin, err := db.Signin(ctx, surrealdb.UserInfo{ + User: "root", + Password: "root", }) if err != nil { panic(err) } - _, err = db.Use("test", "test") + _, err = db.Use(ctx, "test", "test") if err != nil || signin == nil { panic(err) } - userMap, err := db.Create("users", map[string]interface{}{ + userMap, err := db.Create(ctx, "users", map[string]any{ "username": "john", "password": "123", }) @@ -103,7 +116,7 @@ func ExampleDB_Create() { panic(err) } - userData, err := db.Create("users", testUser{ + userData, err := db.Create(ctx, "users", testUser{ Username: "johnny", Password: "123", }) @@ -120,37 +133,41 @@ func ExampleDB_Create() { } func ExampleDB_Select() { - db, err := surrealdb.New("ws://localhost:8000/rpc") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + db, err := surrealdb.New(ctx, "ws://localhost:8000/rpc") if err != nil { panic(err) } defer db.Close() - _, err = db.Signin(map[string]interface{}{ - "user": "root", - "pass": "root", + _, err = db.Signin(ctx, surrealdb.UserInfo{ + User: "root", + Password: "root", }) if err != nil { panic(err) } - _, err = db.Use("test", "test") + _, err = db.Use(ctx, "test", "test") if err != nil { panic(err) } - _, err = db.Create("users", testUser{ + _, err = db.Create(ctx, "users", testUser{ Username: "johnnyjohn", Password: "123", }) - userData, err := db.Select("users") + userData, err := db.Select(ctx, "users") // unmarshal the data into a user slice var users []testUser + log.Print(userData) err = surrealdb.Unmarshal(userData, &users) if err != nil { panic(err) @@ -166,28 +183,31 @@ func ExampleDB_Select() { } func ExampleDB_Update() { - db, err := surrealdb.New("ws://localhost:8000/rpc") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + db, err := surrealdb.New(ctx, "ws://localhost:8000/rpc") if err != nil { panic(err) } defer db.Close() - _, err = db.Signin(map[string]interface{}{ - "user": "root", - "pass": "root", + _, err = db.Signin(ctx, surrealdb.UserInfo{ + User: "root", + Password: "root", }) if err != nil { panic(err) } - _, err = db.Use("test", "test") + _, err = db.Use(ctx, "test", "test") if err != nil { panic(err) } - userData, err := db.Create("users", testUser{ + userData, err := db.Create(ctx, "users", testUser{ Username: "johnny", Password: "123", }) @@ -202,7 +222,7 @@ func ExampleDB_Update() { user.Password = "456" // Update the user - userData, err = db.Update("users", &user) + userData, err = db.Update(ctx, "users", &user) if err != nil { panic(err) @@ -223,28 +243,31 @@ func ExampleDB_Update() { } func TestUnmarshalRaw(t *testing.T) { - db, err := surrealdb.New("ws://localhost:8000/rpc") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + db, err := surrealdb.New(ctx, "ws://localhost:8000/rpc") if err != nil { panic(err) } defer db.Close() - _, err = db.Signin(map[string]interface{}{ - "user": "root", - "pass": "root", + _, err = db.Signin(ctx, surrealdb.UserInfo{ + User: "root", + Password: "root", }) if err != nil { panic(err) } - _, err = db.Use("test", "test") + _, err = db.Use(ctx, "test", "test") if err != nil { panic(err) } - _, err = db.Delete("users") + _, err = db.Delete(ctx, "users") if err != nil { panic(err) } @@ -252,9 +275,9 @@ func TestUnmarshalRaw(t *testing.T) { username := "johnny" password := "123" - //create test user with raw SurrealQL and unmarshal + // create test user with raw SurrealQL and unmarshal - userData, err := db.Query("create users:johnny set Username = $user, Password = $pass", map[string]interface{}{ + userData, err := db.Query(ctx, "create users:johnny set Username = $user, Password = $pass", map[string]any{ "user": username, "pass": password, }) @@ -271,9 +294,9 @@ func TestUnmarshalRaw(t *testing.T) { panic("response does not match the request") } - //send query with empty result and unmarshal + // send query with empty result and unmarshal - userData, err = db.Query("select * from users where id = $id", map[string]interface{}{ + userData, err = db.Query(ctx, "select * from users where id = $id", map[string]any{ "id": "users:jim", }) if err != nil { @@ -292,19 +315,22 @@ func TestUnmarshalRaw(t *testing.T) { } func ExampleDB_Modify() { - db, err := surrealdb.New("ws://localhost:8000/rpc") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + db, err := surrealdb.New(ctx, "ws://localhost:8000/rpc") if err != nil { panic(err) } defer db.Close() - _, err = db.Signin(map[string]interface{}{ - "user": "root", - "pass": "root", + _, err = db.Signin(ctx, surrealdb.UserInfo{ + User: "root", + Password: "root", }) - _, err = db.Use("test", "test") + _, err = db.Use(ctx, "test", "test") - _, err = db.Create("users:999", map[string]interface{}{ + _, err = db.Create(ctx, "users:999", map[string]any{ "username": "john999", "password": "123", }) @@ -318,18 +344,92 @@ func ExampleDB_Modify() { } // Update the user - _, err = db.Modify("users:999", patches) + _, err = db.Modify(ctx, "users:999", patches) if err != nil { panic(err) } - user2, err := db.Select("users:999") + user2, err := db.Select(ctx, "users:999") if err != nil { panic(err) } // // TODO: this needs to simplified for the end user somehow - fmt.Println((user2).(map[string]interface{})["age"]) + fmt.Println((user2).(map[string]any)["age"]) // // Output: 44 } + +type TestDatabaseTestSuite struct { + suite.Suite + ctx context.Context + db *surrealdb.DB +} + +func TestDatabaseSuite(t *testing.T) { + suite.Run(t, new(TestDatabaseTestSuite)) +} + +func (suite *TestDatabaseTestSuite) SetupTest() { + ctx := context.Background() + + rpcUrl := surrealdb.GetEnvOrDefault("SURREALDB_RPC_URL", "ws://localhost:8000/rpc") + user := surrealdb.GetEnvOrDefault("SURREALDB_USER", "root") + pass := surrealdb.GetEnvOrDefault("SURREALDB_PASS", "root") + + db, err := surrealdb.New(ctx, rpcUrl) + suite.Require().NoError(err) + + _, err = db.Signin(ctx, surrealdb.UserInfo{ + User: user, + Password: pass, + }) + suite.Require().NoError(err) + + _, err = db.Use(ctx, "test", "test") + suite.Require().NoError(err) + + suite.db = db + suite.ctx = ctx +} + +func (suite *TestDatabaseTestSuite) TearDownSuite() { + suite.db.Close() +} + +func (suite *TestDatabaseTestSuite) Test_FailingUserSignin() { + // NOTE: this query fails for some reason but works when I run it manually... + // DEFINE SCOPE test_account_scope + // SIGNIN ( SELECT * FROM user WHERE username = $user AND crypto::argon2::compare(password, $pass) ) + // SIGNUP ( CREATE user SET username = $user, password = crypto::argon2::generate($pass) ) + // ; + // result, err := suite.db.Query(suite.ctx, scopeQuery, map[string]any{}) + // suite.Require().NoError(err) + // suite.Require().NotNil(result) + + authResult, err := suite.db.SigninUser(suite.ctx, surrealdb.UserInfo{ + User: "test_username", + Password: "test_password", + Namespace: "test_account_scope", + Database: "test", + Scope: "test", + }) + + suite.Require().Error(err) + suite.Require().NotNil(authResult) + suite.Require().False(authResult.Success) + + authResult, err = suite.db.SignupUser(suite.ctx, surrealdb.UserInfo{ + User: "test_username", + Password: "test_password", + Namespace: "test", + Database: "test", + Scope: "test_account_scope", + }) + suite.Require().NoError(err) + suite.Require().NotNil(authResult) + suite.Require().True(authResult.Success) + suite.Require().NotZero(authResult.Token) + suite.Require().NotZero(authResult.TokenData) + suite.Require().Equal(authResult.TokenData.Scope, "test_account_scope") +} diff --git a/env_util.go b/env_util.go new file mode 100644 index 0000000..43ef1b7 --- /dev/null +++ b/env_util.go @@ -0,0 +1,12 @@ +package surrealdb + +import "os" + +func GetEnvOrDefault(key, defaultValue string) string { + value := os.Getenv(key) + if value == "" { + return defaultValue + } + + return value +} diff --git a/err.go b/err.go index 01c40a8..01e1ba0 100644 --- a/err.go +++ b/err.go @@ -8,6 +8,6 @@ type PermissionError struct { what string } -func (self PermissionError) Error() string { - return fmt.Sprint("Unable to access record:", self.what) +func (pe PermissionError) Error() string { + return fmt.Sprint("Unable to access record:", pe.what) } diff --git a/go.mod b/go.mod index 44e4fcc..293cf30 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,14 @@ module github.com/surrealdb/surrealdb.go go 1.18 -require github.com/gorilla/websocket v1.5.0 +require ( + github.com/gorilla/websocket v1.5.0 + github.com/test-go/testify v1.1.4 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/testify v1.8.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index e5a03d4..16dc1a2 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,19 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/test-go/testify v1.1.4 h1:Tf9lntrKUMHiXQ07qBScBTSA0dhYQlu83hswqelv1iE= +github.com/test-go/testify v1.1.4/go.mod h1:rH7cfJo/47vWGdi4GPj16x3/t1xGOj2YxzmNQzk2ghU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/id.go b/id.go index d30c5ed..46d0200 100644 --- a/id.go +++ b/id.go @@ -1,18 +1,14 @@ package surrealdb import ( - "fmt" - "math/rand" - "time" + "strconv" + "sync/atomic" ) -func xid(length int) string { - // Generate a new seed - rand.Seed(time.Now().UnixNano()) - // Create a random byte slice - b := make([]byte, length) - // Fill the byte slice with data - rand.Read(b) - // Return the byte slice as a string - return fmt.Sprintf("%x", b)[:length] +var _currentid uint64 + +// generate an incrementing id for uniqueness purposes +func xid() string { + id := atomic.AddUint64(&_currentid, 1) + return strconv.FormatUint(id, 16) } diff --git a/rpc.go b/rpc.go index d56ffc0..59904ae 100644 --- a/rpc.go +++ b/rpc.go @@ -12,22 +12,22 @@ func (r *RPCError) Error() string { // RPCRequest represents an incoming JSON-RPC request type RPCRequest struct { - ID interface{} `json:"id" msgpack:"id"` + ID any `json:"id" msgpack:"id"` Async bool `json:"async,omitempty" msgpack:"async,omitempty"` Method string `json:"method,omitempty" msgpack:"method,omitempty"` - Params []interface{} `json:"params,omitempty" msgpack:"params,omitempty"` + Params []any `json:"params,omitempty" msgpack:"params,omitempty"` } // RPCResponse represents an outgoing JSON-RPC response type RPCResponse struct { - ID interface{} `json:"id" msgpack:"id"` + ID any `json:"id" msgpack:"id"` Error *RPCError `json:"error,omitempty" msgpack:"error,omitempty"` - Result interface{} `json:"result,omitempty" msgpack:"result,omitempty"` + Result any `json:"result,omitempty" msgpack:"result,omitempty"` } // RPCNotification represents an outgoing JSON-RPC notification type RPCNotification struct { - ID interface{} `json:"id" msgpack:"id"` + ID any `json:"id" msgpack:"id"` Method string `json:"method,omitempty" msgpack:"method,omitempty"` - Params []interface{} `json:"params,omitempty" msgpack:"params,omitempty"` + Params []any `json:"params,omitempty" msgpack:"params,omitempty"` } diff --git a/types.go b/types.go index a2f3e1d..97eda38 100644 --- a/types.go +++ b/types.go @@ -1,8 +1,93 @@ package surrealdb +import ( + "encoding/base64" + "encoding/json" + "errors" + "strings" +) + +var ( + ErrInvalidToken = errors.New("token string is invalid") +) + // Patch represents a patch object set to MODIFY a record type Patch struct { Op string `json:"op"` Path string `json:"path"` Value any `json:"value"` } + +// UserInfo TODO: A way to make User and Password use different names via configuration +// This method only works if your scope is configured with those namings also, otherwise auth will fail +type UserInfo struct { + User string `json:"user"` + Password string `json:"pass"` + Namespace string `json:"NS,omitempty"` + Database string `json:"DB,omitempty"` + Scope string `json:"SC,omitempty"` +} + +type AuthenticationResult struct { + Success bool `json:"success"` + Token string `json:"token"` + + TokenData +} + +func (data *AuthenticationResult) fromQuery(result any) error { + if result == nil || result == "" { + return ErrInvalidLoginResponse + } + if _, ok := result.(string); !ok { + return ErrInvalidLoginResponse + } + tokenData, err := TokenData{}.FromToken(result.(string)) + if err != nil { + return err + } + + data.Success = true + data.Token = result.(string) + data.TokenData = tokenData + + return nil +} + +type TokenData struct { + IssuedAt int `json:"iat"` + NotBefore int `json:"nbf"` + ExpiresAt int `json:"exp"` + Issuer string `json:"iss"` + Namespace string `json:"ns"` + Database string `json:"db"` + Scope string `json:"sc"` + Id string `json:"id"` +} + +func (token TokenData) FromToken(tokenString string) (TokenData, error) { + data := TokenData{} + + if tokenString == "" { + return data, ErrInvalidToken + } + + segments := strings.Split(tokenString, ".") + if len(segments) != 3 { + return data, ErrInvalidToken + } + + // Decode the payload + payload, err := base64.RawStdEncoding.DecodeString(segments[1]) + if err != nil { + return data, err + } + + // Unmarshal the payload + err = json.Unmarshal(payload, &data) + if err != nil { + return data, err + } + + return data, nil +} diff --git a/types_test.go b/types_test.go new file mode 100644 index 0000000..c2fbcd5 --- /dev/null +++ b/types_test.go @@ -0,0 +1,62 @@ +package surrealdb_test + +import ( + "encoding/base64" + "fmt" + "testing" + + "github.com/surrealdb/surrealdb.go" + "github.com/test-go/testify/suite" +) + +type TestTypesTestSuite struct { + suite.Suite +} + +func TestTypesSuite(t *testing.T) { + suite.Run(t, new(TestTypesTestSuite)) +} + +func (suite *TestTypesTestSuite) SetupTest() { + +} + +func (suite *TestTypesTestSuite) TearDownSuite() { + +} + +func (suite *TestTypesTestSuite) Test_InvalidTokenString() { + tokenData, err := surrealdb.TokenData{}.FromToken("") + suite.Require().Error(err, surrealdb.ErrInvalidToken) + suite.Require().Zero(tokenData) +} +func (suite *TestTypesTestSuite) Test_InvalidTokenSegments() { + tokenData, err := surrealdb.TokenData{}.FromToken("ffff.bbbb") + suite.Require().Error(err, surrealdb.ErrInvalidToken) + suite.Require().Zero(tokenData) +} + +func (suite *TestTypesTestSuite) Test_InvalidTokenBase64() { + tokenData, err := surrealdb.TokenData{}.FromToken("ffff.bbbb.xxx") + suite.Require().Error(err, surrealdb.ErrInvalidToken) + suite.Require().Zero(tokenData) +} + +func (suite *TestTypesTestSuite) Test_InvalidTokenJson() { + invalid := base64.StdEncoding.EncodeToString([]byte("{pls, fail}")) + tokenData, err := surrealdb.TokenData{}.FromToken(fmt.Sprintf("ffff.%s.xxx", invalid)) + suite.Require().Error(err, surrealdb.ErrInvalidToken) + suite.Require().Zero(tokenData) +} + +func (suite *TestTypesTestSuite) Test_ValidToken() { + tokenData, err := surrealdb.TokenData{}.FromToken("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzUxMiJ9.eyJpYXQiOjE2NjQwNjM5NzYsIm5iZiI6MTY2NDA2Mzk3NiwiZXhwIjoxNjY0MDY3NTc2LCJpc3MiOiJTdXJyZWFsREIiLCJucyI6InRlc3QiLCJkYiI6ImFwcGxpY2F0aW9uIiwic2MiOiJhY2NvdW50IiwiaWQiOiJ1c2VyOnl3amVhaG44Y3Y0a3JjeDI5a201In0.Fw5v2vBShVnMnNBKuuOjC24HWrMrBKMeADNgkEcebmjAJzpRIxPEEn5Ehr_a70Jnsl5xi7vl4-r5M2QxfODLZw") + suite.Require().NoError(err, surrealdb.ErrInvalidToken) + suite.Require().NotZero(tokenData) + + suite.Require().Equal("test", tokenData.Namespace) + suite.Require().Equal("application", tokenData.Database) + suite.Require().Equal("account", tokenData.Scope) + suite.Require().Equal("user:ywjeahn8cv4krcx29km5", tokenData.Id) + suite.Require().Equal("SurrealDB", tokenData.Issuer) +} diff --git a/ws.go b/ws.go index 505c007..633c66e 100644 --- a/ws.go +++ b/ws.go @@ -1,6 +1,7 @@ package surrealdb import ( + "context" "encoding/json" "sync" @@ -8,18 +9,21 @@ import ( ) type WS struct { - ws *websocket.Conn // websocket connection - quit chan error // stops: MAIN LOOP - send chan<- *RPCRequest // sender channel + ws *websocket.Conn // websocket connection + send chan<- *RPCRequest // sender channel recv <-chan *RPCResponse // receive channel emit struct { + // TODO: use the lock less, through smaller locks (separate once/when locks ?) + // or ideally by removing locks altogether lock sync.Mutex // pause threads to avoid conflicts - once map[interface{}][]func(error, interface{}) // once listeners - when map[interface{}][]func(error, interface{}) // when listeners + + // do the callbacks really need to be a list ? + once map[any][]func(error, any) // once listeners + when map[any][]func(error, any) // when listeners } } -func NewWebsocket(url string) (*WS, error) { +func NewWebsocket(ctx context.Context, url string) (*WS, error) { dialer := websocket.DefaultDialer dialer.EnableCompression = true @@ -30,8 +34,13 @@ func NewWebsocket(url string) (*WS, error) { } ws := &WS{ws: so} + + // initilialize the callback maps here so we don't need to check them at runtime + ws.emit.once = make(map[any][]func(error, any)) + ws.emit.when = make(map[any][]func(error, any)) + // setup loops and channels - ws.initialise() + ws.initialise(ctx) return ws, nil @@ -41,17 +50,17 @@ func NewWebsocket(url string) (*WS, error) { // Public methods // -------------------------------------------------- -func (self *WS) Close() error { +func (ws *WS) Close() error { msg := websocket.FormatCloseMessage(1000, "") - return self.ws.WriteMessage(websocket.CloseMessage, msg) + return ws.ws.WriteMessage(websocket.CloseMessage, msg) } -func (self *WS) Send(id string, method string, params []interface{}) { +func (ws *WS) Send(id string, method string, params []any) { go func() { - self.send <- &RPCRequest{ + ws.send <- &RPCRequest{ ID: id, Method: method, Params: params, @@ -60,45 +69,42 @@ func (self *WS) Send(id string, method string, params []interface{}) { } +type responseValue struct { + value any + err error +} + // Subscribe to once() -func (self *WS) Once(id, method string) (<-chan interface{}, <-chan error) { - - err := make(chan error) - res := make(chan interface{}) - - self.once(id, func(e error, r interface{}) { - switch { - case e != nil: - err <- e - close(err) - close(res) - case e == nil: - res <- r - close(err) - close(res) +func (ws *WS) Once(id, method string) <-chan responseValue { + + out := make(chan responseValue) + + ws.once(id, func(e error, r any) { + out <- responseValue{ + value: r, + err: e, } + close(out) }) - return res, err + return out } // Subscribe to when() -func (self *WS) When(id, method string) (<-chan interface{}, <-chan error) { +func (ws *WS) When(id, method string) <-chan responseValue { + // TODO: make this cancellable (use of context.Context ?) - err := make(chan error) - res := make(chan interface{}) + out := make(chan responseValue) - self.when(id, func(e error, r interface{}) { - switch { - case e != nil: - err <- e - case e == nil: - res <- r + ws.when(id, func(e error, r any) { + out <- responseValue{ + value: r, + err: e, } }) - return res, err + return out } @@ -106,86 +112,78 @@ func (self *WS) When(id, method string) (<-chan interface{}, <-chan error) { // Private methods // -------------------------------------------------- -func (self *WS) once(id interface{}, fn func(error, interface{})) { +func (ws *WS) once(id any, fn func(error, any)) { // pauses traffic in others threads, so we can add the new listener without conflicts - self.emit.lock.Lock() - defer self.emit.lock.Unlock() - // if its our first listener, we need to setup the map - if self.emit.once == nil { - self.emit.once = make(map[interface{}][]func(error, interface{})) - } + ws.emit.lock.Lock() + defer ws.emit.lock.Unlock() - self.emit.once[id] = append(self.emit.once[id], fn) + ws.emit.once[id] = append(ws.emit.once[id], fn) } // WHEN SYSTEM ISN'T BEEING USED, MAYBE FOR FUTURE IN-DATABASE EVENTS AND/OR REAL TIME stuffs. -func (self *WS) when(id interface{}, fn func(error, interface{})) { +func (ws *WS) when(id any, fn func(error, any)) { // pauses traffic in others threads, so we can add the new listener without conflicts - self.emit.lock.Lock() - defer self.emit.lock.Unlock() - - // if its our first listener, we need to setup the map - if self.emit.when == nil { - self.emit.when = make(map[interface{}][]func(error, interface{})) - } + ws.emit.lock.Lock() + defer ws.emit.lock.Unlock() - self.emit.when[id] = append(self.emit.when[id], fn) + ws.emit.when[id] = append(ws.emit.when[id], fn) } -func (self *WS) done(id interface{}, err error, res interface{}) { +func (ws *WS) done(id any, err error, res any) { // pauses traffic in others threads, so we can modify listeners without conflicts - self.emit.lock.Lock() - defer self.emit.lock.Unlock() + ws.emit.lock.Lock() + defer ws.emit.lock.Unlock() // if our events map exist - if self.emit.when != nil { + if ws.emit.when != nil { // if theres some listener aiming to this id response - if _, ok := self.emit.when[id]; ok { + if when, ok := ws.emit.when[id]; ok { // dispatch the event, starting from the end, so we prioritize the new ones - for i := len(self.emit.when[id]) - 1; i >= 0; i-- { + for i := len(when) - 1; i >= 0; i-- { // invoke callback - self.emit.when[id][i](err, res) + when[i](err, res) } } } // if our events map exist - if self.emit.once != nil { + if ws.emit.once != nil { // if theres some listener aiming to this id response - if _, ok := self.emit.once[id]; ok { + if once, ok := ws.emit.once[id]; ok { // dispatch the event, starting from the end, so we prioritize the new ones - for i := len(self.emit.once[id]) - 1; i >= 0; i-- { + for i := len(once) - 1; i >= 0; i-- { // invoke callback - self.emit.once[id][i](err, res) + once[i](err, res) // erase this listener - self.emit.once[id][i] = nil + once[i] = nil - // remove this listener from the list - self.emit.once[id] = self.emit.once[id][:i] } + + // remove all listeners + ws.emit.once[id] = once[0:] } } } -func (self *WS) read(v interface{}) (err error) { +func (ws *WS) read(v any) (err error) { - _, r, err := self.ws.NextReader() + _, r, err := ws.ws.NextReader() if err != nil { return err } @@ -194,14 +192,17 @@ func (self *WS) read(v interface{}) (err error) { } -func (self *WS) write(v interface{}) (err error) { +func (ws *WS) write(v any) (err error) { - w, err := self.ws.NextWriter(websocket.TextMessage) + w, err := ws.ws.NextWriter(websocket.TextMessage) if err != nil { return err } - err = json.NewEncoder(w).Encode(v) + enc := json.NewEncoder(w) + // the default HTML escaping messes with select arrows + enc.SetEscapeHTML(false) + err = enc.Encode(v) if err != nil { return err } @@ -210,30 +211,26 @@ func (self *WS) write(v interface{}) (err error) { } -func (self *WS) initialise() { +func (ws *WS) initialise(ctx context.Context) { send := make(chan *RPCRequest) recv := make(chan *RPCResponse) - quit := make(chan error, 1) // stops: MAIN LOOP - exit := make(chan int, 1) // stops: RECEIVER LOOP, SENDER LOOP - + ctx, cancel := context.WithCancel(ctx) // RECEIVER LOOP go func() { - loop: for { select { - case <-exit: - break loop // stops: THIS LOOP + case <-ctx.Done(): + return default: var res RPCResponse - err := self.read(&res) // wait and unmarshal UPCOMING response + err := ws.read(&res) // wait and unmarshal UPCOMING response if err != nil { - self.Close() - quit <- err // stops: MAIN LOOP - exit <- 0 // stops: RECEIVER LOOP, SENDER LOOP - break loop // stops: THIS LOOP + ws.Close() + cancel() + return } recv <- &res // redirect response to: MAIN LOOP @@ -245,20 +242,18 @@ func (self *WS) initialise() { // SENDER LOOP go func() { - loop: for { select { - case <-exit: - break loop // stops: THIS LOOP + case <-ctx.Done(): + return // stops: THIS LOOP case res := <-send: - err := self.write(res) // marshal and send + err := ws.write(res) // marshal and send if err != nil { - self.Close() - quit <- err // stops: MAIN LOOP - exit <- 0 // stops: RECEIVER LOOP, SENDER LOOP - break loop // stops: THIS LOOP + ws.Close() + cancel() + return // stops: THIS LOOP } } @@ -270,20 +265,19 @@ func (self *WS) initialise() { go func() { for { select { - case <-self.quit: - break - case res := <-self.recv: + case <-ctx.Done(): + return + case res := <-ws.recv: switch { case res.Error == nil: - self.done(res.ID, nil, res.Result) + ws.done(res.ID, nil, res.Result) case res.Error != nil: - self.done(res.ID, res.Error, res.Result) + ws.done(res.ID, res.Error, res.Result) } } } }() - self.send = send - self.recv = recv - self.quit = quit // stops: MAIN LOOP + ws.send = send + ws.recv = recv }