diff --git a/pkg/backy/backy.go b/pkg/backy/backy.go index 57c29c9..8d2910c 100644 --- a/pkg/backy/backy.go +++ b/pkg/backy/backy.go @@ -11,11 +11,12 @@ import ( ) // Host defines a host to which to connect -// If not provided, the values will be looked up in the default ssh config file +// If not provided, the values will be looked up in the default ssh config files type Host struct { + ConfigFilePath string Empty bool Host string - HostName string + HostName []string Port uint16 PrivateKeyPath string PrivateKeyPassword string diff --git a/pkg/backy/ssh.go b/pkg/backy/ssh.go index d396cd4..31c1aa5 100644 --- a/pkg/backy/ssh.go +++ b/pkg/backy/ssh.go @@ -13,25 +13,36 @@ import ( ) type SshConfig struct { - PrivateKey string - Port uint - HostName string - User string + // Config file to open + configFile string + + // Private key path + privateKey string + + // Port to connect to + port uint16 + + // host to check + host string + + // host name to connect to + hostName []string + + user string } -func GetSSHConfig(host string) (SshConfig, error) { - var config SshConfig - hostName := ssh_config.Get(host, "HostName") - if hostName == "" { +func (config SshConfig) GetSSHConfig() (SshConfig, error) { + hostNames := ssh_config.GetAll(config.host, "HostName") + if hostNames == nil { return SshConfig{}, errors.New("hostname not found") } - config.HostName = hostName - privKey, err := ssh_config.GetStrict(host, "IdentityFile") + config.hostName = hostNames + privKey, err := ssh_config.GetStrict(config.host, "IdentityFile") if err != nil { return SshConfig{}, err } - config.PrivateKey = privKey - User := ssh_config.Get(host, "User") + config.privateKey = privKey + User := ssh_config.Get(config.host, "User") if User == "" { return SshConfig{}, errors.New("user not found") } @@ -39,6 +50,7 @@ func GetSSHConfig(host string) (SshConfig, error) { } func (remoteConfig *Host) ConnectToSSHHost() (*ssh.Client, error) { + var sshc *ssh.Client var connectErr error @@ -61,16 +73,24 @@ func (remoteConfig *Host) ConnectToSSHHost() (*ssh.Client, error) { } remoteConfig.PrivateKeyPath = filepath.Join(identityFile) } - remoteConfig.HostName, _ = cfg.Get(remoteConfig.Host, "HostName") - if remoteConfig.HostName == "" { - remoteConfig.HostName = remoteConfig.Host - } - port, _ := cfg.Get(remoteConfig.Host, "Port") - if port == "" { - port = "22" + remoteConfig.HostName, _ = cfg.GetAll(remoteConfig.Host, "HostName") + if remoteConfig.HostName == nil { + remoteConfig.HostName[0] = remoteConfig.Host + port, _ := cfg.Get(remoteConfig.Host, "Port") + if port == "" { + port = "22" + } + remoteConfig.HostName[0] = remoteConfig.HostName[0] + ":" + port + } else { + for index, hostName := range remoteConfig.HostName { + port, _ := cfg.Get(remoteConfig.Host, "Port") + if port == "" { + port = "22" + } + remoteConfig.HostName[index] = hostName + ":" + port + } } privateKey, err := os.ReadFile(remoteConfig.PrivateKeyPath) - remoteConfig.HostName = remoteConfig.HostName + ":" + port if err != nil { return nil, fmt.Errorf("read private key error: %w", err) } @@ -83,7 +103,12 @@ func (remoteConfig *Host) ConnectToSSHHost() (*ssh.Client, error) { Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } - sshc, connectErr = ssh.Dial("tcp", remoteConfig.HostName, sshConfig) + for _, host := range remoteConfig.HostName { + sshc, connectErr = ssh.Dial("tcp", host, sshConfig) + if connectErr != nil { + continue + } + } break }