use ssh stdlib
This commit is contained in:
		@@ -1,10 +1,12 @@
 | 
				
			|||||||
package backy
 | 
					package backy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
	"os/exec"
 | 
						"os/exec"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/melbahja/goph"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	"git.andrewnw.xyz/CyberShell/backy/pkg/logging"
 | 
						"git.andrewnw.xyz/CyberShell/backy/pkg/logging"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -32,24 +34,22 @@ type Commands struct {
 | 
				
			|||||||
	After  Command
 | 
						After  Command
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// BackupConfig is a configuration struct that is used to define backups
 | 
				
			||||||
type BackupConfig struct {
 | 
					type BackupConfig struct {
 | 
				
			||||||
	Name       string
 | 
						Name       string
 | 
				
			||||||
	BackupType string
 | 
						BackupType string
 | 
				
			||||||
	ConfigPath string
 | 
						ConfigPath string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	Cmds Commands
 | 
						Cmds Commands
 | 
				
			||||||
 | 
					 | 
				
			||||||
	DstDir string
 | 
					 | 
				
			||||||
	SrcDir string
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/*
 | 
					/*
 | 
				
			||||||
* Runs a backup configuration
 | 
					* Runs a backup configuration
 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
func Run(backup BackupConfig) logging.Logging {
 | 
					func (backup BackupConfig) Run() logging.Logging {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	beforeConfig := backup.Cmds.Before
 | 
						beforeConfig := backup.Cmds.Before
 | 
				
			||||||
	beforeOutput := runCmd(beforeConfig)
 | 
						beforeOutput := beforeConfig.runCmd()
 | 
				
			||||||
	if beforeOutput.Err != nil {
 | 
						if beforeOutput.Err != nil {
 | 
				
			||||||
		return logging.Logging{
 | 
							return logging.Logging{
 | 
				
			||||||
			Output: beforeOutput.Output,
 | 
								Output: beforeOutput.Output,
 | 
				
			||||||
@@ -57,7 +57,7 @@ func Run(backup BackupConfig) logging.Logging {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	backupConfig := backup.Cmds.Backup
 | 
						backupConfig := backup.Cmds.Backup
 | 
				
			||||||
	backupOutput := runCmd(backupConfig)
 | 
						backupOutput := backupConfig.runCmd()
 | 
				
			||||||
	if backupOutput.Err != nil {
 | 
						if backupOutput.Err != nil {
 | 
				
			||||||
		return logging.Logging{
 | 
							return logging.Logging{
 | 
				
			||||||
			Output: beforeOutput.Output,
 | 
								Output: beforeOutput.Output,
 | 
				
			||||||
@@ -65,12 +65,9 @@ func Run(backup BackupConfig) logging.Logging {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	afterConfig := backup.Cmds.After
 | 
						afterConfig := backup.Cmds.After
 | 
				
			||||||
	afterOutput := runCmd(afterConfig)
 | 
						afterOutput := afterConfig.runCmd()
 | 
				
			||||||
	if afterOutput.Err != nil {
 | 
						if afterOutput.Err != nil {
 | 
				
			||||||
		return logging.Logging{
 | 
							return afterOutput
 | 
				
			||||||
			Output: beforeOutput.Output,
 | 
					 | 
				
			||||||
			Err:    beforeOutput.Err,
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return logging.Logging{
 | 
						return logging.Logging{
 | 
				
			||||||
		Output: afterOutput.Output,
 | 
							Output: afterOutput.Output,
 | 
				
			||||||
@@ -78,54 +75,68 @@ func Run(backup BackupConfig) logging.Logging {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func runCmd(cmd Command) logging.Logging {
 | 
					func (command Command) runCmd() logging.Logging {
 | 
				
			||||||
	if !cmd.Empty {
 | 
					 | 
				
			||||||
		if cmd.Remote {
 | 
					 | 
				
			||||||
			// Start new ssh connection with private key.
 | 
					 | 
				
			||||||
			auth, err := goph.Key(cmd.RemoteHost.PrivateKeyPath, cmd.RemoteHost.PrivateKeyPassword)
 | 
					 | 
				
			||||||
			if err != nil {
 | 
					 | 
				
			||||||
				return logging.Logging{
 | 
					 | 
				
			||||||
					Output: err.Error(),
 | 
					 | 
				
			||||||
					Err:    err,
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
			client, err := goph.New(cmd.RemoteHost.User, cmd.RemoteHost.Host, auth)
 | 
						var stdoutBuf, stderrBuf bytes.Buffer
 | 
				
			||||||
			if err != nil {
 | 
						var err error
 | 
				
			||||||
				return logging.Logging{
 | 
						var cmdArgs string
 | 
				
			||||||
					Output: err.Error(),
 | 
						for _, v := range command.Args {
 | 
				
			||||||
					Err:    err,
 | 
							cmdArgs += v
 | 
				
			||||||
				}
 | 
						}
 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// Defer closing the network connection.
 | 
						var remoteHost = &command.RemoteHost
 | 
				
			||||||
			defer client.Close()
 | 
						fmt.Printf("\n\nRunning command: " + command.Cmd + " " + cmdArgs + " on host " + command.RemoteHost.Host + "...\n\n")
 | 
				
			||||||
 | 
						if command.Remote {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			command := cmd.Cmd
 | 
							remoteHost.Port = 22
 | 
				
			||||||
			for _, v := range cmd.Args {
 | 
							remoteHost.Host = command.RemoteHost.Host
 | 
				
			||||||
				command += v
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// Execute your command.
 | 
							sshc, err := remoteHost.connectToSSHHost()
 | 
				
			||||||
			out, err := client.Run(command)
 | 
							if err != nil {
 | 
				
			||||||
			if err != nil {
 | 
								panic(fmt.Errorf("ssh dial: %w", err))
 | 
				
			||||||
				return logging.Logging{
 | 
					 | 
				
			||||||
					Output: string(out),
 | 
					 | 
				
			||||||
					Err:    err,
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		cmdOut := exec.Command(cmd.Cmd, cmd.Args...)
 | 
							defer sshc.Close()
 | 
				
			||||||
		output, err := cmdOut.Output()
 | 
							s, err := sshc.NewSession()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								panic(fmt.Errorf("new ssh session: %w", err))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							defer s.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							cmd := command.Cmd
 | 
				
			||||||
 | 
							for _, a := range command.Args {
 | 
				
			||||||
 | 
								cmd += " " + a
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							s.Stdout = io.MultiWriter(os.Stdout, &stdoutBuf)
 | 
				
			||||||
 | 
							s.Stderr = io.MultiWriter(os.Stderr, &stderrBuf)
 | 
				
			||||||
 | 
							err = s.Run(cmd)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return logging.Logging{
 | 
								return logging.Logging{
 | 
				
			||||||
				Output: string(output),
 | 
									Output: stdoutBuf.String(),
 | 
				
			||||||
				Err:    err,
 | 
									Err:    fmt.Errorf("error running " + cmd + ": " + stderrBuf.String()),
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							// fmt.Printf("Output: %s\n", string(output))
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							// shell := "/bin/bash"
 | 
				
			||||||
 | 
							localCMD := exec.Command(command.Cmd, command.Args...)
 | 
				
			||||||
 | 
							localCMD.Stdout = io.MultiWriter(os.Stdout, &stdoutBuf)
 | 
				
			||||||
 | 
							localCMD.Stderr = io.MultiWriter(os.Stderr, &stderrBuf)
 | 
				
			||||||
 | 
							err = localCMD.Run()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return logging.Logging{
 | 
				
			||||||
 | 
									Output: stdoutBuf.String(),
 | 
				
			||||||
 | 
									Err:    fmt.Errorf(stderrBuf.String()),
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return logging.Logging{
 | 
						return logging.Logging{
 | 
				
			||||||
		Output: "",
 | 
							Output: stdoutBuf.String(),
 | 
				
			||||||
		Err:    nil,
 | 
							Err:    nil,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func New() BackupConfig {
 | 
				
			||||||
 | 
						return BackupConfig{}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										92
									
								
								pkg/backy/ssh.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										92
									
								
								pkg/backy/ssh.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,92 @@
 | 
				
			|||||||
 | 
					package backy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
						"os/user"
 | 
				
			||||||
 | 
						"path/filepath"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/kevinburke/ssh_config"
 | 
				
			||||||
 | 
						"golang.org/x/crypto/ssh"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type SshConfig struct {
 | 
				
			||||||
 | 
						PrivateKey string
 | 
				
			||||||
 | 
						Port       uint
 | 
				
			||||||
 | 
						HostName   string
 | 
				
			||||||
 | 
						User       string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func GetSSHConfig(host string) (SshConfig, error) {
 | 
				
			||||||
 | 
						var config SshConfig
 | 
				
			||||||
 | 
						hostName := ssh_config.Get(host, "HostName")
 | 
				
			||||||
 | 
						if hostName == "" {
 | 
				
			||||||
 | 
							return SshConfig{}, errors.New("hostname not found")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						config.HostName = hostName
 | 
				
			||||||
 | 
						privKey, err := ssh_config.GetStrict(host, "IdentityFile")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return SshConfig{}, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						config.PrivateKey = privKey
 | 
				
			||||||
 | 
						User := ssh_config.Get(host, "User")
 | 
				
			||||||
 | 
						if User == "" {
 | 
				
			||||||
 | 
							return SshConfig{}, errors.New("user not found")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return config, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (remoteConfig *Host) connectToSSHHost() (*ssh.Client, error) {
 | 
				
			||||||
 | 
						var sshc *ssh.Client
 | 
				
			||||||
 | 
						var connectErr error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config"))
 | 
				
			||||||
 | 
						cfg, _ := ssh_config.Decode(f)
 | 
				
			||||||
 | 
						for _, host := range cfg.Hosts {
 | 
				
			||||||
 | 
							if host.Matches(remoteConfig.Host) {
 | 
				
			||||||
 | 
								var identityFile string
 | 
				
			||||||
 | 
								if remoteConfig.PrivateKeyPath == "" {
 | 
				
			||||||
 | 
									identityFile, _ = cfg.Get(remoteConfig.Host, "IdentityFile")
 | 
				
			||||||
 | 
									usr, _ := user.Current()
 | 
				
			||||||
 | 
									dir := usr.HomeDir
 | 
				
			||||||
 | 
									if identityFile == "~" {
 | 
				
			||||||
 | 
										// In case of "~", which won't be caught by the "else if"
 | 
				
			||||||
 | 
										identityFile = dir
 | 
				
			||||||
 | 
									} else if strings.HasPrefix(identityFile, "~/") {
 | 
				
			||||||
 | 
										// Use strings.HasPrefix so we don't match paths like
 | 
				
			||||||
 | 
										// "/something/~/something/"
 | 
				
			||||||
 | 
										identityFile = filepath.Join(dir, identityFile[2:])
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									remoteConfig.PrivateKeyPath = filepath.Join(identityFile)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								remoteConfig.HostName, _ = cfg.Get(remoteConfig.Host, "HostName")
 | 
				
			||||||
 | 
								if remoteConfig.HostName == "" {
 | 
				
			||||||
 | 
									remoteConfig.HostName = remoteConfig.Host
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								port, _ := cfg.Get(remoteConfig.Host, "Port")
 | 
				
			||||||
 | 
								if port == "" {
 | 
				
			||||||
 | 
									port = "22"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								privateKey, err := os.ReadFile(remoteConfig.PrivateKeyPath)
 | 
				
			||||||
 | 
								remoteConfig.HostName = remoteConfig.HostName + ":" + port
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									return nil, fmt.Errorf("read private key error: %w", err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								signer, err := ssh.ParsePrivateKey(privateKey)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									return nil, fmt.Errorf("parse private key error: %w", err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								sshConfig := &ssh.ClientConfig{
 | 
				
			||||||
 | 
									User:            remoteConfig.User,
 | 
				
			||||||
 | 
									Auth:            []ssh.AuthMethod{ssh.PublicKeys(signer)},
 | 
				
			||||||
 | 
									HostKeyCallback: ssh.InsecureIgnoreHostKey(),
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								sshc, connectErr = ssh.Dial("tcp", remoteConfig.HostName, sshConfig)
 | 
				
			||||||
 | 
								break
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return sshc, connectErr
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -5,10 +5,6 @@ import (
 | 
				
			|||||||
	"github.com/spf13/viper"
 | 
						"github.com/spf13/viper"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type BackyViperOpts struct {
 | 
					 | 
				
			||||||
	ConfigFilePath string
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func ReadConfig(Config *viper.Viper) (*viper.Viper, error) {
 | 
					func ReadConfig(Config *viper.Viper) (*viper.Viper, error) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	backyViper := viper.New()
 | 
						backyViper := viper.New()
 | 
				
			||||||
@@ -82,9 +78,6 @@ func CreateConfig(backup backy.BackupConfig) backy.BackupConfig {
 | 
				
			|||||||
		Name:       backup.Name,
 | 
							Name:       backup.Name,
 | 
				
			||||||
		BackupType: backup.BackupType,
 | 
							BackupType: backup.BackupType,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		DstDir: backup.DstDir,
 | 
					 | 
				
			||||||
		SrcDir: backup.SrcDir,
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		ConfigPath: backup.ConfigPath,
 | 
							ConfigPath: backup.ConfigPath,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,33 +0,0 @@
 | 
				
			|||||||
package ssh
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"errors"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	"github.com/kevinburke/ssh_config"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type SshConfig struct {
 | 
					 | 
				
			||||||
	PrivateKey string
 | 
					 | 
				
			||||||
	Port       uint
 | 
					 | 
				
			||||||
	HostName   string
 | 
					 | 
				
			||||||
	User       string
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func GetSSHConfig(host string) (SshConfig, error) {
 | 
					 | 
				
			||||||
	var config SshConfig
 | 
					 | 
				
			||||||
	hostName := ssh_config.Get(host, "HostName")
 | 
					 | 
				
			||||||
	if hostName == "" {
 | 
					 | 
				
			||||||
		return SshConfig{}, errors.New("hostname not found")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	config.HostName = hostName
 | 
					 | 
				
			||||||
	privKey, err := ssh_config.GetStrict(host, "IdentityFile")
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return SshConfig{}, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	config.PrivateKey = privKey
 | 
					 | 
				
			||||||
	User := ssh_config.Get(host, "User")
 | 
					 | 
				
			||||||
	if User == "" {
 | 
					 | 
				
			||||||
		return SshConfig{}, errors.New("user not found")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return config, nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
		Reference in New Issue
	
	Block a user