Skip to content

Commit

Permalink
add a test
Browse files Browse the repository at this point in the history
  • Loading branch information
shueybubbles committed Jun 4, 2024
1 parent 77dcb83 commit 0c4f46d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
27 changes: 14 additions & 13 deletions pkg/sqlcmd/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,16 +264,17 @@ func outCommand(s *Sqlcmd, args []string, line uint) error {
if len(args) == 0 || args[0] == "" {
return InvalidCommandError("OUT", line)
}
filePath, err := resolveArgumentVariables(s, []rune(args[0]), true)
if err != nil {
return err
}

switch {
case strings.EqualFold(args[0], "stdout"):
case strings.EqualFold(filePath, "stdout"):
s.SetOutput(os.Stdout)
case strings.EqualFold(args[0], "stderr"):
case strings.EqualFold(filePath, "stderr"):
s.SetOutput(os.Stderr)
default:
filePath, err := resolveArgumentVariables(s, []rune(args[0]), true)
if err != nil {
return err
}
o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return InvalidFileError(err, args[0])
Expand All @@ -294,18 +295,18 @@ func outCommand(s *Sqlcmd, args []string, line uint) error {
// errorCommand changes the error writer to use a file
func errorCommand(s *Sqlcmd, args []string, line uint) error {
if len(args) == 0 || args[0] == "" {
return InvalidCommandError("OUT", line)
return InvalidCommandError("ERROR", line)
}
filePath, err := resolveArgumentVariables(s, []rune(args[0]), true)
if err != nil {
return err
}
switch {
case strings.EqualFold(args[0], "stderr"):
case strings.EqualFold(filePath, "stderr"):
s.SetError(os.Stderr)
case strings.EqualFold(args[0], "stdout"):
case strings.EqualFold(filePath, "stdout"):
s.SetError(os.Stdout)
default:
filePath, err := resolveArgumentVariables(s, []rune(args[0]), true)
if err != nil {
return err
}
o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return InvalidFileError(err, args[0])
Expand Down
11 changes: 10 additions & 1 deletion pkg/sqlcmd/commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ func TestConnectCommand(t *testing.T) {

func TestErrorCommand(t *testing.T) {
s, buf := setupSqlCmdWithMemoryOutput(t)
defer s.SetError(nil)
defer buf.Close()
file, err := os.CreateTemp("", "sqlcmderr")
assert.NoError(t, err, "os.CreateTemp")
Expand All @@ -259,11 +260,14 @@ func TestErrorCommand(t *testing.T) {
// Only some error kinds go to the error output
err = runSqlCmd(t, s, []string{"print N'message'", "RAISERROR(N'Error', 16, 1)", "SELECT 1", ":SETVAR 1", "GO"})
assert.NoError(t, err, "runSqlCmd")
s.SetError(nil)
errText, err := os.ReadFile(file.Name())
if assert.NoError(t, err, "ReadFile") {
assert.Regexp(t, "Msg 50000, Level 16, State 1, Server .*, Line 2"+SqlcmdEol+"Error"+SqlcmdEol, string(errText), "Error file contents: "+string(errText))
}
s.vars.Set("myvar", "stdout")
err = errorCommand(s, []string{"$(myvar)"}, 1)
assert.NoError(t, err, "errorCommand with a variable")
assert.Equal(t, os.Stdout, s.err, "error set to stdout using a variable")
}

func TestOnErrorCommand(t *testing.T) {
Expand Down Expand Up @@ -320,6 +324,11 @@ func TestResolveArgumentVariables(t *testing.T) {
if assert.ErrorContains(t, err, UndefinedVariable("var2").Error(), "fail on unresolved variable") {
assert.Empty(t, actual, "fail on unresolved variable")
}
s.Connect.DisableVariableSubstitution = true
input := "$(var1) notvar"
actual, err = resolveArgumentVariables(s, []rune(input), true)
assert.NoError(t, err)
assert.Equal(t, input, actual, "resolveArgumentVariables when DisableVariableSubstitution is false")
}

func TestExecCommand(t *testing.T) {
Expand Down

0 comments on commit 0c4f46d

Please sign in to comment.