Skip to content

Commit

Permalink
Merge branch 'develop' into combined-bot-prs-branch
Browse files Browse the repository at this point in the history
  • Loading branch information
shahzadlone authored Dec 8, 2024
2 parents d33143c + b4a3eba commit fb9006c
Show file tree
Hide file tree
Showing 28 changed files with 1,330 additions and 91 deletions.
3 changes: 2 additions & 1 deletion acp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,8 @@ Result:
Error: document not found or not authorized to access
```

Sometimes we might want to give a specific access (form a relationship) not just to one identity, but any identity.
Sometimes we might want to give a specific access (i.e. form a relationship) not just with one identity, but with
any identity (includes even requests with no-identity).
In that case we can specify "*" instead of specifying an explicit `actor`:
```sh
defradb client acp relationship add \
Expand Down
10 changes: 8 additions & 2 deletions http/handler_ccip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestCCIPGet_WithValidData(t *testing.T) {
resHex, err := hex.DecodeString(strings.TrimPrefix(ccipRes.Data, "0x"))
require.NoError(t, err)

assert.JSONEq(t, `{"data": {"User": [{"name": "bob"}]}}`, string(resHex))
assert.JSONEq(t, `{"data": {"User": [{"name": "bob"}, {"name": "adam"}]}}`, string(resHex))
}

func TestCCIPGet_WithSubscription(t *testing.T) {
Expand Down Expand Up @@ -153,7 +153,7 @@ func TestCCIPPost_WithValidData(t *testing.T) {
resHex, err := hex.DecodeString(strings.TrimPrefix(ccipRes.Data, "0x"))
require.NoError(t, err)

assert.JSONEq(t, `{"data": {"User": [{"name": "bob"}]}}`, string(resHex))
assert.JSONEq(t, `{"data": {"User": [{"name": "bob"}, {"name": "adam"}]}}`, string(resHex))
}

func TestCCIPPost_WithInvalidGraphQLRequest(t *testing.T) {
Expand Down Expand Up @@ -210,5 +210,11 @@ func setupDatabase(t *testing.T) client.DB {
err = col.Create(ctx, doc)
require.NoError(t, err)

doc2, err := client.NewDocFromJSON([]byte(`{"name": "adam"}`), col.Definition())
require.NoError(t, err)

err = col.Create(ctx, doc2)
require.NoError(t, err)

return cdb
}
15 changes: 14 additions & 1 deletion http/handler_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,21 @@ func (s *storeHandler) ExecRequest(rw http.ResponseWriter, req *http.Request) {
var request GraphQLRequest
switch {
case req.URL.Query().Get("query") != "":

request.Query = req.URL.Query().Get("query")

request.OperationName = req.URL.Query().Get("operationName")

variablesFromQuery := req.URL.Query().Get("variables")
if variablesFromQuery != "" {
var variables map[string]any
if err := json.Unmarshal([]byte(variablesFromQuery), &variables); err != nil {
responseJSON(rw, http.StatusBadRequest, errorResponse{err})
return
}
request.Variables = variables
}

case req.Body != nil:
if err := requestJSON(req, &request); err != nil {
responseJSON(rw, http.StatusBadRequest, errorResponse{err})
Expand All @@ -294,7 +308,6 @@ func (s *storeHandler) ExecRequest(rw http.ResponseWriter, req *http.Request) {
responseJSON(rw, http.StatusBadRequest, errorResponse{ErrMissingRequest})
return
}

var options []client.RequestOption
if request.OperationName != "" {
options = append(options, client.WithOperationName(request.OperationName))
Expand Down
98 changes: 98 additions & 0 deletions http/handler_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -93,3 +94,100 @@ func TestExecRequest_WithInvalidQuery_HasSpecCompliantErrors(t *testing.T) {
"message": "Cannot query field \"invalid\" on type \"User\".",
}})
}

func TestExecRequest_HttpGet_WithOperationName(t *testing.T) {
cdb := setupDatabase(t)

query := `
query UserQuery {
User {
name
}
}
query UserQueryWithDocID {
User {
_docID
name
}
}
`
operationName := "UserQuery"

encodedQuery := url.QueryEscape(query)
encodedOperationName := url.QueryEscape(operationName)

endpointURL := "http://localhost:9181/api/v0/graphql?query=" + encodedQuery + "&operationName=" + encodedOperationName

req := httptest.NewRequest(http.MethodGet, endpointURL, nil)
rec := httptest.NewRecorder()

handler, err := NewHandler(cdb)
require.NoError(t, err)
handler.ServeHTTP(rec, req)

res := rec.Result()
require.NotNil(t, res.Body)

resData, err := io.ReadAll(res.Body)
require.NoError(t, err)

var gqlResponse map[string]any
err = json.Unmarshal(resData, &gqlResponse)
require.NoError(t, err)

// Ensure the response data contains names, but not the _docID field
expectedJSON := `{
"data": {
"User": [
{"name": "bob"},
{"name": "adam"}
]
}
}`
assert.JSONEq(t, expectedJSON, string(resData))
}

func TestExecRequest_HttpGet_WithVariables(t *testing.T) {
cdb := setupDatabase(t)

query := `query getUser($filter: UserFilterArg) {
User(filter: $filter) {
name
}
}`
operationName := "getUser"
variables := `{"filter":{"name":{"_eq":"bob"}}}`

encodedQuery := url.QueryEscape(query)
encodedOperationName := url.QueryEscape(operationName)
encodedVariables := url.QueryEscape(variables)

endpointURL := "http://localhost:9181/api/v0/graphql?query=" + encodedQuery + "&operationName=" + encodedOperationName + "&variables=" + encodedVariables

req := httptest.NewRequest(http.MethodGet, endpointURL, nil)
rec := httptest.NewRecorder()

handler, err := NewHandler(cdb)
require.NoError(t, err)
handler.ServeHTTP(rec, req)

res := rec.Result()
require.NotNil(t, res.Body)

resData, err := io.ReadAll(res.Body)
require.NoError(t, err)

var gqlResponse map[string]any
err = json.Unmarshal(resData, &gqlResponse)
require.NoError(t, err)

// Ensure only bob is returned, because of the filter variable
expectedJSON := `{
"data": {
"User": [
{"name": "bob"}
]
}
}`
assert.JSONEq(t, expectedJSON, string(resData))
}
2 changes: 1 addition & 1 deletion internal/db/base/compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func Compare(a, b any) int {
case bool:
return compareBool(v, b.(bool))
case int:
return compareInt(int64(v), b.(int64))
return compareInt(int64(v), int64(b.(int)))
case int64:
return compareInt(v, b.(int64))
case uint64:
Expand Down
14 changes: 9 additions & 5 deletions internal/db/permission/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,22 @@ func CheckAccessOfDocOnCollectionWithACP(
return true, nil
}

// At this point if the request is not signatured, then it has no access, because:
// the collection has a policy on it, and the acp is enabled/available,
// and the document is not public (is registered with acp).
var identityValue string
if !identity.HasValue() {
return false, nil
// We can't assume that there is no-access just because there is no identity even if the document
// is registered with acp, this is because it is possible that acp has a registered relation targeting
// "*" (any) actor which would mean that even a request without an identity might be able to access
// a document registered with acp. So we pass an empty `did` to accommodate that case.
identityValue = ""
} else {
identityValue = identity.Value().DID
}

// Now actually check using the signature if this identity has access or not.
hasAccess, err := acpSystem.CheckDocAccess(
ctx,
permission,
identity.Value().DID,
identityValue,
policyID,
resourceName,
docID,
Expand Down
6 changes: 5 additions & 1 deletion internal/planner/average.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ type averageNode struct {
virtualFieldIndex int

execInfo averageExecInfo

aggregateFilter *mapper.Filter
}

type averageExecInfo struct {
Expand All @@ -37,6 +39,7 @@ type averageExecInfo struct {

func (p *Planner) Average(
field *mapper.Aggregate,
filter *mapper.Filter,
) (*averageNode, error) {
var sumField *mapper.Aggregate
var countField *mapper.Aggregate
Expand All @@ -57,6 +60,7 @@ func (p *Planner) Average(
countFieldIndex: countField.Index,
virtualFieldIndex: field.Index,
docMapper: docMapper{field.DocumentMapping},
aggregateFilter: filter,
}, nil
}

Expand Down Expand Up @@ -102,7 +106,7 @@ func (n *averageNode) Next() (bool, error) {
return false, client.NewErrUnhandledType("sum", sumProp)
}

return true, nil
return mapper.RunFilter(n.currentValue, n.aggregateFilter)
}

func (n *averageNode) SetPlan(p planNode) { n.plan = p }
Expand Down
6 changes: 4 additions & 2 deletions internal/planner/count.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type countNode struct {

virtualFieldIndex int
aggregateMapping []mapper.AggregateTarget
aggregateFilter *mapper.Filter

execInfo countExecInfo
}
Expand All @@ -44,11 +45,12 @@ type countExecInfo struct {
iterations uint64
}

func (p *Planner) Count(field *mapper.Aggregate, host *mapper.Select) (*countNode, error) {
func (p *Planner) Count(field *mapper.Aggregate, host *mapper.Select, filter *mapper.Filter) (*countNode, error) {
return &countNode{
p: p,
virtualFieldIndex: field.Index,
aggregateMapping: field.AggregateTargets,
aggregateFilter: filter,
docMapper: docMapper{field.DocumentMapping},
}, nil
}
Expand Down Expand Up @@ -181,7 +183,7 @@ func (n *countNode) Next() (bool, error) {
}

n.currentValue.Fields[n.virtualFieldIndex] = count
return true, nil
return mapper.RunFilter(n.currentValue, n.aggregateFilter)
}

// countDocs counts the number of documents in a slice, skipping over hidden items
Expand Down
5 changes: 4 additions & 1 deletion internal/planner/max.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type maxNode struct {
// that contains the result of the aggregate.
virtualFieldIndex int
aggregateMapping []mapper.AggregateTarget
aggregateFilter *mapper.Filter

execInfo maxExecInfo
}
Expand All @@ -45,11 +46,13 @@ type maxExecInfo struct {
func (p *Planner) Max(
field *mapper.Aggregate,
parent *mapper.Select,
filter *mapper.Filter,
) (*maxNode, error) {
return &maxNode{
p: p,
parent: parent,
aggregateMapping: field.AggregateTargets,
aggregateFilter: filter,
virtualFieldIndex: field.Index,
docMapper: docMapper{field.DocumentMapping},
}, nil
Expand Down Expand Up @@ -252,5 +255,5 @@ func (n *maxNode) Next() (bool, error) {
res, _ := max.Int64()
n.currentValue.Fields[n.virtualFieldIndex] = res
}
return true, nil
return mapper.RunFilter(n.currentValue, n.aggregateFilter)
}
5 changes: 4 additions & 1 deletion internal/planner/min.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type minNode struct {
// that contains the result of the aggregate.
virtualFieldIndex int
aggregateMapping []mapper.AggregateTarget
aggregateFilter *mapper.Filter

execInfo minExecInfo
}
Expand All @@ -45,11 +46,13 @@ type minExecInfo struct {
func (p *Planner) Min(
field *mapper.Aggregate,
parent *mapper.Select,
filter *mapper.Filter,
) (*minNode, error) {
return &minNode{
p: p,
parent: parent,
aggregateMapping: field.AggregateTargets,
aggregateFilter: filter,
virtualFieldIndex: field.Index,
docMapper: docMapper{field.DocumentMapping},
}, nil
Expand Down Expand Up @@ -252,5 +255,5 @@ func (n *minNode) Next() (bool, error) {
res, _ := min.Int64()
n.currentValue.Fields[n.virtualFieldIndex] = res
}
return true, nil
return mapper.RunFilter(n.currentValue, n.aggregateFilter)
}
14 changes: 9 additions & 5 deletions internal/planner/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/sourcenetwork/defradb/internal/core"
"github.com/sourcenetwork/defradb/internal/db/base"
"github.com/sourcenetwork/defradb/internal/keys"
"github.com/sourcenetwork/defradb/internal/planner/filter"
"github.com/sourcenetwork/defradb/internal/planner/mapper"
)

Expand Down Expand Up @@ -344,18 +345,21 @@ func (n *selectNode) initFields(selectReq *mapper.Select) ([]aggregateNode, erro
case *mapper.Aggregate:
var plan aggregateNode
var aggregateError error
var aggregateFilter *mapper.Filter

// extract aggregate filters from the select
selectReq.Filter, aggregateFilter = filter.SplitByFields(selectReq.Filter, f.Field)
switch f.Name {
case request.CountFieldName:
plan, aggregateError = n.planner.Count(f, selectReq)
plan, aggregateError = n.planner.Count(f, selectReq, aggregateFilter)
case request.SumFieldName:
plan, aggregateError = n.planner.Sum(f, selectReq)
plan, aggregateError = n.planner.Sum(f, selectReq, aggregateFilter)
case request.AverageFieldName:
plan, aggregateError = n.planner.Average(f)
plan, aggregateError = n.planner.Average(f, aggregateFilter)
case request.MaxFieldName:
plan, aggregateError = n.planner.Max(f, selectReq)
plan, aggregateError = n.planner.Max(f, selectReq, aggregateFilter)
case request.MinFieldName:
plan, aggregateError = n.planner.Min(f, selectReq)
plan, aggregateError = n.planner.Min(f, selectReq, aggregateFilter)
}

if aggregateError != nil {
Expand Down
Loading

0 comments on commit fb9006c

Please sign in to comment.