ssh host key checking with ssh/knownhosts pkg
This commit is contained in:
parent
9648fe8ab9
commit
9d07298eb0
@ -1,8 +1,6 @@
|
|||||||
package backy
|
package backy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"encoding/base64"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
@ -14,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"github.com/kevinburke/ssh_config"
|
"github.com/kevinburke/ssh_config"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
|
"golang.org/x/crypto/ssh/knownhosts"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SshConfig struct {
|
type SshConfig struct {
|
||||||
@ -58,10 +57,11 @@ func (remoteConfig *Host) ConnectToSSHHost() (*ssh.Client, error) {
|
|||||||
var sshClient *ssh.Client
|
var sshClient *ssh.Client
|
||||||
var connectErr error
|
var connectErr error
|
||||||
|
|
||||||
|
khPath := filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts")
|
||||||
f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config"))
|
f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config"))
|
||||||
cfg, _ := ssh_config.Decode(f)
|
cfg, _ := ssh_config.Decode(f)
|
||||||
for _, host := range cfg.Hosts {
|
for _, host := range cfg.Hosts {
|
||||||
var hostKey ssh.PublicKey
|
// var hostKey ssh.PublicKey
|
||||||
if host.Matches(remoteConfig.Host) {
|
if host.Matches(remoteConfig.Host) {
|
||||||
var identityFile string
|
var identityFile string
|
||||||
if remoteConfig.PrivateKeyPath == "" {
|
if remoteConfig.PrivateKeyPath == "" {
|
||||||
@ -92,11 +92,16 @@ func (remoteConfig *Host) ConnectToSSHHost() (*ssh.Client, error) {
|
|||||||
port = "22"
|
port = "22"
|
||||||
}
|
}
|
||||||
remoteConfig.HostName[index] = hostName + ":" + port
|
remoteConfig.HostName[index] = hostName + ":" + port
|
||||||
hostKey = getHostKey(hostName)
|
|
||||||
println("HostName: " + remoteConfig.HostName[0])
|
println("HostName: " + remoteConfig.HostName[0])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Add value/option to config for host key and add bool to check for host key
|
||||||
|
hostKeyCallback, err := knownhosts.New(khPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("could not create hostkeycallback function: ", err)
|
||||||
|
}
|
||||||
privateKey, err := os.ReadFile(remoteConfig.PrivateKeyPath)
|
privateKey, err := os.ReadFile(remoteConfig.PrivateKeyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("read private key error: %w", err)
|
return nil, fmt.Errorf("read private key error: %w", err)
|
||||||
@ -108,7 +113,7 @@ func (remoteConfig *Host) ConnectToSSHHost() (*ssh.Client, error) {
|
|||||||
sshConfig := &ssh.ClientConfig{
|
sshConfig := &ssh.ClientConfig{
|
||||||
User: remoteConfig.User,
|
User: remoteConfig.User,
|
||||||
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
|
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
|
||||||
HostKeyCallback: ssh.FixedHostKey(hostKey),
|
HostKeyCallback: hostKeyCallback,
|
||||||
Timeout: 5 * time.Second,
|
Timeout: 5 * time.Second,
|
||||||
}
|
}
|
||||||
for _, host := range remoteConfig.HostName {
|
for _, host := range remoteConfig.HostName {
|
||||||
@ -124,38 +129,3 @@ func (remoteConfig *Host) ConnectToSSHHost() (*ssh.Client, error) {
|
|||||||
}
|
}
|
||||||
return sshClient, connectErr
|
return sshClient, connectErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func getHostKey(host string) ssh.PublicKey {
|
|
||||||
// parse OpenSSH known_hosts file
|
|
||||||
// ssh or use ssh-keyscan to get initial key
|
|
||||||
file, err := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts"))
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
scanner := bufio.NewScanner(file)
|
|
||||||
var hostKey ssh.PublicKey
|
|
||||||
for scanner.Scan() {
|
|
||||||
fields := strings.Split(scanner.Text(), " ")
|
|
||||||
if len(fields) != 3 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
log.Printf("Field[3]: %s ", fields[3])
|
|
||||||
log.Printf("Base-64 of %s: %s", host, base64.StdEncoding.EncodeToString([]byte(host)))
|
|
||||||
if strings.Contains(fields[3], base64.StdEncoding.EncodeToString([]byte(host))) {
|
|
||||||
var err error
|
|
||||||
hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("error parsing %q: %v", fields[2], err)
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if hostKey == nil {
|
|
||||||
log.Fatalf("no hostkey found for %s", host)
|
|
||||||
}
|
|
||||||
|
|
||||||
return hostKey
|
|
||||||
}
|
|
||||||
|
Loading…
Reference in New Issue
Block a user