Close host connections after all command have been run, added some flags to version subcommand

This commit is contained in:
Andrew 2023-03-10 16:01:02 -06:00
parent 9ffa2e473e
commit 2ca5f193e4
5 changed files with 93 additions and 16 deletions

View File

@ -33,4 +33,9 @@ func Backup(cmd *cobra.Command, args []string) {
backyConfOpts.InitConfig() backyConfOpts.InitConfig()
config := backy.ReadConfig(backyConfOpts) config := backy.ReadConfig(backyConfOpts)
config.RunBackyConfig("") config.RunBackyConfig("")
for _, host := range config.Hosts {
if host.SshClient != nil {
host.SshClient.Close()
}
}
} }

View File

@ -11,15 +11,21 @@ const versionStr = "0.2.4"
var ( var (
versionCmd = &cobra.Command{ versionCmd = &cobra.Command{
Use: "version", Use: "version [flags]",
Short: "Prints the version and exits.", Short: "Prints the version and exits.",
Run: version, Run: version,
} }
numOnly bool
) )
func version(cmd *cobra.Command, args []string) { func version(cmd *cobra.Command, args []string) {
cmd.PersistentFlags().BoolVarP(&numOnly, "num", "n", true, "Output the version number only.")
if numOnly {
fmt.Printf("%s\n", versionStr) fmt.Printf("%s\n", versionStr)
} else {
fmt.Printf("Version: %s", versionStr)
}
os.Exit(0) os.Exit(0)
} }

View File

@ -29,7 +29,7 @@ var Sprintf = fmt.Sprintf
// The environment of local commands will be the machine's environment plus any extra // The environment of local commands will be the machine's environment plus any extra
// variables specified in the Env file or Environment. // variables specified in the Env file or Environment.
// Dir can also be specified for local commands. // Dir can also be specified for local commands.
func (command *Command) RunCmd(log *zerolog.Logger, hosts map[string]*Host) ([]string, error) { func (command *Command) RunCmd(log *zerolog.Logger, backyConf *BackyConfigFile) ([]string, error) {
var ( var (
outputArr []string outputArr []string
@ -50,11 +50,12 @@ func (command *Command) RunCmd(log *zerolog.Logger, hosts map[string]*Host) ([]s
if command.Host != nil { if command.Host != nil {
log.Info().Str("Command", fmt.Sprintf("Running command %s %s on host %s", command.Cmd, ArgsStr, *command.Host)).Send() log.Info().Str("Command", fmt.Sprintf("Running command %s %s on host %s", command.Cmd, ArgsStr, *command.Host)).Send()
err := command.RemoteHost.ConnectToSSHHost(log, hosts) if command.RemoteHost.SshClient == nil {
err := command.RemoteHost.ConnectToSSHHost(log, backyConf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer command.RemoteHost.SshClient.Close() }
commandSession, err := command.RemoteHost.SshClient.NewSession() commandSession, err := command.RemoteHost.SshClient.NewSession()
if err != nil { if err != nil {
log.Err(fmt.Errorf("new ssh session: %w", err)).Send() log.Err(fmt.Errorf("new ssh session: %w", err)).Send()
@ -100,8 +101,11 @@ func (command *Command) RunCmd(log *zerolog.Logger, hosts map[string]*Host) ([]s
var err error var err error
if command.Shell != "" { if command.Shell != "" {
log.Info().Str("Command", fmt.Sprintf("Running command %s %s on local machine in %s", command.Cmd, ArgsStr, command.Shell)).Send() log.Info().Str("Command", fmt.Sprintf("Running command %s %s on local machine in %s", command.Cmd, ArgsStr, command.Shell)).Send()
ArgsStr = fmt.Sprintf("%s %s", command.Cmd, ArgsStr) ArgsStr = fmt.Sprintf("%s %s", command.Cmd, ArgsStr)
localCMD := exec.Command(command.Shell, "-c", ArgsStr) localCMD := exec.Command(command.Shell, "-c", ArgsStr)
if command.Dir != nil { if command.Dir != nil {
localCMD.Dir = *command.Dir localCMD.Dir = *command.Dir
} }
@ -115,8 +119,11 @@ func (command *Command) RunCmd(log *zerolog.Logger, hosts map[string]*Host) ([]s
localCMD.Stdout = cmdOutWriters localCMD.Stdout = cmdOutWriters
localCMD.Stderr = cmdOutWriters localCMD.Stderr = cmdOutWriters
err = localCMD.Run() err = localCMD.Run()
outScanner := bufio.NewScanner(&cmdOutBuf) outScanner := bufio.NewScanner(&cmdOutBuf)
for outScanner.Scan() { for outScanner.Scan() {
outMap := make(map[string]interface{}) outMap := make(map[string]interface{})
outMap["cmd"] = command.Cmd outMap["cmd"] = command.Cmd
@ -156,6 +163,7 @@ func (command *Command) RunCmd(log *zerolog.Logger, hosts map[string]*Host) ([]s
outMap := make(map[string]interface{}) outMap := make(map[string]interface{})
outMap["cmd"] = command.Cmd outMap["cmd"] = command.Cmd
outMap["output"] = outScanner.Text() outMap["output"] = outScanner.Text()
if str, ok := outMap["output"].(string); ok { if str, ok := outMap["output"].(string); ok {
outputArr = append(outputArr, str) outputArr = append(outputArr, str)
} }
@ -175,54 +183,78 @@ func cmdListWorker(msgTemps *msgTemplates, jobs <-chan *CmdList, config *BackyCo
var currentCmd string var currentCmd string
fieldsMap := make(map[string]interface{}) fieldsMap := make(map[string]interface{})
fieldsMap["list"] = list.Name fieldsMap["list"] = list.Name
cmdLog := config.Logger.Info() cmdLog := config.Logger.Info()
var count int var count int
var cmdsRan []string var cmdsRan []string
for _, cmd := range list.Order { for _, cmd := range list.Order {
currentCmd = config.Cmds[cmd].Cmd currentCmd = config.Cmds[cmd].Cmd
fieldsMap["cmd"] = config.Cmds[cmd].Cmd fieldsMap["cmd"] = config.Cmds[cmd].Cmd
cmdLog.Fields(fieldsMap).Send() cmdLog.Fields(fieldsMap).Send()
cmdToRun := config.Cmds[cmd] cmdToRun := config.Cmds[cmd]
cmdLogger := config.Logger.With(). cmdLogger := config.Logger.With().
Str("backy-cmd", cmd). Str("backy-cmd", cmd).Str("Host", "local machine").
Logger() Logger()
outputArr, runOutErr := cmdToRun.RunCmd(&cmdLogger, config.Hosts)
if cmdToRun.Host != nil {
cmdLogger = config.Logger.With().
Str("backy-cmd", cmd).Str("Host", *cmdToRun.Host).
Logger()
}
outputArr, runOutErr := cmdToRun.RunCmd(&cmdLogger, config)
count++ count++
if runOutErr != nil { if runOutErr != nil {
var errMsg bytes.Buffer var errMsg bytes.Buffer
if list.NotifyConfig != nil { if list.NotifyConfig != nil {
errStruct := make(map[string]interface{}) errStruct := make(map[string]interface{})
errStruct["listName"] = list.Name errStruct["listName"] = list.Name
errStruct["Command"] = currentCmd errStruct["Command"] = currentCmd
errStruct["Err"] = runOutErr errStruct["Err"] = runOutErr
errStruct["CmdsRan"] = cmdsRan errStruct["CmdsRan"] = cmdsRan
errStruct["Output"] = outputArr errStruct["Output"] = outputArr
tmpErr := msgTemps.err.Execute(&errMsg, errStruct) tmpErr := msgTemps.err.Execute(&errMsg, errStruct)
if tmpErr != nil { if tmpErr != nil {
config.Logger.Err(tmpErr).Send() config.Logger.Err(tmpErr).Send()
} }
notifySendErr := list.NotifyConfig.Send(context.Background(), fmt.Sprintf("List %s failed on command %s ", list.Name, cmd), errMsg.String()) notifySendErr := list.NotifyConfig.Send(context.Background(), fmt.Sprintf("List %s failed on command %s ", list.Name, cmd), errMsg.String())
if notifySendErr != nil { if notifySendErr != nil {
config.Logger.Err(notifySendErr).Send() config.Logger.Err(notifySendErr).Send()
} }
} }
config.Logger.Err(runOutErr).Send() config.Logger.Err(runOutErr).Send()
break break
} else { } else {
if count == len(list.Order) { if count == len(list.Order) {
cmdsRan = append(cmdsRan, cmd) cmdsRan = append(cmdsRan, cmd)
var successMsg bytes.Buffer var successMsg bytes.Buffer
if list.NotifyConfig != nil { if list.NotifyConfig != nil {
successStruct := make(map[string]interface{}) successStruct := make(map[string]interface{})
successStruct["listName"] = list.Name successStruct["listName"] = list.Name
successStruct["CmdsRan"] = cmdsRan successStruct["CmdsRan"] = cmdsRan
tmpErr := msgTemps.success.Execute(&successMsg, successStruct) tmpErr := msgTemps.success.Execute(&successMsg, successStruct)
if tmpErr != nil { if tmpErr != nil {
config.Logger.Err(tmpErr).Send() config.Logger.Err(tmpErr).Send()
break break
} }
err := list.NotifyConfig.Send(context.Background(), fmt.Sprintf("List %s succeded", list.Name), successMsg.String()) err := list.NotifyConfig.Send(context.Background(), fmt.Sprintf("List %s succeded", list.Name), successMsg.String())
if err != nil { if err != nil {
config.Logger.Err(err).Send() config.Logger.Err(err).Send()
} }
@ -275,6 +307,7 @@ func (config *BackyConfigFile) RunBackyConfig(cron string) {
<-results <-results
} }
config.closeHostConnections()
} }
func (config *BackyConfigFile) ExecuteCmds(opts *BackyConfigOpts) { func (config *BackyConfigFile) ExecuteCmds(opts *BackyConfigOpts) {
@ -283,10 +316,32 @@ func (config *BackyConfigFile) ExecuteCmds(opts *BackyConfigOpts) {
cmdLogger := config.Logger.With(). cmdLogger := config.Logger.With().
Str("backy-cmd", cmd). Str("backy-cmd", cmd).
Logger() Logger()
_, runErr := cmdToRun.RunCmd(&cmdLogger, config.Hosts) _, runErr := cmdToRun.RunCmd(&cmdLogger, config)
if runErr != nil { if runErr != nil {
config.Logger.Err(runErr).Send() config.Logger.Err(runErr).Send()
} }
} }
config.closeHostConnections()
}
func (c *BackyConfigFile) closeHostConnections() {
for _, host := range c.Hosts {
if host.SshClient != nil {
if _, err := host.SshClient.NewSession(); err == nil {
c.Logger.Info().Msgf("Closing host connection %s", host.HostName)
host.SshClient.Close()
}
}
for _, proxyHost := range host.ProxyHost {
if proxyHost.SshClient != nil {
if _, err := host.SshClient.NewSession(); err == nil {
c.Logger.Info().Msgf("Closing connection to proxy host %s", host.HostName)
host.SshClient.Close()
}
}
}
}
} }

View File

@ -231,12 +231,14 @@ func ReadConfig(opts *BackyConfigOpts) *BackyConfigFile {
cmd.RemoteHost.HostName = host.HostName cmd.RemoteHost.HostName = host.HostName
} }
} else { } else {
backyConfigFile.Hosts[*cmd.Host] = &Host{Host: *cmd.Host}
cmd.RemoteHost = &Host{Host: *cmd.Host} cmd.RemoteHost = &Host{Host: *cmd.Host}
} }
} }
} }
backyConfigFile.SetupNotify() backyConfigFile.SetupNotify()
opts.ConfigFile = backyConfigFile
return backyConfigFile return backyConfigFile
} }

View File

@ -28,7 +28,7 @@ var TS = strings.TrimSpace
// It returns an ssh.Client used to run commands against. // It returns an ssh.Client used to run commands against.
// If configFile is empty, any required configuration is looked up in the default config files // If configFile is empty, any required configuration is looked up in the default config files
// If any value is not found, defaults are used // If any value is not found, defaults are used
func (remoteConfig *Host) ConnectToSSHHost(log *zerolog.Logger, hosts map[string]*Host) error { func (remoteConfig *Host) ConnectToSSHHost(log *zerolog.Logger, config *BackyConfigFile) error {
// var sshClient *ssh.Client // var sshClient *ssh.Client
var connectErr error var connectErr error
@ -68,28 +68,31 @@ func (remoteConfig *Host) ConnectToSSHHost(log *zerolog.Logger, hosts map[string
return decodeErr return decodeErr
} }
err := remoteConfig.GetProxyJumpFromConfig(hosts) err := remoteConfig.GetProxyJumpFromConfig(config.Hosts)
if err != nil { if err != nil {
return err return err
} }
if remoteConfig.ProxyHost != nil { if remoteConfig.ProxyHost != nil {
for _, proxyHost := range remoteConfig.ProxyHost { for _, proxyHost := range remoteConfig.ProxyHost {
log.Info().Msgf("Proxy Host %s", proxyHost.Host) err := proxyHost.GetProxyJumpConfig(config.Hosts)
err := proxyHost.GetProxyJumpConfig(hosts) log.Info().Msgf("Proxy host: %s", proxyHost.Host)
if err != nil { if err != nil {
return err return err
} }
} }
} }
remoteConfig.ClientConfig.Timeout = time.Second * 30 remoteConfig.ClientConfig.Timeout = time.Second * 30
remoteConfig.GetPrivateKeyFileFromConfig() remoteConfig.GetPrivateKeyFileFromConfig()
remoteConfig.GetPort() remoteConfig.GetPort()
remoteConfig.GetHostName() remoteConfig.GetHostName()
remoteConfig.CombineHostNameWithPort() remoteConfig.CombineHostNameWithPort()
remoteConfig.GetSshUserFromConfig() remoteConfig.GetSshUserFromConfig()
if remoteConfig.HostName == "" { if remoteConfig.HostName == "" {
return errors.New("No hostname found or specified") return errors.Errorf("No hostname found or specified for host %s", remoteConfig.Host)
} }
err = remoteConfig.GetAuthMethods() err = remoteConfig.GetAuthMethods()
if err != nil { if err != nil {
return err return err
@ -107,6 +110,7 @@ func (remoteConfig *Host) ConnectToSSHHost(log *zerolog.Logger, hosts map[string
return connectErr return connectErr
} }
if remoteConfig.SshClient != nil { if remoteConfig.SshClient != nil {
config.Hosts[remoteConfig.Host] = remoteConfig
return nil return nil
} }
@ -115,6 +119,7 @@ func (remoteConfig *Host) ConnectToSSHHost(log *zerolog.Logger, hosts map[string
if connectErr != nil { if connectErr != nil {
return connectErr return connectErr
} }
config.Hosts[remoteConfig.Host] = remoteConfig
return nil return nil
} }
@ -242,6 +247,7 @@ func (remoteHost *Host) ConnectThroughBastion(log *zerolog.Logger) (*ssh.Client,
if err != nil { if err != nil {
return nil, err return nil, err
} }
remoteHost.ProxyHost[0].SshClient = bClient
// Dial a connection to the service host, from the bastion // Dial a connection to the service host, from the bastion
conn, err := bClient.Dial("tcp", remoteHost.HostName) conn, err := bClient.Dial("tcp", remoteHost.HostName)
@ -345,6 +351,7 @@ func (remoteConfig *Host) GetProxyJumpFromConfig(hosts map[string]*Host) error {
return nil return nil
} }
func (remoteConfig *Host) GetProxyJumpConfig(hosts map[string]*Host) error { func (remoteConfig *Host) GetProxyJumpConfig(hosts map[string]*Host) error {
if TS(remoteConfig.ConfigFilePath) == "" { if TS(remoteConfig.ConfigFilePath) == "" {
remoteConfig.useDefaultConfig = true remoteConfig.useDefaultConfig = true
@ -386,7 +393,7 @@ func (remoteConfig *Host) GetProxyJumpConfig(hosts map[string]*Host) error {
remoteConfig.CombineHostNameWithPort() remoteConfig.CombineHostNameWithPort()
remoteConfig.GetSshUserFromConfig() remoteConfig.GetSshUserFromConfig()
if remoteConfig.HostName == "" { if remoteConfig.HostName == "" {
return errors.New("No hostname found or specified") return errors.Errorf("No hostname found or specified for host %s", remoteConfig.Host)
} }
err := remoteConfig.GetAuthMethods() err := remoteConfig.GetAuthMethods()
if err != nil { if err != nil {
@ -399,5 +406,7 @@ func (remoteConfig *Host) GetProxyJumpConfig(hosts map[string]*Host) error {
return errors.Wrap(err, "could not create hostkeycallback function") return errors.Wrap(err, "could not create hostkeycallback function")
} }
remoteConfig.ClientConfig.HostKeyCallback = hostKeyCallback remoteConfig.ClientConfig.HostKeyCallback = hostKeyCallback
hosts[remoteConfig.Host] = remoteConfig
return nil return nil
} }