From 2daae5cf9e9ddece1b6147dc25717ef68225dd63 Mon Sep 17 00:00:00 2001 From: Andrew Woodlee Date: Mon, 26 Dec 2022 23:20:11 -0600 Subject: [PATCH] use ssh stdlib --- pkg/backy/backy.go | 111 ++++++++++++++++++++----------------- pkg/backy/ssh.go | 92 ++++++++++++++++++++++++++++++ pkg/config/backy/config.go | 7 --- pkg/ssh/ssh.go | 33 ----------- 4 files changed, 153 insertions(+), 90 deletions(-) create mode 100644 pkg/backy/ssh.go delete mode 100644 pkg/ssh/ssh.go diff --git a/pkg/backy/backy.go b/pkg/backy/backy.go index 9f5d2cd..441a671 100644 --- a/pkg/backy/backy.go +++ b/pkg/backy/backy.go @@ -1,10 +1,12 @@ package backy import ( + "bytes" + "fmt" + "io" + "os" "os/exec" - "github.com/melbahja/goph" - "git.andrewnw.xyz/CyberShell/backy/pkg/logging" ) @@ -32,24 +34,22 @@ type Commands struct { After Command } +// BackupConfig is a configuration struct that is used to define backups type BackupConfig struct { Name string BackupType string ConfigPath string Cmds Commands - - DstDir string - SrcDir string } /* * Runs a backup configuration */ -func Run(backup BackupConfig) logging.Logging { +func (backup BackupConfig) Run() logging.Logging { beforeConfig := backup.Cmds.Before - beforeOutput := runCmd(beforeConfig) + beforeOutput := beforeConfig.runCmd() if beforeOutput.Err != nil { return logging.Logging{ Output: beforeOutput.Output, @@ -57,7 +57,7 @@ func Run(backup BackupConfig) logging.Logging { } } backupConfig := backup.Cmds.Backup - backupOutput := runCmd(backupConfig) + backupOutput := backupConfig.runCmd() if backupOutput.Err != nil { return logging.Logging{ Output: beforeOutput.Output, @@ -65,12 +65,9 @@ func Run(backup BackupConfig) logging.Logging { } } afterConfig := backup.Cmds.After - afterOutput := runCmd(afterConfig) + afterOutput := afterConfig.runCmd() if afterOutput.Err != nil { - return logging.Logging{ - Output: beforeOutput.Output, - Err: beforeOutput.Err, - } + return afterOutput } return logging.Logging{ Output: afterOutput.Output, @@ -78,54 +75,68 @@ func Run(backup BackupConfig) logging.Logging { } } -func runCmd(cmd Command) logging.Logging { - if !cmd.Empty { - if cmd.Remote { - // Start new ssh connection with private key. - auth, err := goph.Key(cmd.RemoteHost.PrivateKeyPath, cmd.RemoteHost.PrivateKeyPassword) - if err != nil { - return logging.Logging{ - Output: err.Error(), - Err: err, - } - } +func (command Command) runCmd() logging.Logging { - client, err := goph.New(cmd.RemoteHost.User, cmd.RemoteHost.Host, auth) - if err != nil { - return logging.Logging{ - Output: err.Error(), - Err: err, - } - } + var stdoutBuf, stderrBuf bytes.Buffer + var err error + var cmdArgs string + for _, v := range command.Args { + cmdArgs += v + } - // Defer closing the network connection. - defer client.Close() + var remoteHost = &command.RemoteHost + fmt.Printf("\n\nRunning command: " + command.Cmd + " " + cmdArgs + " on host " + command.RemoteHost.Host + "...\n\n") + if command.Remote { - command := cmd.Cmd - for _, v := range cmd.Args { - command += v - } + remoteHost.Port = 22 + remoteHost.Host = command.RemoteHost.Host - // Execute your command. - out, err := client.Run(command) - if err != nil { - return logging.Logging{ - Output: string(out), - Err: err, - } - } + sshc, err := remoteHost.connectToSSHHost() + if err != nil { + panic(fmt.Errorf("ssh dial: %w", err)) } - cmdOut := exec.Command(cmd.Cmd, cmd.Args...) - output, err := cmdOut.Output() + defer sshc.Close() + s, err := sshc.NewSession() + if err != nil { + panic(fmt.Errorf("new ssh session: %w", err)) + } + defer s.Close() + + cmd := command.Cmd + for _, a := range command.Args { + cmd += " " + a + } + + s.Stdout = io.MultiWriter(os.Stdout, &stdoutBuf) + s.Stderr = io.MultiWriter(os.Stderr, &stderrBuf) + err = s.Run(cmd) if err != nil { return logging.Logging{ - Output: string(output), - Err: err, + Output: stdoutBuf.String(), + Err: fmt.Errorf("error running " + cmd + ": " + stderrBuf.String()), + } + } + // fmt.Printf("Output: %s\n", string(output)) + } else { + // shell := "/bin/bash" + localCMD := exec.Command(command.Cmd, command.Args...) + localCMD.Stdout = io.MultiWriter(os.Stdout, &stdoutBuf) + localCMD.Stderr = io.MultiWriter(os.Stderr, &stderrBuf) + err = localCMD.Run() + + if err != nil { + return logging.Logging{ + Output: stdoutBuf.String(), + Err: fmt.Errorf(stderrBuf.String()), } } } return logging.Logging{ - Output: "", + Output: stdoutBuf.String(), Err: nil, } } + +func New() BackupConfig { + return BackupConfig{} +} diff --git a/pkg/backy/ssh.go b/pkg/backy/ssh.go new file mode 100644 index 0000000..d1c7893 --- /dev/null +++ b/pkg/backy/ssh.go @@ -0,0 +1,92 @@ +package backy + +import ( + "errors" + "fmt" + "os" + "os/user" + "path/filepath" + "strings" + + "github.com/kevinburke/ssh_config" + "golang.org/x/crypto/ssh" +) + +type SshConfig struct { + PrivateKey string + Port uint + HostName string + User string +} + +func GetSSHConfig(host string) (SshConfig, error) { + var config SshConfig + hostName := ssh_config.Get(host, "HostName") + if hostName == "" { + return SshConfig{}, errors.New("hostname not found") + } + config.HostName = hostName + privKey, err := ssh_config.GetStrict(host, "IdentityFile") + if err != nil { + return SshConfig{}, err + } + config.PrivateKey = privKey + User := ssh_config.Get(host, "User") + if User == "" { + return SshConfig{}, errors.New("user not found") + } + return config, nil +} + +func (remoteConfig *Host) connectToSSHHost() (*ssh.Client, error) { + var sshc *ssh.Client + var connectErr error + + f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config")) + cfg, _ := ssh_config.Decode(f) + for _, host := range cfg.Hosts { + if host.Matches(remoteConfig.Host) { + var identityFile string + if remoteConfig.PrivateKeyPath == "" { + identityFile, _ = cfg.Get(remoteConfig.Host, "IdentityFile") + usr, _ := user.Current() + dir := usr.HomeDir + if identityFile == "~" { + // In case of "~", which won't be caught by the "else if" + identityFile = dir + } else if strings.HasPrefix(identityFile, "~/") { + // Use strings.HasPrefix so we don't match paths like + // "/something/~/something/" + identityFile = filepath.Join(dir, identityFile[2:]) + } + remoteConfig.PrivateKeyPath = filepath.Join(identityFile) + } + remoteConfig.HostName, _ = cfg.Get(remoteConfig.Host, "HostName") + if remoteConfig.HostName == "" { + remoteConfig.HostName = remoteConfig.Host + } + port, _ := cfg.Get(remoteConfig.Host, "Port") + if port == "" { + port = "22" + } + privateKey, err := os.ReadFile(remoteConfig.PrivateKeyPath) + remoteConfig.HostName = remoteConfig.HostName + ":" + port + if err != nil { + return nil, fmt.Errorf("read private key error: %w", err) + } + signer, err := ssh.ParsePrivateKey(privateKey) + if err != nil { + return nil, fmt.Errorf("parse private key error: %w", err) + } + sshConfig := &ssh.ClientConfig{ + User: remoteConfig.User, + Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + sshc, connectErr = ssh.Dial("tcp", remoteConfig.HostName, sshConfig) + break + } + + } + return sshc, connectErr +} diff --git a/pkg/config/backy/config.go b/pkg/config/backy/config.go index e655f2c..7a578b1 100644 --- a/pkg/config/backy/config.go +++ b/pkg/config/backy/config.go @@ -5,10 +5,6 @@ import ( "github.com/spf13/viper" ) -type BackyViperOpts struct { - ConfigFilePath string -} - func ReadConfig(Config *viper.Viper) (*viper.Viper, error) { backyViper := viper.New() @@ -82,9 +78,6 @@ func CreateConfig(backup backy.BackupConfig) backy.BackupConfig { Name: backup.Name, BackupType: backup.BackupType, - DstDir: backup.DstDir, - SrcDir: backup.SrcDir, - ConfigPath: backup.ConfigPath, } diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go deleted file mode 100644 index f79db0d..0000000 --- a/pkg/ssh/ssh.go +++ /dev/null @@ -1,33 +0,0 @@ -package ssh - -import ( - "errors" - - "github.com/kevinburke/ssh_config" -) - -type SshConfig struct { - PrivateKey string - Port uint - HostName string - User string -} - -func GetSSHConfig(host string) (SshConfig, error) { - var config SshConfig - hostName := ssh_config.Get(host, "HostName") - if hostName == "" { - return SshConfig{}, errors.New("hostname not found") - } - config.HostName = hostName - privKey, err := ssh_config.GetStrict(host, "IdentityFile") - if err != nil { - return SshConfig{}, err - } - config.PrivateKey = privKey - User := ssh_config.Get(host, "User") - if User == "" { - return SshConfig{}, errors.New("user not found") - } - return config, nil -}