use ssh stdlib

This commit is contained in:
Andrew Woodlee 2022-12-26 23:20:11 -06:00
parent ab12b02dc9
commit 2daae5cf9e
4 changed files with 153 additions and 90 deletions

View File

@ -1,10 +1,12 @@
package backy package backy
import ( import (
"bytes"
"fmt"
"io"
"os"
"os/exec" "os/exec"
"github.com/melbahja/goph"
"git.andrewnw.xyz/CyberShell/backy/pkg/logging" "git.andrewnw.xyz/CyberShell/backy/pkg/logging"
) )
@ -32,24 +34,22 @@ type Commands struct {
After Command After Command
} }
// BackupConfig is a configuration struct that is used to define backups
type BackupConfig struct { type BackupConfig struct {
Name string Name string
BackupType string BackupType string
ConfigPath string ConfigPath string
Cmds Commands Cmds Commands
DstDir string
SrcDir string
} }
/* /*
* Runs a backup configuration * Runs a backup configuration
*/ */
func Run(backup BackupConfig) logging.Logging { func (backup BackupConfig) Run() logging.Logging {
beforeConfig := backup.Cmds.Before beforeConfig := backup.Cmds.Before
beforeOutput := runCmd(beforeConfig) beforeOutput := beforeConfig.runCmd()
if beforeOutput.Err != nil { if beforeOutput.Err != nil {
return logging.Logging{ return logging.Logging{
Output: beforeOutput.Output, Output: beforeOutput.Output,
@ -57,7 +57,7 @@ func Run(backup BackupConfig) logging.Logging {
} }
} }
backupConfig := backup.Cmds.Backup backupConfig := backup.Cmds.Backup
backupOutput := runCmd(backupConfig) backupOutput := backupConfig.runCmd()
if backupOutput.Err != nil { if backupOutput.Err != nil {
return logging.Logging{ return logging.Logging{
Output: beforeOutput.Output, Output: beforeOutput.Output,
@ -65,12 +65,9 @@ func Run(backup BackupConfig) logging.Logging {
} }
} }
afterConfig := backup.Cmds.After afterConfig := backup.Cmds.After
afterOutput := runCmd(afterConfig) afterOutput := afterConfig.runCmd()
if afterOutput.Err != nil { if afterOutput.Err != nil {
return logging.Logging{ return afterOutput
Output: beforeOutput.Output,
Err: beforeOutput.Err,
}
} }
return logging.Logging{ return logging.Logging{
Output: afterOutput.Output, Output: afterOutput.Output,
@ -78,54 +75,68 @@ func Run(backup BackupConfig) logging.Logging {
} }
} }
func runCmd(cmd Command) logging.Logging { func (command Command) runCmd() logging.Logging {
if !cmd.Empty {
if cmd.Remote {
// Start new ssh connection with private key.
auth, err := goph.Key(cmd.RemoteHost.PrivateKeyPath, cmd.RemoteHost.PrivateKeyPassword)
if err != nil {
return logging.Logging{
Output: err.Error(),
Err: err,
}
}
client, err := goph.New(cmd.RemoteHost.User, cmd.RemoteHost.Host, auth) var stdoutBuf, stderrBuf bytes.Buffer
if err != nil { var err error
return logging.Logging{ var cmdArgs string
Output: err.Error(), for _, v := range command.Args {
Err: err, cmdArgs += v
} }
}
// Defer closing the network connection. var remoteHost = &command.RemoteHost
defer client.Close() fmt.Printf("\n\nRunning command: " + command.Cmd + " " + cmdArgs + " on host " + command.RemoteHost.Host + "...\n\n")
if command.Remote {
command := cmd.Cmd remoteHost.Port = 22
for _, v := range cmd.Args { remoteHost.Host = command.RemoteHost.Host
command += v
}
// Execute your command. sshc, err := remoteHost.connectToSSHHost()
out, err := client.Run(command) if err != nil {
if err != nil { panic(fmt.Errorf("ssh dial: %w", err))
return logging.Logging{
Output: string(out),
Err: err,
}
}
} }
cmdOut := exec.Command(cmd.Cmd, cmd.Args...) defer sshc.Close()
output, err := cmdOut.Output() s, err := sshc.NewSession()
if err != nil {
panic(fmt.Errorf("new ssh session: %w", err))
}
defer s.Close()
cmd := command.Cmd
for _, a := range command.Args {
cmd += " " + a
}
s.Stdout = io.MultiWriter(os.Stdout, &stdoutBuf)
s.Stderr = io.MultiWriter(os.Stderr, &stderrBuf)
err = s.Run(cmd)
if err != nil { if err != nil {
return logging.Logging{ return logging.Logging{
Output: string(output), Output: stdoutBuf.String(),
Err: err, Err: fmt.Errorf("error running " + cmd + ": " + stderrBuf.String()),
}
}
// fmt.Printf("Output: %s\n", string(output))
} else {
// shell := "/bin/bash"
localCMD := exec.Command(command.Cmd, command.Args...)
localCMD.Stdout = io.MultiWriter(os.Stdout, &stdoutBuf)
localCMD.Stderr = io.MultiWriter(os.Stderr, &stderrBuf)
err = localCMD.Run()
if err != nil {
return logging.Logging{
Output: stdoutBuf.String(),
Err: fmt.Errorf(stderrBuf.String()),
} }
} }
} }
return logging.Logging{ return logging.Logging{
Output: "", Output: stdoutBuf.String(),
Err: nil, Err: nil,
} }
} }
func New() BackupConfig {
return BackupConfig{}
}

92
pkg/backy/ssh.go Normal file
View File

@ -0,0 +1,92 @@
package backy
import (
"errors"
"fmt"
"os"
"os/user"
"path/filepath"
"strings"
"github.com/kevinburke/ssh_config"
"golang.org/x/crypto/ssh"
)
type SshConfig struct {
PrivateKey string
Port uint
HostName string
User string
}
func GetSSHConfig(host string) (SshConfig, error) {
var config SshConfig
hostName := ssh_config.Get(host, "HostName")
if hostName == "" {
return SshConfig{}, errors.New("hostname not found")
}
config.HostName = hostName
privKey, err := ssh_config.GetStrict(host, "IdentityFile")
if err != nil {
return SshConfig{}, err
}
config.PrivateKey = privKey
User := ssh_config.Get(host, "User")
if User == "" {
return SshConfig{}, errors.New("user not found")
}
return config, nil
}
func (remoteConfig *Host) connectToSSHHost() (*ssh.Client, error) {
var sshc *ssh.Client
var connectErr error
f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config"))
cfg, _ := ssh_config.Decode(f)
for _, host := range cfg.Hosts {
if host.Matches(remoteConfig.Host) {
var identityFile string
if remoteConfig.PrivateKeyPath == "" {
identityFile, _ = cfg.Get(remoteConfig.Host, "IdentityFile")
usr, _ := user.Current()
dir := usr.HomeDir
if identityFile == "~" {
// In case of "~", which won't be caught by the "else if"
identityFile = dir
} else if strings.HasPrefix(identityFile, "~/") {
// Use strings.HasPrefix so we don't match paths like
// "/something/~/something/"
identityFile = filepath.Join(dir, identityFile[2:])
}
remoteConfig.PrivateKeyPath = filepath.Join(identityFile)
}
remoteConfig.HostName, _ = cfg.Get(remoteConfig.Host, "HostName")
if remoteConfig.HostName == "" {
remoteConfig.HostName = remoteConfig.Host
}
port, _ := cfg.Get(remoteConfig.Host, "Port")
if port == "" {
port = "22"
}
privateKey, err := os.ReadFile(remoteConfig.PrivateKeyPath)
remoteConfig.HostName = remoteConfig.HostName + ":" + port
if err != nil {
return nil, fmt.Errorf("read private key error: %w", err)
}
signer, err := ssh.ParsePrivateKey(privateKey)
if err != nil {
return nil, fmt.Errorf("parse private key error: %w", err)
}
sshConfig := &ssh.ClientConfig{
User: remoteConfig.User,
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
sshc, connectErr = ssh.Dial("tcp", remoteConfig.HostName, sshConfig)
break
}
}
return sshc, connectErr
}

View File

@ -5,10 +5,6 @@ import (
"github.com/spf13/viper" "github.com/spf13/viper"
) )
type BackyViperOpts struct {
ConfigFilePath string
}
func ReadConfig(Config *viper.Viper) (*viper.Viper, error) { func ReadConfig(Config *viper.Viper) (*viper.Viper, error) {
backyViper := viper.New() backyViper := viper.New()
@ -82,9 +78,6 @@ func CreateConfig(backup backy.BackupConfig) backy.BackupConfig {
Name: backup.Name, Name: backup.Name,
BackupType: backup.BackupType, BackupType: backup.BackupType,
DstDir: backup.DstDir,
SrcDir: backup.SrcDir,
ConfigPath: backup.ConfigPath, ConfigPath: backup.ConfigPath,
} }

View File

@ -1,33 +0,0 @@
package ssh
import (
"errors"
"github.com/kevinburke/ssh_config"
)
type SshConfig struct {
PrivateKey string
Port uint
HostName string
User string
}
func GetSSHConfig(host string) (SshConfig, error) {
var config SshConfig
hostName := ssh_config.Get(host, "HostName")
if hostName == "" {
return SshConfig{}, errors.New("hostname not found")
}
config.HostName = hostName
privKey, err := ssh_config.GetStrict(host, "IdentityFile")
if err != nil {
return SshConfig{}, err
}
config.PrivateKey = privKey
User := ssh_config.Get(host, "User")
if User == "" {
return SshConfig{}, errors.New("user not found")
}
return config, nil
}