From 0c4f46d4835aec1356b3e8691d28e3eaf8036b09 Mon Sep 17 00:00:00 2001 From: davidshi Date: Tue, 4 Jun 2024 09:16:20 -0500 Subject: [PATCH] add a test --- pkg/sqlcmd/commands.go | 27 ++++++++++++++------------- pkg/sqlcmd/commands_test.go | 11 ++++++++++- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index acbfa3d2..31c2d9c0 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -264,16 +264,17 @@ func outCommand(s *Sqlcmd, args []string, line uint) error { if len(args) == 0 || args[0] == "" { return InvalidCommandError("OUT", line) } + filePath, err := resolveArgumentVariables(s, []rune(args[0]), true) + if err != nil { + return err + } + switch { - case strings.EqualFold(args[0], "stdout"): + case strings.EqualFold(filePath, "stdout"): s.SetOutput(os.Stdout) - case strings.EqualFold(args[0], "stderr"): + case strings.EqualFold(filePath, "stderr"): s.SetOutput(os.Stderr) default: - filePath, err := resolveArgumentVariables(s, []rune(args[0]), true) - if err != nil { - return err - } o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) if err != nil { return InvalidFileError(err, args[0]) @@ -294,18 +295,18 @@ func outCommand(s *Sqlcmd, args []string, line uint) error { // errorCommand changes the error writer to use a file func errorCommand(s *Sqlcmd, args []string, line uint) error { if len(args) == 0 || args[0] == "" { - return InvalidCommandError("OUT", line) + return InvalidCommandError("ERROR", line) + } + filePath, err := resolveArgumentVariables(s, []rune(args[0]), true) + if err != nil { + return err } switch { - case strings.EqualFold(args[0], "stderr"): + case strings.EqualFold(filePath, "stderr"): s.SetError(os.Stderr) - case strings.EqualFold(args[0], "stdout"): + case strings.EqualFold(filePath, "stdout"): s.SetError(os.Stdout) default: - filePath, err := resolveArgumentVariables(s, []rune(args[0]), true) - if err != nil { - return err - } o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) if err != nil { return InvalidFileError(err, args[0]) diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index 94caeb2e..9244c660 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -246,6 +246,7 @@ func TestConnectCommand(t *testing.T) { func TestErrorCommand(t *testing.T) { s, buf := setupSqlCmdWithMemoryOutput(t) + defer s.SetError(nil) defer buf.Close() file, err := os.CreateTemp("", "sqlcmderr") assert.NoError(t, err, "os.CreateTemp") @@ -259,11 +260,14 @@ func TestErrorCommand(t *testing.T) { // Only some error kinds go to the error output err = runSqlCmd(t, s, []string{"print N'message'", "RAISERROR(N'Error', 16, 1)", "SELECT 1", ":SETVAR 1", "GO"}) assert.NoError(t, err, "runSqlCmd") - s.SetError(nil) errText, err := os.ReadFile(file.Name()) if assert.NoError(t, err, "ReadFile") { assert.Regexp(t, "Msg 50000, Level 16, State 1, Server .*, Line 2"+SqlcmdEol+"Error"+SqlcmdEol, string(errText), "Error file contents: "+string(errText)) } + s.vars.Set("myvar", "stdout") + err = errorCommand(s, []string{"$(myvar)"}, 1) + assert.NoError(t, err, "errorCommand with a variable") + assert.Equal(t, os.Stdout, s.err, "error set to stdout using a variable") } func TestOnErrorCommand(t *testing.T) { @@ -320,6 +324,11 @@ func TestResolveArgumentVariables(t *testing.T) { if assert.ErrorContains(t, err, UndefinedVariable("var2").Error(), "fail on unresolved variable") { assert.Empty(t, actual, "fail on unresolved variable") } + s.Connect.DisableVariableSubstitution = true + input := "$(var1) notvar" + actual, err = resolveArgumentVariables(s, []rune(input), true) + assert.NoError(t, err) + assert.Equal(t, input, actual, "resolveArgumentVariables when DisableVariableSubstitution is false") } func TestExecCommand(t *testing.T) {