tests: beginning of tests using Docker

This commit is contained in:
2025-07-04 09:02:27 -05:00
parent 7be2679b91
commit 305b504ca1
52 changed files with 1423 additions and 521 deletions

View File

@ -8,11 +8,11 @@ import (
"strings"
)
const _AllowedExternalDirectivesName = "DefaultExternalDirvaultvault-filevault-file-envfile-envfileenv"
const _AllowedExternalDirectivesName = "DefaultExternalDirvaultvault-envvault-filevault-file-envfile-envfileenv"
var _AllowedExternalDirectivesIndex = [...]uint8{0, 18, 23, 33, 47, 55, 59, 62}
var _AllowedExternalDirectivesIndex = [...]uint8{0, 18, 23, 32, 42, 56, 64, 68, 71}
const _AllowedExternalDirectivesLowerName = "defaultexternaldirvaultvault-filevault-file-envfile-envfileenv"
const _AllowedExternalDirectivesLowerName = "defaultexternaldirvaultvault-envvault-filevault-file-envfile-envfileenv"
func (i AllowedExternalDirectives) String() string {
if i < 0 || i >= AllowedExternalDirectives(len(_AllowedExternalDirectivesIndex)-1) {
@ -27,40 +27,44 @@ func _AllowedExternalDirectivesNoOp() {
var x [1]struct{}
_ = x[DefaultExternalDir-(0)]
_ = x[AllowedExternalDirectiveVault-(1)]
_ = x[AllowedExternalDirectiveVaultFile-(2)]
_ = x[AllowedExternalDirectiveAll-(3)]
_ = x[AllowedExternalDirectiveFileEnv-(4)]
_ = x[AllowedExternalDirectiveFile-(5)]
_ = x[AllowedExternalDirectiveEnv-(6)]
_ = x[AllowedExternalDirectiveVaultEnv-(2)]
_ = x[AllowedExternalDirectiveVaultFile-(3)]
_ = x[AllowedExternalDirectiveAll-(4)]
_ = x[AllowedExternalDirectiveFileEnv-(5)]
_ = x[AllowedExternalDirectiveFile-(6)]
_ = x[AllowedExternalDirectiveEnv-(7)]
}
var _AllowedExternalDirectivesValues = []AllowedExternalDirectives{DefaultExternalDir, AllowedExternalDirectiveVault, AllowedExternalDirectiveVaultFile, AllowedExternalDirectiveAll, AllowedExternalDirectiveFileEnv, AllowedExternalDirectiveFile, AllowedExternalDirectiveEnv}
var _AllowedExternalDirectivesValues = []AllowedExternalDirectives{DefaultExternalDir, AllowedExternalDirectiveVault, AllowedExternalDirectiveVaultEnv, AllowedExternalDirectiveVaultFile, AllowedExternalDirectiveAll, AllowedExternalDirectiveFileEnv, AllowedExternalDirectiveFile, AllowedExternalDirectiveEnv}
var _AllowedExternalDirectivesNameToValueMap = map[string]AllowedExternalDirectives{
_AllowedExternalDirectivesName[0:18]: DefaultExternalDir,
_AllowedExternalDirectivesLowerName[0:18]: DefaultExternalDir,
_AllowedExternalDirectivesName[18:23]: AllowedExternalDirectiveVault,
_AllowedExternalDirectivesLowerName[18:23]: AllowedExternalDirectiveVault,
_AllowedExternalDirectivesName[23:33]: AllowedExternalDirectiveVaultFile,
_AllowedExternalDirectivesLowerName[23:33]: AllowedExternalDirectiveVaultFile,
_AllowedExternalDirectivesName[33:47]: AllowedExternalDirectiveAll,
_AllowedExternalDirectivesLowerName[33:47]: AllowedExternalDirectiveAll,
_AllowedExternalDirectivesName[47:55]: AllowedExternalDirectiveFileEnv,
_AllowedExternalDirectivesLowerName[47:55]: AllowedExternalDirectiveFileEnv,
_AllowedExternalDirectivesName[55:59]: AllowedExternalDirectiveFile,
_AllowedExternalDirectivesLowerName[55:59]: AllowedExternalDirectiveFile,
_AllowedExternalDirectivesName[59:62]: AllowedExternalDirectiveEnv,
_AllowedExternalDirectivesLowerName[59:62]: AllowedExternalDirectiveEnv,
_AllowedExternalDirectivesName[23:32]: AllowedExternalDirectiveVaultEnv,
_AllowedExternalDirectivesLowerName[23:32]: AllowedExternalDirectiveVaultEnv,
_AllowedExternalDirectivesName[32:42]: AllowedExternalDirectiveVaultFile,
_AllowedExternalDirectivesLowerName[32:42]: AllowedExternalDirectiveVaultFile,
_AllowedExternalDirectivesName[42:56]: AllowedExternalDirectiveAll,
_AllowedExternalDirectivesLowerName[42:56]: AllowedExternalDirectiveAll,
_AllowedExternalDirectivesName[56:64]: AllowedExternalDirectiveFileEnv,
_AllowedExternalDirectivesLowerName[56:64]: AllowedExternalDirectiveFileEnv,
_AllowedExternalDirectivesName[64:68]: AllowedExternalDirectiveFile,
_AllowedExternalDirectivesLowerName[64:68]: AllowedExternalDirectiveFile,
_AllowedExternalDirectivesName[68:71]: AllowedExternalDirectiveEnv,
_AllowedExternalDirectivesLowerName[68:71]: AllowedExternalDirectiveEnv,
}
var _AllowedExternalDirectivesNames = []string{
_AllowedExternalDirectivesName[0:18],
_AllowedExternalDirectivesName[18:23],
_AllowedExternalDirectivesName[23:33],
_AllowedExternalDirectivesName[33:47],
_AllowedExternalDirectivesName[47:55],
_AllowedExternalDirectivesName[55:59],
_AllowedExternalDirectivesName[59:62],
_AllowedExternalDirectivesName[23:32],
_AllowedExternalDirectivesName[32:42],
_AllowedExternalDirectivesName[42:56],
_AllowedExternalDirectivesName[56:64],
_AllowedExternalDirectivesName[64:68],
_AllowedExternalDirectivesName[68:71],
}
// AllowedExternalDirectivesString retrieves an enum value from the enum constants string name.

View File

@ -35,7 +35,7 @@ var Sprintf = fmt.Sprintf
func (command *Command) RunCmd(cmdCtxLogger zerolog.Logger, opts *ConfigOpts) ([]string, error) {
var (
ArgsStr string // concatenating the arguments
ArgsStr string
cmdOutBuf bytes.Buffer
cmdOutWriters io.Writer
errSSH error
@ -55,7 +55,7 @@ func (command *Command) RunCmd(cmdCtxLogger zerolog.Logger, opts *ConfigOpts) ([
ArgsStr += fmt.Sprintf(" %s", v)
}
if command.Type == UserCT {
if command.Type == UserCommandType {
if command.UserOperation == "password" {
cmdCtxLogger.Info().Str("password", command.UserPassword).Msg("user password to be updated")
}
@ -63,19 +63,27 @@ func (command *Command) RunCmd(cmdCtxLogger zerolog.Logger, opts *ConfigOpts) ([
if !IsHostLocal(command.Host) {
outputArr, errSSH = command.RunCmdSSH(cmdCtxLogger, opts)
outputArr, errSSH = command.RunCmdOnHost(cmdCtxLogger, opts)
if errSSH != nil {
return outputArr, errSSH
}
} else {
// Handle package operations
if command.Type == PackageCT && command.PackageOperation == PackOpCheckVersion {
cmdCtxLogger.Info().Str("package", command.PackageName).Msg("Checking package versions")
if command.Type == PackageCommandType && command.PackageOperation == PackageOperationCheckVersion {
opts.Logger.Info().Msg("")
for _, p := range command.Packages {
cmdCtxLogger.Info().Str("package", p.Name).Msg("Checking installed and remote package versions")
}
opts.Logger.Info().Msg("")
// Execute the package version command
cmd := exec.Command(command.Cmd, command.Args...)
cmdOutWriters = io.MultiWriter(&cmdOutBuf)
if IsCmdStdOutEnabled() {
cmdOutWriters = io.MultiWriter(os.Stdout, &cmdOutBuf)
}
cmd.Stdout = cmdOutWriters
cmd.Stderr = cmdOutWriters
@ -87,7 +95,8 @@ func (command *Command) RunCmd(cmdCtxLogger zerolog.Logger, opts *ConfigOpts) ([
}
var localCMD *exec.Cmd
if command.Type == RemoteScriptCT {
if command.Type == RemoteScriptCommandType {
script, err := command.Fetcher.Fetch(command.Cmd)
if err != nil {
return nil, err
@ -104,8 +113,8 @@ func (command *Command) RunCmd(cmdCtxLogger zerolog.Logger, opts *ConfigOpts) ([
if IsCmdStdOutEnabled() {
cmdOutWriters = io.MultiWriter(os.Stdout, &cmdOutBuf)
}
if command.OutputFile != "" {
file, err := os.Create(command.OutputFile)
if command.Output.File != "" {
file, err := os.Create(command.Output.File)
if err != nil {
return nil, fmt.Errorf("error creating output file: %w", err)
}
@ -138,7 +147,7 @@ func (command *Command) RunCmd(cmdCtxLogger zerolog.Logger, opts *ConfigOpts) ([
if str, ok := outMap["output"].(string); ok {
outputArr = append(outputArr, str)
}
if command.OutputToLog {
if command.Output.ToLog {
cmdCtxLogger.Info().Fields(outMap).Send()
}
}
@ -159,8 +168,10 @@ func (command *Command) RunCmd(cmdCtxLogger zerolog.Logger, opts *ConfigOpts) ([
cmdCtxLogger.Info().Str("Command", fmt.Sprintf("Running command %s on local machine", command.Name)).Send()
// execute package commands in a shell
if command.Type == PackageCT {
cmdCtxLogger.Info().Str("package", command.PackageName).Msg("Executing package command")
if command.Type == PackageCommandType {
for _, p := range command.Packages {
cmdCtxLogger.Info().Str("packages", p.Name).Msg("Executing package command")
}
ArgsStr = fmt.Sprintf("%s %s", command.Cmd, ArgsStr)
localCMD = exec.Command("/bin/sh", "-c", ArgsStr)
} else {
@ -168,7 +179,7 @@ func (command *Command) RunCmd(cmdCtxLogger zerolog.Logger, opts *ConfigOpts) ([
}
}
if command.Type == UserCT {
if command.Type == UserCommandType {
if command.UserOperation == "password" {
localCMD.Stdin = command.stdin
cmdCtxLogger.Info().Str("password", command.UserPassword).Msg("user password to be updated")
@ -197,14 +208,14 @@ func (command *Command) RunCmd(cmdCtxLogger zerolog.Logger, opts *ConfigOpts) ([
return outputArr, err
}
if command.Type == UserCT {
if command.Type == UserCommandType {
if command.UserOperation == "add" {
if command.UserSshPubKeys != nil {
var (
f *os.File
err error
userHome []byte
authorizedKeysFile *os.File
err error
userHome []byte
)
cmdCtxLogger.Info().Msg("adding SSH Keys")
@ -212,7 +223,7 @@ func (command *Command) RunCmd(cmdCtxLogger zerolog.Logger, opts *ConfigOpts) ([
localCMD := exec.Command(fmt.Sprintf("grep \"%s\" /etc/passwd | cut -d: -f6", command.Username))
userHome, err = localCMD.CombinedOutput()
if err != nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), fmt.Errorf("error finding user home from /etc/passwd: %v", err)
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error finding user home from /etc/passwd: %v", err)
}
command.UserHome = strings.TrimSpace(string(userHome))
@ -221,33 +232,33 @@ func (command *Command) RunCmd(cmdCtxLogger zerolog.Logger, opts *ConfigOpts) ([
if _, err := os.Stat(userSshDir); os.IsNotExist(err) {
err := os.MkdirAll(userSshDir, 0700)
if err != nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), fmt.Errorf("error creating directory %s %v", userSshDir, err)
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error creating directory %s %v", userSshDir, err)
}
}
if _, err := os.Stat(fmt.Sprintf("%s/authorized_keys", userSshDir)); os.IsNotExist(err) {
_, err := os.Create(fmt.Sprintf("%s/authorized_keys", userSshDir))
if err != nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), fmt.Errorf("error creating file %s/authorized_keys: %v", userSshDir, err)
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error creating file %s/authorized_keys: %v", userSshDir, err)
}
}
f, err = os.OpenFile(fmt.Sprintf("%s/authorized_keys", userSshDir), 0700, os.ModeAppend)
authorizedKeysFile, err = os.OpenFile(fmt.Sprintf("%s/authorized_keys", userSshDir), 0700, os.ModeAppend)
if err != nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), fmt.Errorf("error opening file %s/authorized_keys: %v", userSshDir, err)
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error opening file %s/authorized_keys: %v", userSshDir, err)
}
defer f.Close()
defer authorizedKeysFile.Close()
for _, k := range command.UserSshPubKeys {
buf := bytes.NewBufferString(k)
cmdCtxLogger.Info().Str("key", k).Msg("adding SSH key")
if _, err := f.ReadFrom(buf); err != nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), fmt.Errorf("error adding to authorized keys: %v", err)
if _, err := authorizedKeysFile.ReadFrom(buf); err != nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error adding to authorized keys: %v", err)
}
}
localCMD = exec.Command(fmt.Sprintf("chown -R %s:%s %s", command.Username, command.Username, userHome))
_, err = localCMD.CombinedOutput()
if err != nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), err
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), err
}
}
@ -257,16 +268,18 @@ func (command *Command) RunCmd(cmdCtxLogger zerolog.Logger, opts *ConfigOpts) ([
return outputArr, nil
}
func cmdListWorker(msgTemps *msgTemplates, jobs <-chan *CmdList, results chan<- CmdResult, opts *ConfigOpts) {
func cmdListWorker(msgTemps *msgTemplates, jobs <-chan *CmdList, results chan<- string, opts *ConfigOpts) {
for list := range jobs {
fieldsMap := map[string]interface{}{"list": list.Name}
var cmdLogger zerolog.Logger
var commandExecuted *Command
var cmdsRan []string
var outStructArr []outStruct
var hasError bool // Tracks if any command in the list failed
for _, cmd := range list.Order {
cmdToRun := opts.Cmds[cmd]
commandExecuted = cmdToRun
currentCmd := cmdToRun.Name
fieldsMap["cmd"] = currentCmd
cmdLogger = cmdToRun.GenerateLogger(opts)
@ -277,23 +290,21 @@ func cmdListWorker(msgTemps *msgTemplates, jobs <-chan *CmdList, results chan<-
if runErr != nil {
// Log the error and send a failed result
cmdLogger.Err(runErr).Send()
results <- CmdResult{CmdName: cmd, ListName: list.Name, Error: runErr}
// Execute error hooks for the failed command
cmdToRun.ExecuteHooks("error", opts)
// Notify failure
if list.NotifyConfig != nil {
notifyError(cmdLogger, msgTemps, list, cmdsRan, outStructArr, runErr, cmdToRun)
}
// Execute error hooks for the failed command
hasError = true
break
}
// Collect output if required
if list.GetOutput || cmdToRun.GetOutputInList {
if list.GetCommandOutputInNotificationsOnSuccess || cmdToRun.Output.InList {
outStructArr = append(outStructArr, outStruct{
CmdName: currentCmd,
CmdExecuted: currentCmd,
@ -302,27 +313,17 @@ func cmdListWorker(msgTemps *msgTemplates, jobs <-chan *CmdList, results chan<-
}
}
if !hasError && list.NotifyConfig != nil && (list.NotifyOnSuccess || list.GetOutput) {
if !hasError && list.NotifyConfig != nil && list.Notify.OnFailure {
notifySuccess(cmdLogger, msgTemps, list, cmdsRan, outStructArr)
}
for _, cmd := range list.Order {
cmdToRun := opts.Cmds[cmd]
if !hasError {
cmdToRun.ExecuteHooks("success", opts)
}
// Execute final hooks for every command
cmdToRun.ExecuteHooks("final", opts)
if !hasError {
commandExecuted.ExecuteHooks("success", opts)
}
// Send the final result for the list
if hasError {
results <- CmdResult{CmdName: cmdsRan[len(cmdsRan)-1], ListName: list.Name, Error: fmt.Errorf("list execution failed")}
} else {
results <- CmdResult{CmdName: cmdsRan[len(cmdsRan)-1], ListName: list.Name, Error: nil}
}
commandExecuted.ExecuteHooks("final", opts)
results <- "done"
}
}
@ -371,7 +372,7 @@ func (opts *ConfigOpts) RunListConfig(cron string) {
}
configListsLen := len(opts.CmdConfigLists)
listChan := make(chan *CmdList, configListsLen)
results := make(chan CmdResult, configListsLen)
results := make(chan string, configListsLen)
// Start workers
for w := 1; w <= configListsLen; w++ {
@ -391,9 +392,7 @@ func (opts *ConfigOpts) RunListConfig(cron string) {
// Process results
for a := 1; a <= configListsLen; a++ {
result := <-results
opts.Logger.Debug().Msgf("Processing result for list %s, command %s", result.ListName, result.CmdName)
<-results
}
opts.closeHostConnections()
}
@ -460,29 +459,31 @@ func (cmd *Command) ExecuteHooks(hookType string, opts *ConfigOpts) {
case "error":
for _, v := range cmd.Hooks.Error {
errCmd := opts.Cmds[v]
opts.Logger.Info().Msgf("Running error hook command %s", v)
cmdLogger := opts.Logger.With().
Str("backy-cmd", v).Str("hookType", "error").
Logger()
cmdLogger.Info().Msgf("Running error hook command %s", v)
// URGENT: Never returns
_, _ = errCmd.RunCmd(cmdLogger, opts)
return
}
case "success":
for _, v := range cmd.Hooks.Success {
successCmd := opts.Cmds[v]
opts.Logger.Info().Msgf("Running success hook command %s", v)
cmdLogger := opts.Logger.With().
Str("backy-cmd", v).Str("hookType", "success").
Logger()
cmdLogger.Info().Msgf("Running success hook command %s", v)
_, _ = successCmd.RunCmd(cmdLogger, opts)
}
case "final":
for _, v := range cmd.Hooks.Final {
finalCmd := opts.Cmds[v]
opts.Logger.Info().Msgf("Running final hook command %s", v)
cmdLogger := opts.Logger.With().
Str("backy-cmd", v).Str("hookType", "final").
Logger()
cmdLogger.Info().Msgf("Running final hook command %s", v)
_, _ = finalCmd.RunCmd(cmdLogger, opts)
}
}
@ -501,18 +502,27 @@ func (cmd *Command) GenerateLogger(opts *ConfigOpts) zerolog.Logger {
return cmdLogger
}
func (opts *ConfigOpts) ExecCmdsSSH(cmdList []string, hostsList []string) {
func (opts *ConfigOpts) ExecCmdsOnHosts(cmdList []string, hostsList []string) {
// Iterate over hosts and exec commands
for _, h := range hostsList {
host := opts.Hosts[h]
for _, c := range cmdList {
cmd := opts.Cmds[c]
cmd.RemoteHost = host
cmd.Host = host.Host
opts.Logger.Info().Str("host", h).Str("cmd", c).Send()
_, err := cmd.RunCmdSSH(cmd.GenerateLogger(opts), opts)
if err != nil {
opts.Logger.Err(err).Str("host", h).Str("cmd", c).Send()
cmd.Host = h
if IsHostLocal(h) {
_, err := cmd.RunCmd(cmd.GenerateLogger(opts), opts)
if err != nil {
opts.Logger.Err(err).Str("host", h).Str("cmd", c).Send()
}
} else {
cmd.Host = host.Host
opts.Logger.Info().Str("host", h).Str("cmd", c).Send()
_, err := cmd.RunCmdOnHost(cmd.GenerateLogger(opts), opts)
if err != nil {
opts.Logger.Err(err).Str("host", h).Str("cmd", c).Send()
}
}
}
}
@ -530,20 +540,13 @@ func logCommandOutput(command *Command, cmdOutBuf bytes.Buffer, cmdCtxLogger zer
if str, ok := outMap["output"].(string); ok {
outputArr = append(outputArr, str)
}
if command.OutputToLog {
if command.Output.ToLog {
cmdCtxLogger.Info().Fields(outMap).Send()
}
}
return outputArr
}
func (c *Command) GetVariablesFromConf(opts *ConfigOpts) {
c.ScriptEnvFile = replaceVarInString(opts.Vars, c.ScriptEnvFile, opts.Logger)
c.Name = replaceVarInString(opts.Vars, c.Name, opts.Logger)
c.OutputFile = replaceVarInString(opts.Vars, c.OutputFile, opts.Logger)
c.Host = replaceVarInString(opts.Vars, c.Host, opts.Logger)
}
// func executeUserCommands() []string {
// }

83
pkg/backy/backy_test.go Normal file
View File

@ -0,0 +1,83 @@
package backy
import (
"context"
"fmt"
"io"
"log"
"testing"
"git.andrewnw.xyz/CyberShell/backy/pkg/pkgman"
packagemanagercommon "git.andrewnw.xyz/CyberShell/backy/pkg/pkgman/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
)
// TestConfigOptions tests the configuration options for the backy package.
func Test_ErrorHook(t *testing.T) {
configFile := "../../tests/ErrorHook.yml"
logFile := "ErrorHook.log"
backyConfigOptions := NewConfigOptions(configFile, SetLogFile(logFile))
backyConfigOptions.InitConfig()
backyConfigOptions.ParseConfigurationFile()
backyConfigOptions.RunListConfig("")
}
func TestSettingCommandInfoPackageCommandDnf(t *testing.T) {
packagecommand := &Command{
Type: PackageCommandType,
PackageManager: "dnf",
Shell: "zsh",
PackageOperation: PackageOperationCheckVersion,
Packages: []packagemanagercommon.Package{{Name: "docker-ce"}},
}
dnfPackage, _ := pkgman.PackageManagerFactory("dnf", pkgman.WithoutAuth())
packagecommand.pkgMan = dnfPackage
PackageCommand := getCommandTypeAndSetCommandInfo(packagecommand)
assert.Equal(t, "dnf", PackageCommand.Cmd)
}
func TestWithDockerFile(t *testing.T) {
ctx := context.Background()
docker, err := testcontainers.Run(ctx, "",
testcontainers.WithDockerfile(testcontainers.FromDockerfile{
Context: "../../tests/docker",
Dockerfile: "Dockerfile",
KeepImage: false,
// BuildOptionsModifier: func(buildOptions *types.ImageBuildOptions) {
// buildOptions.Target = "target2"
// },
}),
)
// docker.
if err != nil {
log.Printf("failed to start container: %v", err)
return
}
r, err := docker.Logs(ctx)
if err != nil {
log.Printf("failed to get logs: %v", err)
return
}
logs, err := io.ReadAll(r)
if err != nil {
log.Printf("failed to read logs: %v", err)
return
}
fmt.Println(string(logs))
require.NoError(t, err)
}

View File

@ -25,29 +25,29 @@ func (i CommandType) String() string {
// Re-run the stringer command to generate them again.
func _CommandTypeNoOp() {
var x [1]struct{}
_ = x[DefaultCT-(0)]
_ = x[ScriptCT-(1)]
_ = x[ScriptFileCT-(2)]
_ = x[RemoteScriptCT-(3)]
_ = x[PackageCT-(4)]
_ = x[UserCT-(5)]
_ = x[DefaultCommandType-(0)]
_ = x[ScriptCommandType-(1)]
_ = x[ScriptFileCommandType-(2)]
_ = x[RemoteScriptCommandType-(3)]
_ = x[PackageCommandType-(4)]
_ = x[UserCommandType-(5)]
}
var _CommandTypeValues = []CommandType{DefaultCT, ScriptCT, ScriptFileCT, RemoteScriptCT, PackageCT, UserCT}
var _CommandTypeValues = []CommandType{DefaultCommandType, ScriptCommandType, ScriptFileCommandType, RemoteScriptCommandType, PackageCommandType, UserCommandType}
var _CommandTypeNameToValueMap = map[string]CommandType{
_CommandTypeName[0:0]: DefaultCT,
_CommandTypeLowerName[0:0]: DefaultCT,
_CommandTypeName[0:6]: ScriptCT,
_CommandTypeLowerName[0:6]: ScriptCT,
_CommandTypeName[6:16]: ScriptFileCT,
_CommandTypeLowerName[6:16]: ScriptFileCT,
_CommandTypeName[16:28]: RemoteScriptCT,
_CommandTypeLowerName[16:28]: RemoteScriptCT,
_CommandTypeName[28:35]: PackageCT,
_CommandTypeLowerName[28:35]: PackageCT,
_CommandTypeName[35:39]: UserCT,
_CommandTypeLowerName[35:39]: UserCT,
_CommandTypeName[0:0]: DefaultCommandType,
_CommandTypeLowerName[0:0]: DefaultCommandType,
_CommandTypeName[0:6]: ScriptCommandType,
_CommandTypeLowerName[0:6]: ScriptCommandType,
_CommandTypeName[6:16]: ScriptFileCommandType,
_CommandTypeLowerName[6:16]: ScriptFileCommandType,
_CommandTypeName[16:28]: RemoteScriptCommandType,
_CommandTypeLowerName[16:28]: RemoteScriptCommandType,
_CommandTypeName[28:35]: PackageCommandType,
_CommandTypeLowerName[28:35]: PackageCommandType,
_CommandTypeName[35:39]: UserCommandType,
_CommandTypeLowerName[35:39]: UserCommandType,
}
var _CommandTypeNames = []string{

View File

@ -95,10 +95,11 @@ func (opts *ConfigOpts) InitConfig() {
} else {
loadDefaultConfigFiles(fetcher, configFiles, backyKoanf, opts)
}
opts.koanf = backyKoanf
}
func (opts *ConfigOpts) ReadConfig() *ConfigOpts {
func (opts *ConfigOpts) ParseConfigurationFile() *ConfigOpts {
setTerminalEnv()
backyKoanf := opts.koanf
@ -129,9 +130,23 @@ func (opts *ConfigOpts) ReadConfig() *ConfigOpts {
log := setupLogger(opts)
opts.Logger = log
hostsFetcher, err := remotefetcher.NewRemoteFetcher(opts.HostsFilePath, opts.Cache)
opts.Logger.Info().Str("hosts file", opts.HostsFilePath).Send()
if err != nil {
logging.ExitWithMSG(fmt.Sprintf("error initializing config fetcher: %v", err), 1, nil)
}
var hostKoanf = koanf.New(".")
if opts.HostsFilePath != "" {
loadConfigFile(hostsFetcher, opts.HostsFilePath, hostKoanf, opts)
unmarshalConfigIntoStruct(hostKoanf, "hosts", &opts.Hosts, opts.Logger)
} else {
unmarshalConfigIntoStruct(backyKoanf, "hosts", &opts.Hosts, opts.Logger)
}
log.Info().Str("config file", opts.ConfigFilePath).Send()
if err := opts.initVault(); err != nil {
if err := opts.initializeVault(); err != nil {
log.Err(err).Send()
}
@ -139,12 +154,10 @@ func (opts *ConfigOpts) ReadConfig() *ConfigOpts {
getCommandEnvironments(opts)
unmarshalConfigIntoStruct(backyKoanf, "hosts", &opts.Hosts, opts.Logger)
resolveHostConfigs(opts)
getHostConfigs(opts)
for k, v := range opts.Vars {
v = getExternalConfigDirectiveValue(v, opts)
v = getExternalConfigDirectiveValue(v, opts, AllowedExternalDirectiveAll)
opts.Vars[k] = v
}
@ -171,13 +184,13 @@ func (opts *ConfigOpts) ReadConfig() *ConfigOpts {
return opts
}
func loadConfigFile(fetcher remotefetcher.RemoteFetcher, filePath string, k *koanf.Koanf, opts *ConfigOpts) {
func loadConfigFile(fetcher remotefetcher.RemoteFetcher, filePath string, koanfConfigParser *koanf.Koanf, opts *ConfigOpts) {
data, err := fetcher.Fetch(filePath)
if err != nil {
logging.ExitWithMSG(generateFileFetchErrorString(filePath, "config", err), 1, nil)
}
if err := k.Load(rawbytes.Provider(data), yaml.Parser()); err != nil {
if err := koanfConfigParser.Load(rawbytes.Provider(data), yaml.Parser()); err != nil {
logging.ExitWithMSG(fmt.Sprintf("error loading config: %v", err), 1, &opts.Logger)
}
}
@ -222,24 +235,23 @@ func validateExecCommandsFromCLI(k *koanf.Koanf, opts *ConfigOpts) {
}
}
func setLoggingOptions(k *koanf.Koanf, opts *ConfigOpts) {
isLoggingVerbose := k.Bool(getLoggingKeyFromConfig("verbose"))
func setLoggingOptions(backyKoanf *koanf.Koanf, opts *ConfigOpts) {
isVerboseLoggingSetInConfig := backyKoanf.Bool(getLoggingKeyFromConfig("verbose"))
// if log file is set in config file and not set on command line, use "./backy.log"
logFile := "./backy.log"
if opts.LogFilePath == "" && k.Exists(getLoggingKeyFromConfig("file")) {
logFile = k.String(getLoggingKeyFromConfig("file"))
if opts.LogFilePath == "" && backyKoanf.Exists(getLoggingKeyFromConfig("file")) {
logFile = backyKoanf.String(getLoggingKeyFromConfig("file"))
opts.LogFilePath = logFile
}
opts.LogFilePath = logFile
zerolog.SetGlobalLevel(zerolog.InfoLevel)
if isLoggingVerbose {
if isVerboseLoggingSetInConfig {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
os.Setenv("BACKY_LOGLEVEL", fmt.Sprintf("%v", zerolog.GlobalLevel()))
}
if k.Bool(getLoggingKeyFromConfig("console-disabled")) {
if backyKoanf.Bool(getLoggingKeyFromConfig("console-disabled")) {
os.Setenv("BACKY_CONSOLE_LOGGING", "")
} else {
os.Setenv("BACKY_CONSOLE_LOGGING", "enabled")
@ -270,18 +282,18 @@ func getCommandEnvironments(opts *ConfigOpts) {
}
}
func resolveHostConfigs(opts *ConfigOpts) {
func getHostConfigs(opts *ConfigOpts) {
for hostConfigName, host := range opts.Hosts {
if host.Host == "" {
host.Host = hostConfigName
}
if host.ProxyJump != "" {
resolveProxyHosts(host, opts)
getProxyHosts(host, opts)
}
}
}
func resolveProxyHosts(host *Host, opts *ConfigOpts) {
func getProxyHosts(host *Host, opts *ConfigOpts) {
proxyHosts := strings.Split(host.ProxyJump, ",")
for _, h := range proxyHosts {
proxyHost, defined := opts.Hosts[h]
@ -321,12 +333,6 @@ func loadCommandLists(opts *ConfigOpts, backyKoanf *koanf.Koanf) {
listsConfig := koanf.New(".")
for _, l := range listConfigFiles {
if loadListConfigFile(l, listsConfig, opts) {
break
}
}
if backyKoanf.Exists("cmdLists") {
if backyKoanf.Exists("cmdLists.file") {
loadCmdListsFile(backyKoanf, listsConfig, opts)
@ -334,6 +340,14 @@ func loadCommandLists(opts *ConfigOpts, backyKoanf *koanf.Koanf) {
unmarshalConfigIntoStruct(backyKoanf, "cmdLists", &opts.CmdConfigLists, opts.Logger)
}
}
if opts.CmdConfigLists == nil {
for _, l := range listConfigFiles {
if loadListConfigFile(l, listsConfig, opts) {
break
}
}
}
}
func isRemoteURL(filePath string) bool {
@ -380,6 +394,7 @@ func loadListConfigFile(filePath string, k *koanf.Koanf, opts *ConfigOpts) bool
func loadCmdListsFile(backyKoanf *koanf.Koanf, listsConfig *koanf.Koanf, opts *ConfigOpts) {
opts.CmdListFile = strings.TrimSpace(backyKoanf.String("cmdLists.file"))
if !path.IsAbs(opts.CmdListFile) {
// TODO: Needs testing - might cause undefined/unexpected behavior if remote config path is used
opts.CmdListFile = path.Join(path.Dir(opts.ConfigFilePath), opts.CmdListFile)
}
@ -456,7 +471,7 @@ func getLoggingKeyFromConfig(key string) string {
// return fmt.Sprintf("cmdLists.%s", list)
// }
func (opts *ConfigOpts) initVault() error {
func (opts *ConfigOpts) initializeVault() error {
if !opts.koanf.Bool("vault.enabled") {
return nil
}
@ -501,6 +516,8 @@ func processCmds(opts *ConfigOpts) error {
// process commands
for cmdName, cmd := range opts.Cmds {
cmd.GetVariablesFromConf(opts)
cmd.Cmd = replaceVarInString(opts.Vars, cmd.Cmd, opts.Logger)
for i, v := range cmd.Args {
v = replaceVarInString(opts.Vars, v, opts.Logger)
cmd.Args[i] = v
@ -508,9 +525,8 @@ func processCmds(opts *ConfigOpts) error {
if cmd.Name == "" {
cmd.Name = cmdName
}
// println("Cmd.Name = " + cmd.Name)
hooks := cmd.Hooks
// resolve hooks
if hooks != nil {
processHookSuccess := processHooks(cmd, hooks.Error, opts, "error")
@ -562,15 +578,15 @@ func processCmds(opts *ConfigOpts) error {
}
}
if cmd.Type == PackageCT {
if cmd.Type == PackageCommandType {
if cmd.PackageManager == "" {
return fmt.Errorf("package manager is required for package command %s", cmd.PackageName)
return fmt.Errorf("package manager is required for package command %s", cmd.Name)
}
if cmd.PackageOperation.String() == "" {
return fmt.Errorf("package operation is required for package command %s", cmd.PackageName)
return fmt.Errorf("package operation is required for package command %s", cmd.Name)
}
if cmd.PackageName == "" {
return fmt.Errorf("package name is required for package command %s", cmd.PackageName)
if cmd.Packages == nil {
return fmt.Errorf("package name is required for package command %s", cmd.Name)
}
var err error
@ -588,7 +604,7 @@ func processCmds(opts *ConfigOpts) error {
}
// Parse user commands
if cmd.Type == UserCT {
if cmd.Type == UserCommandType {
if cmd.Username == "" {
return fmt.Errorf("username is required for user command %s", cmd.Name)
}
@ -605,7 +621,7 @@ func processCmds(opts *ConfigOpts) error {
if cmd.UserOperation == "password" {
opts.Logger.Debug().Msg("changing password for user: " + cmd.Username)
cmd.UserPassword = getExternalConfigDirectiveValue(cmd.UserPassword, opts)
cmd.UserPassword = getExternalConfigDirectiveValue(cmd.UserPassword, opts, AllowedExternalDirectiveAll)
}
if !IsHostLocal(cmd.Host) {
@ -617,7 +633,7 @@ func processCmds(opts *ConfigOpts) error {
}
for indx, key := range cmd.UserSshPubKeys {
opts.Logger.Debug().Msg("adding SSH Keys")
key = getExternalConfigDirectiveValue(key, opts)
key = getExternalConfigDirectiveValue(key, opts, AllowedExternalDirectiveAll)
cmd.UserSshPubKeys[indx] = key
}
if err != nil {
@ -629,7 +645,7 @@ func processCmds(opts *ConfigOpts) error {
}
if cmd.Type == RemoteScriptCT {
if cmd.Type == RemoteScriptCommandType {
var fetchErr error
if !isRemoteURL(cmd.Cmd) {
return fmt.Errorf("remoteScript command %s must be a remote resource", cmdName)
@ -640,9 +656,9 @@ func processCmds(opts *ConfigOpts) error {
}
}
if cmd.OutputFile != "" {
if cmd.Output.File != "" {
var err error
cmd.OutputFile, err = getFullPathWithHomeDir(cmd.OutputFile)
cmd.Output.File, err = getFullPathWithHomeDir(cmd.Output.File)
if err != nil {
return err
}
@ -738,8 +754,9 @@ func replaceVarInString(vars map[string]string, str string, logger zerolog.Logge
return str
}
func VariadicFunctionParameterTest(allowedKeys ...string) {
if contains(allowedKeys, "file") {
println("file param included")
}
func (c *Command) GetVariablesFromConf(opts *ConfigOpts) {
c.ScriptEnvFile = replaceVarInString(opts.Vars, c.ScriptEnvFile, opts.Logger)
c.Name = replaceVarInString(opts.Vars, c.Name, opts.Logger)
c.Output.File = replaceVarInString(opts.Vars, c.Output.File, opts.Logger)
c.Host = replaceVarInString(opts.Vars, c.Host, opts.Logger)
}

133
pkg/backy/lineinfile.go Normal file
View File

@ -0,0 +1,133 @@
package backy
import (
"bufio"
"bytes"
"fmt"
"io"
"regexp"
"strings"
"golang.org/x/crypto/ssh"
)
func sshConnect(user, password, host string, port int) (*ssh.Client, error) {
config := &ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{
ssh.Password(password),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
addr := fmt.Sprintf("%s:%d", host, port)
return ssh.Dial("tcp", addr, config)
}
func sshReadFile(client *ssh.Client, remotePath string) (string, error) {
session, err := client.NewSession()
if err != nil {
return "", err
}
defer session.Close()
var b bytes.Buffer
session.Stdout = &b
if err := session.Run(fmt.Sprintf("cat %s", remotePath)); err != nil {
return "", err
}
return b.String(), nil
}
func sshWriteFile(client *ssh.Client, remotePath, content string) error {
session, err := client.NewSession()
if err != nil {
return err
}
defer session.Close()
stdin, err := session.StdinPipe()
if err != nil {
return err
}
go func() {
defer stdin.Close()
io.WriteString(stdin, content)
}()
cmd := fmt.Sprintf("cat > %s", remotePath)
return session.Run(cmd)
}
func lineInString(content, regexpPattern, line string) string {
scanner := bufio.NewScanner(strings.NewReader(content))
var lines []string
found := false
re := regexp.MustCompile(regexpPattern)
for scanner.Scan() {
l := scanner.Text()
if re.MatchString(l) {
found = true
lines = append(lines, line)
} else {
lines = append(lines, l)
}
}
if !found {
lines = append(lines, line)
}
return strings.Join(lines, "\n") + "\n"
}
func Call() {
user := "youruser"
password := "yourpassword"
host := "yourhost"
port := 22
remotePath := "/path/to/remote/file"
client, err := sshConnect(user, password, host, port)
if err != nil {
fmt.Println("SSH connection error:", err)
return
}
defer client.Close()
content, err := sshReadFile(client, remotePath)
if err != nil {
fmt.Println("Read error:", err)
return
}
newContent := lineInString(content, "^foo=", "foo=bar")
if err := sshWriteFile(client, remotePath, newContent); err != nil {
fmt.Println("Write error:", err)
return
}
fmt.Println("Line updated successfully over SSH.")
}
type LineInFile struct {
RemotePath string // Path to the remote file
Pattern string // Regex pattern to match lines
Line string // Line to insert or replace
InsertAfter bool // If true, insert after matched line; else replace
User string // SSH username
Password string // SSH password (use key for production)
Host string // SSH host
Port int // SSH port
regexCompiled *regexp.Regexp // Compiled regex (internal use)
}
// CompileRegex compiles the regex pattern for later use
func (l *LineInFile) CompileRegex() error {
re, err := regexp.Compile(l.Pattern)
if err != nil {
return err
}
l.regexCompiled = re
return nil
}

57
pkg/backy/metrics.go Normal file
View File

@ -0,0 +1,57 @@
package backy
import (
"encoding/json"
"os"
)
type Metrics struct {
SuccessfulExecutions uint64 `json:"successful_executions"`
FailedExecutions uint64 `json:"failed_executions"`
TotalExecutions uint64 `json:"total_executions"`
ExecutionTime float64 `json:"execution_time"` // in seconds
AverageExecutionTime float64 `json:"average_execution_time"` // in seconds
SuccessRate float64 `json:"success_rate"` // percentage of successful executions
FailureRate float64 `json:"failure_rate"` // percentage of failed executions
}
func NewMetrics() *Metrics {
return &Metrics{
SuccessfulExecutions: 0,
FailedExecutions: 0,
TotalExecutions: 0,
ExecutionTime: 0.0,
AverageExecutionTime: 0.0,
SuccessRate: 0.0,
FailureRate: 0.0,
}
}
func (m *Metrics) Update(success bool, executionTime float64) {
m.TotalExecutions++
if success {
m.SuccessfulExecutions++
} else {
m.FailedExecutions++
}
m.ExecutionTime += executionTime
m.AverageExecutionTime = m.ExecutionTime / float64(m.TotalExecutions)
if m.TotalExecutions > 0 {
m.SuccessRate = float64(m.SuccessfulExecutions) / float64(m.TotalExecutions) * 100
m.FailureRate = float64(m.FailedExecutions) / float64(m.TotalExecutions) * 100
}
}
func SaveToFile(metrics *Metrics, filename string) error {
data, err := json.MarshalIndent(metrics, "", " ")
if err != nil {
return err
}
return os.WriteFile(filename, data, 0644)
}
func LoadFromFile(filename string) (*Metrics, error) {
return nil, nil
}

View File

@ -0,0 +1,7 @@
package backy
import "testing"
func TestAddingMetricsForCommand(t *testing.T) {
}

View File

@ -65,7 +65,7 @@ func (opts *ConfigOpts) SetupNotify() {
opts.Logger.Info().Err(fmt.Errorf("error: ID %s not found in mail object", confId)).Str("list", confName).Send()
continue
}
conf.Password = getExternalConfigDirectiveValue(conf.Password, opts)
conf.Password = getExternalConfigDirectiveValue(conf.Password, opts, AllowedExternalDirectiveAll)
opts.Logger.Debug().Str("list", confName).Str("id", confId).Msg("adding mail notification service")
mailConf := setupMail(conf)
services = append(services, mailConf)
@ -75,7 +75,7 @@ func (opts *ConfigOpts) SetupNotify() {
opts.Logger.Info().Err(fmt.Errorf("error: ID %s not found in matrix object", confId)).Str("list", confName).Send()
continue
}
conf.AccessToken = getExternalConfigDirectiveValue(conf.AccessToken, opts)
conf.AccessToken = getExternalConfigDirectiveValue(conf.AccessToken, opts, AllowedExternalDirectiveAll)
opts.Logger.Debug().Str("list", confName).Str("id", confId).Msg("adding matrix notification service")
mtrxConf, mtrxErr := setupMatrix(conf)
if mtrxErr != nil {

View File

@ -26,31 +26,31 @@ func (i PackageOperation) String() string {
func _PackageOperationNoOp() {
var x [1]struct{}
_ = x[DefaultPO-(0)]
_ = x[PackOpInstall-(1)]
_ = x[PackOpUpgrade-(2)]
_ = x[PackOpPurge-(3)]
_ = x[PackOpRemove-(4)]
_ = x[PackOpCheckVersion-(5)]
_ = x[PackOpIsInstalled-(6)]
_ = x[PackageOperationInstall-(1)]
_ = x[PackageOperationUpgrade-(2)]
_ = x[PackageOperationPurge-(3)]
_ = x[PackageOperationRemove-(4)]
_ = x[PackageOperationCheckVersion-(5)]
_ = x[PackageOperationIsInstalled-(6)]
}
var _PackageOperationValues = []PackageOperation{DefaultPO, PackOpInstall, PackOpUpgrade, PackOpPurge, PackOpRemove, PackOpCheckVersion, PackOpIsInstalled}
var _PackageOperationValues = []PackageOperation{DefaultPO, PackageOperationInstall, PackageOperationUpgrade, PackageOperationPurge, PackageOperationRemove, PackageOperationCheckVersion, PackageOperationIsInstalled}
var _PackageOperationNameToValueMap = map[string]PackageOperation{
_PackageOperationName[0:0]: DefaultPO,
_PackageOperationLowerName[0:0]: DefaultPO,
_PackageOperationName[0:7]: PackOpInstall,
_PackageOperationLowerName[0:7]: PackOpInstall,
_PackageOperationName[7:14]: PackOpUpgrade,
_PackageOperationLowerName[7:14]: PackOpUpgrade,
_PackageOperationName[14:19]: PackOpPurge,
_PackageOperationLowerName[14:19]: PackOpPurge,
_PackageOperationName[19:25]: PackOpRemove,
_PackageOperationLowerName[19:25]: PackOpRemove,
_PackageOperationName[25:37]: PackOpCheckVersion,
_PackageOperationLowerName[25:37]: PackOpCheckVersion,
_PackageOperationName[37:48]: PackOpIsInstalled,
_PackageOperationLowerName[37:48]: PackOpIsInstalled,
_PackageOperationName[0:7]: PackageOperationInstall,
_PackageOperationLowerName[0:7]: PackageOperationInstall,
_PackageOperationName[7:14]: PackageOperationUpgrade,
_PackageOperationLowerName[7:14]: PackageOperationUpgrade,
_PackageOperationName[14:19]: PackageOperationPurge,
_PackageOperationLowerName[14:19]: PackageOperationPurge,
_PackageOperationName[19:25]: PackageOperationRemove,
_PackageOperationLowerName[19:25]: PackageOperationRemove,
_PackageOperationName[25:37]: PackageOperationCheckVersion,
_PackageOperationLowerName[25:37]: PackageOperationCheckVersion,
_PackageOperationName[37:48]: PackageOperationIsInstalled,
_PackageOperationLowerName[37:48]: PackageOperationIsInstalled,
}
var _PackageOperationNames = []string{

View File

@ -29,38 +29,38 @@ var TS = strings.TrimSpace
// ConnectToHost connects to a host by looking up the config values in the file ~/.ssh/config
// It uses any set values and looks up an unset values in the config files
// remoteConfig is modified directly. The *ssh.Client is returned as part of remoteConfig,
// remoteHost is modified directly. The *ssh.Client is returned as part of remoteHost,
// If configFile is empty, any required configuration is looked up in the default config files
// If any value is not found, defaults are used
func (remoteConfig *Host) ConnectToHost(opts *ConfigOpts) error {
func (remoteHost *Host) ConnectToHost(opts *ConfigOpts) error {
var connectErr error
if TS(remoteConfig.ConfigFilePath) == "" {
remoteConfig.useDefaultConfig = true
if TS(remoteHost.ConfigFilePath) == "" {
remoteHost.useDefaultConfig = true
}
khPathErr := remoteConfig.GetKnownHosts()
khPathErr := remoteHost.GetKnownHosts()
if khPathErr != nil {
return khPathErr
}
if remoteConfig.ClientConfig == nil {
remoteConfig.ClientConfig = &ssh.ClientConfig{}
if remoteHost.ClientConfig == nil {
remoteHost.ClientConfig = &ssh.ClientConfig{}
}
var configFile *os.File
var sshConfigFileOpenErr error
if !remoteConfig.useDefaultConfig {
if !remoteHost.useDefaultConfig {
var err error
remoteConfig.ConfigFilePath, err = getFullPathWithHomeDir(remoteConfig.ConfigFilePath)
remoteHost.ConfigFilePath, err = getFullPathWithHomeDir(remoteHost.ConfigFilePath)
if err != nil {
return err
}
configFile, sshConfigFileOpenErr = os.Open(remoteConfig.ConfigFilePath)
configFile, sshConfigFileOpenErr = os.Open(remoteHost.ConfigFilePath)
if sshConfigFileOpenErr != nil {
return sshConfigFileOpenErr
}
@ -71,22 +71,22 @@ func (remoteConfig *Host) ConnectToHost(opts *ConfigOpts) error {
return sshConfigFileOpenErr
}
}
remoteConfig.SSHConfigFile = &sshConfigFile{}
remoteConfig.SSHConfigFile.DefaultUserSettings = ssh_config.DefaultUserSettings
remoteHost.SSHConfigFile = &sshConfigFile{}
remoteHost.SSHConfigFile.DefaultUserSettings = ssh_config.DefaultUserSettings
var decodeErr error
remoteConfig.SSHConfigFile.SshConfigFile, decodeErr = ssh_config.Decode(configFile)
remoteHost.SSHConfigFile.SshConfigFile, decodeErr = ssh_config.Decode(configFile)
if decodeErr != nil {
return decodeErr
}
err := remoteConfig.GetProxyJumpFromConfig(opts.Hosts)
err := remoteHost.GetProxyJumpFromConfig(opts.Hosts)
if err != nil {
return err
}
if remoteConfig.ProxyHost != nil {
for _, proxyHost := range remoteConfig.ProxyHost {
if remoteHost.ProxyHost != nil {
for _, proxyHost := range remoteHost.ProxyHost {
err := proxyHost.GetProxyJumpConfig(opts.Hosts, opts)
opts.Logger.Info().Msgf("Proxy host: %s", proxyHost.Host)
if err != nil {
@ -95,49 +95,49 @@ func (remoteConfig *Host) ConnectToHost(opts *ConfigOpts) error {
}
}
remoteConfig.ClientConfig.Timeout = time.Second * 30
remoteHost.ClientConfig.Timeout = time.Second * 30
remoteConfig.GetPrivateKeyFileFromConfig()
remoteHost.GetPrivateKeyFileFromConfig()
remoteConfig.GetPort()
remoteHost.GetPort()
remoteConfig.GetHostName()
remoteHost.GetHostName()
remoteConfig.CombineHostNameWithPort()
remoteHost.CombineHostNameWithPort()
remoteConfig.GetSshUserFromConfig()
remoteHost.GetSshUserFromConfig()
if remoteConfig.HostName == "" {
return errors.Errorf("No hostname found or specified for host %s", remoteConfig.Host)
if remoteHost.HostName == "" {
return errors.Errorf("No hostname found or specified for host %s", remoteHost.Host)
}
err = remoteConfig.GetAuthMethods(opts)
err = remoteHost.GetAuthMethods(opts)
if err != nil {
return err
}
hostKeyCallback, err := knownhosts.New(remoteConfig.KnownHostsFile)
hostKeyCallback, err := knownhosts.New(remoteHost.KnownHostsFile)
if err != nil {
return errors.Wrap(err, "could not create hostkeycallback function")
}
remoteConfig.ClientConfig.HostKeyCallback = hostKeyCallback
remoteHost.ClientConfig.HostKeyCallback = hostKeyCallback
remoteConfig.SshClient, connectErr = remoteConfig.ConnectThroughBastion(opts.Logger)
remoteHost.SshClient, connectErr = remoteHost.ConnectThroughBastion(opts.Logger)
if connectErr != nil {
return connectErr
}
if remoteConfig.SshClient != nil {
opts.Hosts[remoteConfig.Host] = remoteConfig
if remoteHost.SshClient != nil {
opts.Hosts[remoteHost.Host] = remoteHost
return nil
}
opts.Logger.Info().Msgf("Connecting to host %s", remoteConfig.HostName)
remoteConfig.SshClient, connectErr = ssh.Dial("tcp", remoteConfig.HostName, remoteConfig.ClientConfig)
opts.Logger.Info().Msgf("Connecting to host %s", remoteHost.HostName)
remoteHost.SshClient, connectErr = ssh.Dial("tcp", remoteHost.HostName, remoteHost.ClientConfig)
if connectErr != nil {
return connectErr
}
opts.Hosts[remoteConfig.Host] = remoteConfig
opts.Hosts[remoteHost.Host] = remoteHost
return nil
}
@ -227,6 +227,8 @@ func (remoteHost *Host) GetPrivateKeyFileFromConfig() {
var identityFile string
if remoteHost.PrivateKeyPath == "" {
identityFile, _ = remoteHost.SSHConfigFile.SshConfigFile.Get(remoteHost.Host, "IdentityFile")
// println("Identity file:", identityFile)
// println("Host:", remoteHost.Host)
if identityFile == "" {
identityFile, _ = remoteHost.SSHConfigFile.DefaultUserSettings.GetStrict(remoteHost.Host, "IdentityFile")
if identityFile == "" {
@ -238,6 +240,7 @@ func (remoteHost *Host) GetPrivateKeyFileFromConfig() {
identityFile = remoteHost.PrivateKeyPath
}
// println("Identity file:", identityFile)
remoteHost.PrivateKeyPath, _ = getFullPathWithHomeDir(identityFile)
}
@ -326,33 +329,33 @@ func (remoteHost *Host) GetKnownHosts() error {
}
func GetPrivateKeyPassword(key string, opts *ConfigOpts) string {
return getExternalConfigDirectiveValue(key, opts)
return getExternalConfigDirectiveValue(key, opts, AllowedExternalDirectiveAll)
}
// GetPassword gets any password
func GetPassword(pass string, opts *ConfigOpts) string {
return getExternalConfigDirectiveValue(pass, opts)
return getExternalConfigDirectiveValue(pass, opts, AllowedExternalDirectiveAll)
}
func (remoteConfig *Host) GetProxyJumpFromConfig(hosts map[string]*Host) error {
func (remoteHost *Host) GetProxyJumpFromConfig(hosts map[string]*Host) error {
proxyJump, _ := remoteConfig.SSHConfigFile.SshConfigFile.Get(remoteConfig.Host, "ProxyJump")
proxyJump, _ := remoteHost.SSHConfigFile.SshConfigFile.Get(remoteHost.Host, "ProxyJump")
if proxyJump == "" {
proxyJump = remoteConfig.SSHConfigFile.DefaultUserSettings.Get(remoteConfig.Host, "ProxyJump")
proxyJump = remoteHost.SSHConfigFile.DefaultUserSettings.Get(remoteHost.Host, "ProxyJump")
}
if remoteConfig.ProxyJump == "" && proxyJump != "" {
remoteConfig.ProxyJump = proxyJump
if remoteHost.ProxyJump == "" && proxyJump != "" {
remoteHost.ProxyJump = proxyJump
}
proxyJumpHosts := strings.Split(remoteConfig.ProxyJump, ",")
if remoteConfig.ProxyHost == nil && len(proxyJumpHosts) == 1 {
remoteConfig.ProxyJump = proxyJump
proxyJumpHosts := strings.Split(remoteHost.ProxyJump, ",")
if remoteHost.ProxyHost == nil && len(proxyJumpHosts) == 1 {
remoteHost.ProxyJump = proxyJump
proxyHost, proxyHostFound := hosts[proxyJump]
if proxyHostFound {
remoteConfig.ProxyHost = append(remoteConfig.ProxyHost, proxyHost)
remoteHost.ProxyHost = append(remoteHost.ProxyHost, proxyHost)
} else {
if proxyJump != "" {
newProxy := &Host{Host: proxyJump}
remoteConfig.ProxyHost = append(remoteConfig.ProxyHost, newProxy)
remoteHost.ProxyHost = append(remoteHost.ProxyHost, newProxy)
}
}
}
@ -360,25 +363,25 @@ func (remoteConfig *Host) GetProxyJumpFromConfig(hosts map[string]*Host) error {
return nil
}
func (remoteConfig *Host) GetProxyJumpConfig(hosts map[string]*Host, opts *ConfigOpts) error {
func (remoteHost *Host) GetProxyJumpConfig(hosts map[string]*Host, opts *ConfigOpts) error {
if TS(remoteConfig.ConfigFilePath) == "" {
remoteConfig.useDefaultConfig = true
if TS(remoteHost.ConfigFilePath) == "" {
remoteHost.useDefaultConfig = true
}
khPathErr := remoteConfig.GetKnownHosts()
khPathErr := remoteHost.GetKnownHosts()
if khPathErr != nil {
return khPathErr
}
if remoteConfig.ClientConfig == nil {
remoteConfig.ClientConfig = &ssh.ClientConfig{}
if remoteHost.ClientConfig == nil {
remoteHost.ClientConfig = &ssh.ClientConfig{}
}
var configFile *os.File
var sshConfigFileOpenErr error
if !remoteConfig.useDefaultConfig {
if !remoteHost.useDefaultConfig {
configFile, sshConfigFileOpenErr = os.Open(remoteConfig.ConfigFilePath)
configFile, sshConfigFileOpenErr = os.Open(remoteHost.ConfigFilePath)
if sshConfigFileOpenErr != nil {
return sshConfigFileOpenErr
}
@ -389,39 +392,39 @@ func (remoteConfig *Host) GetProxyJumpConfig(hosts map[string]*Host, opts *Confi
return sshConfigFileOpenErr
}
}
remoteConfig.SSHConfigFile = &sshConfigFile{}
remoteConfig.SSHConfigFile.DefaultUserSettings = ssh_config.DefaultUserSettings
remoteHost.SSHConfigFile = &sshConfigFile{}
remoteHost.SSHConfigFile.DefaultUserSettings = ssh_config.DefaultUserSettings
var decodeErr error
remoteConfig.SSHConfigFile.SshConfigFile, decodeErr = ssh_config.Decode(configFile)
remoteHost.SSHConfigFile.SshConfigFile, decodeErr = ssh_config.Decode(configFile)
if decodeErr != nil {
return decodeErr
}
remoteConfig.GetPrivateKeyFileFromConfig()
remoteConfig.GetPort()
remoteConfig.GetHostName()
remoteConfig.CombineHostNameWithPort()
remoteConfig.GetSshUserFromConfig()
remoteConfig.isProxyHost = true
if remoteConfig.HostName == "" {
return errors.Errorf("No hostname found or specified for host %s", remoteConfig.Host)
remoteHost.GetPrivateKeyFileFromConfig()
remoteHost.GetPort()
remoteHost.GetHostName()
remoteHost.CombineHostNameWithPort()
remoteHost.GetSshUserFromConfig()
remoteHost.isProxyHost = true
if remoteHost.HostName == "" {
return errors.Errorf("No hostname found or specified for host %s", remoteHost.Host)
}
err := remoteConfig.GetAuthMethods(opts)
err := remoteHost.GetAuthMethods(opts)
if err != nil {
return err
}
// TODO: Add value/option to config for host key and add bool to check for host key
hostKeyCallback, err := knownhosts.New(remoteConfig.KnownHostsFile)
hostKeyCallback, err := knownhosts.New(remoteHost.KnownHostsFile)
if err != nil {
return fmt.Errorf("could not create hostkeycallback function: %v", err)
}
remoteConfig.ClientConfig.HostKeyCallback = hostKeyCallback
hosts[remoteConfig.Host] = remoteConfig
remoteHost.ClientConfig.HostKeyCallback = hostKeyCallback
hosts[remoteHost.Host] = remoteHost
return nil
}
func (command *Command) RunCmdSSH(cmdCtxLogger zerolog.Logger, opts *ConfigOpts) ([]string, error) {
func (command *Command) RunCmdOnHost(cmdCtxLogger zerolog.Logger, opts *ConfigOpts) ([]string, error) {
var (
ArgsStr string
cmdOutBuf bytes.Buffer
@ -473,14 +476,14 @@ func (command *Command) RunCmdSSH(cmdCtxLogger zerolog.Logger, opts *ConfigOpts)
// Handle command execution based on type
switch command.Type {
case ScriptCT:
case ScriptCommandType:
return command.runScript(commandSession, cmdCtxLogger, &cmdOutBuf)
case RemoteScriptCT:
case RemoteScriptCommandType:
return command.runRemoteScript(commandSession, cmdCtxLogger, &cmdOutBuf)
case ScriptFileCT:
case ScriptFileCommandType:
return command.runScriptFile(commandSession, cmdCtxLogger, &cmdOutBuf)
case PackageCT:
if command.PackageOperation == PackOpCheckVersion {
case PackageCommandType:
if command.PackageOperation == PackageOperationCheckVersion {
commandSession.Stderr = nil
// Execute the package version command remotely
// Parse the output of package version command
@ -497,7 +500,7 @@ func (command *Command) RunCmdSSH(cmdCtxLogger zerolog.Logger, opts *ConfigOpts)
cmdCtxLogger.Debug().Str("cmd + args", ArgsStr).Send()
// Run simple command
if err := commandSession.Run(ArgsStr); err != nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), fmt.Errorf("error running command: %w", err)
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error running command: %w", err)
}
}
default:
@ -508,24 +511,24 @@ func (command *Command) RunCmdSSH(cmdCtxLogger zerolog.Logger, opts *ConfigOpts)
}
cmdCtxLogger.Debug().Str("cmd + args", ArgsStr).Send()
if command.Type == UserCT && command.UserOperation == "password" {
if command.Type == UserCommandType && command.UserOperation == "password" {
// cmdCtxLogger.Debug().Msgf("adding stdin")
userNamePass := fmt.Sprintf("%s:%s", command.Username, command.UserPassword)
client, err := sftp.NewClient(command.RemoteHost.SshClient)
if err != nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), fmt.Errorf("error creating sftp client: %v", err)
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error creating sftp client: %v", err)
}
uuidFile := uuid.New()
passFilePath := fmt.Sprintf("/tmp/%s", uuidFile.String())
passFile, passFileErr := client.Create(passFilePath)
if passFileErr != nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), fmt.Errorf("error creating file /tmp/%s: %v", uuidFile.String(), passFileErr)
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error creating file /tmp/%s: %v", uuidFile.String(), passFileErr)
}
_, err = passFile.Write([]byte(userNamePass))
if err != nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), fmt.Errorf("error writing to file /tmp/%s: %v", uuidFile.String(), err)
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error writing to file /tmp/%s: %v", uuidFile.String(), err)
}
ArgsStr = fmt.Sprintf("cat %s | chpasswd", passFilePath)
@ -539,10 +542,12 @@ func (command *Command) RunCmdSSH(cmdCtxLogger zerolog.Logger, opts *ConfigOpts)
// commandSession.Stdin = command.stdin
}
if err := commandSession.Run(ArgsStr); err != nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), fmt.Errorf("error running command: %w", err)
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error running command: %w", err)
}
if command.Type == UserCT {
if command.Type == UserCommandType {
// REFACTOR IF/WHEN WINDOWS SUPPORT IS ADDED
if command.UserOperation == "add" {
if command.UserSshPubKeys != nil {
@ -558,41 +563,41 @@ func (command *Command) RunCmdSSH(cmdCtxLogger zerolog.Logger, opts *ConfigOpts)
commandSession, _ = command.RemoteHost.createSSHSession(opts)
userHome, err = commandSession.CombinedOutput(fmt.Sprintf("grep \"%s\" /etc/passwd | cut -d: -f6", command.Username))
if err != nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), fmt.Errorf("error finding user home from /etc/passwd: %v", err)
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error finding user home from /etc/passwd: %v", err)
}
command.UserHome = strings.TrimSpace(string(userHome))
userSshDir := fmt.Sprintf("%s/.ssh", command.UserHome)
client, err = sftp.NewClient(command.RemoteHost.SshClient)
if err != nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), fmt.Errorf("error creating sftp client: %v", err)
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error creating sftp client: %v", err)
}
err = client.MkdirAll(userSshDir)
if err != nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), fmt.Errorf("error creating directory %s: %v", userSshDir, err)
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error creating directory %s: %v", userSshDir, err)
}
_, err = client.Create(fmt.Sprintf("%s/authorized_keys", userSshDir))
if err != nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), fmt.Errorf("error opening file %s/authorized_keys: %v", userSshDir, err)
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error opening file %s/authorized_keys: %v", userSshDir, err)
}
f, err = client.OpenFile(fmt.Sprintf("%s/authorized_keys", userSshDir), os.O_APPEND|os.O_CREATE|os.O_WRONLY)
if err != nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), fmt.Errorf("error opening file %s/authorized_keys: %v", userSshDir, err)
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error opening file %s/authorized_keys: %v", userSshDir, err)
}
defer f.Close()
for _, k := range command.UserSshPubKeys {
buf := bytes.NewBufferString(k)
cmdCtxLogger.Info().Str("key", k).Msg("adding SSH key")
if _, err := f.ReadFrom(buf); err != nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), fmt.Errorf("error adding to authorized keys: %v", err)
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error adding to authorized keys: %v", err)
}
}
commandSession, _ = command.RemoteHost.createSSHSession(opts)
_, err = commandSession.CombinedOutput(fmt.Sprintf("chown -R %s:%s %s", command.Username, command.Username, userHome))
if err != nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), err
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), err
}
}
@ -600,11 +605,13 @@ func (command *Command) RunCmdSSH(cmdCtxLogger zerolog.Logger, opts *ConfigOpts)
}
}
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), nil
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), nil
}
func checkPackageVersion(cmdCtxLogger zerolog.Logger, command *Command, commandSession *ssh.Session, cmdOutBuf bytes.Buffer) ([]string, error) {
cmdCtxLogger.Info().Str("package", command.PackageName).Msg("Checking package versions")
for _, p := range command.Packages {
cmdCtxLogger.Info().Str("package", p.Name).Msg("Checking package versions")
}
// Prepare command arguments
ArgsStr := command.Cmd
for _, v := range command.Args {
@ -619,9 +626,9 @@ func checkPackageVersion(cmdCtxLogger zerolog.Logger, command *Command, commandS
_, parseErr := parsePackageVersion(string(cmdOut), cmdCtxLogger, command, cmdOutBuf)
if parseErr != nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), fmt.Errorf("error: package %s not listed: %w", command.PackageName, err)
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error: packages %v not listed: %w", command.Packages, err)
}
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), fmt.Errorf("error running %s: %w", ArgsStr, err)
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error running %s: %w", ArgsStr, err)
}
return parsePackageVersion(string(cmdOut), cmdCtxLogger, command, cmdOutBuf)
@ -651,7 +658,7 @@ func (command *Command) runScript(session *ssh.Session, cmdCtxLogger zerolog.Log
return collectOutput(outputBuf, command.Name, cmdCtxLogger, true), fmt.Errorf("error waiting for shell: %w", err)
}
return collectOutput(outputBuf, command.Name, cmdCtxLogger, command.OutputToLog), nil
return collectOutput(outputBuf, command.Name, cmdCtxLogger, command.Output.ToLog), nil
}
// runScriptFile handles the execution of script files.
@ -670,7 +677,7 @@ func (command *Command) runScriptFile(session *ssh.Session, cmdCtxLogger zerolog
return collectOutput(outputBuf, command.Name, cmdCtxLogger, true), fmt.Errorf("error waiting for shell: %w", err)
}
return collectOutput(outputBuf, command.Name, cmdCtxLogger, command.OutputToLog), nil
return collectOutput(outputBuf, command.Name, cmdCtxLogger, command.Output.ToLog), nil
}
// prepareScriptBuffer prepares a buffer for inline scripts.
@ -727,10 +734,10 @@ func (command *Command) runRemoteScript(session *ssh.Session, cmdCtxLogger zerol
err = session.Run(command.Shell)
if err != nil {
return collectOutput(outputBuf, command.Name, cmdCtxLogger, command.OutputToLog), fmt.Errorf("error running remote script: %w", err)
return collectOutput(outputBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error running remote script: %w", err)
}
return collectOutput(outputBuf, command.Name, cmdCtxLogger, command.OutputToLog), nil
return collectOutput(outputBuf, command.Name, cmdCtxLogger, command.Output.ToLog), nil
}
// readFileToBuffer reads a file into a buffer.
@ -803,7 +810,7 @@ func (h *Host) DetectOS(opts *ConfigOpts) (string, error) {
return osName, nil
}
func CheckIfHostHasHostName(host string) (bool, string) {
func DoesHostHaveHostName(host string) (bool, string) {
HostName, err := ssh_config.DefaultUserSettings.GetStrict(host, "HostName")
if err != nil {
return false, ""

View File

@ -7,6 +7,7 @@ import (
"strings"
"git.andrewnw.xyz/CyberShell/backy/pkg/pkgman"
packagemanagercommon "git.andrewnw.xyz/CyberShell/backy/pkg/pkgman/common"
"git.andrewnw.xyz/CyberShell/backy/pkg/remotefetcher"
"git.andrewnw.xyz/CyberShell/backy/pkg/usermanager"
vaultapi "github.com/hashicorp/vault/api"
@ -33,7 +34,7 @@ type (
Port uint16 `yaml:"port,omitempty"`
ProxyJump string `yaml:"proxyjump,omitempty"`
Password string `yaml:"password,omitempty"`
PrivateKeyPath string `yaml:"privateKeyPath,omitempty"`
PrivateKeyPath string `yaml:"IdentityFile,omitempty"`
PrivateKeyPassword string `yaml:"privateKeyPassword,omitempty"`
useDefaultConfig bool
User string `yaml:"user,omitempty"`
@ -74,19 +75,19 @@ type (
Environment []string `yaml:"environment,omitempty"`
GetOutputInList bool `yaml:"getOutputInList,omitempty"`
ScriptEnvFile string `yaml:"scriptEnvFile"`
OutputToLog bool `yaml:"outputToLog,omitempty"`
OutputFile string `yaml:"outputFile,omitempty"`
Output struct {
File string `yaml:"file,omitempty"`
ToLog bool `yaml:"toLog,omitempty"`
InList bool `yaml:"inList,omitempty"`
} `yaml:"output"`
// BEGIN PACKAGE COMMAND FIELDS
PackageManager string `yaml:"packageManager,omitempty"`
PackageName string `yaml:"packageName,omitempty"`
Packages []packagemanagercommon.Package `yaml:"packages,omitempty"`
PackageVersion string `yaml:"packageVersion,omitempty"`
@ -135,7 +136,7 @@ type (
// stdin only for userOperation = password (for now)
stdin *strings.Reader
// END USER STRUCT FIELDS
// END USER STRUCommandType FIELDS
}
RemoteSource struct {
@ -150,13 +151,17 @@ type (
BackyOptionFunc func(*ConfigOpts)
CmdList struct {
Name string `yaml:"name,omitempty"`
Cron string `yaml:"cron,omitempty"`
RunCmdOnFailure string `yaml:"runCmdOnFailure,omitempty"`
Order []string `yaml:"order,omitempty"`
Notifications []string `yaml:"notifications,omitempty"`
GetOutput bool `yaml:"getOutput,omitempty"`
NotifyOnSuccess bool `yaml:"notifyOnSuccess,omitempty"`
Name string `yaml:"name,omitempty"`
Cron string `yaml:"cron,omitempty"`
RunCmdOnFailure string `yaml:"runCmdOnFailure,omitempty"`
Order []string `yaml:"order,omitempty"`
Notifications []string `yaml:"notifications,omitempty"`
GetCommandOutputInNotificationsOnSuccess bool `yaml:"sendNotificationOnSuccess,omitempty"`
Notify struct {
OnFailure bool `yaml:"onFailure,omitempty"`
OnSuccess bool `yaml:"onSuccess,omitempty"`
} `yaml:"notify,omitempty"`
NotifyConfig *notify.Notify
Source string `yaml:"source"` // URL to fetch remote commands
@ -185,6 +190,8 @@ type (
ConfigFilePath string
HostsFilePath string
ConfigDir string
LogFilePath string
@ -277,6 +284,20 @@ type (
Error error // Error encountered, if any
}
ListMetrics struct {
Name string
SuccessfulExecutions uint64
FailedExecutions uint64
TotalExecutions uint64
}
CommandMetrics struct {
Name string
SuccessfulExecutions uint64
FailedExecutions uint64
TotalExecutions uint64
}
// use ints so we can use enums
CommandType int
PackageOperation int
@ -285,29 +306,30 @@ type (
//go:generate go run github.com/dmarkham/enumer -linecomment -yaml -text -json -type=CommandType
const (
DefaultCT CommandType = iota //
ScriptCT // script
ScriptFileCT // scriptFile
RemoteScriptCT // remoteScript
PackageCT // package
UserCT // user
DefaultCommandType CommandType = iota //
ScriptCommandType // script
ScriptFileCommandType // scriptFile
RemoteScriptCommandType // remoteScript
PackageCommandType // package
UserCommandType // user
)
//go:generate go run github.com/dmarkham/enumer -linecomment -yaml -text -json -type=PackageOperation
const (
DefaultPO PackageOperation = iota //
PackOpInstall // install
PackOpUpgrade // upgrade
PackOpPurge // purge
PackOpRemove // remove
PackOpCheckVersion // checkVersion
PackOpIsInstalled // isInstalled
DefaultPO PackageOperation = iota //
PackageOperationInstall // install
PackageOperationUpgrade // upgrade
PackageOperationPurge // purge
PackageOperationRemove // remove
PackageOperationCheckVersion // checkVersion
PackageOperationIsInstalled // isInstalled
)
//go:generate go run github.com/dmarkham/enumer -linecomment -yaml -text -json -type=AllowedExternalDirectives
const (
DefaultExternalDir AllowedExternalDirectives = iota
AllowedExternalDirectiveVault // vault
AllowedExternalDirectiveVaultEnv // vault-env
AllowedExternalDirectiveVaultFile // vault-file
AllowedExternalDirectiveAll // vault-file-env
AllowedExternalDirectiveFileEnv // file-env

View File

@ -13,6 +13,7 @@ import (
"os/exec"
"path"
"path/filepath"
"regexp"
"strings"
"git.andrewnw.xyz/CyberShell/backy/pkg/logging"
@ -67,8 +68,14 @@ func SetLogFile(logFile string) BackyOptionFunc {
}
}
// SetCmdStdOut forces the command output to stdout
func SetCmdStdOut(setStdOut bool) BackyOptionFunc {
func SetHostsConfigFile(hostsConfigFile string) BackyOptionFunc {
return func(bco *ConfigOpts) {
bco.HostsFilePath = hostsConfigFile
}
}
// EnableCommandStdOut forces the command output to stdout
func EnableCommandStdOut(setStdOut bool) BackyOptionFunc {
return func(bco *ConfigOpts) {
bco.CmdStdOut = setStdOut
}
@ -81,7 +88,7 @@ func EnableCron() BackyOptionFunc {
}
}
func NewOpts(configFilePath string, opts ...BackyOptionFunc) *ConfigOpts {
func NewConfigOptions(configFilePath string, opts ...BackyOptionFunc) *ConfigOpts {
b := &ConfigOpts{}
b.ConfigFilePath = configFilePath
for _, opt := range opts {
@ -110,7 +117,7 @@ func injectEnvIntoSSH(envVarsToInject environmentVars, process *ssh.Session, opt
goto errEnvFile
}
for key, val := range envMap {
err = process.Setenv(key, GetVaultKey(val, opts, log))
err = process.Setenv(key, getExternalConfigDirectiveValue(val, opts, AllowedExternalDirectiveVault))
if err != nil {
log.Error().Err(err).Send()
@ -125,10 +132,9 @@ errEnvFile:
if strings.Contains(envVal, "=") {
envVarArr := strings.Split(envVal, "=")
err := process.Setenv(envVarArr[0], getExternalConfigDirectiveValue(envVarArr[1], opts))
err := process.Setenv(envVarArr[0], getExternalConfigDirectiveValue(envVarArr[1], opts, AllowedExternalDirectiveVaultFile))
if err != nil {
log.Error().Err(err).Send()
}
}
}
@ -159,7 +165,7 @@ errEnvFile:
for _, envVal := range envVarsToInject.env {
if strings.Contains(envVal, "=") {
envVarArr := strings.Split(envVal, "=")
process.Env = append(process.Env, fmt.Sprintf("%s=%s", envVarArr[0], getExternalConfigDirectiveValue(envVarArr[1], opts)))
process.Env = append(process.Env, fmt.Sprintf("%s=%s", envVarArr[0], getExternalConfigDirectiveValue(envVarArr[1], opts, AllowedExternalDirectiveVault)))
}
}
process.Env = append(process.Env, os.Environ()...)
@ -281,21 +287,21 @@ func expandEnvVars(backyEnv map[string]string, envVars []string) {
func getCommandTypeAndSetCommandInfo(command *Command) *Command {
if command.Type == PackageCT && !command.packageCmdSet {
if command.Type == PackageCommandType && !command.packageCmdSet {
command.packageCmdSet = true
switch command.PackageOperation {
case PackOpInstall:
command.Cmd, command.Args = command.pkgMan.Install(command.PackageName, command.PackageVersion, command.Args)
case PackOpRemove:
command.Cmd, command.Args = command.pkgMan.Remove(command.PackageName, command.Args)
case PackOpUpgrade:
command.Cmd, command.Args = command.pkgMan.Upgrade(command.PackageName, command.PackageVersion)
case PackOpCheckVersion:
command.Cmd, command.Args = command.pkgMan.CheckVersion(command.PackageName, command.PackageVersion)
case PackageOperationInstall:
command.Cmd, command.Args = command.pkgMan.Install(command.Packages, command.Args)
case PackageOperationRemove:
command.Cmd, command.Args = command.pkgMan.Remove(command.Packages, command.Args)
case PackageOperationUpgrade:
command.Cmd, command.Args = command.pkgMan.Upgrade(command.Packages)
case PackageOperationCheckVersion:
command.Cmd, command.Args = command.pkgMan.CheckVersion(command.Packages)
}
}
if command.Type == UserCT && !command.userCmdSet {
if command.Type == UserCommandType && !command.userCmdSet {
command.userCmdSet = true
switch command.UserOperation {
case "add":
@ -327,72 +333,121 @@ func getCommandTypeAndSetCommandInfo(command *Command) *Command {
func parsePackageVersion(output string, cmdCtxLogger zerolog.Logger, command *Command, cmdOutBuf bytes.Buffer) ([]string, error) {
var err error
pkgVersion, err := command.pkgMan.Parse(output)
// println(output)
if err != nil {
cmdCtxLogger.Error().Err(err).Str("package", command.PackageName).Msg("Error parsing package version output")
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.OutputToLog), err
var errs []error
pkgVersionOnSystem, errs := command.pkgMan.ParseRemotePackageManagerVersionOutput(output)
if errs != nil {
cmdCtxLogger.Error().Errs("Error parsing package version output", errs).Send()
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error parsing package version output: %v", errs)
}
cmdCtxLogger.Info().
Str("Installed", pkgVersion.Installed).
Str("Candidate", pkgVersion.Candidate).
Msg("Package version comparison")
if command.PackageVersion != "" {
if pkgVersion.Installed == command.PackageVersion {
cmdCtxLogger.Info().Msgf("Installed version matches specified version: %s", command.PackageVersion)
} else {
cmdCtxLogger.Info().Msgf("Installed version does not match specified version: %s", command.PackageVersion)
err = fmt.Errorf("Installed version does not match specified version: %s", command.PackageVersion)
for _, p := range pkgVersionOnSystem {
packageIndex := getPackageIndexFromCommand(command, p.Name)
if packageIndex == -1 {
cmdCtxLogger.Error().Str("package", p.Name).Msg("Package not found in command")
continue
}
} else {
if pkgVersion.Installed == pkgVersion.Candidate {
cmdCtxLogger.Info().Msg("Installed and Candidate versions match")
command.Packages[packageIndex].VersionCheck = p.VersionCheck
packageFromCommand := command.Packages[packageIndex]
cmdCtxLogger.Info().
Str("Installed", packageFromCommand.VersionCheck.Installed).
Msg("Package version comparison")
versionLogger := cmdCtxLogger.With().Str("package", packageFromCommand.Name).Logger()
if packageFromCommand.Version != "" {
versionLogger := cmdCtxLogger.With().Str("package", packageFromCommand.Name).Str("Specified Version", packageFromCommand.Version).Logger()
packageVersionRegex, PkgRegexErr := regexp.Compile(packageFromCommand.Version)
if PkgRegexErr != nil {
versionLogger.Error().Err(PkgRegexErr).Msg("Error compiling package version regex")
errs = append(errs, PkgRegexErr)
continue
}
if p.Version == packageFromCommand.Version {
versionLogger.Info().Msgf("Installed version matches specified version: %s", packageFromCommand.Version)
} else if packageVersionRegex.MatchString(p.VersionCheck.Installed) {
versionLogger.Info().Msgf("Installed version contains specified version: %s", packageFromCommand.Version)
} else {
versionLogger.Info().Msg("Installed version does not match specified version")
errs = append(errs, fmt.Errorf("installed version of %s does not match specified version: %s", packageFromCommand.Name, packageFromCommand.Version))
}
} else {
cmdCtxLogger.Info().Msg("Installed and Candidate versions differ")
err = errors.New("Installed and Candidate versions differ")
if p.VersionCheck.Installed == p.VersionCheck.Candidate {
versionLogger.Info().Msg("Installed and Candidate versions match")
} else {
cmdCtxLogger.Info().Msg("Installed and Candidate versions differ")
errs = append(errs, errors.New("installed and Candidate versions differ"))
}
}
}
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, false), err
if errs == nil {
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), nil
}
return collectOutput(&cmdOutBuf, command.Name, cmdCtxLogger, command.Output.ToLog), fmt.Errorf("error parsing package version output: %v", errs)
}
func getExternalConfigDirectiveValue(key string, opts *ConfigOpts) string {
func getPackageIndexFromCommand(command *Command, name string) int {
for i, v := range command.Packages {
if name == v.Name {
return i
}
}
return -1
}
func getExternalConfigDirectiveValue(key string, opts *ConfigOpts, allowedDirectives AllowedExternalDirectives) string {
if !(strings.HasPrefix(key, externDirectiveStart) && strings.HasSuffix(key, externDirectiveEnd)) {
return key
}
key = replaceVarInString(opts.Vars, key, opts.Logger)
opts.Logger.Debug().Str("expanding external key", key).Send()
if strings.HasPrefix(key, envExternDirectiveStart) {
key = strings.TrimPrefix(key, envExternDirectiveStart)
key = strings.TrimSuffix(key, externDirectiveEnd)
key = os.Getenv(key)
if IsExternalDirectiveEnv(allowedDirectives) {
key = strings.TrimPrefix(key, envExternDirectiveStart)
key = strings.TrimSuffix(key, externDirectiveEnd)
key = os.Getenv(key)
} else {
opts.Logger.Error().Msgf("Config key with value %s does not support env directive", key)
}
}
if strings.HasPrefix(key, externFileDirectiveStart) {
var err error
var keyValue []byte
key = strings.TrimPrefix(key, externFileDirectiveStart)
key = strings.TrimSuffix(key, externDirectiveEnd)
key, err = getFullPathWithHomeDir(key)
if err != nil {
opts.Logger.Err(err).Send()
return ""
if IsExternalDirectiveFile(allowedDirectives) {
var err error
var keyValue []byte
key = strings.TrimPrefix(key, externFileDirectiveStart)
key = strings.TrimSuffix(key, externDirectiveEnd)
key, err = getFullPathWithHomeDir(key)
if err != nil {
opts.Logger.Err(err).Send()
return ""
}
if !path.IsAbs(key) {
key = path.Join(opts.ConfigDir, key)
}
keyValue, err = os.ReadFile(key)
if err != nil {
opts.Logger.Err(err).Send()
return ""
}
key = string(keyValue)
} else {
opts.Logger.Error().Msgf("Config key with value %s does not support file directive", key)
}
if !path.IsAbs(key) {
key = path.Join(opts.ConfigDir, key)
}
keyValue, err = os.ReadFile(key)
if err != nil {
opts.Logger.Err(err).Send()
return ""
}
key = string(keyValue)
}
if strings.HasPrefix(key, vaultExternDirectiveStart) {
key = strings.TrimPrefix(key, vaultExternDirectiveStart)
key = strings.TrimSuffix(key, externDirectiveEnd)
key = GetVaultKey(key, opts, opts.Logger)
if IsExternalDirectiveVault(allowedDirectives) {
key = strings.TrimPrefix(key, vaultExternDirectiveStart)
key = strings.TrimSuffix(key, externDirectiveEnd)
key = GetVaultKey(key, opts, opts.Logger)
} else {
opts.Logger.Error().Msgf("Config key with value %s does not support vault directive", key)
}
}
return key
@ -451,3 +506,15 @@ func GetVaultKey(str string, opts *ConfigOpts, log zerolog.Logger) string {
}
return value
}
func IsExternalDirectiveFile(allowedExternalDirectives AllowedExternalDirectives) bool {
return strings.Contains(allowedExternalDirectives.String(), "file")
}
func IsExternalDirectiveEnv(allowedExternalDirectives AllowedExternalDirectives) bool {
return strings.Contains(allowedExternalDirectives.String(), "env")
}
func IsExternalDirectiveVault(allowedExternalDirectives AllowedExternalDirectives) bool {
return strings.Contains(allowedExternalDirectives.String(), "vault")
}

View File

@ -1,18 +1,19 @@
package apt
import (
"bufio"
"bytes"
"fmt"
"regexp"
"strings"
"git.andrewnw.xyz/CyberShell/backy/pkg/pkgman/pkgcommon"
packagemanagercommon "git.andrewnw.xyz/CyberShell/backy/pkg/pkgman/common"
)
// AptManager implements PackageManager for systems using APT.
type AptManager struct {
useAuth bool // Whether to use an authentication command
authCommand string // The authentication command, e.g., "sudo"
Parser pkgcommon.PackageParser
Parser packagemanagercommon.PackageParser
}
// DefaultAuthCommand is the default command used for authentication.
@ -29,14 +30,13 @@ func NewAptManager() *AptManager {
}
// Install returns the command and arguments for installing a package.
func (a *AptManager) Install(pkg, version string, args []string) (string, []string) {
func (a *AptManager) Install(pkgs []packagemanagercommon.Package, args []string) (string, []string) {
baseCmd := a.prependAuthCommand(DefaultPackageCommand)
baseArgs := []string{"update", "&&", baseCmd, "install", "-y"}
if version != "" {
baseArgs = append(baseArgs, fmt.Sprintf("%s=%s", pkg, version))
} else {
baseArgs = append(baseArgs, pkg)
for _, p := range pkgs {
baseArgs = append(baseArgs, p.Name)
}
if args != nil {
baseArgs = append(baseArgs, args...)
}
@ -44,31 +44,34 @@ func (a *AptManager) Install(pkg, version string, args []string) (string, []stri
}
// Remove returns the command and arguments for removing a package.
func (a *AptManager) Remove(pkg string, args []string) (string, []string) {
func (a *AptManager) Remove(pkgs []packagemanagercommon.Package, args []string) (string, []string) {
baseCmd := a.prependAuthCommand(DefaultPackageCommand)
baseArgs := []string{"remove", "-y", pkg}
baseArgs := []string{"remove", "-y"}
for _, p := range pkgs {
baseArgs = append(baseArgs, p.Name)
}
if args != nil {
baseArgs = append(baseArgs, args...)
}
return baseCmd, baseArgs
}
// Upgrade returns the command and arguments for upgrading a specific package.
func (a *AptManager) Upgrade(pkg, version string) (string, []string) {
func (a *AptManager) Upgrade(pkgs []packagemanagercommon.Package) (string, []string) {
baseCmd := a.prependAuthCommand(DefaultPackageCommand)
baseArgs := []string{"update", "&&", baseCmd, "install", "--only-upgrade", "-y"}
if version != "" {
baseArgs = append(baseArgs, fmt.Sprintf("%s=%s", pkg, version))
} else {
baseArgs = append(baseArgs, pkg)
for _, p := range pkgs {
baseArgs = append(baseArgs, p.Name)
}
return baseCmd, baseArgs
}
// CheckVersion returns the command and arguments for checking the info of a specific package.
func (a *AptManager) CheckVersion(pkg, version string) (string, []string) {
func (a *AptManager) CheckVersion(pkgs []packagemanagercommon.Package) (string, []string) {
baseCmd := a.prependAuthCommand("apt-cache")
baseArgs := []string{"policy", pkg}
baseArgs := []string{"policy"}
for _, p := range pkgs {
baseArgs = append(baseArgs, p.Name)
}
return baseCmd, baseArgs
}
@ -81,7 +84,7 @@ func (a *AptManager) UpgradeAll() (string, []string) {
}
// Configure applies functional options to customize the package manager.
func (a *AptManager) Configure(options ...pkgcommon.PackageManagerOption) {
func (a *AptManager) Configure(options ...packagemanagercommon.PackageManagerOption) {
for _, opt := range options {
opt(a)
}
@ -106,25 +109,56 @@ func (a *AptManager) SetAuthCommand(authCommand string) {
}
// Parse parses the apt-cache policy output to extract Installed and Candidate versions.
func (a *AptManager) Parse(output string) (*pkgcommon.PackageVersion, error) {
func (a *AptManager) ParseRemotePackageManagerVersionOutput(output string) ([]packagemanagercommon.Package, []error) {
var (
packageName string
installedString string
candidateString string
countRelevantLines int
)
// Check for error message in the output
if strings.Contains(output, "Unable to locate package") {
return nil, fmt.Errorf("error: %s", strings.TrimSpace(output))
return nil, []error{fmt.Errorf("error: %s", strings.TrimSpace(output))}
}
packages := []packagemanagercommon.Package{}
outputBuf := bytes.NewBufferString(output)
outputScan := bufio.NewScanner(outputBuf)
for outputScan.Scan() {
line := outputScan.Text()
if !strings.HasPrefix(line, " ") && strings.HasSuffix(line, ":") {
// count++
packageName = strings.TrimSpace(strings.TrimSuffix(line, ":"))
}
if strings.Contains(line, "Installed:") {
countRelevantLines++
installedString = strings.TrimPrefix(strings.TrimSpace(line), "Installed:")
}
if strings.Contains(line, "Candidate:") {
countRelevantLines++
candidateString = strings.TrimPrefix(strings.TrimSpace(line), "Candidate:")
}
if countRelevantLines == 2 {
countRelevantLines = 0
packages = append(packages, packagemanagercommon.Package{
Name: packageName,
VersionCheck: packagemanagercommon.PackageVersion{
Installed: strings.TrimSpace(installedString),
Candidate: strings.TrimSpace(candidateString),
Match: installedString == candidateString,
}},
)
}
}
reInstalled := regexp.MustCompile(`Installed:\s*([^\s]+)`)
reCandidate := regexp.MustCompile(`Candidate:\s*([^\s]+)`)
installedMatch := reInstalled.FindStringSubmatch(output)
candidateMatch := reCandidate.FindStringSubmatch(output)
if len(installedMatch) < 2 || len(candidateMatch) < 2 {
return nil, fmt.Errorf("failed to parse Installed or Candidate versions from apt output. check package name")
}
return &pkgcommon.PackageVersion{
Installed: strings.TrimSpace(installedMatch[1]),
Candidate: strings.TrimSpace(candidateMatch[1]),
Match: installedMatch[1] == candidateMatch[1],
}, nil
return packages, nil
}
func SearchPackages(pkgs []string, version string) (string, []string) {
baseCommand := "dpkg-query"
baseArgs := []string{"-W", "-f='${Package}\t${Architecture}\t${db:Status-Status}\t${Version}\t${Installed-Size}\t${Binary:summary}\n'"}
baseArgs = append(baseArgs, pkgs...)
return baseCommand, baseArgs
}

View File

@ -1,4 +1,4 @@
package pkgcommon
package packagemanagercommon
// PackageManagerOption defines a functional option for configuring a PackageManager.
type PackageManagerOption func(interface{})
@ -15,3 +15,9 @@ type PackageVersion struct {
Match bool
Message string
}
type Package struct {
Name string `yaml:"name"`
Version string `yaml:"version,omitempty"`
VersionCheck PackageVersion
}

View File

@ -5,7 +5,7 @@ import (
"regexp"
"strings"
"git.andrewnw.xyz/CyberShell/backy/pkg/pkgman/pkgcommon"
packagemanagercommon "git.andrewnw.xyz/CyberShell/backy/pkg/pkgman/common"
)
// DnfManager implements PackageManager for systems using YUM.
@ -26,21 +26,21 @@ func NewDnfManager() *DnfManager {
}
// Configure applies functional options to customize the package manager.
func (y *DnfManager) Configure(options ...pkgcommon.PackageManagerOption) {
func (y *DnfManager) Configure(options ...packagemanagercommon.PackageManagerOption) {
for _, opt := range options {
opt(y)
}
}
// Install returns the command and arguments for installing a package.
func (y *DnfManager) Install(pkg, version string, args []string) (string, []string) {
func (y *DnfManager) Install(pkgs []packagemanagercommon.Package, args []string) (string, []string) {
baseCmd := y.prependAuthCommand("dnf")
baseArgs := []string{"install", "-y"}
if version != "" {
baseArgs = append(baseArgs, fmt.Sprintf("%s-%s", pkg, version))
} else {
baseArgs = append(baseArgs, pkg)
for _, p := range pkgs {
baseArgs = append(baseArgs, p.Name)
}
if args != nil {
baseArgs = append(baseArgs, args...)
}
@ -48,9 +48,13 @@ func (y *DnfManager) Install(pkg, version string, args []string) (string, []stri
}
// Remove returns the command and arguments for removing a package.
func (y *DnfManager) Remove(pkg string, args []string) (string, []string) {
func (y *DnfManager) Remove(pkgs []packagemanagercommon.Package, args []string) (string, []string) {
baseCmd := y.prependAuthCommand("dnf")
baseArgs := []string{"remove", "-y", pkg}
baseArgs := []string{"remove", "-y"}
for _, p := range pkgs {
baseArgs = append(baseArgs, p.Name)
}
if args != nil {
baseArgs = append(baseArgs, args...)
}
@ -58,38 +62,41 @@ func (y *DnfManager) Remove(pkg string, args []string) (string, []string) {
}
// Upgrade returns the command and arguments for upgrading a specific package.
func (y *DnfManager) Upgrade(pkg, version string) (string, []string) {
func (y *DnfManager) Upgrade(pkgs []packagemanagercommon.Package) (string, []string) {
baseCmd := y.prependAuthCommand("dnf")
baseArgs := []string{"update", "-y"}
if version != "" {
baseArgs = append(baseArgs, fmt.Sprintf("%s-%s", pkg, version))
} else {
baseArgs = append(baseArgs, pkg)
for _, p := range pkgs {
baseArgs = append(baseArgs, p.Name)
}
return baseCmd, baseArgs
}
// UpgradeAll returns the command and arguments for upgrading all packages.
func (y *DnfManager) UpgradeAll() (string, []string) {
baseCmd := y.prependAuthCommand("dnf")
baseArgs := []string{"update", "-y"}
baseArgs := []string{"upgrade", "-y"}
return baseCmd, baseArgs
}
// CheckVersion returns the command and arguments for checking the info of a specific package.
func (d *DnfManager) CheckVersion(pkg, version string) (string, []string) {
func (d *DnfManager) CheckVersion(pkgs []packagemanagercommon.Package) (string, []string) {
baseCmd := d.prependAuthCommand("dnf")
baseArgs := []string{"info", pkg}
baseArgs := []string{"info"}
for _, p := range pkgs {
baseArgs = append(baseArgs, p.Name)
}
return baseCmd, baseArgs
}
// Parse parses the dnf info output to extract Installed and Candidate versions.
func (d DnfManager) Parse(output string) (*pkgcommon.PackageVersion, error) {
func (d DnfManager) ParseRemotePackageManagerVersionOutput(output string) ([]packagemanagercommon.Package, []error) {
// Check for error message in the output
if strings.Contains(output, "No matching packages to list") {
return nil, fmt.Errorf("error: package not listed")
return nil, []error{fmt.Errorf("error: package not listed")}
}
// Define regular expressions to capture installed and available versions
@ -111,13 +118,10 @@ func (d DnfManager) Parse(output string) (*pkgcommon.PackageVersion, error) {
}
if installedVersion == "" && candidateVersion == "" {
return nil, fmt.Errorf("failed to parse versions from dnf output")
return nil, []error{fmt.Errorf("failed to parse versions from dnf output")}
}
return &pkgcommon.PackageVersion{
Installed: installedVersion,
Candidate: candidateVersion,
}, nil
return nil, nil
}
// prependAuthCommand prepends the authentication command if UseAuth is true.

View File

@ -4,26 +4,26 @@ import (
"fmt"
"git.andrewnw.xyz/CyberShell/backy/pkg/pkgman/apt"
packagemanagercommon "git.andrewnw.xyz/CyberShell/backy/pkg/pkgman/common"
"git.andrewnw.xyz/CyberShell/backy/pkg/pkgman/dnf"
"git.andrewnw.xyz/CyberShell/backy/pkg/pkgman/pkgcommon"
"git.andrewnw.xyz/CyberShell/backy/pkg/pkgman/yum"
)
// PackageManager is an interface used to define common package commands. This shall be implemented by every package.
type PackageManager interface {
Install(pkg, version string, args []string) (string, []string)
Remove(pkg string, args []string) (string, []string)
Upgrade(pkg, version string) (string, []string) // Upgrade a specific package
Install(pkgs []packagemanagercommon.Package, args []string) (string, []string)
Remove(pkgs []packagemanagercommon.Package, args []string) (string, []string)
Upgrade(pkgs []packagemanagercommon.Package) (string, []string) // Upgrade a specific package
UpgradeAll() (string, []string)
CheckVersion(pkg, version string) (string, []string)
Parse(output string) (*pkgcommon.PackageVersion, error)
CheckVersion(pkgs []packagemanagercommon.Package) (string, []string)
ParseRemotePackageManagerVersionOutput(output string) ([]packagemanagercommon.Package, []error)
// Configure applies functional options to customize the package manager.
Configure(options ...pkgcommon.PackageManagerOption)
Configure(options ...packagemanagercommon.PackageManagerOption)
}
// PackageManagerFactory returns the appropriate PackageManager based on the package tool.
// Takes variable number of options.
func PackageManagerFactory(managerType string, options ...pkgcommon.PackageManagerOption) (PackageManager, error) {
func PackageManagerFactory(managerType string, options ...packagemanagercommon.PackageManagerOption) (PackageManager, error) {
var manager PackageManager
switch managerType {
@ -43,7 +43,7 @@ func PackageManagerFactory(managerType string, options ...pkgcommon.PackageManag
}
// WithAuth enables authentication and sets the authentication command.
func WithAuth(authCommand string) pkgcommon.PackageManagerOption {
func WithAuth(authCommand string) packagemanagercommon.PackageManagerOption {
return func(manager interface{}) {
if configurable, ok := manager.(interface {
SetUseAuth(bool)
@ -56,7 +56,7 @@ func WithAuth(authCommand string) pkgcommon.PackageManagerOption {
}
// WithoutAuth disables authentication.
func WithoutAuth() pkgcommon.PackageManagerOption {
func WithoutAuth() packagemanagercommon.PackageManagerOption {
return func(manager interface{}) {
if configurable, ok := manager.(interface {
SetUseAuth(bool)
@ -68,8 +68,8 @@ func WithoutAuth() pkgcommon.PackageManagerOption {
// ConfigurablePackageManager defines methods for setting configuration options.
type ConfigurablePackageManager interface {
pkgcommon.PackageParser
packagemanagercommon.PackageParser
SetUseAuth(useAuth bool)
SetAuthCommand(authCommand string)
SetPackageParser(parser pkgcommon.PackageParser)
SetPackageParser(parser packagemanagercommon.PackageParser)
}

View File

@ -3,8 +3,9 @@ package yum
import (
"fmt"
"regexp"
"strings"
"git.andrewnw.xyz/CyberShell/backy/pkg/pkgman/pkgcommon"
packagemanagercommon "git.andrewnw.xyz/CyberShell/backy/pkg/pkgman/common"
)
// YumManager implements PackageManager for systems using YUM.
@ -25,21 +26,20 @@ func NewYumManager() *YumManager {
}
// Configure applies functional options to customize the package manager.
func (y *YumManager) Configure(options ...pkgcommon.PackageManagerOption) {
func (y *YumManager) Configure(options ...packagemanagercommon.PackageManagerOption) {
for _, opt := range options {
opt(y)
}
}
// Install returns the command and arguments for installing a package.
func (y *YumManager) Install(pkg, version string, args []string) (string, []string) {
func (y *YumManager) Install(pkgs []packagemanagercommon.Package, args []string) (string, []string) {
baseCmd := y.prependAuthCommand("yum")
baseArgs := []string{"install", "-y"}
if version != "" {
baseArgs = append(baseArgs, fmt.Sprintf("%s-%s", pkg, version))
} else {
baseArgs = append(baseArgs, pkg)
for _, p := range pkgs {
baseArgs = append(baseArgs, p.Name)
}
if args != nil {
baseArgs = append(baseArgs, args...)
}
@ -47,9 +47,13 @@ func (y *YumManager) Install(pkg, version string, args []string) (string, []stri
}
// Remove returns the command and arguments for removing a package.
func (y *YumManager) Remove(pkg string, args []string) (string, []string) {
func (y *YumManager) Remove(pkgs []packagemanagercommon.Package, args []string) (string, []string) {
baseCmd := y.prependAuthCommand("yum")
baseArgs := []string{"remove", "-y", pkg}
baseArgs := []string{"remove", "-y"}
for _, p := range pkgs {
baseArgs = append(baseArgs, p.Name)
}
if args != nil {
baseArgs = append(baseArgs, args...)
}
@ -57,14 +61,13 @@ func (y *YumManager) Remove(pkg string, args []string) (string, []string) {
}
// Upgrade returns the command and arguments for upgrading a specific package.
func (y *YumManager) Upgrade(pkg, version string) (string, []string) {
func (y *YumManager) Upgrade(pkgs []packagemanagercommon.Package) (string, []string) {
baseCmd := y.prependAuthCommand("yum")
baseArgs := []string{"update", "-y"}
if version != "" {
baseArgs = append(baseArgs, fmt.Sprintf("%s-%s", pkg, version))
} else {
baseArgs = append(baseArgs, pkg)
for _, p := range pkgs {
baseArgs = append(baseArgs, p.Name)
}
return baseCmd, baseArgs
}
@ -76,17 +79,27 @@ func (y *YumManager) UpgradeAll() (string, []string) {
}
// CheckVersion returns the command and arguments for checking the info of a specific package.
func (y *YumManager) CheckVersion(pkg, version string) (string, []string) {
func (y *YumManager) CheckVersion(pkgs []packagemanagercommon.Package) (string, []string) {
baseCmd := y.prependAuthCommand("yum")
baseArgs := []string{"info", pkg}
baseArgs := []string{"info"}
for _, p := range pkgs {
baseArgs = append(baseArgs, p.Name)
}
return baseCmd, baseArgs
}
// Parse parses the dnf info output to extract Installed and Candidate versions.
func (y YumManager) Parse(output string) (*pkgcommon.PackageVersion, error) {
reInstalled := regexp.MustCompile(`(?m)^Installed Packages\s*Name\s*:\s*\S+\s*Version\s*:\s*([^\s]+)\s*Release\s*:\s*([^\s]+)`)
reAvailable := regexp.MustCompile(`(?m)^Available Packages\s*Name\s*:\s*\S+\s*Version\s*:\s*([^\s]+)\s*Release\s*:\s*([^\s]+)`)
func (y YumManager) ParseRemotePackageManagerVersionOutput(output string) ([]packagemanagercommon.Package, []error) {
// Check for error message in the output
if strings.Contains(output, "No matching packages to list") {
return nil, []error{fmt.Errorf("error: package not listed")}
}
// Define regular expressions to capture installed and available versions
reInstalled := regexp.MustCompile(`(?m)^Installed packages\s*Name\s*:\s*\S+\s*Epoch\s*:\s*\S+\s*Version\s*:\s*([^\s]+)\s*Release\s*:\s*([^\s]+)`)
reAvailable := regexp.MustCompile(`(?m)^Available packages\s*Name\s*:\s*\S+\s*Epoch\s*:\s*\S+\s*Version\s*:\s*([^\s]+)\s*Release\s*:\s*([^\s]+)`)
installedMatch := reInstalled.FindStringSubmatch(output)
candidateMatch := reAvailable.FindStringSubmatch(output)
@ -103,13 +116,10 @@ func (y YumManager) Parse(output string) (*pkgcommon.PackageVersion, error) {
}
if installedVersion == "" && candidateVersion == "" {
return nil, fmt.Errorf("failed to parse versions from dnf output")
return nil, []error{fmt.Errorf("failed to parse versions from dnf output")}
}
return &pkgcommon.PackageVersion{
Installed: installedVersion,
Candidate: candidateVersion,
}, nil
return nil, nil
}
// prependAuthCommand prepends the authentication command if UseAuth is true.