diff --git a/cmd/bench/main.go b/cmd/bench/main.go index b546f0f9..5dad190c 100644 --- a/cmd/bench/main.go +++ b/cmd/bench/main.go @@ -207,7 +207,7 @@ func runVM(bytecode *compiler.Bytecode) (time.Duration, objects.Object, error) { start := time.Now() - v := runtime.NewVM(bytecode, globals, nil) + v := runtime.NewVM(bytecode, globals, nil, nil) if err := v.Run(); err != nil { return time.Since(start), nil, err } diff --git a/cmd/tengo/main.go b/cmd/tengo/main.go index 635a2044..3ff57c25 100644 --- a/cmd/tengo/main.go +++ b/cmd/tengo/main.go @@ -148,7 +148,7 @@ func compileAndRun(data []byte, inputFile string) (err error) { return } - machine := runtime.NewVM(bytecode, nil, nil) + machine := runtime.NewVM(bytecode, nil, nil, nil) err = machine.Run() if err != nil { @@ -165,7 +165,7 @@ func runCompiled(data []byte) (err error) { return } - machine := runtime.NewVM(bytecode, nil, nil) + machine := runtime.NewVM(bytecode, nil, nil, nil) err = machine.Run() if err != nil { @@ -216,7 +216,7 @@ func runREPL(in io.Reader, out io.Writer) { bytecode := c.Bytecode() - machine := runtime.NewVM(bytecode, globals, nil) + machine := runtime.NewVM(bytecode, globals, nil, nil) if err := machine.Run(); err != nil { _, _ = fmt.Fprintln(out, err.Error()) continue diff --git a/docs/interoperability.md b/docs/interoperability.md index 547f93fc..6a6dc84d 100644 --- a/docs/interoperability.md +++ b/docs/interoperability.md @@ -118,33 +118,37 @@ Users can add and use a custom user type in Tengo code by implementing [Object]( To securely compile and execute _potentially_ unsafe script code, you can use the following Script functions. -#### Script.DisableBuiltinFunction(name string) +#### Script.SetBuiltinFunctions(funcs []*objects.BuiltinFunction) -DisableBuiltinFunction disables and removes a builtin function from the compiler. Compiler will reports a compile-time error if the given name is referenced. +SetBuiltinFunctions resets all builtin functions in the compiler to the ones provided in the input parameter. Compiler will report a compile-time error if the a function not set is referenced. All builtin functions are included by default unless `SetBuiltinFunctions` is called. ```golang s := script.New([]byte(`print([1, 2, 3])`)) -s.DisableBuiltinFunction("print") +s.SetBuiltinFunctions(nil) -_, err := s.Run() // compile error -``` +_, err := s.Run() // compile error + +s.SetBuiltinFunctions([]*objects.BuiltinFunction{&objects.Builtins[0]}) -Note that when a script is being added to another script as a module (via `Script.AddModule`), it does not inherit the disabled builtin function list from the main script. +_, err := s.Run() // prints [1, 2, 3] +``` -#### Script.DisableStdModule(name string) +#### Script.SetBuiltinModules(modules map[string]*objects.ImmutableMap) -DisableStdModule disables a [standard library](https://github.com/d5/tengo/blob/master/docs/stdlib.md) module. Compile will report a compile-time error if the code tries to import the module with the given name. +SetBuiltinModules resets all [standard library](https://github.com/d5/tengo/blob/master/docs/stdlib.md) modules with modules provided in the input parameter. Compile will report a compile-time error if the code tries to import a module that hasn't been included. All standard library modules are included by default unless `SetBuiltinModules` is called. ```golang -s := script.New([]byte(`import("exec")`)) +s := script.New([]byte(`math := import("math"); a := math.abs(-19.84)`)) -s.DisableStdModule("exec") +s.SetBuiltinModules(nil) -_, err := s.Run() // compile error -``` +_, err := s.Run() // compile error + +s.SetBuiltinModules(map[string]*objects.ImmutableMap{"math": objectPtr(*stdlib.Modules["math"])}) -Note that when a script is being added to another script as a module (via `Script.AddModule`), it does not inherit the disabled standard module list from the main script. +_, err := s.Run() // a = 19.84 +``` #### Script.SetUserModuleLoader(loader compiler.ModuleLoader) diff --git a/objects/builtins.go b/objects/builtins.go index 67553932..c8b2a372 100644 --- a/objects/builtins.go +++ b/objects/builtins.go @@ -1,135 +1,129 @@ package objects -// NamedBuiltinFunc is a named builtin function. -type NamedBuiltinFunc struct { - Name string - Func CallableFunc -} - // Builtins contains all default builtin functions. -var Builtins = []NamedBuiltinFunc{ +var Builtins = []BuiltinFunction{ { - Name: "print", - Func: builtinPrint, + Name: "print", + Value: builtinPrint, }, { - Name: "printf", - Func: builtinPrintf, + Name: "printf", + Value: builtinPrintf, }, { - Name: "sprintf", - Func: builtinSprintf, + Name: "sprintf", + Value: builtinSprintf, }, { - Name: "len", - Func: builtinLen, + Name: "len", + Value: builtinLen, }, { - Name: "copy", - Func: builtinCopy, + Name: "copy", + Value: builtinCopy, }, { - Name: "append", - Func: builtinAppend, + Name: "append", + Value: builtinAppend, }, { - Name: "string", - Func: builtinString, + Name: "string", + Value: builtinString, }, { - Name: "int", - Func: builtinInt, + Name: "int", + Value: builtinInt, }, { - Name: "bool", - Func: builtinBool, + Name: "bool", + Value: builtinBool, }, { - Name: "float", - Func: builtinFloat, + Name: "float", + Value: builtinFloat, }, { - Name: "char", - Func: builtinChar, + Name: "char", + Value: builtinChar, }, { - Name: "bytes", - Func: builtinBytes, + Name: "bytes", + Value: builtinBytes, }, { - Name: "time", - Func: builtinTime, + Name: "time", + Value: builtinTime, }, { - Name: "is_int", - Func: builtinIsInt, + Name: "is_int", + Value: builtinIsInt, }, { - Name: "is_float", - Func: builtinIsFloat, + Name: "is_float", + Value: builtinIsFloat, }, { - Name: "is_string", - Func: builtinIsString, + Name: "is_string", + Value: builtinIsString, }, { - Name: "is_bool", - Func: builtinIsBool, + Name: "is_bool", + Value: builtinIsBool, }, { - Name: "is_char", - Func: builtinIsChar, + Name: "is_char", + Value: builtinIsChar, }, { - Name: "is_bytes", - Func: builtinIsBytes, + Name: "is_bytes", + Value: builtinIsBytes, }, { - Name: "is_array", - Func: builtinIsArray, + Name: "is_array", + Value: builtinIsArray, }, { - Name: "is_immutable_array", - Func: builtinIsImmutableArray, + Name: "is_immutable_array", + Value: builtinIsImmutableArray, }, { - Name: "is_map", - Func: builtinIsMap, + Name: "is_map", + Value: builtinIsMap, }, { - Name: "is_immutable_map", - Func: builtinIsImmutableMap, + Name: "is_immutable_map", + Value: builtinIsImmutableMap, }, { - Name: "is_time", - Func: builtinIsTime, + Name: "is_time", + Value: builtinIsTime, }, { - Name: "is_error", - Func: builtinIsError, + Name: "is_error", + Value: builtinIsError, }, { - Name: "is_undefined", - Func: builtinIsUndefined, + Name: "is_undefined", + Value: builtinIsUndefined, }, { - Name: "is_function", - Func: builtinIsFunction, + Name: "is_function", + Value: builtinIsFunction, }, { - Name: "is_callable", - Func: builtinIsCallable, + Name: "is_callable", + Value: builtinIsCallable, }, { - Name: "to_json", - Func: builtinToJSON, + Name: "to_json", + Value: builtinToJSON, }, { - Name: "from_json", - Func: builtinFromJSON, + Name: "from_json", + Value: builtinFromJSON, }, { - Name: "type_name", - Func: builtinTypeName, + Name: "type_name", + Value: builtinTypeName, }, } diff --git a/runtime/vm.go b/runtime/vm.go index 2708fde7..6e586a98 100644 --- a/runtime/vm.go +++ b/runtime/vm.go @@ -26,7 +26,6 @@ var ( truePtr = &objects.TrueValue falsePtr = &objects.FalseValue undefinedPtr = &objects.UndefinedValue - builtinFuncs []objects.Object ) // VM is a virtual machine that executes the bytecode compiled by Compiler. @@ -43,11 +42,12 @@ type VM struct { curIPLimit int ip int aborting int64 + builtinFuncs []objects.Object builtinModules map[string]*objects.Object } // NewVM creates a VM. -func NewVM(bytecode *compiler.Bytecode, globals []*objects.Object, builtinModules map[string]*objects.Object) *VM { +func NewVM(bytecode *compiler.Bytecode, globals []*objects.Object, builtinFuncs []objects.Object, builtinModules map[string]*objects.Object) *VM { if globals == nil { globals = make([]*objects.Object, GlobalsSize) } @@ -56,6 +56,16 @@ func NewVM(bytecode *compiler.Bytecode, globals []*objects.Object, builtinModule builtinModules = stdlib.Modules } + if builtinFuncs == nil { + builtinFuncs = make([]objects.Object, len(objects.Builtins)) + for idx, fn := range objects.Builtins { + builtinFuncs[idx] = &objects.BuiltinFunction{ + Name: fn.Name, + Value: fn.Value, + } + } + } + frames := make([]Frame, MaxFrames) frames[0].fn = bytecode.MainFunction frames[0].freeVars = nil @@ -74,6 +84,7 @@ func NewVM(bytecode *compiler.Bytecode, globals []*objects.Object, builtinModule curInsts: frames[0].fn.Instructions, curIPLimit: len(frames[0].fn.Instructions) - 1, ip: -1, + builtinFuncs: builtinFuncs, builtinModules: builtinModules, } } @@ -1183,7 +1194,7 @@ mainloop: break mainloop } - v.stack[v.sp] = &builtinFuncs[builtinIndex] + v.stack[v.sp] = &v.builtinFuncs[builtinIndex] v.sp++ case compiler.OpGetBuiltinModule: @@ -1412,13 +1423,3 @@ func indexAssign(dst, src *objects.Object, selectors []*objects.Object) error { return nil } - -func init() { - builtinFuncs = make([]objects.Object, len(objects.Builtins)) - for i, b := range objects.Builtins { - builtinFuncs[i] = &objects.BuiltinFunction{ - Name: b.Name, - Value: b.Func, - } - } -} diff --git a/runtime/vm_test.go b/runtime/vm_test.go index b2aa19f0..c058938d 100644 --- a/runtime/vm_test.go +++ b/runtime/vm_test.go @@ -240,7 +240,7 @@ func traceCompileRun(file *ast.File, symbols map[string]objects.Object, userModu trace = append(trace, fmt.Sprintf("\n[Compiled Constants]\n\n%s", strings.Join(bytecode.FormatConstants(), "\n"))) trace = append(trace, fmt.Sprintf("\n[Compiled Instructions]\n\n%s\n", strings.Join(bytecode.FormatInstructions(), "\n"))) - v = runtime.NewVM(bytecode, globals, nil) + v = runtime.NewVM(bytecode, globals, nil, nil) err = v.Run() { diff --git a/script/script.go b/script/script.go index d084551b..6529c119 100644 --- a/script/script.go +++ b/script/script.go @@ -14,11 +14,11 @@ import ( // Script can simplify compilation and execution of embedded scripts. type Script struct { - variables map[string]*Variable - removedBuiltins map[string]bool - removedStdModules map[string]bool - userModuleLoader compiler.ModuleLoader - input []byte + variables map[string]*Variable + builtinFuncs []objects.Object + builtinModules map[string]*objects.Object + userModuleLoader compiler.ModuleLoader + input []byte } // New creates a Script instance with an input script. @@ -56,22 +56,28 @@ func (s *Script) Remove(name string) bool { return true } -// DisableBuiltinFunction disables a builtin function. -func (s *Script) DisableBuiltinFunction(name string) { - if s.removedBuiltins == nil { - s.removedBuiltins = make(map[string]bool) +// SetBuiltinFunctions allows to define builtin functions. +func (s *Script) SetBuiltinFunctions(funcs []*objects.BuiltinFunction) { + if funcs != nil { + s.builtinFuncs = make([]objects.Object, len(funcs)) + for idx, fn := range funcs { + s.builtinFuncs[idx] = fn + } + } else { + s.builtinFuncs = []objects.Object{} } - - s.removedBuiltins[name] = true } -// DisableStdModule disables a standard library module. -func (s *Script) DisableStdModule(name string) { - if s.removedStdModules == nil { - s.removedStdModules = make(map[string]bool) +// SetBuiltinModules allows to define builtin modules. +func (s *Script) SetBuiltinModules(modules map[string]*objects.ImmutableMap) { + if modules != nil { + s.builtinModules = make(map[string]*objects.Object, len(modules)) + for k, mod := range modules { + s.builtinModules[k] = objectPtr(mod) + } + } else { + s.builtinModules = map[string]*objects.Object{} } - - s.removedStdModules[name] = true } // SetUserModuleLoader sets the user module loader for the compiler. @@ -81,7 +87,7 @@ func (s *Script) SetUserModuleLoader(loader compiler.ModuleLoader) { // Compile compiles the script with all the defined variables, and, returns Compiled object. func (s *Script) Compile() (*Compiled, error) { - symbolTable, stdModules, globals, err := s.prepCompile() + symbolTable, builtinModules, globals, err := s.prepCompile() if err != nil { return nil, err } @@ -95,7 +101,7 @@ func (s *Script) Compile() (*Compiled, error) { return nil, err } - c := compiler.NewCompiler(srcFile, symbolTable, nil, stdModules, nil) + c := compiler.NewCompiler(srcFile, symbolTable, nil, builtinModules, nil) if s.userModuleLoader != nil { c.SetModuleLoader(s.userModuleLoader) @@ -107,7 +113,7 @@ func (s *Script) Compile() (*Compiled, error) { return &Compiled{ symbolTable: symbolTable, - machine: runtime.NewVM(c.Bytecode(), globals, nil), + machine: runtime.NewVM(c.Bytecode(), globals, s.builtinFuncs, s.builtinModules), }, nil } @@ -136,24 +142,36 @@ func (s *Script) RunContext(ctx context.Context) (compiled *Compiled, err error) return } -func (s *Script) prepCompile() (symbolTable *compiler.SymbolTable, stdModules map[string]bool, globals []*objects.Object, err error) { +func (s *Script) prepCompile() (symbolTable *compiler.SymbolTable, builtinModules map[string]bool, globals []*objects.Object, err error) { var names []string for name := range s.variables { names = append(names, name) } symbolTable = compiler.NewSymbolTable() - for idx, fn := range objects.Builtins { - if !s.removedBuiltins[fn.Name] { - symbolTable.DefineBuiltin(idx, fn.Name) + + if s.builtinFuncs == nil { + s.builtinFuncs = make([]objects.Object, len(objects.Builtins)) + for idx, fn := range objects.Builtins { + s.builtinFuncs[idx] = &objects.BuiltinFunction{ + Name: fn.Name, + Value: fn.Value, + } } } - stdModules = make(map[string]bool) - for name := range stdlib.Modules { - if !s.removedStdModules[name] { - stdModules[name] = true - } + if s.builtinModules == nil { + s.builtinModules = stdlib.Modules + } + + for idx, fn := range s.builtinFuncs { + f := fn.(*objects.BuiltinFunction) + symbolTable.DefineBuiltin(idx, f.Name) + } + + builtinModules = make(map[string]bool) + for name := range s.builtinModules { + builtinModules[name] = true } globals = make([]*objects.Object, runtime.GlobalsSize, runtime.GlobalsSize) @@ -178,3 +196,7 @@ func (s *Script) copyVariables() map[string]*Variable { return vars } + +func objectPtr(o objects.Object) *objects.Object { + return &o +} diff --git a/script/script_module_test.go b/script/script_module_test.go index b9793345..71a026e7 100644 --- a/script/script_module_test.go +++ b/script/script_module_test.go @@ -34,7 +34,7 @@ func TestScript_SetUserModuleLoader(t *testing.T) { c, err = scr.Run() assert.NoError(t, err) assert.Equal(t, int64(3), c.Get("out").Value()) - scr.DisableBuiltinFunction("len") + scr.SetBuiltinFunctions(nil) _, err = scr.Run() assert.Error(t, err) @@ -49,7 +49,7 @@ func TestScript_SetUserModuleLoader(t *testing.T) { c, err = scr.Run() assert.NoError(t, err) assert.Equal(t, "Foo", c.Get("out").Value()) - scr.DisableStdModule("text") + scr.SetBuiltinModules(nil) _, err = scr.Run() assert.Error(t, err) diff --git a/script/script_test.go b/script/script_test.go index ee2b7a99..b9fa09c9 100644 --- a/script/script_test.go +++ b/script/script_test.go @@ -4,7 +4,9 @@ import ( "testing" "github.com/d5/tengo/assert" + "github.com/d5/tengo/objects" "github.com/d5/tengo/script" + "github.com/d5/tengo/stdlib" ) func TestScript_Add(t *testing.T) { @@ -37,24 +39,51 @@ func TestScript_Run(t *testing.T) { compiledGet(t, c, "a", int64(5)) } -func TestScript_DisableBuiltinFunction(t *testing.T) { +func TestScript_SetBuiltinFunctions(t *testing.T) { s := script.New([]byte(`a := len([1, 2, 3])`)) c, err := s.Run() assert.NoError(t, err) assert.NotNil(t, c) compiledGet(t, c, "a", int64(3)) - s.DisableBuiltinFunction("len") + + s = script.New([]byte(`a := len([1, 2, 3])`)) + s.SetBuiltinFunctions([]*objects.BuiltinFunction{&objects.Builtins[3]}) + c, err = s.Run() + assert.NoError(t, err) + assert.NotNil(t, c) + compiledGet(t, c, "a", int64(3)) + + s.SetBuiltinFunctions([]*objects.BuiltinFunction{&objects.Builtins[0]}) + _, err = s.Run() + assert.Error(t, err) + + s.SetBuiltinFunctions(nil) _, err = s.Run() assert.Error(t, err) } -func TestScript_DisableStdModule(t *testing.T) { +func TestScript_SetBuiltinModules(t *testing.T) { s := script.New([]byte(`math := import("math"); a := math.abs(-19.84)`)) c, err := s.Run() assert.NoError(t, err) assert.NotNil(t, c) compiledGet(t, c, "a", 19.84) - s.DisableStdModule("math") + + s.SetBuiltinModules(map[string]*objects.ImmutableMap{"math": objectPtr(*stdlib.Modules["math"])}) + c, err = s.Run() + assert.NoError(t, err) + assert.NotNil(t, c) + compiledGet(t, c, "a", 19.84) + + s.SetBuiltinModules(map[string]*objects.ImmutableMap{"os": objectPtr(*stdlib.Modules["os"])}) _, err = s.Run() assert.Error(t, err) + + s.SetBuiltinModules(nil) + _, err = s.Run() + assert.Error(t, err) +} + +func objectPtr(o objects.Object) *objects.ImmutableMap { + return o.(*objects.ImmutableMap) }