Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support redirecting stdout/stdin/stderr to/from other sources #135

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions auxiliary.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ func LoadFile(l *State, fileName, mode string) error {
}
if fileName == "" {
l.PushString("=stdin")
f = os.Stdin
f = l.stdin
} else {
l.PushString("@" + fileName)
var err error
Expand All @@ -509,7 +509,7 @@ func LoadFile(l *State, fileName, mode string) error {
}
s, _ := l.ToString(-1)
err := l.Load(r, s, mode)
if f != os.Stdin {
if f != l.stdin {
_ = f.Close()
}
switch err {
Expand Down Expand Up @@ -538,7 +538,7 @@ func NewStateEx() *State {
if l != nil {
_ = AtPanic(l, func(l *State) int {
s, _ := l.ToString(-1)
fmt.Fprintf(os.Stderr, "PANIC: unprotected error in call to Lua API (%s)\n", s)
fmt.Fprintf(l.stderr, "PANIC: unprotected error in call to Lua API (%s)\n", s)
return 0
})
}
Expand Down
17 changes: 12 additions & 5 deletions base.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package lua

import (
"io"
"os"
"runtime"
"strconv"
"strings"
Expand Down Expand Up @@ -215,13 +214,13 @@ var baseLibrary = []RegistryFunction{
panic("unreachable")
}
if i > 1 {
os.Stdout.WriteString("\t")
l.writeToStdout("\t")
}
os.Stdout.WriteString(s)
l.writeToStdout(s)
l.Pop(1) // pop result
}
os.Stdout.WriteString("\n")
os.Stdout.Sync()
l.writeToStdout("\n")
l.stdout.Sync()
return 0
}},
{"rawequal", func(l *State) int {
Expand Down Expand Up @@ -328,3 +327,11 @@ func BaseOpen(l *State) int {
l.SetField(-2, "_VERSION")
return 1
}

func (l *State) writeToStdout(s string) {
_, err := l.stdout.WriteString(s)
if err != nil {
Errorf(l, "failed writing to stdout: %v", err)
panic("unreachable")
}
}
8 changes: 8 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
module github.com/Shopify/go-lua

go 1.22

require github.com/stretchr/testify v1.9.0

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
10 changes: 10 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
6 changes: 3 additions & 3 deletions io.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,9 @@ func IOOpen(l *State) int {
SetFunctions(l, fileHandleMethods, 0)
l.Pop(1)

registerStdFile(l, os.Stdin, input, "stdin")
registerStdFile(l, os.Stdout, output, "stdout")
registerStdFile(l, os.Stderr, "", "stderr")
registerStdFile(l, l.stdin, input, "stdin")
registerStdFile(l, l.stdout, output, "stdout")
registerStdFile(l, l.stderr, "", "stderr")

return 1
}
13 changes: 12 additions & 1 deletion lua.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"io"
"math"
"os"
"strings"
)

Expand Down Expand Up @@ -239,6 +240,7 @@ type State struct {
errorFunction int // current error handling function (stack index)
baseCallInfo callInfo // callInfo for first level (go calling lua)
protectFunction func()
stdout, stderr, stdin *os.File
}

type globalState struct {
Expand Down Expand Up @@ -455,7 +457,7 @@ func (l *State) Dump(w io.Writer) error {
// http://www.lua.org/manual/5.2/manual.html#lua_newstate
func NewState() *State {
v := float64(VersionNumber)
l := &State{allowHook: true, error: nil, nonYieldableCallCount: 1}
l := &State{allowHook: true, error: nil, nonYieldableCallCount: 1, stdout: os.Stdout, stderr: os.Stderr, stdin: os.Stdin}
g := &globalState{mainThread: l, registry: newTable(), version: &v, memoryErrorMessage: "not enough memory"}
l.global = g
l.initializeStack()
Expand Down Expand Up @@ -1529,3 +1531,12 @@ func (l *State) IsNoneOrNil(index int) bool { return l.TypeOf(index) <= TypeNil
//
// http://www.lua.org/manual/5.2/manual.html#lua_pushglobaltable
func (l *State) PushGlobalTable() { l.RawGetInt(RegistryIndex, RegistryIndexGlobals) }

// SetStdout redirects interpreter stdout to the given *os.File
func (l *State) SetStdout(stdout *os.File) { l.stdout = stdout }

// SetStderr redirects interpreter stderr to the given *os.File
func (l *State) SetStderr(stderr *os.File) { l.stderr = stderr }

// SetStdin redirects interpreter stdin from the given *os.File
func (l *State) SetStdin(stdin *os.File) { l.stdin = stdin }
67 changes: 63 additions & 4 deletions lua_test.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
package lua
package lua_test

import (
"fmt"
"io"
"os"
"testing"

"github.com/Shopify/go-lua"
"github.com/stretchr/testify/assert"
)

func TestPushFStringPointer(t *testing.T) {
l := NewState()
l := lua.NewState()
l.PushFString("%p %s", l, "test")

expected := fmt.Sprintf("%p %s", l, "test")
actual := CheckString(l, -1)
actual := lua.CheckString(l, -1)
if expected != actual {
t.Errorf("PushFString, expected \"%s\" but found \"%s\"", expected, actual)
}
}

func TestToBooleanOutOfRange(t *testing.T) {
l := NewState()
l := lua.NewState()
l.SetTop(0)
l.PushBoolean(false)
l.PushBoolean(true)
Expand All @@ -29,3 +34,57 @@ func TestToBooleanOutOfRange(t *testing.T) {
}
}
}

func TestInputOutputErrorRedirect(t *testing.T) {
t.Run("test redirect stdout", func(t *testing.T) {
// use a variable as stdout
reader, writer, err := os.Pipe()
assert.Nil(t, err)

// setup runtime state
l := lua.NewState()
l.SetStdout(writer)
lua.OpenLibraries(l)

// run lua code
err = lua.DoString(l, "print(1 + 1)")
assert.Nil(t, err)

writer.Close()
output, err := io.ReadAll(reader)
assert.Nil(t, err)

assert.Equal(t, "2\n", string(output))
})

t.Run("test std redirect", func(t *testing.T) {
// create a pipe to stdin and add some lua code it it
inputReader, inputWriter, err := os.Pipe()
assert.Nil(t, err)

outputReader, outputWriter, err := os.Pipe()
assert.Nil(t, err)

// setup runtime state
l := lua.NewState()
l.SetStdin(inputReader)
l.SetStdout(outputWriter)
lua.OpenLibraries(l)

// write to the file input
_, err = inputWriter.Write([]byte("print(1 + 1)"))
assert.Nil(t, err)
assert.Nil(t, inputWriter.Close())

// run lua code
err = lua.DoFile(l, "")
assert.Nil(t, err)
assert.Nil(t, outputWriter.Close())

output, err := io.ReadAll(outputReader)
assert.Nil(t, err)

assert.Equal(t, "2\n", string(output))
})

}
Loading