From 2ca5f193e4e539848b42f7470fbc79a87482bd68 Mon Sep 17 00:00:00 2001 From: Andrew Date: Fri, 10 Mar 2023 16:01:02 -0600 Subject: [PATCH] Close host connections after all command have been run, added some flags to version subcommand --- cmd/backup.go | 5 ++++ cmd/version.go | 10 +++++-- pkg/backy/backy.go | 71 ++++++++++++++++++++++++++++++++++++++++----- pkg/backy/config.go | 2 ++ pkg/backy/ssh.go | 21 ++++++++++---- 5 files changed, 93 insertions(+), 16 deletions(-) diff --git a/cmd/backup.go b/cmd/backup.go index 8f20526..a1fc11d 100644 --- a/cmd/backup.go +++ b/cmd/backup.go @@ -33,4 +33,9 @@ func Backup(cmd *cobra.Command, args []string) { backyConfOpts.InitConfig() config := backy.ReadConfig(backyConfOpts) config.RunBackyConfig("") + for _, host := range config.Hosts { + if host.SshClient != nil { + host.SshClient.Close() + } + } } diff --git a/cmd/version.go b/cmd/version.go index 35a1a8a..a7a1d3a 100644 --- a/cmd/version.go +++ b/cmd/version.go @@ -11,15 +11,21 @@ const versionStr = "0.2.4" var ( versionCmd = &cobra.Command{ - Use: "version", + Use: "version [flags]", Short: "Prints the version and exits.", Run: version, } + numOnly bool ) func version(cmd *cobra.Command, args []string) { - fmt.Printf("%s\n", versionStr) + cmd.PersistentFlags().BoolVarP(&numOnly, "num", "n", true, "Output the version number only.") + if numOnly { + fmt.Printf("%s\n", versionStr) + } else { + fmt.Printf("Version: %s", versionStr) + } os.Exit(0) } diff --git a/pkg/backy/backy.go b/pkg/backy/backy.go index 59e26c0..7d5a737 100644 --- a/pkg/backy/backy.go +++ b/pkg/backy/backy.go @@ -29,7 +29,7 @@ var Sprintf = fmt.Sprintf // The environment of local commands will be the machine's environment plus any extra // variables specified in the Env file or Environment. // Dir can also be specified for local commands. -func (command *Command) RunCmd(log *zerolog.Logger, hosts map[string]*Host) ([]string, error) { +func (command *Command) RunCmd(log *zerolog.Logger, backyConf *BackyConfigFile) ([]string, error) { var ( outputArr []string @@ -50,11 +50,12 @@ func (command *Command) RunCmd(log *zerolog.Logger, hosts map[string]*Host) ([]s if command.Host != nil { log.Info().Str("Command", fmt.Sprintf("Running command %s %s on host %s", command.Cmd, ArgsStr, *command.Host)).Send() - err := command.RemoteHost.ConnectToSSHHost(log, hosts) - if err != nil { - return nil, err + if command.RemoteHost.SshClient == nil { + err := command.RemoteHost.ConnectToSSHHost(log, backyConf) + if err != nil { + return nil, err + } } - defer command.RemoteHost.SshClient.Close() commandSession, err := command.RemoteHost.SshClient.NewSession() if err != nil { log.Err(fmt.Errorf("new ssh session: %w", err)).Send() @@ -100,8 +101,11 @@ func (command *Command) RunCmd(log *zerolog.Logger, hosts map[string]*Host) ([]s var err error if command.Shell != "" { log.Info().Str("Command", fmt.Sprintf("Running command %s %s on local machine in %s", command.Cmd, ArgsStr, command.Shell)).Send() + ArgsStr = fmt.Sprintf("%s %s", command.Cmd, ArgsStr) + localCMD := exec.Command(command.Shell, "-c", ArgsStr) + if command.Dir != nil { localCMD.Dir = *command.Dir } @@ -115,8 +119,11 @@ func (command *Command) RunCmd(log *zerolog.Logger, hosts map[string]*Host) ([]s localCMD.Stdout = cmdOutWriters localCMD.Stderr = cmdOutWriters + err = localCMD.Run() + outScanner := bufio.NewScanner(&cmdOutBuf) + for outScanner.Scan() { outMap := make(map[string]interface{}) outMap["cmd"] = command.Cmd @@ -156,6 +163,7 @@ func (command *Command) RunCmd(log *zerolog.Logger, hosts map[string]*Host) ([]s outMap := make(map[string]interface{}) outMap["cmd"] = command.Cmd outMap["output"] = outScanner.Text() + if str, ok := outMap["output"].(string); ok { outputArr = append(outputArr, str) } @@ -175,54 +183,78 @@ func cmdListWorker(msgTemps *msgTemplates, jobs <-chan *CmdList, config *BackyCo var currentCmd string fieldsMap := make(map[string]interface{}) fieldsMap["list"] = list.Name + cmdLog := config.Logger.Info() + var count int var cmdsRan []string + for _, cmd := range list.Order { currentCmd = config.Cmds[cmd].Cmd + fieldsMap["cmd"] = config.Cmds[cmd].Cmd cmdLog.Fields(fieldsMap).Send() cmdToRun := config.Cmds[cmd] + cmdLogger := config.Logger.With(). - Str("backy-cmd", cmd). + Str("backy-cmd", cmd).Str("Host", "local machine"). Logger() - outputArr, runOutErr := cmdToRun.RunCmd(&cmdLogger, config.Hosts) + + if cmdToRun.Host != nil { + cmdLogger = config.Logger.With(). + Str("backy-cmd", cmd).Str("Host", *cmdToRun.Host). + Logger() + } + + outputArr, runOutErr := cmdToRun.RunCmd(&cmdLogger, config) count++ if runOutErr != nil { var errMsg bytes.Buffer if list.NotifyConfig != nil { errStruct := make(map[string]interface{}) + errStruct["listName"] = list.Name errStruct["Command"] = currentCmd errStruct["Err"] = runOutErr errStruct["CmdsRan"] = cmdsRan errStruct["Output"] = outputArr + tmpErr := msgTemps.err.Execute(&errMsg, errStruct) + if tmpErr != nil { config.Logger.Err(tmpErr).Send() } + notifySendErr := list.NotifyConfig.Send(context.Background(), fmt.Sprintf("List %s failed on command %s ", list.Name, cmd), errMsg.String()) + if notifySendErr != nil { config.Logger.Err(notifySendErr).Send() } } + config.Logger.Err(runOutErr).Send() + break } else { if count == len(list.Order) { cmdsRan = append(cmdsRan, cmd) var successMsg bytes.Buffer + if list.NotifyConfig != nil { successStruct := make(map[string]interface{}) successStruct["listName"] = list.Name successStruct["CmdsRan"] = cmdsRan + tmpErr := msgTemps.success.Execute(&successMsg, successStruct) + if tmpErr != nil { config.Logger.Err(tmpErr).Send() break } + err := list.NotifyConfig.Send(context.Background(), fmt.Sprintf("List %s succeded", list.Name), successMsg.String()) + if err != nil { config.Logger.Err(err).Send() } @@ -275,6 +307,7 @@ func (config *BackyConfigFile) RunBackyConfig(cron string) { <-results } + config.closeHostConnections() } func (config *BackyConfigFile) ExecuteCmds(opts *BackyConfigOpts) { @@ -283,10 +316,32 @@ func (config *BackyConfigFile) ExecuteCmds(opts *BackyConfigOpts) { cmdLogger := config.Logger.With(). Str("backy-cmd", cmd). Logger() - _, runErr := cmdToRun.RunCmd(&cmdLogger, config.Hosts) + _, runErr := cmdToRun.RunCmd(&cmdLogger, config) if runErr != nil { config.Logger.Err(runErr).Send() } } + config.closeHostConnections() + +} + +func (c *BackyConfigFile) closeHostConnections() { + for _, host := range c.Hosts { + + if host.SshClient != nil { + if _, err := host.SshClient.NewSession(); err == nil { + c.Logger.Info().Msgf("Closing host connection %s", host.HostName) + host.SshClient.Close() + } + } + for _, proxyHost := range host.ProxyHost { + if proxyHost.SshClient != nil { + if _, err := host.SshClient.NewSession(); err == nil { + c.Logger.Info().Msgf("Closing connection to proxy host %s", host.HostName) + host.SshClient.Close() + } + } + } + } } diff --git a/pkg/backy/config.go b/pkg/backy/config.go index dfca7a1..840cc02 100644 --- a/pkg/backy/config.go +++ b/pkg/backy/config.go @@ -231,12 +231,14 @@ func ReadConfig(opts *BackyConfigOpts) *BackyConfigFile { cmd.RemoteHost.HostName = host.HostName } } else { + backyConfigFile.Hosts[*cmd.Host] = &Host{Host: *cmd.Host} cmd.RemoteHost = &Host{Host: *cmd.Host} } } } backyConfigFile.SetupNotify() + opts.ConfigFile = backyConfigFile return backyConfigFile } diff --git a/pkg/backy/ssh.go b/pkg/backy/ssh.go index 55772cb..13570b5 100644 --- a/pkg/backy/ssh.go +++ b/pkg/backy/ssh.go @@ -28,7 +28,7 @@ var TS = strings.TrimSpace // It returns an ssh.Client used to run commands against. // If configFile is empty, any required configuration is looked up in the default config files // If any value is not found, defaults are used -func (remoteConfig *Host) ConnectToSSHHost(log *zerolog.Logger, hosts map[string]*Host) error { +func (remoteConfig *Host) ConnectToSSHHost(log *zerolog.Logger, config *BackyConfigFile) error { // var sshClient *ssh.Client var connectErr error @@ -68,28 +68,31 @@ func (remoteConfig *Host) ConnectToSSHHost(log *zerolog.Logger, hosts map[string return decodeErr } - err := remoteConfig.GetProxyJumpFromConfig(hosts) + err := remoteConfig.GetProxyJumpFromConfig(config.Hosts) if err != nil { return err } if remoteConfig.ProxyHost != nil { for _, proxyHost := range remoteConfig.ProxyHost { - log.Info().Msgf("Proxy Host %s", proxyHost.Host) - err := proxyHost.GetProxyJumpConfig(hosts) + err := proxyHost.GetProxyJumpConfig(config.Hosts) + log.Info().Msgf("Proxy host: %s", proxyHost.Host) if err != nil { return err } } } + remoteConfig.ClientConfig.Timeout = time.Second * 30 remoteConfig.GetPrivateKeyFileFromConfig() remoteConfig.GetPort() remoteConfig.GetHostName() remoteConfig.CombineHostNameWithPort() remoteConfig.GetSshUserFromConfig() + if remoteConfig.HostName == "" { - return errors.New("No hostname found or specified") + return errors.Errorf("No hostname found or specified for host %s", remoteConfig.Host) } + err = remoteConfig.GetAuthMethods() if err != nil { return err @@ -107,6 +110,7 @@ func (remoteConfig *Host) ConnectToSSHHost(log *zerolog.Logger, hosts map[string return connectErr } if remoteConfig.SshClient != nil { + config.Hosts[remoteConfig.Host] = remoteConfig return nil } @@ -115,6 +119,7 @@ func (remoteConfig *Host) ConnectToSSHHost(log *zerolog.Logger, hosts map[string if connectErr != nil { return connectErr } + config.Hosts[remoteConfig.Host] = remoteConfig return nil } @@ -242,6 +247,7 @@ func (remoteHost *Host) ConnectThroughBastion(log *zerolog.Logger) (*ssh.Client, if err != nil { return nil, err } + remoteHost.ProxyHost[0].SshClient = bClient // Dial a connection to the service host, from the bastion conn, err := bClient.Dial("tcp", remoteHost.HostName) @@ -345,6 +351,7 @@ func (remoteConfig *Host) GetProxyJumpFromConfig(hosts map[string]*Host) error { return nil } + func (remoteConfig *Host) GetProxyJumpConfig(hosts map[string]*Host) error { if TS(remoteConfig.ConfigFilePath) == "" { remoteConfig.useDefaultConfig = true @@ -386,7 +393,7 @@ func (remoteConfig *Host) GetProxyJumpConfig(hosts map[string]*Host) error { remoteConfig.CombineHostNameWithPort() remoteConfig.GetSshUserFromConfig() if remoteConfig.HostName == "" { - return errors.New("No hostname found or specified") + return errors.Errorf("No hostname found or specified for host %s", remoteConfig.Host) } err := remoteConfig.GetAuthMethods() if err != nil { @@ -399,5 +406,7 @@ func (remoteConfig *Host) GetProxyJumpConfig(hosts map[string]*Host) error { return errors.Wrap(err, "could not create hostkeycallback function") } remoteConfig.ClientConfig.HostKeyCallback = hostKeyCallback + hosts[remoteConfig.Host] = remoteConfig + return nil }