Skip to content

Commit

Permalink
Extensions now get reset at start of a run
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Moore <[email protected]>
  • Loading branch information
jimmyaxod committed Sep 19, 2023
1 parent b387cb8 commit b70efb4
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 29 deletions.
13 changes: 3 additions & 10 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ type Config[T interfaces.Signature] struct {
context context.Context
Stdout io.Writer
Stderr io.Writer

extensions map[string]extension.InstallableFunc
extensions []extension.HostExtension
}

// NewConfig returns a new Scale Runtime Config
Expand Down Expand Up @@ -87,14 +86,8 @@ func (c *Config[T]) validate() error {
return nil
}

func (c *Config[T]) WithExtensions(e map[string]extension.InstallableFunc) *Config[T] {
if c.extensions == nil {
c.extensions = make(map[string]extension.InstallableFunc)
}
for n, f := range e {
// TODO: Check for stomping etc
c.extensions[n] = f
}
func (c *Config[T]) WithExtension(e extension.HostExtension) *Config[T] {
c.extensions = append(c.extensions, e)
return c
}

Expand Down
5 changes: 5 additions & 0 deletions extension/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,8 @@ type ModuleMemory interface {
type Resizer func(name string, size uint64) (uint64, error)

type InstallableFunc func(mem ModuleMemory, resize Resizer, params []uint64)

type HostExtension interface {
Init() map[string]InstallableFunc
Reset()
}
24 changes: 20 additions & 4 deletions extension/generator/templates/host.go.templ
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"errors"
"sync/atomic"
"sync"

"github.com/loopholelabs/polyglot"

"github.com/loopholelabs/scale/extension"
Expand All @@ -34,16 +33,30 @@ func hostError(mem extension.ModuleMemory, resize extension.Resizer, err error)

{{ $schema := .schema }}

func InstallExtension(impl {{ .schema.Name }}Ifc) map[string]extension.InstallableFunc {
type hostExt struct {
functions map[string]extension.InstallableFunc
host *{{ .schema.Name }}Host
}

func (he *hostExt) Init() map[string]extension.InstallableFunc {
return he.functions
}

func (he *hostExt) Reset() {
// Reset any instances that have been created.
{{ range $ifc := .schema.Interfaces }}
he.host.instances_{{ $ifc.Name }} = make(map[uint64]{{ $ifc.Name }})
{{ end }}
}

func New(impl {{ .schema.Name }}Ifc) extension.HostExtension {
hostWrapper := &{{ .schema.Name }}Host{ impl: impl }

fns := make(map[string]extension.InstallableFunc)

// Add global functions to the runtime
{{ range $fn := .schema.Functions }}
fns["ext_{{ $schema.Name }}_{{ $fn.Name }}"] = hostWrapper.host_ext_{{ $schema.Name }}_{{ $fn.Name }}

{{ end }}

{{ range $ifc := .schema.Interfaces }}
Expand All @@ -56,7 +69,10 @@ func InstallExtension(impl {{ .schema.Name }}Ifc) map[string]extension.Installab
{{ end }}
{{ end }}

return fns
return &hostExt{
functions: fns,
host: hostWrapper,
}
}

type {{ .schema.Name }}Host struct {
Expand Down
2 changes: 2 additions & 0 deletions instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ func newInstance[T interfaces.Signature](ctx context.Context, runtime *Scale[T],
}

func (i *Instance[T]) Run(ctx context.Context, signature T) error {
i.runtime.resetExtensions()

m, err := i.head.getModule(signature)
if err != nil {
return fmt.Errorf("failed to get module for function '%s': %w", i.head.template.identifier, err)
Expand Down
38 changes: 23 additions & 15 deletions scale.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ func (r *Scale[T]) Instance(next ...Next[T]) (*Instance[T], error) {
return newInstance(r.config.context, r, next...)
}

// Reset any extensions between executions.
func (r *Scale[T]) resetExtensions() {
for _, ext := range r.config.extensions {
ext.Reset()
}
}

func (r *Scale[T]) init() error {
err := r.config.validate()
if err != nil {
Expand All @@ -87,23 +94,24 @@ func (r *Scale[T]) init() error {
envModule := r.runtime.NewHostModuleBuilder("env")

// Install any extensions...
for name, fn := range r.config.extensions {
fmt.Printf("Installing module [%s]\n", name)
wfn := func(n string, f extension.InstallableFunc) func(context.Context, api.Module, []uint64) {
return func(ctx context.Context, mod api.Module, params []uint64) {
fmt.Printf("HOST FUNCTION CALLED %s\n", n)
mem := mod.Memory()
resize := func(name string, size uint64) (uint64, error) {
w, err := mod.ExportedFunction(name).Call(context.Background(), size)
return w[0], err
for _, ext := range r.config.extensions {
fns := ext.Init()
for name, fn := range fns {
wfn := func(n string, f extension.InstallableFunc) func(context.Context, api.Module, []uint64) {
return func(ctx context.Context, mod api.Module, params []uint64) {
mem := mod.Memory()
resize := func(name string, size uint64) (uint64, error) {
w, err := mod.ExportedFunction(name).Call(context.Background(), size)
return w[0], err
}
f(mem, resize, params)
}
f(mem, resize, params)
}
}(name, fn)
}(name, fn)

envModule.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(wfn), []api.ValueType{api.ValueTypeI64, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI64}).
WithParameterNames("instance", "pointer", "length").Export(name)
envModule.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(wfn), []api.ValueType{api.ValueTypeI64, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI64}).
WithParameterNames("instance", "pointer", "length").Export(name)
}
}

envHostModuleBuilder := envModule.
Expand Down

0 comments on commit b70efb4

Please sign in to comment.