Skip to content

Commit

Permalink
[CAPPL-128] limit the amount of fetch calls per request (#894)
Browse files Browse the repository at this point in the history
* feat: limit the amount of fetch calls per request

* fix: move counter to a request level

* fix: define defaultMaxFetchRequests as a const

* fix: rename and change type of fetchRequestsCounter
  • Loading branch information
agparadiso authored Oct 28, 2024
1 parent 8529bcc commit fcae9bd
Show file tree
Hide file tree
Showing 4 changed files with 286 additions and 41 deletions.
50 changes: 33 additions & 17 deletions pkg/workflows/wasm/host/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ import (
)

type RequestData struct {
response *wasmpb.Response
callWithCtx func(func(context.Context) (*wasmpb.FetchResponse, error)) (*wasmpb.FetchResponse, error)
fetchRequestsCounter int
response *wasmpb.Response
callWithCtx func(func(context.Context) (*wasmpb.FetchResponse, error)) (*wasmpb.FetchResponse, error)
}

type store struct {
Expand Down Expand Up @@ -69,24 +70,26 @@ func (r *store) delete(id string) {
}

var (
defaultTickInterval = 100 * time.Millisecond
defaultTimeout = 2 * time.Second
defaultMaxMemoryMBs = 256
DefaultInitialFuel = uint64(100_000_000)
defaultTickInterval = 100 * time.Millisecond
defaultTimeout = 2 * time.Second
defaultMaxMemoryMBs = 256
DefaultInitialFuel = uint64(100_000_000)
defaultMaxFetchRequests = 5
)

type DeterminismConfig struct {
// Seed is the seed used to generate cryptographically insecure random numbers in the module.
Seed int64
}
type ModuleConfig struct {
TickInterval time.Duration
Timeout *time.Duration
MaxMemoryMBs int64
InitialFuel uint64
Logger logger.Logger
IsUncompressed bool
Fetch func(ctx context.Context, req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error)
TickInterval time.Duration
Timeout *time.Duration
MaxMemoryMBs int64
InitialFuel uint64
Logger logger.Logger
IsUncompressed bool
Fetch func(ctx context.Context, req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error)
MaxFetchRequests int

// Labeler is used to emit messages from the module.
Labeler custmsg.MessageEmitter
Expand Down Expand Up @@ -139,6 +142,10 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig))
}
}

if modCfg.MaxFetchRequests == 0 {
modCfg.MaxFetchRequests = defaultMaxFetchRequests
}

if modCfg.Labeler == nil {
modCfg.Labeler = &unimplementedMessageEmitter{}
}
Expand Down Expand Up @@ -417,7 +424,7 @@ func (m *Module) Run(ctx context.Context, request *wasmpb.Request) (*wasmpb.Resp
return nil, innerErr
}

return nil, fmt.Errorf("error executing runner: %s: %w", storedRequest.response.ErrMsg, innerErr)
return nil, fmt.Errorf("error executing runner: %s: %w", storedRequest.response.ErrMsg, err)
case containsCode(err, wasm.CodeHostErr):
return nil, fmt.Errorf("invariant violation: host errored during sendResponse")
default:
Expand All @@ -434,10 +441,12 @@ func createFetchFn(
reader unsafeReaderFunc,
writer unsafeWriterFunc,
sizeWriter unsafeFixedLengthWriterFunc,
modCfg *ModuleConfig, store *store,
modCfg *ModuleConfig,
requestStore *store,
) func(caller *wasmtime.Caller, respptr int32, resplenptr int32, reqptr int32, reqptrlen int32) int32 {
const errFetchSfx = "error calling fetch"
return func(caller *wasmtime.Caller, respptr int32, resplenptr int32, reqptr int32, reqptrlen int32) int32 {
const errFetchSfx = "error calling fetch"

b, innerErr := reader(caller, reqptr, reqptrlen)
if innerErr != nil {
logger.Errorf("%s: %s", errFetchSfx, innerErr)
Expand All @@ -451,12 +460,19 @@ func createFetchFn(
return ErrnoFault
}

storedRequest, innerErr := store.get(req.Id)
storedRequest, innerErr := requestStore.get(req.Id)
if innerErr != nil {
logger.Errorf("%s: %s", errFetchSfx, innerErr)
return ErrnoFault
}

// limit the number of fetch calls we can make per request
if storedRequest.fetchRequestsCounter >= modCfg.MaxFetchRequests {
logger.Errorf("%s: max number of fetch request %d exceeded", errFetchSfx, modCfg.MaxFetchRequests)
return ErrnoFault
}
storedRequest.fetchRequestsCounter++

fetchResp, innerErr := storedRequest.callWithCtx(func(ctx context.Context) (*wasmpb.FetchResponse, error) {
if ctx == nil {
return nil, errors.New("context is nil")
Expand Down
1 change: 1 addition & 0 deletions pkg/workflows/wasm/host/module_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ func TestCreateFetchFn(t *testing.T) {
Fetch: func(ctx context.Context, req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error) {
return &wasmpb.FetchResponse{}, nil
},
MaxFetchRequests: 5,
},
store,
)
Expand Down
50 changes: 50 additions & 0 deletions pkg/workflows/wasm/host/test/fetchlimit/cmd/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
//go:build wasip1

package main

import (
"net/http"

"github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm"

"github.com/smartcontractkit/chainlink-common/pkg/capabilities/cli/cmd/testdata/fixtures/capabilities/basictrigger"
"github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk"
)

func BuildWorkflow(config []byte) *sdk.WorkflowSpecFactory {
workflow := sdk.NewWorkflowSpecFactory(
sdk.NewWorkflowParams{
Name: "tester",
Owner: "ryan",
},
)

triggerCfg := basictrigger.TriggerConfig{Name: "trigger", Number: 100}
trigger := triggerCfg.New(workflow)

sdk.Compute1[basictrigger.TriggerOutputs, bool](
workflow,
"transform",
sdk.Compute1Inputs[basictrigger.TriggerOutputs]{Arg0: trigger},
func(rsdk sdk.Runtime, outputs basictrigger.TriggerOutputs) (bool, error) {

for i := 0; i < 6; i++ {
_, err := rsdk.Fetch(sdk.FetchRequest{
Method: http.MethodGet,
URL: "https://min-api.cryptocompare.com/data/pricemultifull?fsyms=ETH&tsyms=BTC",
})
if err != nil {
return false, err
}
}

return true, nil
})

return workflow
}
func main() {
runner := wasm.NewRunner()
workflow := BuildWorkflow(runner.Config())
runner.Run(workflow)
}
Loading

0 comments on commit fcae9bd

Please sign in to comment.