From b70efb4e0d503fe9bc9dc865aaff02a09ebea264 Mon Sep 17 00:00:00 2001 From: Jimmy Moore Date: Mon, 18 Sep 2023 14:11:46 +0100 Subject: [PATCH] Extensions now get reset at start of a run Signed-off-by: Jimmy Moore --- config.go | 13 ++----- extension/extension.go | 5 +++ extension/generator/templates/host.go.templ | 24 ++++++++++--- instance.go | 2 ++ scale.go | 38 +++++++++++++-------- 5 files changed, 53 insertions(+), 29 deletions(-) diff --git a/config.go b/config.go index d66d2d68..2161a4e3 100644 --- a/config.go +++ b/config.go @@ -50,8 +50,7 @@ type Config[T interfaces.Signature] struct { context context.Context Stdout io.Writer Stderr io.Writer - - extensions map[string]extension.InstallableFunc + extensions []extension.HostExtension } // NewConfig returns a new Scale Runtime Config @@ -87,14 +86,8 @@ func (c *Config[T]) validate() error { return nil } -func (c *Config[T]) WithExtensions(e map[string]extension.InstallableFunc) *Config[T] { - if c.extensions == nil { - c.extensions = make(map[string]extension.InstallableFunc) - } - for n, f := range e { - // TODO: Check for stomping etc - c.extensions[n] = f - } +func (c *Config[T]) WithExtension(e extension.HostExtension) *Config[T] { + c.extensions = append(c.extensions, e) return c } diff --git a/extension/extension.go b/extension/extension.go index 729aec90..f1612467 100644 --- a/extension/extension.go +++ b/extension/extension.go @@ -8,3 +8,8 @@ type ModuleMemory interface { type Resizer func(name string, size uint64) (uint64, error) type InstallableFunc func(mem ModuleMemory, resize Resizer, params []uint64) + +type HostExtension interface { + Init() map[string]InstallableFunc + Reset() +} diff --git a/extension/generator/templates/host.go.templ b/extension/generator/templates/host.go.templ index c9404553..9af83209 100644 --- a/extension/generator/templates/host.go.templ +++ b/extension/generator/templates/host.go.templ @@ -8,7 +8,6 @@ import ( "errors" "sync/atomic" "sync" - "github.com/loopholelabs/polyglot" "github.com/loopholelabs/scale/extension" @@ -34,8 +33,23 @@ func hostError(mem extension.ModuleMemory, resize extension.Resizer, err error) {{ $schema := .schema }} -func InstallExtension(impl {{ .schema.Name }}Ifc) map[string]extension.InstallableFunc { +type hostExt struct { + functions map[string]extension.InstallableFunc + host *{{ .schema.Name }}Host +} + +func (he *hostExt) Init() map[string]extension.InstallableFunc { + return he.functions +} + +func (he *hostExt) Reset() { + // Reset any instances that have been created. + {{ range $ifc := .schema.Interfaces }} + he.host.instances_{{ $ifc.Name }} = make(map[uint64]{{ $ifc.Name }}) + {{ end }} +} +func New(impl {{ .schema.Name }}Ifc) extension.HostExtension { hostWrapper := &{{ .schema.Name }}Host{ impl: impl } fns := make(map[string]extension.InstallableFunc) @@ -43,7 +57,6 @@ func InstallExtension(impl {{ .schema.Name }}Ifc) map[string]extension.Installab // Add global functions to the runtime {{ range $fn := .schema.Functions }} fns["ext_{{ $schema.Name }}_{{ $fn.Name }}"] = hostWrapper.host_ext_{{ $schema.Name }}_{{ $fn.Name }} - {{ end }} {{ range $ifc := .schema.Interfaces }} @@ -56,7 +69,10 @@ func InstallExtension(impl {{ .schema.Name }}Ifc) map[string]extension.Installab {{ end }} {{ end }} - return fns + return &hostExt{ + functions: fns, + host: hostWrapper, + } } type {{ .schema.Name }}Host struct { diff --git a/instance.go b/instance.go index 11083116..778c6aa3 100644 --- a/instance.go +++ b/instance.go @@ -80,6 +80,8 @@ func newInstance[T interfaces.Signature](ctx context.Context, runtime *Scale[T], } func (i *Instance[T]) Run(ctx context.Context, signature T) error { + i.runtime.resetExtensions() + m, err := i.head.getModule(signature) if err != nil { return fmt.Errorf("failed to get module for function '%s': %w", i.head.template.identifier, err) diff --git a/scale.go b/scale.go index 6e5262fb..ac03bd2b 100644 --- a/scale.go +++ b/scale.go @@ -70,6 +70,13 @@ func (r *Scale[T]) Instance(next ...Next[T]) (*Instance[T], error) { return newInstance(r.config.context, r, next...) } +// Reset any extensions between executions. +func (r *Scale[T]) resetExtensions() { + for _, ext := range r.config.extensions { + ext.Reset() + } +} + func (r *Scale[T]) init() error { err := r.config.validate() if err != nil { @@ -87,23 +94,24 @@ func (r *Scale[T]) init() error { envModule := r.runtime.NewHostModuleBuilder("env") // Install any extensions... - for name, fn := range r.config.extensions { - fmt.Printf("Installing module [%s]\n", name) - wfn := func(n string, f extension.InstallableFunc) func(context.Context, api.Module, []uint64) { - return func(ctx context.Context, mod api.Module, params []uint64) { - fmt.Printf("HOST FUNCTION CALLED %s\n", n) - mem := mod.Memory() - resize := func(name string, size uint64) (uint64, error) { - w, err := mod.ExportedFunction(name).Call(context.Background(), size) - return w[0], err + for _, ext := range r.config.extensions { + fns := ext.Init() + for name, fn := range fns { + wfn := func(n string, f extension.InstallableFunc) func(context.Context, api.Module, []uint64) { + return func(ctx context.Context, mod api.Module, params []uint64) { + mem := mod.Memory() + resize := func(name string, size uint64) (uint64, error) { + w, err := mod.ExportedFunction(name).Call(context.Background(), size) + return w[0], err + } + f(mem, resize, params) } - f(mem, resize, params) - } - }(name, fn) + }(name, fn) - envModule.NewFunctionBuilder(). - WithGoModuleFunction(api.GoModuleFunc(wfn), []api.ValueType{api.ValueTypeI64, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI64}). - WithParameterNames("instance", "pointer", "length").Export(name) + envModule.NewFunctionBuilder(). + WithGoModuleFunction(api.GoModuleFunc(wfn), []api.ValueType{api.ValueTypeI64, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI64}). + WithParameterNames("instance", "pointer", "length").Export(name) + } } envHostModuleBuilder := envModule.