diff --git a/auxiliary.go b/auxiliary.go index ebfc34e..80141d8 100644 --- a/auxiliary.go +++ b/auxiliary.go @@ -478,6 +478,7 @@ func skipComment(r *bufio.Reader) (bool, error) { func LoadFile(l *State, fileName, mode string) error { var f *os.File + var skipClose bool fileNameIndex := l.Top() + 1 fileError := func(what string) error { fileName, _ := l.ToString(fileNameIndex) @@ -487,7 +488,8 @@ func LoadFile(l *State, fileName, mode string) error { } if fileName == "" { l.PushString("=stdin") - f = os.Stdin + f = l.stdin + skipClose = true } else { l.PushString("@" + fileName) var err error @@ -504,7 +506,7 @@ func LoadFile(l *State, fileName, mode string) error { } s, _ := l.ToString(-1) err := l.Load(r, s, mode) - if f != os.Stdin { + if !skipClose { _ = f.Close() } if err != nil { @@ -531,7 +533,8 @@ func NewStateEx() *State { if l != nil { _ = AtPanic(l, func(l *State) int { s, _ := l.ToString(-1) - fmt.Fprintf(os.Stderr, "PANIC: unprotected error in call to Lua API (%s)\n", s) + fmt.Fprintf(l.stderr, + "PANIC: unprotected error in call to Lua API (%s)\n", s) return 0 }) } diff --git a/base.go b/base.go index 2cffec2..cfbad53 100644 --- a/base.go +++ b/base.go @@ -2,7 +2,6 @@ package lua import ( "io" - "os" "runtime" "strconv" "strings" @@ -215,13 +214,25 @@ var baseLibrary = []RegistryFunction{ panic("unreachable") } if i > 1 { - os.Stdout.WriteString("\t") + _, err := l.stdout.WriteString("\t") + if err != nil { + Errorf(l, "failed writing to stdout: %v", err) + panic("unreachable") + } + } + _, err := l.stdout.WriteString(s) + if err != nil { + Errorf(l, "failed writing to stdout: %v", err) + panic("unreachable") } - os.Stdout.WriteString(s) l.Pop(1) // pop result } - os.Stdout.WriteString("\n") - os.Stdout.Sync() + _, err := l.stdout.WriteString("\n") + if err != nil { + Errorf(l, "failed writing to stdout: %v", err) + panic("unreachable") + } + l.stdout.Sync() return 0 }}, {"rawequal", func(l *State) int { diff --git a/io.go b/io.go index 2bacd86..5bb5502 100644 --- a/io.go +++ b/io.go @@ -316,9 +316,9 @@ func IOOpen(l *State) int { SetFunctions(l, fileHandleMethods, 0) l.Pop(1) - registerStdFile(l, os.Stdin, input, "stdin") - registerStdFile(l, os.Stdout, output, "stdout") - registerStdFile(l, os.Stderr, "", "stderr") + registerStdFile(l, l.stdin, input, "stdin") + registerStdFile(l, l.stdout, output, "stdout") + registerStdFile(l, l.stderr, "", "stderr") return 1 } diff --git a/lua.go b/lua.go index 749247a..2b548ff 100644 --- a/lua.go +++ b/lua.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "math" + "os" "strings" ) @@ -239,6 +240,7 @@ type State struct { errorFunction int // current error handling function (stack index) baseCallInfo callInfo // callInfo for first level (go calling lua) protectFunction func() + stdout, stderr, stdin *os.File } type globalState struct { @@ -442,7 +444,8 @@ func (l *State) Load(r io.Reader, chunkName string, mode string) error { // http://www.lua.org/manual/5.2/manual.html#lua_newstate func NewState() *State { v := float64(VersionNumber) - l := &State{allowHook: true, error: nil, nonYieldableCallCount: 1} + l := &State{allowHook: true, error: nil, nonYieldableCallCount: 1, + stdout: os.Stdout, stderr: os.Stderr, stdin: os.Stdin} g := &globalState{mainThread: l, registry: newTable(), version: &v, memoryErrorMessage: "not enough memory"} l.global = g l.initializeStack() @@ -1515,3 +1518,12 @@ func (l *State) IsNoneOrNil(index int) bool { return l.TypeOf(index) <= TypeNil // // http://www.lua.org/manual/5.2/manual.html#lua_pushglobaltable func (l *State) PushGlobalTable() { l.RawGetInt(RegistryIndex, RegistryIndexGlobals) } + +// SetStdout redirects interpreter stdout to the given *os.File +func (l *State) SetStdout(stdout *os.File) { l.stdout = stdout } + +// SetStderr redirects interpreter stderr to the given *os.File +func (l *State) SetStderr(stderr *os.File) { l.stderr = stderr } + +// SetStdin redirects interpreter stdin from the given *os.File +func (l *State) SetStdin(stdin *os.File) { l.stdin = stdin }