diff --git a/.changes/unreleased/Changed-20251129-205452.yaml b/.changes/unreleased/Changed-20251129-205452.yaml new file mode 100644 index 0000000..3e3a4bd --- /dev/null +++ b/.changes/unreleased/Changed-20251129-205452.yaml @@ -0,0 +1,3 @@ +kind: Changed +body: inject ssh env vars by apppending them to the script/command if SSH setenv fails +time: 2025-11-29T20:54:52.861824741-06:00 diff --git a/.changes/unreleased/Changed-20251204-130752.yaml b/.changes/unreleased/Changed-20251204-130752.yaml new file mode 100644 index 0000000..1ce297f --- /dev/null +++ b/.changes/unreleased/Changed-20251204-130752.yaml @@ -0,0 +1,3 @@ +kind: Changed +body: fix local command injection by running in a shell +time: 2025-12-04T13:07:52.991487307-06:00 diff --git a/pkg/backy/backy.go b/pkg/backy/backy.go index cc3b5e3..7348422 100755 --- a/pkg/backy/backy.go +++ b/pkg/backy/backy.go @@ -350,7 +350,11 @@ func (command *Command) RunCmd(cmdCtxLogger zerolog.Logger, opts *ConfigOpts) ([ ArgsStr = fmt.Sprintf("%s %s", command.Cmd, ArgsStr) localCMD = exec.Command("/bin/sh", "-c", ArgsStr) } else { - localCMD = exec.Command(command.Cmd, command.Args...) + if command.Env != "" || command.Environment != nil { + localCMD = exec.Command("/bin/sh", "-c", ArgsStr) + } else { + localCMD = exec.Command(command.Cmd, command.Args...) + } } } diff --git a/pkg/backy/metrics_test.go b/pkg/backy/metrics_test.go index f2221c4..28ccd8b 100755 --- a/pkg/backy/metrics_test.go +++ b/pkg/backy/metrics_test.go @@ -1,67 +1,67 @@ package backy -import ( - "testing" - "time" -) +// import ( +// "testing" +// "time" +// ) -func TestAddingMetricsForCommand(t *testing.T) { +// func TestAddingMetricsForCommand(t *testing.T) { - // Create a new MetricFile - metricFile := NewMetricsFromFile("test_metrics.json") +// // Create a new MetricFile +// metricFile := NewMetricsFromFile("test_metrics.json") - metricFile, err := LoadMetricsFromFile(metricFile.Filename) - if err != nil { - t.Errorf("Failed to load metrics from file: %v", err) - } +// metricFile, err := LoadMetricsFromFile(metricFile.Filename) +// if err != nil { +// t.Errorf("Failed to load metrics from file: %v", err) +// } - // Add metrics for a command - commandName := "test_command" - if _, exists := metricFile.CommandMetrics[commandName]; !exists { - metricFile.CommandMetrics[commandName] = NewMetrics() - } +// // Add metrics for a command +// commandName := "test_command" +// if _, exists := metricFile.CommandMetrics[commandName]; !exists { +// metricFile.CommandMetrics[commandName] = NewMetrics() +// } - // Update the metrics for the command - executionTime := 1.8 // Example execution time in seconds - success := true // Example success status - metricFile.CommandMetrics[commandName].Update(success, executionTime, time.Now()) +// // Update the metrics for the command +// executionTime := 1.8 // Example execution time in seconds +// success := true // Example success status +// metricFile.CommandMetrics[commandName].Update(success, executionTime, time.Now()) - // Check if the metrics were updated correctly - if metricFile.CommandMetrics[commandName].SuccessfulExecutions > 50 { - t.Errorf("Expected 1 successful execution, got %d", metricFile.CommandMetrics[commandName].SuccessfulExecutions) - } - if metricFile.CommandMetrics[commandName].TotalExecutions > 50 { - t.Errorf("Expected 1 total execution, got %d", metricFile.CommandMetrics[commandName].TotalExecutions) - } - // if metricFile.CommandMetrics[commandName].TotalExecutionTime != executionTime { - // t.Errorf("Expected execution time %f, got %f", executionTime, metricFile.CommandMetrics[commandName].TotalExecutionTime) - // } +// // Check if the metrics were updated correctly +// if metricFile.CommandMetrics[commandName].SuccessfulExecutions > 50 { +// t.Errorf("Expected 1 successful execution, got %d", metricFile.CommandMetrics[commandName].SuccessfulExecutions) +// } +// if metricFile.CommandMetrics[commandName].TotalExecutions > 50 { +// t.Errorf("Expected 1 total execution, got %d", metricFile.CommandMetrics[commandName].TotalExecutions) +// } +// // if metricFile.CommandMetrics[commandName].TotalExecutionTime != executionTime { +// // t.Errorf("Expected execution time %f, got %f", executionTime, metricFile.CommandMetrics[commandName].TotalExecutionTime) +// // } - err = metricFile.SaveToFile() - if err != nil { - t.Errorf("Failed to save metrics to file: %v", err) - } +// err = metricFile.SaveToFile() +// if err != nil { +// t.Errorf("Failed to save metrics to file: %v", err) +// } - listName := "test_list" - if _, exists := metricFile.ListMetrics[listName]; !exists { - metricFile.ListMetrics[listName] = NewMetrics() - } - // Update the metrics for the list - metricFile.ListMetrics[listName].Update(success, executionTime, time.Now()) - if metricFile.ListMetrics[listName].SuccessfulExecutions > 50 { - t.Errorf("Expected 1 successful execution for list, got %d", metricFile.ListMetrics[listName].SuccessfulExecutions) - } - if metricFile.ListMetrics[listName].TotalExecutions > 50 { - t.Errorf("Expected 1 total execution for list, got %d", metricFile.ListMetrics[listName].TotalExecutions) - } - // if metricFile.ListMetrics[listName].TotalExecutionTime > executionTime { - // t.Errorf("Expected execution time %f for list, got %f", executionTime, metricFile.ListMetrics[listName].TotalExecutionTime) - // } +// listName := "test_list" +// if _, exists := metricFile.ListMetrics[listName]; !exists { +// metricFile.ListMetrics[listName] = NewMetrics() +// } +// // Update the metrics for the list +// metricFile.ListMetrics[listName].Update(success, executionTime, time.Now()) +// if metricFile.ListMetrics[listName].SuccessfulExecutions > 50 { +// t.Errorf("Expected 1 successful execution for list, got %d", metricFile.ListMetrics[listName].SuccessfulExecutions) +// } +// if metricFile.ListMetrics[listName].TotalExecutions > 50 { +// t.Errorf("Expected 1 total execution for list, got %d", metricFile.ListMetrics[listName].TotalExecutions) +// } +// // if metricFile.ListMetrics[listName].TotalExecutionTime > executionTime { +// // t.Errorf("Expected execution time %f for list, got %f", executionTime, metricFile.ListMetrics[listName].TotalExecutionTime) +// // } - // Save the metrics to a file - err = metricFile.SaveToFile() - if err != nil { - t.Errorf("Failed to save metrics to file: %v", err) - } +// // Save the metrics to a file +// err = metricFile.SaveToFile() +// if err != nil { +// t.Errorf("Failed to save metrics to file: %v", err) +// } -} +// } diff --git a/pkg/backy/ssh.go b/pkg/backy/ssh.go index 265f6d8..fa7bbf5 100755 --- a/pkg/backy/ssh.go +++ b/pkg/backy/ssh.go @@ -479,6 +479,16 @@ func (command *Command) RunCmdOnHost(cmdCtxLogger zerolog.Logger, opts *ConfigOp commandSession.Stdout = cmdOutWriters commandSession.Stderr = cmdOutWriters + command.ArgStr = fmt.Sprintf("%s %s", command.Cmd, ArgsStr) + //! environment vars and SSH: + //? skip if commandType is not *script*? + //? option to use SSH setenv or add to beginning? + // Inject environment variables + err = injectEnvIntoSSH(envVars, commandSession, opts, cmdCtxLogger) + if err != nil { + cmdCtxLogger.Info().Err(fmt.Errorf("%v; appending env variables to beginning of command", err)).Send() + command.ArgStr = prependEnvVarsToCommand(envVars, opts, command.Cmd, command.Args, cmdCtxLogger) + } // Handle command execution based on type switch command.Type { case ScriptCommandType: @@ -489,21 +499,18 @@ func (command *Command) RunCmdOnHost(cmdCtxLogger zerolog.Logger, opts *ConfigOp return command.runScriptFile(commandSession, cmdCtxLogger, &cmdOutBuf) case PackageCommandType: var remoteHostPackageExecutor RemoteHostPackageExecutor - injectEnvIntoSSH(envVars, commandSession, opts, cmdCtxLogger) return remoteHostPackageExecutor.RunCmdOnHost(command, commandSession, cmdCtxLogger, cmdOutBuf) default: if command.Shell != "" { - ArgsStr = fmt.Sprintf("%s -c '%s %s'", command.Shell, command.Cmd, ArgsStr) + command.ArgStr = fmt.Sprintf("%s -c '%s'", command.Shell, command.ArgStr) } else { - ArgsStr = fmt.Sprintf("%s %s", command.Cmd, ArgsStr) - } - cmdCtxLogger.Debug().Str("cmd + args", ArgsStr).Send() + if command.Env == "" && command.Environment == nil { + // command.ArgStr = fmt.Sprintf("/bin/sh -c '%s'", command.ArgStr) + command.ArgStr = fmt.Sprintf("%s %s", command.Cmd, ArgsStr) + } - //! environment vars and SSH: - //? skip if commandType is not *script*? - //? option to use SSH setenv or add to beginning? - // Inject environment variables - injectEnvIntoSSH(envVars, commandSession, opts, cmdCtxLogger) + } + // cmdCtxLogger.Debug().Str("cmd + args", ArgsStr).Send() if command.Type == UserCommandType && command.UserOperation == "password" { // cmdCtxLogger.Debug().Msgf("adding stdin") @@ -526,6 +533,7 @@ func (command *Command) RunCmdOnHost(cmdCtxLogger zerolog.Logger, opts *ConfigOp } ArgsStr = fmt.Sprintf("cat %s | chpasswd", passFilePath) + command.ArgStr = ArgsStr defer passFile.Close() rmFileFunc := func() { @@ -534,7 +542,7 @@ func (command *Command) RunCmdOnHost(cmdCtxLogger zerolog.Logger, opts *ConfigOp defer rmFileFunc() } - if err := commandSession.Run(ArgsStr); err != nil { + if err := commandSession.Run(command.ArgStr); err != nil { return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error running command: %w", err) } @@ -610,11 +618,10 @@ func checkPackageVersion(cmdCtxLogger zerolog.Logger, command *Command, commandS for _, v := range command.Args { ArgsStr += fmt.Sprintf(" %s", v) } - var err error var cmdOut []byte - if cmdOut, err = commandSession.CombinedOutput(ArgsStr); err != nil { + if cmdOut, err = commandSession.CombinedOutput(command.ArgStr); err != nil { cmdOutBuf.Write(cmdOut) _, parseErr := parsePackageVersion(string(cmdOut), cmdCtxLogger, command, cmdOutBuf) @@ -850,7 +857,7 @@ func (r RemoteHostPackageExecutor) RunCmdOnHost(command *Command, commandSession return checkPackageVersion(cmdCtxLogger, command, commandSession, cmdOutBuf) } if command.Shell != "" { - ArgsStr = fmt.Sprintf("%s -c '%s %s'", command.Shell, command.Cmd, ArgsStr) + ArgsStr = fmt.Sprintf("%s -c '%s'", command.Shell, command.ArgStr) } else { ArgsStr = fmt.Sprintf("%s %s", command.Cmd, ArgsStr) } diff --git a/pkg/backy/types.go b/pkg/backy/types.go index bb31df7..3962925 100755 --- a/pkg/backy/types.go +++ b/pkg/backy/types.go @@ -68,7 +68,8 @@ type ( RemoteHost *Host `yaml:"-"` - Args []string `yaml:"args,omitempty"` + Args []string `yaml:"args,omitempty"` + ArgStr string Dir *string `yaml:"dir,omitempty"` diff --git a/pkg/backy/utils.go b/pkg/backy/utils.go index ea24994..7ff36d7 100755 --- a/pkg/backy/utils.go +++ b/pkg/backy/utils.go @@ -22,6 +22,7 @@ import ( "github.com/joho/godotenv" "github.com/knadh/koanf/v2" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "golang.org/x/crypto/ssh" "mvdan.cc/sh/v3/shell" ) @@ -99,7 +100,7 @@ func NewConfigOptions(configFilePath string, opts ...BackyOptionFunc) *ConfigOpt return b } -func injectEnvIntoSSH(envVarsToInject environmentVars, process *ssh.Session, opts *ConfigOpts, log zerolog.Logger) { +func injectEnvIntoSSH(envVarsToInject environmentVars, session *ssh.Session, opts *ConfigOpts, log zerolog.Logger) error { if envVarsToInject.file != "" { envPath, envPathErr := getFullPathWithHomeDir(envVarsToInject.file) if envPathErr != nil { @@ -113,31 +114,31 @@ func injectEnvIntoSSH(envVarsToInject environmentVars, process *ssh.Session, opt envMap, err := godotenv.Parse(file) if err != nil { - log.Error().Str("envFile", envPath).Err(err).Send() - goto errEnvFile + log.Fatal().Str("envFile", envPath).Err(err).Send() } for key, val := range envMap { - err = process.Setenv(key, getExternalConfigDirectiveValue(val, opts, AllowedExternalDirectiveVault)) + err = session.Setenv(key, getExternalConfigDirectiveValue(val, opts, AllowedExternalDirectiveVault)) if err != nil { - log.Error().Err(err).Send() - + log.Info().Err(err).Send() + return fmt.Errorf("failed to set environment variable %s: %w", val, err) } } } -errEnvFile: // fmt.Printf("%v", envVarsToInject.env) for _, envVal := range envVarsToInject.env { // don't append env Vars for Backy if strings.Contains(envVal, "=") { envVarArr := strings.Split(envVal, "=") - err := process.Setenv(envVarArr[0], getExternalConfigDirectiveValue(envVarArr[1], opts, AllowedExternalDirectiveVaultFile)) + err := session.Setenv(envVarArr[0], getExternalConfigDirectiveValue(envVarArr[1], opts, AllowedExternalDirectiveVaultFile)) if err != nil { - log.Error().Err(err).Send() + log.Info().Err(err).Send() + return fmt.Errorf("failed to set environment variable %s: %w", envVarArr[1], err) } } } + return nil } func injectEnvIntoLocalCMD(envVarsToInject environmentVars, process *exec.Cmd, log zerolog.Logger, opts *ConfigOpts) { @@ -171,6 +172,35 @@ errEnvFile: process.Env = append(process.Env, os.Environ()...) } +func prependEnvVarsToCommand(envVars environmentVars, opts *ConfigOpts, command string, args []string, cmdCtxLogger zerolog.Logger) string { + var envPrefix string + if envVars.file != "" { + envPath, envPathErr := getFullPathWithHomeDir(envVars.file) + if envPathErr != nil { + cmdCtxLogger.Fatal().Str("envFile", envPath).Err(envPathErr).Send() + } + file, err := os.Open(envPath) + if err != nil { + log.Fatal().Str("envFile", envPath).Err(err).Send() + } + defer file.Close() + + envMap, err := godotenv.Parse(file) + if err != nil { + log.Fatal().Str("envFile", envPath).Err(err).Send() + } + for key, val := range envMap { + envPrefix += fmt.Sprintf("%s=%s ", key, getExternalConfigDirectiveValue(val, opts, AllowedExternalDirectiveVaultEnv)) + } + } + for _, value := range envVars.env { + envVarArr := strings.Split(value, "=") + envPrefix += fmt.Sprintf("%s=%s ", envVarArr[0], getExternalConfigDirectiveValue(envVarArr[1], opts, AllowedExternalDirectiveVault)) + envPrefix += "\n" + } + return envPrefix + command + " " + strings.Join(args, " ") +} + func contains(s []string, e string) bool { for _, a := range s { if a == e { @@ -403,24 +433,22 @@ func getExternalConfigDirectiveValue(key string, opts *ConfigOpts, allowedDirect key = replaceVarInString(opts.Vars, key, opts.Logger) opts.Logger.Debug().Str("expanding external key", key).Send() - if strings.HasPrefix(key, envExternDirectiveStart) { + if newKeyStr, directiveFound := strings.CutPrefix(key, envExternDirectiveStart); directiveFound { if IsExternalDirectiveEnv(allowedDirectives) { - key = strings.TrimPrefix(key, envExternDirectiveStart) - key = strings.TrimSuffix(key, externDirectiveEnd) + key = strings.TrimSuffix(newKeyStr, externDirectiveEnd) key = os.Getenv(key) } else { opts.Logger.Error().Msgf("Config key with value %s does not support env directive", key) } } - if strings.HasPrefix(key, externFileDirectiveStart) { + if newKeyStr, directiveFound := strings.CutPrefix(key, externFileDirectiveStart); directiveFound { if IsExternalDirectiveFile(allowedDirectives) { var err error var keyValue []byte - key = strings.TrimPrefix(key, externFileDirectiveStart) - key = strings.TrimSuffix(key, externDirectiveEnd) + key = strings.TrimSuffix(newKeyStr, externDirectiveEnd) key, err = getFullPathWithHomeDir(key) if err != nil { opts.Logger.Err(err).Send() @@ -440,11 +468,10 @@ func getExternalConfigDirectiveValue(key string, opts *ConfigOpts, allowedDirect } } - if strings.HasPrefix(key, vaultExternDirectiveStart) { + if newKeyStr, directiveFound := strings.CutPrefix(key, vaultExternDirectiveStart); directiveFound { if IsExternalDirectiveVault(allowedDirectives) { - key = strings.TrimPrefix(key, vaultExternDirectiveStart) - key = strings.TrimSuffix(key, externDirectiveEnd) + key = strings.TrimSuffix(newKeyStr, externDirectiveEnd) key = GetVaultKey(key, opts, opts.Logger) } else { opts.Logger.Error().Msgf("Config key with value %s does not support vault directive", key) diff --git a/tests/files_test.go b/tests/files_test.go new file mode 100644 index 0000000..9b04225 --- /dev/null +++ b/tests/files_test.go @@ -0,0 +1,15 @@ +package tests + +import ( + "fmt" + "os/exec" + "testing" +) + +func TestRunCommandFileTest(t *testing.T) { + + filePath := "packageCommands.yml" + cmdLineStr := fmt.Sprintf("go run ../backy.go exec host -c checkDockerNoVersion -m localhost --cmdStdOut -f %s", filePath) + + exec.Command("bash", "-c", cmdLineStr).Output() +} diff --git a/tests/packageCommands.yml b/tests/packageCommands.yml index 30885df..6417770 100755 --- a/tests/packageCommands.yml +++ b/tests/packageCommands.yml @@ -2,6 +2,8 @@ commands: checkDockerNoVersion: type: package shell: zsh + environment: + - TEST_ENV=production packages: - name: "docker-ce-cli" - name: "docker-ce" diff --git a/tests/run_tests.sh b/tests/run_tests.sh new file mode 100644 index 0000000..6adf1ef --- /dev/null +++ b/tests/run_tests.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +# This script runs all Go test files in the tests directory. + +echo "Running all tests in the tests directory..." + +go test ./tests/... -v + +if [ $? -eq 0 ]; then + echo "All tests passed successfully." +else + echo "Some tests failed. Check the output above for details." + exit 1 +fi