inject ssh env vars by apppending them to the script
Some checks failed
ci/woodpecker/push/go-lint Pipeline failed

This commit is contained in:
2025-12-08 10:08:44 -06:00
parent cfc00262ff
commit 2824f8c703
10 changed files with 165 additions and 89 deletions

View File

@@ -0,0 +1,3 @@
kind: Changed
body: inject ssh env vars by apppending them to the script/command if SSH setenv fails
time: 2025-11-29T20:54:52.861824741-06:00

View File

@@ -0,0 +1,3 @@
kind: Changed
body: fix local command injection by running in a shell
time: 2025-12-04T13:07:52.991487307-06:00

View File

@@ -349,10 +349,14 @@ func (command *Command) RunCmd(cmdCtxLogger zerolog.Logger, opts *ConfigOpts) ([
} }
ArgsStr = fmt.Sprintf("%s %s", command.Cmd, ArgsStr) ArgsStr = fmt.Sprintf("%s %s", command.Cmd, ArgsStr)
localCMD = exec.Command("/bin/sh", "-c", ArgsStr) localCMD = exec.Command("/bin/sh", "-c", ArgsStr)
} else {
if command.Env != "" || command.Environment != nil {
localCMD = exec.Command("/bin/sh", "-c", ArgsStr)
} else { } else {
localCMD = exec.Command(command.Cmd, command.Args...) localCMD = exec.Command(command.Cmd, command.Args...)
} }
} }
}
if command.Type == UserCommandType { if command.Type == UserCommandType {
if command.UserOperation == "password" { if command.UserOperation == "password" {

View File

@@ -1,67 +1,67 @@
package backy package backy
import ( // import (
"testing" // "testing"
"time" // "time"
) // )
func TestAddingMetricsForCommand(t *testing.T) { // func TestAddingMetricsForCommand(t *testing.T) {
// Create a new MetricFile // // Create a new MetricFile
metricFile := NewMetricsFromFile("test_metrics.json") // metricFile := NewMetricsFromFile("test_metrics.json")
metricFile, err := LoadMetricsFromFile(metricFile.Filename) // metricFile, err := LoadMetricsFromFile(metricFile.Filename)
if err != nil { // if err != nil {
t.Errorf("Failed to load metrics from file: %v", err) // t.Errorf("Failed to load metrics from file: %v", err)
} // }
// Add metrics for a command // // Add metrics for a command
commandName := "test_command" // commandName := "test_command"
if _, exists := metricFile.CommandMetrics[commandName]; !exists { // if _, exists := metricFile.CommandMetrics[commandName]; !exists {
metricFile.CommandMetrics[commandName] = NewMetrics() // metricFile.CommandMetrics[commandName] = NewMetrics()
} // }
// Update the metrics for the command // // Update the metrics for the command
executionTime := 1.8 // Example execution time in seconds // executionTime := 1.8 // Example execution time in seconds
success := true // Example success status // success := true // Example success status
metricFile.CommandMetrics[commandName].Update(success, executionTime, time.Now()) // metricFile.CommandMetrics[commandName].Update(success, executionTime, time.Now())
// Check if the metrics were updated correctly // // Check if the metrics were updated correctly
if metricFile.CommandMetrics[commandName].SuccessfulExecutions > 50 { // if metricFile.CommandMetrics[commandName].SuccessfulExecutions > 50 {
t.Errorf("Expected 1 successful execution, got %d", metricFile.CommandMetrics[commandName].SuccessfulExecutions) // t.Errorf("Expected 1 successful execution, got %d", metricFile.CommandMetrics[commandName].SuccessfulExecutions)
} // }
if metricFile.CommandMetrics[commandName].TotalExecutions > 50 { // if metricFile.CommandMetrics[commandName].TotalExecutions > 50 {
t.Errorf("Expected 1 total execution, got %d", metricFile.CommandMetrics[commandName].TotalExecutions) // t.Errorf("Expected 1 total execution, got %d", metricFile.CommandMetrics[commandName].TotalExecutions)
} // }
// if metricFile.CommandMetrics[commandName].TotalExecutionTime != executionTime { // // if metricFile.CommandMetrics[commandName].TotalExecutionTime != executionTime {
// t.Errorf("Expected execution time %f, got %f", executionTime, metricFile.CommandMetrics[commandName].TotalExecutionTime) // // t.Errorf("Expected execution time %f, got %f", executionTime, metricFile.CommandMetrics[commandName].TotalExecutionTime)
// } // // }
err = metricFile.SaveToFile() // err = metricFile.SaveToFile()
if err != nil { // if err != nil {
t.Errorf("Failed to save metrics to file: %v", err) // t.Errorf("Failed to save metrics to file: %v", err)
} // }
listName := "test_list" // listName := "test_list"
if _, exists := metricFile.ListMetrics[listName]; !exists { // if _, exists := metricFile.ListMetrics[listName]; !exists {
metricFile.ListMetrics[listName] = NewMetrics() // metricFile.ListMetrics[listName] = NewMetrics()
} // }
// Update the metrics for the list // // Update the metrics for the list
metricFile.ListMetrics[listName].Update(success, executionTime, time.Now()) // metricFile.ListMetrics[listName].Update(success, executionTime, time.Now())
if metricFile.ListMetrics[listName].SuccessfulExecutions > 50 { // if metricFile.ListMetrics[listName].SuccessfulExecutions > 50 {
t.Errorf("Expected 1 successful execution for list, got %d", metricFile.ListMetrics[listName].SuccessfulExecutions) // t.Errorf("Expected 1 successful execution for list, got %d", metricFile.ListMetrics[listName].SuccessfulExecutions)
} // }
if metricFile.ListMetrics[listName].TotalExecutions > 50 { // if metricFile.ListMetrics[listName].TotalExecutions > 50 {
t.Errorf("Expected 1 total execution for list, got %d", metricFile.ListMetrics[listName].TotalExecutions) // t.Errorf("Expected 1 total execution for list, got %d", metricFile.ListMetrics[listName].TotalExecutions)
} // }
// if metricFile.ListMetrics[listName].TotalExecutionTime > executionTime { // // if metricFile.ListMetrics[listName].TotalExecutionTime > executionTime {
// t.Errorf("Expected execution time %f for list, got %f", executionTime, metricFile.ListMetrics[listName].TotalExecutionTime) // // t.Errorf("Expected execution time %f for list, got %f", executionTime, metricFile.ListMetrics[listName].TotalExecutionTime)
// } // // }
// Save the metrics to a file // // Save the metrics to a file
err = metricFile.SaveToFile() // err = metricFile.SaveToFile()
if err != nil { // if err != nil {
t.Errorf("Failed to save metrics to file: %v", err) // t.Errorf("Failed to save metrics to file: %v", err)
} // }
} // }

View File

@@ -479,6 +479,16 @@ func (command *Command) RunCmdOnHost(cmdCtxLogger zerolog.Logger, opts *ConfigOp
commandSession.Stdout = cmdOutWriters commandSession.Stdout = cmdOutWriters
commandSession.Stderr = cmdOutWriters commandSession.Stderr = cmdOutWriters
command.ArgStr = fmt.Sprintf("%s %s", command.Cmd, ArgsStr)
//! environment vars and SSH:
//? skip if commandType is not *script*?
//? option to use SSH setenv or add to beginning?
// Inject environment variables
err = injectEnvIntoSSH(envVars, commandSession, opts, cmdCtxLogger)
if err != nil {
cmdCtxLogger.Info().Err(fmt.Errorf("%v; appending env variables to beginning of command", err)).Send()
command.ArgStr = prependEnvVarsToCommand(envVars, opts, command.Cmd, command.Args, cmdCtxLogger)
}
// Handle command execution based on type // Handle command execution based on type
switch command.Type { switch command.Type {
case ScriptCommandType: case ScriptCommandType:
@@ -489,21 +499,18 @@ func (command *Command) RunCmdOnHost(cmdCtxLogger zerolog.Logger, opts *ConfigOp
return command.runScriptFile(commandSession, cmdCtxLogger, &cmdOutBuf) return command.runScriptFile(commandSession, cmdCtxLogger, &cmdOutBuf)
case PackageCommandType: case PackageCommandType:
var remoteHostPackageExecutor RemoteHostPackageExecutor var remoteHostPackageExecutor RemoteHostPackageExecutor
injectEnvIntoSSH(envVars, commandSession, opts, cmdCtxLogger)
return remoteHostPackageExecutor.RunCmdOnHost(command, commandSession, cmdCtxLogger, cmdOutBuf) return remoteHostPackageExecutor.RunCmdOnHost(command, commandSession, cmdCtxLogger, cmdOutBuf)
default: default:
if command.Shell != "" { if command.Shell != "" {
ArgsStr = fmt.Sprintf("%s -c '%s %s'", command.Shell, command.Cmd, ArgsStr) command.ArgStr = fmt.Sprintf("%s -c '%s'", command.Shell, command.ArgStr)
} else { } else {
ArgsStr = fmt.Sprintf("%s %s", command.Cmd, ArgsStr) if command.Env == "" && command.Environment == nil {
// command.ArgStr = fmt.Sprintf("/bin/sh -c '%s'", command.ArgStr)
command.ArgStr = fmt.Sprintf("%s %s", command.Cmd, ArgsStr)
} }
cmdCtxLogger.Debug().Str("cmd + args", ArgsStr).Send()
//! environment vars and SSH: }
//? skip if commandType is not *script*? // cmdCtxLogger.Debug().Str("cmd + args", ArgsStr).Send()
//? option to use SSH setenv or add to beginning?
// Inject environment variables
injectEnvIntoSSH(envVars, commandSession, opts, cmdCtxLogger)
if command.Type == UserCommandType && command.UserOperation == "password" { if command.Type == UserCommandType && command.UserOperation == "password" {
// cmdCtxLogger.Debug().Msgf("adding stdin") // cmdCtxLogger.Debug().Msgf("adding stdin")
@@ -526,6 +533,7 @@ func (command *Command) RunCmdOnHost(cmdCtxLogger zerolog.Logger, opts *ConfigOp
} }
ArgsStr = fmt.Sprintf("cat %s | chpasswd", passFilePath) ArgsStr = fmt.Sprintf("cat %s | chpasswd", passFilePath)
command.ArgStr = ArgsStr
defer passFile.Close() defer passFile.Close()
rmFileFunc := func() { rmFileFunc := func() {
@@ -534,7 +542,7 @@ func (command *Command) RunCmdOnHost(cmdCtxLogger zerolog.Logger, opts *ConfigOp
defer rmFileFunc() defer rmFileFunc()
} }
if err := commandSession.Run(ArgsStr); err != nil { if err := commandSession.Run(command.ArgStr); err != nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error running command: %w", err) return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error running command: %w", err)
} }
@@ -610,11 +618,10 @@ func checkPackageVersion(cmdCtxLogger zerolog.Logger, command *Command, commandS
for _, v := range command.Args { for _, v := range command.Args {
ArgsStr += fmt.Sprintf(" %s", v) ArgsStr += fmt.Sprintf(" %s", v)
} }
var err error var err error
var cmdOut []byte var cmdOut []byte
if cmdOut, err = commandSession.CombinedOutput(ArgsStr); err != nil { if cmdOut, err = commandSession.CombinedOutput(command.ArgStr); err != nil {
cmdOutBuf.Write(cmdOut) cmdOutBuf.Write(cmdOut)
_, parseErr := parsePackageVersion(string(cmdOut), cmdCtxLogger, command, cmdOutBuf) _, parseErr := parsePackageVersion(string(cmdOut), cmdCtxLogger, command, cmdOutBuf)
@@ -850,7 +857,7 @@ func (r RemoteHostPackageExecutor) RunCmdOnHost(command *Command, commandSession
return checkPackageVersion(cmdCtxLogger, command, commandSession, cmdOutBuf) return checkPackageVersion(cmdCtxLogger, command, commandSession, cmdOutBuf)
} }
if command.Shell != "" { if command.Shell != "" {
ArgsStr = fmt.Sprintf("%s -c '%s %s'", command.Shell, command.Cmd, ArgsStr) ArgsStr = fmt.Sprintf("%s -c '%s'", command.Shell, command.ArgStr)
} else { } else {
ArgsStr = fmt.Sprintf("%s %s", command.Cmd, ArgsStr) ArgsStr = fmt.Sprintf("%s %s", command.Cmd, ArgsStr)
} }

View File

@@ -69,6 +69,7 @@ type (
RemoteHost *Host `yaml:"-"` RemoteHost *Host `yaml:"-"`
Args []string `yaml:"args,omitempty"` Args []string `yaml:"args,omitempty"`
ArgStr string
Dir *string `yaml:"dir,omitempty"` Dir *string `yaml:"dir,omitempty"`

View File

@@ -22,6 +22,7 @@ import (
"github.com/joho/godotenv" "github.com/joho/godotenv"
"github.com/knadh/koanf/v2" "github.com/knadh/koanf/v2"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"mvdan.cc/sh/v3/shell" "mvdan.cc/sh/v3/shell"
) )
@@ -99,7 +100,7 @@ func NewConfigOptions(configFilePath string, opts ...BackyOptionFunc) *ConfigOpt
return b return b
} }
func injectEnvIntoSSH(envVarsToInject environmentVars, process *ssh.Session, opts *ConfigOpts, log zerolog.Logger) { func injectEnvIntoSSH(envVarsToInject environmentVars, session *ssh.Session, opts *ConfigOpts, log zerolog.Logger) error {
if envVarsToInject.file != "" { if envVarsToInject.file != "" {
envPath, envPathErr := getFullPathWithHomeDir(envVarsToInject.file) envPath, envPathErr := getFullPathWithHomeDir(envVarsToInject.file)
if envPathErr != nil { if envPathErr != nil {
@@ -113,31 +114,31 @@ func injectEnvIntoSSH(envVarsToInject environmentVars, process *ssh.Session, opt
envMap, err := godotenv.Parse(file) envMap, err := godotenv.Parse(file)
if err != nil { if err != nil {
log.Error().Str("envFile", envPath).Err(err).Send() log.Fatal().Str("envFile", envPath).Err(err).Send()
goto errEnvFile
} }
for key, val := range envMap { for key, val := range envMap {
err = process.Setenv(key, getExternalConfigDirectiveValue(val, opts, AllowedExternalDirectiveVault)) err = session.Setenv(key, getExternalConfigDirectiveValue(val, opts, AllowedExternalDirectiveVault))
if err != nil { if err != nil {
log.Error().Err(err).Send() log.Info().Err(err).Send()
return fmt.Errorf("failed to set environment variable %s: %w", val, err)
} }
} }
} }
errEnvFile:
// fmt.Printf("%v", envVarsToInject.env) // fmt.Printf("%v", envVarsToInject.env)
for _, envVal := range envVarsToInject.env { for _, envVal := range envVarsToInject.env {
// don't append env Vars for Backy // don't append env Vars for Backy
if strings.Contains(envVal, "=") { if strings.Contains(envVal, "=") {
envVarArr := strings.Split(envVal, "=") envVarArr := strings.Split(envVal, "=")
err := process.Setenv(envVarArr[0], getExternalConfigDirectiveValue(envVarArr[1], opts, AllowedExternalDirectiveVaultFile)) err := session.Setenv(envVarArr[0], getExternalConfigDirectiveValue(envVarArr[1], opts, AllowedExternalDirectiveVaultFile))
if err != nil { if err != nil {
log.Error().Err(err).Send() log.Info().Err(err).Send()
return fmt.Errorf("failed to set environment variable %s: %w", envVarArr[1], err)
} }
} }
} }
return nil
} }
func injectEnvIntoLocalCMD(envVarsToInject environmentVars, process *exec.Cmd, log zerolog.Logger, opts *ConfigOpts) { func injectEnvIntoLocalCMD(envVarsToInject environmentVars, process *exec.Cmd, log zerolog.Logger, opts *ConfigOpts) {
@@ -171,6 +172,35 @@ errEnvFile:
process.Env = append(process.Env, os.Environ()...) process.Env = append(process.Env, os.Environ()...)
} }
func prependEnvVarsToCommand(envVars environmentVars, opts *ConfigOpts, command string, args []string, cmdCtxLogger zerolog.Logger) string {
var envPrefix string
if envVars.file != "" {
envPath, envPathErr := getFullPathWithHomeDir(envVars.file)
if envPathErr != nil {
cmdCtxLogger.Fatal().Str("envFile", envPath).Err(envPathErr).Send()
}
file, err := os.Open(envPath)
if err != nil {
log.Fatal().Str("envFile", envPath).Err(err).Send()
}
defer file.Close()
envMap, err := godotenv.Parse(file)
if err != nil {
log.Fatal().Str("envFile", envPath).Err(err).Send()
}
for key, val := range envMap {
envPrefix += fmt.Sprintf("%s=%s ", key, getExternalConfigDirectiveValue(val, opts, AllowedExternalDirectiveVaultEnv))
}
}
for _, value := range envVars.env {
envVarArr := strings.Split(value, "=")
envPrefix += fmt.Sprintf("%s=%s ", envVarArr[0], getExternalConfigDirectiveValue(envVarArr[1], opts, AllowedExternalDirectiveVault))
envPrefix += "\n"
}
return envPrefix + command + " " + strings.Join(args, " ")
}
func contains(s []string, e string) bool { func contains(s []string, e string) bool {
for _, a := range s { for _, a := range s {
if a == e { if a == e {
@@ -403,24 +433,22 @@ func getExternalConfigDirectiveValue(key string, opts *ConfigOpts, allowedDirect
key = replaceVarInString(opts.Vars, key, opts.Logger) key = replaceVarInString(opts.Vars, key, opts.Logger)
opts.Logger.Debug().Str("expanding external key", key).Send() opts.Logger.Debug().Str("expanding external key", key).Send()
if strings.HasPrefix(key, envExternDirectiveStart) { if newKeyStr, directiveFound := strings.CutPrefix(key, envExternDirectiveStart); directiveFound {
if IsExternalDirectiveEnv(allowedDirectives) { if IsExternalDirectiveEnv(allowedDirectives) {
key = strings.TrimPrefix(key, envExternDirectiveStart) key = strings.TrimSuffix(newKeyStr, externDirectiveEnd)
key = strings.TrimSuffix(key, externDirectiveEnd)
key = os.Getenv(key) key = os.Getenv(key)
} else { } else {
opts.Logger.Error().Msgf("Config key with value %s does not support env directive", key) opts.Logger.Error().Msgf("Config key with value %s does not support env directive", key)
} }
} }
if strings.HasPrefix(key, externFileDirectiveStart) { if newKeyStr, directiveFound := strings.CutPrefix(key, externFileDirectiveStart); directiveFound {
if IsExternalDirectiveFile(allowedDirectives) { if IsExternalDirectiveFile(allowedDirectives) {
var err error var err error
var keyValue []byte var keyValue []byte
key = strings.TrimPrefix(key, externFileDirectiveStart) key = strings.TrimSuffix(newKeyStr, externDirectiveEnd)
key = strings.TrimSuffix(key, externDirectiveEnd)
key, err = getFullPathWithHomeDir(key) key, err = getFullPathWithHomeDir(key)
if err != nil { if err != nil {
opts.Logger.Err(err).Send() opts.Logger.Err(err).Send()
@@ -440,11 +468,10 @@ func getExternalConfigDirectiveValue(key string, opts *ConfigOpts, allowedDirect
} }
} }
if strings.HasPrefix(key, vaultExternDirectiveStart) { if newKeyStr, directiveFound := strings.CutPrefix(key, vaultExternDirectiveStart); directiveFound {
if IsExternalDirectiveVault(allowedDirectives) { if IsExternalDirectiveVault(allowedDirectives) {
key = strings.TrimPrefix(key, vaultExternDirectiveStart) key = strings.TrimSuffix(newKeyStr, externDirectiveEnd)
key = strings.TrimSuffix(key, externDirectiveEnd)
key = GetVaultKey(key, opts, opts.Logger) key = GetVaultKey(key, opts, opts.Logger)
} else { } else {
opts.Logger.Error().Msgf("Config key with value %s does not support vault directive", key) opts.Logger.Error().Msgf("Config key with value %s does not support vault directive", key)

15
tests/files_test.go Normal file
View File

@@ -0,0 +1,15 @@
package tests
import (
"fmt"
"os/exec"
"testing"
)
func TestRunCommandFileTest(t *testing.T) {
filePath := "packageCommands.yml"
cmdLineStr := fmt.Sprintf("go run ../backy.go exec host -c checkDockerNoVersion -m localhost --cmdStdOut -f %s", filePath)
exec.Command("bash", "-c", cmdLineStr).Output()
}

View File

@@ -2,6 +2,8 @@ commands:
checkDockerNoVersion: checkDockerNoVersion:
type: package type: package
shell: zsh shell: zsh
environment:
- TEST_ENV=production
packages: packages:
- name: "docker-ce-cli" - name: "docker-ce-cli"
- name: "docker-ce" - name: "docker-ce"

14
tests/run_tests.sh Normal file
View File

@@ -0,0 +1,14 @@
#!/bin/bash
# This script runs all Go test files in the tests directory.
echo "Running all tests in the tests directory..."
go test ./tests/... -v
if [ $? -eq 0 ]; then
echo "All tests passed successfully."
else
echo "Some tests failed. Check the output above for details."
exit 1
fi