From 13adf4ca4df4929c9838eb16372c11539fbe156e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Ant=C3=B4nio?= Date: Mon, 20 May 2024 13:06:28 -0300 Subject: [PATCH 1/3] fea: reuse @jtolio implementation - Base PR: https://github.com/Shopify/go-lua/pull/43/ --- auxiliary.go | 6 +++--- base.go | 17 ++++++++++++----- go.mod | 8 ++++++++ io.go | 6 +++--- lua.go | 13 ++++++++++++- 5 files changed, 38 insertions(+), 12 deletions(-) diff --git a/auxiliary.go b/auxiliary.go index 78dc064..60c3a28 100644 --- a/auxiliary.go +++ b/auxiliary.go @@ -492,7 +492,7 @@ func LoadFile(l *State, fileName, mode string) error { } if fileName == "" { l.PushString("=stdin") - f = os.Stdin + f = l.stdin } else { l.PushString("@" + fileName) var err error @@ -509,7 +509,7 @@ func LoadFile(l *State, fileName, mode string) error { } s, _ := l.ToString(-1) err := l.Load(r, s, mode) - if f != os.Stdin { + if f != l.stdin { _ = f.Close() } switch err { @@ -538,7 +538,7 @@ func NewStateEx() *State { if l != nil { _ = AtPanic(l, func(l *State) int { s, _ := l.ToString(-1) - fmt.Fprintf(os.Stderr, "PANIC: unprotected error in call to Lua API (%s)\n", s) + fmt.Fprintf(l.stderr, "PANIC: unprotected error in call to Lua API (%s)\n", s) return 0 }) } diff --git a/base.go b/base.go index 2cffec2..e352c5e 100644 --- a/base.go +++ b/base.go @@ -2,7 +2,6 @@ package lua import ( "io" - "os" "runtime" "strconv" "strings" @@ -215,13 +214,13 @@ var baseLibrary = []RegistryFunction{ panic("unreachable") } if i > 1 { - os.Stdout.WriteString("\t") + l.writeToStdout("\t") } - os.Stdout.WriteString(s) + l.writeToStdout(s) l.Pop(1) // pop result } - os.Stdout.WriteString("\n") - os.Stdout.Sync() + l.writeToStdout("\n") + l.stdout.Sync() return 0 }}, {"rawequal", func(l *State) int { @@ -328,3 +327,11 @@ func BaseOpen(l *State) int { l.SetField(-2, "_VERSION") return 1 } + +func (l *State) writeToStdout(s string) { + _, err := l.stdout.WriteString(s) + if err != nil { + Errorf(l, "failed writing to stdout: %v", err) + panic("unreachable") + } +} diff --git a/go.mod b/go.mod index 34db174..50b4f06 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,11 @@ module github.com/Shopify/go-lua go 1.22 + +require github.com/stretchr/testify v1.9.0 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/io.go b/io.go index 2bacd86..5bb5502 100644 --- a/io.go +++ b/io.go @@ -316,9 +316,9 @@ func IOOpen(l *State) int { SetFunctions(l, fileHandleMethods, 0) l.Pop(1) - registerStdFile(l, os.Stdin, input, "stdin") - registerStdFile(l, os.Stdout, output, "stdout") - registerStdFile(l, os.Stderr, "", "stderr") + registerStdFile(l, l.stdin, input, "stdin") + registerStdFile(l, l.stdout, output, "stdout") + registerStdFile(l, l.stderr, "", "stderr") return 1 } diff --git a/lua.go b/lua.go index 68514e8..999c3f2 100644 --- a/lua.go +++ b/lua.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "math" + "os" "strings" ) @@ -239,6 +240,7 @@ type State struct { errorFunction int // current error handling function (stack index) baseCallInfo callInfo // callInfo for first level (go calling lua) protectFunction func() + stdout, stderr, stdin *os.File } type globalState struct { @@ -455,7 +457,7 @@ func (l *State) Dump(w io.Writer) error { // http://www.lua.org/manual/5.2/manual.html#lua_newstate func NewState() *State { v := float64(VersionNumber) - l := &State{allowHook: true, error: nil, nonYieldableCallCount: 1} + l := &State{allowHook: true, error: nil, nonYieldableCallCount: 1, stdout: os.Stdout, stderr: os.Stderr, stdin: os.Stdin} g := &globalState{mainThread: l, registry: newTable(), version: &v, memoryErrorMessage: "not enough memory"} l.global = g l.initializeStack() @@ -1529,3 +1531,12 @@ func (l *State) IsNoneOrNil(index int) bool { return l.TypeOf(index) <= TypeNil // // http://www.lua.org/manual/5.2/manual.html#lua_pushglobaltable func (l *State) PushGlobalTable() { l.RawGetInt(RegistryIndex, RegistryIndexGlobals) } + +// SetStdout redirects interpreter stdout to the given *os.File +func (l *State) SetStdout(stdout *os.File) { l.stdout = stdout } + +// SetStderr redirects interpreter stderr to the given *os.File +func (l *State) SetStderr(stderr *os.File) { l.stderr = stderr } + +// SetStdin redirects interpreter stdin from the given *os.File +func (l *State) SetStdin(stdin *os.File) { l.stdin = stdin } From a132991df2a80d4f5c85ebd4a3bbb6e2fd8ed2c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Ant=C3=B4nio?= Date: Mon, 20 May 2024 14:43:13 -0300 Subject: [PATCH 2/3] ref: move lua test to lua_test package - to avoid confusion with the lua package - and also fix missing references when running tests --- go.sum | 10 ++++++++++ lua_test.go | 10 ++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) create mode 100644 go.sum diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..60ce688 --- /dev/null +++ b/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/lua_test.go b/lua_test.go index 57bbd9a..1051ee9 100644 --- a/lua_test.go +++ b/lua_test.go @@ -1,23 +1,25 @@ -package lua +package lua_test import ( "fmt" "testing" + + "github.com/Shopify/go-lua" ) func TestPushFStringPointer(t *testing.T) { - l := NewState() + l := lua.NewState() l.PushFString("%p %s", l, "test") expected := fmt.Sprintf("%p %s", l, "test") - actual := CheckString(l, -1) + actual := lua.CheckString(l, -1) if expected != actual { t.Errorf("PushFString, expected \"%s\" but found \"%s\"", expected, actual) } } func TestToBooleanOutOfRange(t *testing.T) { - l := NewState() + l := lua.NewState() l.SetTop(0) l.PushBoolean(false) l.PushBoolean(true) From 3498ae7e7cd85f258cdcd7b6aa26df2a1441ea6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Ant=C3=B4nio?= Date: Mon, 20 May 2024 15:27:26 -0300 Subject: [PATCH 3/3] enh: test input and output redirect --- lua_test.go | 57 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/lua_test.go b/lua_test.go index 1051ee9..e4bfdca 100644 --- a/lua_test.go +++ b/lua_test.go @@ -2,9 +2,12 @@ package lua_test import ( "fmt" + "io" + "os" "testing" "github.com/Shopify/go-lua" + "github.com/stretchr/testify/assert" ) func TestPushFStringPointer(t *testing.T) { @@ -31,3 +34,57 @@ func TestToBooleanOutOfRange(t *testing.T) { } } } + +func TestInputOutputErrorRedirect(t *testing.T) { + t.Run("test redirect stdout", func(t *testing.T) { + // use a variable as stdout + reader, writer, err := os.Pipe() + assert.Nil(t, err) + + // setup runtime state + l := lua.NewState() + l.SetStdout(writer) + lua.OpenLibraries(l) + + // run lua code + err = lua.DoString(l, "print(1 + 1)") + assert.Nil(t, err) + + writer.Close() + output, err := io.ReadAll(reader) + assert.Nil(t, err) + + assert.Equal(t, "2\n", string(output)) + }) + + t.Run("test std redirect", func(t *testing.T) { + // create a pipe to stdin and add some lua code it it + inputReader, inputWriter, err := os.Pipe() + assert.Nil(t, err) + + outputReader, outputWriter, err := os.Pipe() + assert.Nil(t, err) + + // setup runtime state + l := lua.NewState() + l.SetStdin(inputReader) + l.SetStdout(outputWriter) + lua.OpenLibraries(l) + + // write to the file input + _, err = inputWriter.Write([]byte("print(1 + 1)")) + assert.Nil(t, err) + assert.Nil(t, inputWriter.Close()) + + // run lua code + err = lua.DoFile(l, "") + assert.Nil(t, err) + assert.Nil(t, outputWriter.Close()) + + output, err := io.ReadAll(outputReader) + assert.Nil(t, err) + + assert.Equal(t, "2\n", string(output)) + }) + +}