From b1701d16ef43915830b1abaab5581054c1d080ab Mon Sep 17 00:00:00 2001 From: JT Olds Date: Sun, 27 Sep 2015 12:33:28 -0600 Subject: [PATCH] Support redirecting stdout/stdin/stderr to/from other sources Like os/exec's Cmd object, it is sometimes useful to be able to control where and how the child program is interacting with the world. By doing this, we can now support building lua shells over telnet or other things. --- auxiliary.go | 9 ++++++--- base.go | 21 ++++++++++++++++----- io.go | 6 +++--- lua.go | 14 +++++++++++++- 4 files changed, 38 insertions(+), 12 deletions(-) diff --git a/auxiliary.go b/auxiliary.go index ebfc34e..80141d8 100644 --- a/auxiliary.go +++ b/auxiliary.go @@ -478,6 +478,7 @@ func skipComment(r *bufio.Reader) (bool, error) { func LoadFile(l *State, fileName, mode string) error { var f *os.File + var skipClose bool fileNameIndex := l.Top() + 1 fileError := func(what string) error { fileName, _ := l.ToString(fileNameIndex) @@ -487,7 +488,8 @@ func LoadFile(l *State, fileName, mode string) error { } if fileName == "" { l.PushString("=stdin") - f = os.Stdin + f = l.stdin + skipClose = true } else { l.PushString("@" + fileName) var err error @@ -504,7 +506,7 @@ func LoadFile(l *State, fileName, mode string) error { } s, _ := l.ToString(-1) err := l.Load(r, s, mode) - if f != os.Stdin { + if !skipClose { _ = f.Close() } if err != nil { @@ -531,7 +533,8 @@ func NewStateEx() *State { if l != nil { _ = AtPanic(l, func(l *State) int { s, _ := l.ToString(-1) - fmt.Fprintf(os.Stderr, "PANIC: unprotected error in call to Lua API (%s)\n", s) + fmt.Fprintf(l.stderr, + "PANIC: unprotected error in call to Lua API (%s)\n", s) return 0 }) } diff --git a/base.go b/base.go index 2cffec2..cfbad53 100644 --- a/base.go +++ b/base.go @@ -2,7 +2,6 @@ package lua import ( "io" - "os" "runtime" "strconv" "strings" @@ -215,13 +214,25 @@ var baseLibrary = []RegistryFunction{ panic("unreachable") } if i > 1 { - os.Stdout.WriteString("\t") + _, err := l.stdout.WriteString("\t") + if err != nil { + Errorf(l, "failed writing to stdout: %v", err) + panic("unreachable") + } + } + _, err := l.stdout.WriteString(s) + if err != nil { + Errorf(l, "failed writing to stdout: %v", err) + panic("unreachable") } - os.Stdout.WriteString(s) l.Pop(1) // pop result } - os.Stdout.WriteString("\n") - os.Stdout.Sync() + _, err := l.stdout.WriteString("\n") + if err != nil { + Errorf(l, "failed writing to stdout: %v", err) + panic("unreachable") + } + l.stdout.Sync() return 0 }}, {"rawequal", func(l *State) int { diff --git a/io.go b/io.go index 2bacd86..5bb5502 100644 --- a/io.go +++ b/io.go @@ -316,9 +316,9 @@ func IOOpen(l *State) int { SetFunctions(l, fileHandleMethods, 0) l.Pop(1) - registerStdFile(l, os.Stdin, input, "stdin") - registerStdFile(l, os.Stdout, output, "stdout") - registerStdFile(l, os.Stderr, "", "stderr") + registerStdFile(l, l.stdin, input, "stdin") + registerStdFile(l, l.stdout, output, "stdout") + registerStdFile(l, l.stderr, "", "stderr") return 1 } diff --git a/lua.go b/lua.go index 749247a..2b548ff 100644 --- a/lua.go +++ b/lua.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "math" + "os" "strings" ) @@ -239,6 +240,7 @@ type State struct { errorFunction int // current error handling function (stack index) baseCallInfo callInfo // callInfo for first level (go calling lua) protectFunction func() + stdout, stderr, stdin *os.File } type globalState struct { @@ -442,7 +444,8 @@ func (l *State) Load(r io.Reader, chunkName string, mode string) error { // http://www.lua.org/manual/5.2/manual.html#lua_newstate func NewState() *State { v := float64(VersionNumber) - l := &State{allowHook: true, error: nil, nonYieldableCallCount: 1} + l := &State{allowHook: true, error: nil, nonYieldableCallCount: 1, + stdout: os.Stdout, stderr: os.Stderr, stdin: os.Stdin} g := &globalState{mainThread: l, registry: newTable(), version: &v, memoryErrorMessage: "not enough memory"} l.global = g l.initializeStack() @@ -1515,3 +1518,12 @@ func (l *State) IsNoneOrNil(index int) bool { return l.TypeOf(index) <= TypeNil // // http://www.lua.org/manual/5.2/manual.html#lua_pushglobaltable func (l *State) PushGlobalTable() { l.RawGetInt(RegistryIndex, RegistryIndexGlobals) } + +// SetStdout redirects interpreter stdout to the given *os.File +func (l *State) SetStdout(stdout *os.File) { l.stdout = stdout } + +// SetStderr redirects interpreter stderr to the given *os.File +func (l *State) SetStderr(stderr *os.File) { l.stderr = stderr } + +// SetStdin redirects interpreter stdin from the given *os.File +func (l *State) SetStdin(stdin *os.File) { l.stdin = stdin }