diff --git a/cli/p2p_info.go b/cli/p2p_info.go index 36adfb8fac..cc30d37701 100644 --- a/cli/p2p_info.go +++ b/cli/p2p_info.go @@ -12,8 +12,6 @@ package cli import ( "github.com/spf13/cobra" - - "github.com/sourcenetwork/defradb/http" ) func MakeP2PInfoCommand() *cobra.Command { @@ -22,7 +20,7 @@ func MakeP2PInfoCommand() *cobra.Command { Short: "Get peer info from a DefraDB node", Long: `Get peer info from a DefraDB node`, RunE: func(cmd *cobra.Command, args []string) error { - db := cmd.Context().Value(dbContextKey).(*http.Client) + db := mustGetContextHTTP(cmd) return writeJSON(cmd, db.PeerInfo()) }, } diff --git a/cli/purge.go b/cli/purge.go index 5880e021b8..0e2552b625 100644 --- a/cli/purge.go +++ b/cli/purge.go @@ -12,8 +12,6 @@ package cli import ( "github.com/spf13/cobra" - - "github.com/sourcenetwork/defradb/http" ) func MakePurgeCommand() *cobra.Command { @@ -24,7 +22,7 @@ func MakePurgeCommand() *cobra.Command { Long: `Delete all persisted data and restart. WARNING this operation cannot be reversed.`, RunE: func(cmd *cobra.Command, args []string) error { - db := mustGetContextDB(cmd).(*http.Client) + db := mustGetContextHTTP(cmd) if !force { return ErrPurgeForceFlagRequired } diff --git a/cli/utils.go b/cli/utils.go index 845cea671b..fb9b5a6d3f 100644 --- a/cli/utils.go +++ b/cli/utils.go @@ -60,35 +60,42 @@ const ( // // If a db is not set in the current context this function panics. func mustGetContextDB(cmd *cobra.Command) client.DB { - return cmd.Context().Value(dbContextKey).(client.DB) + return cmd.Context().Value(dbContextKey).(client.DB) //nolint:forcetypeassert } // mustGetContextStore returns the store for the current command context. // // If a store is not set in the current context this function panics. func mustGetContextStore(cmd *cobra.Command) client.Store { - return cmd.Context().Value(dbContextKey).(client.Store) + return cmd.Context().Value(dbContextKey).(client.Store) //nolint:forcetypeassert } // mustGetContextP2P returns the p2p implementation for the current command context. // // If a p2p implementation is not set in the current context this function panics. func mustGetContextP2P(cmd *cobra.Command) client.P2P { - return cmd.Context().Value(dbContextKey).(client.P2P) + return cmd.Context().Value(dbContextKey).(client.P2P) //nolint:forcetypeassert +} + +// mustGetContextHTTP returns the http client for the current command context. +// +// If http client is not set in the current context this function panics. +func mustGetContextHTTP(cmd *cobra.Command) *http.Client { + return cmd.Context().Value(dbContextKey).(*http.Client) //nolint:forcetypeassert } // mustGetContextConfig returns the config for the current command context. // // If a config is not set in the current context this function panics. func mustGetContextConfig(cmd *cobra.Command) *viper.Viper { - return cmd.Context().Value(cfgContextKey).(*viper.Viper) + return cmd.Context().Value(cfgContextKey).(*viper.Viper) //nolint:forcetypeassert } // mustGetContextRootDir returns the rootdir for the current command context. // // If a rootdir is not set in the current context this function panics. func mustGetContextRootDir(cmd *cobra.Command) string { - return cmd.Context().Value(rootDirContextKey).(string) + return cmd.Context().Value(rootDirContextKey).(string) //nolint:forcetypeassert } // tryGetContextCollection returns the collection for the current command context diff --git a/http/handler.go b/http/handler.go index cdb09767c6..336dfc54d3 100644 --- a/http/handler.go +++ b/http/handler.go @@ -12,7 +12,6 @@ package http import ( "context" - "fmt" "net/http" "sync" @@ -100,9 +99,10 @@ func NewHandler(db client.DB) (*Handler, error) { func (h *Handler) Transaction(id uint64) (datastore.Txn, error) { tx, ok := h.txs.Load(id) if !ok { - return nil, fmt.Errorf("invalid transaction id") + return nil, ErrInvalidTransactionId } - return tx.(datastore.Txn), nil + + return mustGetDataStoreTxn(tx), nil } func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { diff --git a/http/handler_acp.go b/http/handler_acp.go index d359d5085e..f9ef17cbee 100644 --- a/http/handler_acp.go +++ b/http/handler_acp.go @@ -15,18 +15,12 @@ import ( "net/http" "github.com/getkin/kin-openapi/openapi3" - - "github.com/sourcenetwork/defradb/client" ) type acpHandler struct{} func (s *acpHandler) AddPolicy(rw http.ResponseWriter, req *http.Request) { - db, ok := req.Context().Value(dbContextKey).(client.DB) - if !ok { - responseJSON(rw, http.StatusBadRequest, errorResponse{NewErrFailedToGetContext("db")}) - return - } + db := mustGetContextClientDB(req) policyBytes, err := io.ReadAll(req.Body) if err != nil { @@ -47,11 +41,7 @@ func (s *acpHandler) AddPolicy(rw http.ResponseWriter, req *http.Request) { } func (s *acpHandler) AddDocActorRelationship(rw http.ResponseWriter, req *http.Request) { - db, ok := req.Context().Value(dbContextKey).(client.DB) - if !ok { - responseJSON(rw, http.StatusBadRequest, errorResponse{NewErrFailedToGetContext("db")}) - return - } + db := mustGetContextClientDB(req) var message addDocActorRelationshipRequest err := requestJSON(req, &message) @@ -76,11 +66,7 @@ func (s *acpHandler) AddDocActorRelationship(rw http.ResponseWriter, req *http.R } func (s *acpHandler) DeleteDocActorRelationship(rw http.ResponseWriter, req *http.Request) { - db, ok := req.Context().Value(dbContextKey).(client.DB) - if !ok { - responseJSON(rw, http.StatusBadRequest, errorResponse{NewErrFailedToGetContext("db")}) - return - } + db := mustGetContextClientDB(req) var message deleteDocActorRelationshipRequest err := requestJSON(req, &message) diff --git a/http/handler_ccip.go b/http/handler_ccip.go index 5b9aeb5402..f4855d69d7 100644 --- a/http/handler_ccip.go +++ b/http/handler_ccip.go @@ -18,8 +18,6 @@ import ( "github.com/getkin/kin-openapi/openapi3" "github.com/go-chi/chi/v5" - - "github.com/sourcenetwork/defradb/client" ) type ccipHandler struct{} @@ -35,7 +33,7 @@ type CCIPResponse struct { // ExecCCIP handles GraphQL over Cross Chain Interoperability Protocol requests. func (c *ccipHandler) ExecCCIP(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(dbContextKey).(client.Store) + store := mustGetContextClientStore(req) var ccipReq CCIPRequest switch req.Method { diff --git a/http/handler_collection.go b/http/handler_collection.go index 8f45a7948f..ddade699e3 100644 --- a/http/handler_collection.go +++ b/http/handler_collection.go @@ -40,7 +40,7 @@ type CollectionUpdateRequest struct { } func (s *collectionHandler) Create(rw http.ResponseWriter, req *http.Request) { - col := req.Context().Value(colContextKey).(client.Collection) + col := mustGetContextClientCollection(req) data, err := io.ReadAll(req.Body) if err != nil { @@ -89,7 +89,7 @@ func (s *collectionHandler) Create(rw http.ResponseWriter, req *http.Request) { } func (s *collectionHandler) DeleteWithFilter(rw http.ResponseWriter, req *http.Request) { - col := req.Context().Value(colContextKey).(client.Collection) + col := mustGetContextClientCollection(req) var request CollectionDeleteRequest if err := requestJSON(req, &request); err != nil { @@ -106,7 +106,7 @@ func (s *collectionHandler) DeleteWithFilter(rw http.ResponseWriter, req *http.R } func (s *collectionHandler) UpdateWithFilter(rw http.ResponseWriter, req *http.Request) { - col := req.Context().Value(colContextKey).(client.Collection) + col := mustGetContextClientCollection(req) var request CollectionUpdateRequest if err := requestJSON(req, &request); err != nil { @@ -123,7 +123,7 @@ func (s *collectionHandler) UpdateWithFilter(rw http.ResponseWriter, req *http.R } func (s *collectionHandler) Update(rw http.ResponseWriter, req *http.Request) { - col := req.Context().Value(colContextKey).(client.Collection) + col := mustGetContextClientCollection(req) docID, err := client.NewDocIDFromString(chi.URLParam(req, "docID")) if err != nil { @@ -160,7 +160,7 @@ func (s *collectionHandler) Update(rw http.ResponseWriter, req *http.Request) { } func (s *collectionHandler) Delete(rw http.ResponseWriter, req *http.Request) { - col := req.Context().Value(colContextKey).(client.Collection) + col := mustGetContextClientCollection(req) docID, err := client.NewDocIDFromString(chi.URLParam(req, "docID")) if err != nil { @@ -177,7 +177,7 @@ func (s *collectionHandler) Delete(rw http.ResponseWriter, req *http.Request) { } func (s *collectionHandler) Get(rw http.ResponseWriter, req *http.Request) { - col := req.Context().Value(colContextKey).(client.Collection) + col := mustGetContextClientCollection(req) showDeleted, _ := strconv.ParseBool(req.URL.Query().Get("show_deleted")) docID, err := client.NewDocIDFromString(chi.URLParam(req, "docID")) @@ -211,7 +211,7 @@ type DocIDResult struct { } func (s *collectionHandler) GetAllDocIDs(rw http.ResponseWriter, req *http.Request) { - col := req.Context().Value(colContextKey).(client.Collection) + col := mustGetContextClientCollection(req) flusher, ok := rw.(http.Flusher) if !ok { @@ -252,7 +252,7 @@ func (s *collectionHandler) GetAllDocIDs(rw http.ResponseWriter, req *http.Reque } func (s *collectionHandler) CreateIndex(rw http.ResponseWriter, req *http.Request) { - col := req.Context().Value(colContextKey).(client.Collection) + col := mustGetContextClientCollection(req) var indexDesc client.IndexDescription if err := requestJSON(req, &indexDesc); err != nil { @@ -268,7 +268,7 @@ func (s *collectionHandler) CreateIndex(rw http.ResponseWriter, req *http.Reques } func (s *collectionHandler) GetIndexes(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(dbContextKey).(client.Store) + store := mustGetContextClientStore(req) indexesMap, err := store.GetAllIndexes(req.Context()) if err != nil { @@ -283,7 +283,7 @@ func (s *collectionHandler) GetIndexes(rw http.ResponseWriter, req *http.Request } func (s *collectionHandler) DropIndex(rw http.ResponseWriter, req *http.Request) { - col := req.Context().Value(colContextKey).(client.Collection) + col := mustGetContextClientCollection(req) err := col.DropIndex(req.Context(), chi.URLParam(req, "index")) if err != nil { diff --git a/http/handler_extras.go b/http/handler_extras.go index c891e9befc..1f14cc40a7 100644 --- a/http/handler_extras.go +++ b/http/handler_extras.go @@ -15,7 +15,6 @@ import ( "github.com/getkin/kin-openapi/openapi3" - "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/event" ) @@ -23,7 +22,7 @@ import ( type extrasHandler struct{} func (s *extrasHandler) Purge(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) + db := mustGetContextClientDB(req) rw.WriteHeader(http.StatusOK) // write the response before we restart to purge db.Events().Publish(event.NewMessage(event.PurgeName, nil)) } diff --git a/http/handler_lens.go b/http/handler_lens.go index 94ef9c2abe..f6d20465f0 100644 --- a/http/handler_lens.go +++ b/http/handler_lens.go @@ -15,14 +15,12 @@ import ( "github.com/getkin/kin-openapi/openapi3" "github.com/sourcenetwork/immutable/enumerable" - - "github.com/sourcenetwork/defradb/client" ) type lensHandler struct{} func (s *lensHandler) ReloadLenses(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(dbContextKey).(client.Store) + store := mustGetContextClientStore(req) err := store.LensRegistry().ReloadLenses(req.Context()) if err != nil { @@ -33,7 +31,7 @@ func (s *lensHandler) ReloadLenses(rw http.ResponseWriter, req *http.Request) { } func (s *lensHandler) SetMigration(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(dbContextKey).(client.Store) + store := mustGetContextClientStore(req) var request setMigrationRequest if err := requestJSON(req, &request); err != nil { @@ -50,7 +48,7 @@ func (s *lensHandler) SetMigration(rw http.ResponseWriter, req *http.Request) { } func (s *lensHandler) MigrateUp(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(dbContextKey).(client.Store) + store := mustGetContextClientStore(req) var request migrateRequest if err := requestJSON(req, &request); err != nil { @@ -75,7 +73,7 @@ func (s *lensHandler) MigrateUp(rw http.ResponseWriter, req *http.Request) { } func (s *lensHandler) MigrateDown(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(dbContextKey).(client.Store) + store := mustGetContextClientStore(req) var request migrateRequest if err := requestJSON(req, &request); err != nil { diff --git a/http/handler_p2p.go b/http/handler_p2p.go index 13fc88a90c..941b6f4b5b 100644 --- a/http/handler_p2p.go +++ b/http/handler_p2p.go @@ -21,7 +21,7 @@ import ( type p2pHandler struct{} func (s *p2pHandler) PeerInfo(rw http.ResponseWriter, req *http.Request) { - p2p, ok := req.Context().Value(dbContextKey).(client.P2P) + p2p, ok := tryGetContextClientP2P(req) if !ok { responseJSON(rw, http.StatusBadRequest, errorResponse{ErrP2PDisabled}) return @@ -30,7 +30,7 @@ func (s *p2pHandler) PeerInfo(rw http.ResponseWriter, req *http.Request) { } func (s *p2pHandler) SetReplicator(rw http.ResponseWriter, req *http.Request) { - p2p, ok := req.Context().Value(dbContextKey).(client.P2P) + p2p, ok := tryGetContextClientP2P(req) if !ok { responseJSON(rw, http.StatusBadRequest, errorResponse{ErrP2PDisabled}) return @@ -50,7 +50,7 @@ func (s *p2pHandler) SetReplicator(rw http.ResponseWriter, req *http.Request) { } func (s *p2pHandler) DeleteReplicator(rw http.ResponseWriter, req *http.Request) { - p2p, ok := req.Context().Value(dbContextKey).(client.P2P) + p2p, ok := tryGetContextClientP2P(req) if !ok { responseJSON(rw, http.StatusBadRequest, errorResponse{ErrP2PDisabled}) return @@ -70,7 +70,7 @@ func (s *p2pHandler) DeleteReplicator(rw http.ResponseWriter, req *http.Request) } func (s *p2pHandler) GetAllReplicators(rw http.ResponseWriter, req *http.Request) { - p2p, ok := req.Context().Value(dbContextKey).(client.P2P) + p2p, ok := tryGetContextClientP2P(req) if !ok { responseJSON(rw, http.StatusBadRequest, errorResponse{ErrP2PDisabled}) return @@ -85,7 +85,7 @@ func (s *p2pHandler) GetAllReplicators(rw http.ResponseWriter, req *http.Request } func (s *p2pHandler) AddP2PCollection(rw http.ResponseWriter, req *http.Request) { - p2p, ok := req.Context().Value(dbContextKey).(client.P2P) + p2p, ok := tryGetContextClientP2P(req) if !ok { responseJSON(rw, http.StatusBadRequest, errorResponse{ErrP2PDisabled}) return @@ -105,7 +105,7 @@ func (s *p2pHandler) AddP2PCollection(rw http.ResponseWriter, req *http.Request) } func (s *p2pHandler) RemoveP2PCollection(rw http.ResponseWriter, req *http.Request) { - p2p, ok := req.Context().Value(dbContextKey).(client.P2P) + p2p, ok := tryGetContextClientP2P(req) if !ok { responseJSON(rw, http.StatusBadRequest, errorResponse{ErrP2PDisabled}) return @@ -125,7 +125,7 @@ func (s *p2pHandler) RemoveP2PCollection(rw http.ResponseWriter, req *http.Reque } func (s *p2pHandler) GetAllP2PCollections(rw http.ResponseWriter, req *http.Request) { - p2p, ok := req.Context().Value(dbContextKey).(client.P2P) + p2p, ok := tryGetContextClientP2P(req) if !ok { responseJSON(rw, http.StatusBadRequest, errorResponse{ErrP2PDisabled}) return diff --git a/http/handler_store.go b/http/handler_store.go index e08f2aa9cf..86ab9aeb2d 100644 --- a/http/handler_store.go +++ b/http/handler_store.go @@ -26,7 +26,7 @@ import ( type storeHandler struct{} func (s *storeHandler) BasicImport(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(dbContextKey).(client.Store) + store := mustGetContextClientStore(req) var config client.BackupConfig if err := requestJSON(req, &config); err != nil { @@ -42,7 +42,7 @@ func (s *storeHandler) BasicImport(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) BasicExport(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(dbContextKey).(client.Store) + store := mustGetContextClientStore(req) var config client.BackupConfig if err := requestJSON(req, &config); err != nil { @@ -58,7 +58,7 @@ func (s *storeHandler) BasicExport(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) AddSchema(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(dbContextKey).(client.Store) + store := mustGetContextClientStore(req) schema, err := io.ReadAll(req.Body) if err != nil { @@ -74,7 +74,7 @@ func (s *storeHandler) AddSchema(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) PatchSchema(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(dbContextKey).(client.Store) + store := mustGetContextClientStore(req) var message patchSchemaRequest err := requestJSON(req, &message) @@ -92,7 +92,7 @@ func (s *storeHandler) PatchSchema(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) PatchCollection(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(dbContextKey).(client.Store) + store := mustGetContextClientStore(req) var patch string err := requestJSON(req, &patch) @@ -110,7 +110,7 @@ func (s *storeHandler) PatchCollection(rw http.ResponseWriter, req *http.Request } func (s *storeHandler) SetActiveSchemaVersion(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(dbContextKey).(client.Store) + store := mustGetContextClientStore(req) schemaVersionID, err := io.ReadAll(req.Body) if err != nil { @@ -126,7 +126,7 @@ func (s *storeHandler) SetActiveSchemaVersion(rw http.ResponseWriter, req *http. } func (s *storeHandler) AddView(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(dbContextKey).(client.Store) + store := mustGetContextClientStore(req) var message addViewRequest err := requestJSON(req, &message) @@ -145,7 +145,7 @@ func (s *storeHandler) AddView(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) SetMigration(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(dbContextKey).(client.Store) + store := mustGetContextClientStore(req) var cfg client.LensConfig if err := requestJSON(req, &cfg); err != nil { @@ -162,7 +162,7 @@ func (s *storeHandler) SetMigration(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) GetCollection(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(dbContextKey).(client.Store) + store := mustGetContextClientStore(req) options := client.CollectionFetchOptions{} if req.URL.Query().Has("name") { @@ -198,7 +198,7 @@ func (s *storeHandler) GetCollection(rw http.ResponseWriter, req *http.Request) } func (s *storeHandler) GetSchema(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(dbContextKey).(client.Store) + store := mustGetContextClientStore(req) options := client.SchemaFetchOptions{} if req.URL.Query().Has("version_id") { @@ -220,7 +220,7 @@ func (s *storeHandler) GetSchema(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) RefreshViews(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(dbContextKey).(client.Store) + store := mustGetContextClientStore(req) options := client.CollectionFetchOptions{} if req.URL.Query().Has("name") { @@ -252,7 +252,7 @@ func (s *storeHandler) RefreshViews(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) GetAllIndexes(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(dbContextKey).(client.Store) + store := mustGetContextClientStore(req) indexes, err := store.GetAllIndexes(req.Context()) if err != nil { @@ -263,7 +263,7 @@ func (s *storeHandler) GetAllIndexes(rw http.ResponseWriter, req *http.Request) } func (s *storeHandler) PrintDump(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) + db := mustGetContextClientDB(req) if err := db.PrintDump(req.Context()); err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) @@ -279,7 +279,7 @@ type GraphQLRequest struct { } func (s *storeHandler) ExecRequest(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(dbContextKey).(client.Store) + store := mustGetContextClientStore(req) var request GraphQLRequest switch { diff --git a/http/handler_tx.go b/http/handler_tx.go index e28acab3df..e1ac38376c 100644 --- a/http/handler_tx.go +++ b/http/handler_tx.go @@ -13,13 +13,9 @@ package http import ( "net/http" "strconv" - "sync" "github.com/getkin/kin-openapi/openapi3" "github.com/go-chi/chi/v5" - - "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/datastore" ) type txHandler struct{} @@ -29,8 +25,8 @@ type CreateTxResponse struct { } func (h *txHandler) NewTxn(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) - txs := req.Context().Value(txsContextKey).(*sync.Map) + db := mustGetContextClientDB(req) + txs := mustGetContextSyncMap(req) readOnly, _ := strconv.ParseBool(req.URL.Query().Get("read_only")) tx, err := db.NewTxn(req.Context(), readOnly) @@ -43,8 +39,8 @@ func (h *txHandler) NewTxn(rw http.ResponseWriter, req *http.Request) { } func (h *txHandler) NewConcurrentTxn(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) - txs := req.Context().Value(txsContextKey).(*sync.Map) + db := mustGetContextClientDB(req) + txs := mustGetContextSyncMap(req) readOnly, _ := strconv.ParseBool(req.URL.Query().Get("read_only")) tx, err := db.NewConcurrentTxn(req.Context(), readOnly) @@ -57,41 +53,46 @@ func (h *txHandler) NewConcurrentTxn(rw http.ResponseWriter, req *http.Request) } func (h *txHandler) Commit(rw http.ResponseWriter, req *http.Request) { - txs := req.Context().Value(txsContextKey).(*sync.Map) + txs := mustGetContextSyncMap(req) - txId, err := strconv.ParseUint(chi.URLParam(req, "id"), 10, 64) + txID, err := strconv.ParseUint(chi.URLParam(req, "id"), 10, 64) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{ErrInvalidTransactionId}) return } - txVal, ok := txs.Load(txId) + txVal, ok := txs.Load(txID) if !ok { responseJSON(rw, http.StatusBadRequest, errorResponse{ErrInvalidTransactionId}) return } - err = txVal.(datastore.Txn).Commit(req.Context()) + + dsTxn := mustGetDataStoreTxn(txVal) + err = dsTxn.Commit(req.Context()) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return } - txs.Delete(txId) + txs.Delete(txID) rw.WriteHeader(http.StatusOK) } func (h *txHandler) Discard(rw http.ResponseWriter, req *http.Request) { - txs := req.Context().Value(txsContextKey).(*sync.Map) + txs := mustGetContextSyncMap(req) - txId, err := strconv.ParseUint(chi.URLParam(req, "id"), 10, 64) + txID, err := strconv.ParseUint(chi.URLParam(req, "id"), 10, 64) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{ErrInvalidTransactionId}) return } - txVal, ok := txs.LoadAndDelete(txId) + txVal, ok := txs.LoadAndDelete(txID) if !ok { responseJSON(rw, http.StatusBadRequest, errorResponse{ErrInvalidTransactionId}) return } - txVal.(datastore.Txn).Discard(req.Context()) + + dsTxn := mustGetDataStoreTxn(txVal) + dsTxn.Discard(req.Context()) + rw.WriteHeader(http.StatusOK) } diff --git a/http/middleware.go b/http/middleware.go index d02c3d6470..cc98473711 100644 --- a/http/middleware.go +++ b/http/middleware.go @@ -26,26 +26,6 @@ import ( "github.com/sourcenetwork/defradb/internal/db" ) -const ( - // txHeaderName is the name of the transaction header. - // This header should contain a valid transaction id. - txHeaderName = "x-defradb-tx" -) - -type contextKey string - -var ( - // txsContextKey is the context key for the transaction *sync.Map - txsContextKey = contextKey("txs") - // dbContextKey is the context key for the client.DB - dbContextKey = contextKey("db") - // colContextKey is the context key for the client.Collection - // - // If a transaction exists, all operations will be executed - // in the current transaction context. - colContextKey = contextKey("col") -) - // CorsMiddleware handles cross origin request func CorsMiddleware(allowedOrigins []string) func(http.Handler) http.Handler { return cors.Handler(cors.Options{ @@ -76,7 +56,7 @@ func ApiMiddleware(db client.DB, txs *sync.Map) func(http.Handler) http.Handler // TransactionMiddleware sets the transaction context for the current request. func TransactionMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - txs := req.Context().Value(txsContextKey).(*sync.Map) + txs := mustGetContextSyncMap(req) txValue := req.Header.Get(txHeaderName) if txValue == "" { @@ -104,7 +84,7 @@ func TransactionMiddleware(next http.Handler) http.Handler { // CollectionMiddleware sets the collection context for the current request. func CollectionMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) + db := mustGetContextClientDB(req) col, err := db.GetCollectionByName(req.Context(), chi.URLParam(req, "name")) if err != nil { diff --git a/http/utils.go b/http/utils.go index 176fe3d035..d371c802e6 100644 --- a/http/utils.go +++ b/http/utils.go @@ -14,8 +14,76 @@ import ( "encoding/json" "io" "net/http" + "sync" + + "github.com/sourcenetwork/defradb/client" + "github.com/sourcenetwork/defradb/datastore" +) + +const ( + // txHeaderName is the name of the transaction header. + // This header should contain a valid transaction id. + txHeaderName = "x-defradb-tx" +) + +type contextKey string + +var ( + // txsContextKey is the context key for the transaction *sync.Map + txsContextKey = contextKey("txs") + // dbContextKey is the context key for the client.DB + dbContextKey = contextKey("db") + // colContextKey is the context key for the client.Collection + // + // If a transaction exists, all operations will be executed + // in the current transaction context. + colContextKey = contextKey("col") ) +// mustGetContextClientCollection returns the client collection from the http request context or panics. +// +// This should only be called from functions within the http package. +func mustGetContextClientCollection(req *http.Request) client.Collection { + return req.Context().Value(colContextKey).(client.Collection) //nolint:forcetypeassert +} + +// mustGetContextSyncMap returns the sync map from the http request context or panics. +// +// This should only be called from functions within the http package. +func mustGetContextSyncMap(req *http.Request) *sync.Map { + return req.Context().Value(txsContextKey).(*sync.Map) //nolint:forcetypeassert +} + +// mustGetContextClientDB returns the client DB from the http request context or panics. +// +// This should only be called from functions within the http package. +func mustGetContextClientDB(req *http.Request) client.DB { + return req.Context().Value(dbContextKey).(client.DB) //nolint:forcetypeassert +} + +// mustGetContextClientStore returns the client store from the http request context or panics. +// +// This should only be called from functions within the http package. +func mustGetContextClientStore(req *http.Request) client.Store { + return req.Context().Value(dbContextKey).(client.Store) //nolint:forcetypeassert +} + +// mustGetDataStoreTxn returns the datastore transaction or panics. +// +// This should only be called from functions within the http package. +func mustGetDataStoreTxn(tx any) datastore.Txn { + return tx.(datastore.Txn) //nolint:forcetypeassert +} + +// tryGetContextClientP2P returns the P2P client from the http request context and a boolean +// indicating if p2p was enabled. +// +// This should only be called from functions within the http package. +func tryGetContextClientP2P(req *http.Request) (client.P2P, bool) { + p2p, ok := req.Context().Value(dbContextKey).(client.P2P) + return p2p, ok +} + func requestJSON(req *http.Request, out any) error { data, err := io.ReadAll(req.Body) if err != nil { diff --git a/tools/configs/golangci.yaml b/tools/configs/golangci.yaml index d8162783df..1b4abe0718 100644 --- a/tools/configs/golangci.yaml +++ b/tools/configs/golangci.yaml @@ -113,6 +113,7 @@ linters: - errcheck - errorlint - forbidigo + - forcetypeassert - goconst - gofmt - goheader @@ -151,6 +152,69 @@ issues: linters: - goheader + # Exclude running force type assert check in these file paths, we are ignoring these files for now + # because there are many linter complaints in them, we want to resolve all of them eventually. + # TODO: https://github.com/sourcenetwork/defradb/issues/3154 + # Note: The last item must not have a `|` at the end otherwise linter ignores everyfile. + - path: "(\ + client/document.go|\ + client/normal_value_test.go|\ + net/grpc.go|\ + node/store_badger.go|\ + internal/connor/eq.go|\ + internal/core/block/block.go|\ + internal/core/block/block_test.go|\ + internal/core/key_test.go|\ + internal/core/view_item.go|\ + internal/db/backup.go|\ + internal/db/base/compare.go|\ + internal/db/collection.go|\ + internal/db/context.go|\ + internal/db/fetcher/indexer_iterators.go|\ + internal/db/index_test.go|\ + internal/db/indexed_docs_test.go|\ + internal/db/merge.go|\ + internal/db/merge_test.go|\ + internal/db/p2p_replicator.go|\ + internal/db/p2p_replicator_test.go|\ + internal/db/p2p_schema_root.go|\ + internal/db/p2p_schema_root_test.go|\ + internal/lens/fetcher.go|\ + internal/merkle/clock/clock.go|\ + internal/merkle/crdt/merklecrdt.go|\ + internal/planner/arbitrary_join.go|\ + internal/planner/filter/complex.go|\ + internal/planner/filter/copy.go|\ + internal/planner/filter/copy_field.go|\ + internal/planner/filter/copy_test.go|\ + internal/planner/filter/extract_properties.go|\ + internal/planner/filter/normalize.go|\ + internal/planner/filter/unwrap_relation.go|\ + internal/planner/group.go|\ + internal/planner/lens.go|\ + internal/planner/mapper/mapper.go|\ + internal/planner/mapper/targetable.go|\ + internal/planner/planner.go|\ + internal/planner/sum.go|\ + internal/planner/view.go|\ + internal/request/graphql/parser/commit.go|\ + internal/request/graphql/parser/filter.go|\ + internal/request/graphql/parser/mutation.go|\ + internal/request/graphql/parser/query.go|\ + internal/request/graphql/parser/request.go|\ + internal/request/graphql/schema/collection.go|\ + internal/request/graphql/schema/generate.go|\ + tests/gen|\ + tests/integration/utils.go|\ + tests/integration/explain.go|\ + tests/integration/events.go|\ + tests/integration/acp.go|\ + tests/integration/schema/default_fields.go|\ + tests/predefined/gen_predefined.go\ + )" + linters: + - forcetypeassert + # Independently from option `exclude` we use default exclude patterns, # it can be disabled by this option. To list all # excluded by default patterns execute `golangci-lint run --help`.