ssh host key checking with ssh/knownhosts pkg

This commit is contained in:
Andrew Woodlee 2023-01-02 20:02:54 -06:00
parent 9648fe8ab9
commit 9d07298eb0

View File

@ -1,8 +1,6 @@
package backy package backy
import ( import (
"bufio"
"encoding/base64"
"errors" "errors"
"fmt" "fmt"
"log" "log"
@ -14,6 +12,7 @@ import (
"github.com/kevinburke/ssh_config" "github.com/kevinburke/ssh_config"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
) )
type SshConfig struct { type SshConfig struct {
@ -58,10 +57,11 @@ func (remoteConfig *Host) ConnectToSSHHost() (*ssh.Client, error) {
var sshClient *ssh.Client var sshClient *ssh.Client
var connectErr error var connectErr error
khPath := filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts")
f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config")) f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config"))
cfg, _ := ssh_config.Decode(f) cfg, _ := ssh_config.Decode(f)
for _, host := range cfg.Hosts { for _, host := range cfg.Hosts {
var hostKey ssh.PublicKey // var hostKey ssh.PublicKey
if host.Matches(remoteConfig.Host) { if host.Matches(remoteConfig.Host) {
var identityFile string var identityFile string
if remoteConfig.PrivateKeyPath == "" { if remoteConfig.PrivateKeyPath == "" {
@ -92,11 +92,16 @@ func (remoteConfig *Host) ConnectToSSHHost() (*ssh.Client, error) {
port = "22" port = "22"
} }
remoteConfig.HostName[index] = hostName + ":" + port remoteConfig.HostName[index] = hostName + ":" + port
hostKey = getHostKey(hostName)
println("HostName: " + remoteConfig.HostName[0]) println("HostName: " + remoteConfig.HostName[0])
} }
} }
// TODO: Add value/option to config for host key and add bool to check for host key
hostKeyCallback, err := knownhosts.New(khPath)
if err != nil {
log.Fatal("could not create hostkeycallback function: ", err)
}
privateKey, err := os.ReadFile(remoteConfig.PrivateKeyPath) privateKey, err := os.ReadFile(remoteConfig.PrivateKeyPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("read private key error: %w", err) return nil, fmt.Errorf("read private key error: %w", err)
@ -108,7 +113,7 @@ func (remoteConfig *Host) ConnectToSSHHost() (*ssh.Client, error) {
sshConfig := &ssh.ClientConfig{ sshConfig := &ssh.ClientConfig{
User: remoteConfig.User, User: remoteConfig.User,
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
HostKeyCallback: ssh.FixedHostKey(hostKey), HostKeyCallback: hostKeyCallback,
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
} }
for _, host := range remoteConfig.HostName { for _, host := range remoteConfig.HostName {
@ -124,38 +129,3 @@ func (remoteConfig *Host) ConnectToSSHHost() (*ssh.Client, error) {
} }
return sshClient, connectErr return sshClient, connectErr
} }
func getHostKey(host string) ssh.PublicKey {
// parse OpenSSH known_hosts file
// ssh or use ssh-keyscan to get initial key
file, err := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts"))
if err != nil {
panic(err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
var hostKey ssh.PublicKey
for scanner.Scan() {
fields := strings.Split(scanner.Text(), " ")
if len(fields) != 3 {
continue
}
log.Printf("Field[3]: %s ", fields[3])
log.Printf("Base-64 of %s: %s", host, base64.StdEncoding.EncodeToString([]byte(host)))
if strings.Contains(fields[3], base64.StdEncoding.EncodeToString([]byte(host))) {
var err error
hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes())
if err != nil {
log.Fatalf("error parsing %q: %v", fields[2], err)
}
break
}
}
if hostKey == nil {
log.Fatalf("no hostkey found for %s", host)
}
return hostKey
}