diff --git a/pkg/backy/ssh.go b/pkg/backy/ssh.go index 4a7356e..481d76d 100644 --- a/pkg/backy/ssh.go +++ b/pkg/backy/ssh.go @@ -1,8 +1,6 @@ package backy import ( - "bufio" - "encoding/base64" "errors" "fmt" "log" @@ -14,6 +12,7 @@ import ( "github.com/kevinburke/ssh_config" "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" ) type SshConfig struct { @@ -58,10 +57,11 @@ func (remoteConfig *Host) ConnectToSSHHost() (*ssh.Client, error) { var sshClient *ssh.Client var connectErr error + khPath := filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts") f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config")) cfg, _ := ssh_config.Decode(f) for _, host := range cfg.Hosts { - var hostKey ssh.PublicKey + // var hostKey ssh.PublicKey if host.Matches(remoteConfig.Host) { var identityFile string if remoteConfig.PrivateKeyPath == "" { @@ -92,11 +92,16 @@ func (remoteConfig *Host) ConnectToSSHHost() (*ssh.Client, error) { port = "22" } remoteConfig.HostName[index] = hostName + ":" + port - hostKey = getHostKey(hostName) + println("HostName: " + remoteConfig.HostName[0]) } } + // TODO: Add value/option to config for host key and add bool to check for host key + hostKeyCallback, err := knownhosts.New(khPath) + if err != nil { + log.Fatal("could not create hostkeycallback function: ", err) + } privateKey, err := os.ReadFile(remoteConfig.PrivateKeyPath) if err != nil { return nil, fmt.Errorf("read private key error: %w", err) @@ -108,7 +113,7 @@ func (remoteConfig *Host) ConnectToSSHHost() (*ssh.Client, error) { sshConfig := &ssh.ClientConfig{ User: remoteConfig.User, Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, - HostKeyCallback: ssh.FixedHostKey(hostKey), + HostKeyCallback: hostKeyCallback, Timeout: 5 * time.Second, } for _, host := range remoteConfig.HostName { @@ -124,38 +129,3 @@ func (remoteConfig *Host) ConnectToSSHHost() (*ssh.Client, error) { } return sshClient, connectErr } - -func getHostKey(host string) ssh.PublicKey { - // parse OpenSSH known_hosts file - // ssh or use ssh-keyscan to get initial key - file, err := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts")) - if err != nil { - panic(err) - } - defer file.Close() - - scanner := bufio.NewScanner(file) - var hostKey ssh.PublicKey - for scanner.Scan() { - fields := strings.Split(scanner.Text(), " ") - if len(fields) != 3 { - continue - } - log.Printf("Field[3]: %s ", fields[3]) - log.Printf("Base-64 of %s: %s", host, base64.StdEncoding.EncodeToString([]byte(host))) - if strings.Contains(fields[3], base64.StdEncoding.EncodeToString([]byte(host))) { - var err error - hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes()) - if err != nil { - log.Fatalf("error parsing %q: %v", fields[2], err) - } - break - } - } - - if hostKey == nil { - log.Fatalf("no hostkey found for %s", host) - } - - return hostKey -}