diff --git a/backup.go b/backup.go new file mode 100644 index 0000000..80ca6ba --- /dev/null +++ b/backup.go @@ -0,0 +1,160 @@ +package main + +import ( + "bytes" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "os/exec" + "os/user" + + "git.andrewnw.xyz/CyberShell/backy/pkg/backy" + "github.com/kevinburke/ssh_config" + "golang.org/x/crypto/ssh" +) + +type backupCommand struct { + + // command to run + cmd string + + // host on which to run cmd + host string + + // Shell specifies which shell to run the command in, if any + // Not applicable when host is not local + shell string + // path to the command + + path string + + // cmdArgs is an array that holds the arguments to cmd + cmdArgs []string +} + +var remoteHost *backy.Host + +func main() { + +} + +func (command *backupCommand) runCmd() { + + var cmdArgsStr string + for _, v := range command.cmdArgs { + cmdArgsStr += fmt.Sprintf(" %s", v) + } + fmt.Printf("\n\nRunning command: " + command.cmd + " " + cmdArgsStr + " on host " + command.host + "...\n\n") + if command.host != "local" { + + remoteHost.Port = 22 + remoteHost.Host = command.host + + sshc, err := connectToSSHHost(remoteHost) + if err != nil { + panic(fmt.Errorf("ssh dial: %w", err)) + } + defer sshc.Close() + s, err := sshc.NewSession() + if err != nil { + panic(fmt.Errorf("new ssh session: %w", err)) + } + defer s.Close() + + cmd := command.cmd + for _, a := range command.cmdArgs { + cmd += " " + a + } + + var stdoutBuf, stderrBuf bytes.Buffer + s.Stdout = io.MultiWriter(os.Stdout, &stdoutBuf) + s.Stderr = io.MultiWriter(os.Stderr, &stderrBuf) + err = s.Run(cmd) + if err != nil { + panic(fmt.Errorf("error when running cmd " + cmd + "\n Error: " + err.Error())) + } + // fmt.Printf("Output: %s\n", string(output)) + } else { + // shell := "/bin/bash" + var err error + if command.shell != "" { + cmdArgsStr = fmt.Sprintf("%s %s", command.cmd, cmdArgsStr) + localCMD := exec.Command(command.shell, "-c", cmdArgsStr) + + var stdoutBuf, stderrBuf bytes.Buffer + localCMD.Stdout = io.MultiWriter(os.Stdout, &stdoutBuf) + localCMD.Stderr = io.MultiWriter(os.Stderr, &stderrBuf) + err = localCMD.Run() + if err != nil { + panic(fmt.Errorf("error when running cmd: %s: %w", command.cmd, err)) + } + } + // localCMD := exec.Command(command.cmd, command.args...) + + // var stdoutBuf, stderrBuf bytes.Buffer + // localCMD.Stdout = io.MultiWriter(os.Stdout, &stdoutBuf) + // localCMD.Stderr = io.MultiWriter(os.Stderr, &stderrBuf) + // err = localCMD.Run() + + if err != nil { + panic(fmt.Errorf("error when running cmd: %s: %w", command.cmd, err)) + } + // fmt.Printf("%s\n", string(output)) + } +} + +func connectToSSHHost(remoteConfig *backy.Host) (*ssh.Client, error) { + var sshc *ssh.Client + var connectErr error + + f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config")) + cfg, _ := ssh_config.Decode(f) + for _, host := range cfg.Hosts { + if host.Matches(remoteConfig.Host) { + var identityFile string + if remoteConfig.PrivateKeyPath == "" { + identityFile, _ = cfg.Get(remoteConfig.Host, "IdentityFile") + usr, _ := user.Current() + dir := usr.HomeDir + if identityFile == "~" { + // In case of "~", which won't be caught by the "else if" + identityFile = dir + } else if strings.HasPrefix(identityFile, "~/") { + // Use strings.HasPrefix so we don't match paths like + // "/something/~/something/" + identityFile = filepath.Join(dir, identityFile[2:]) + } + remoteConfig.PrivateKeyPath = filepath.Join(identityFile) + } + remoteConfig.HostName, _ = cfg.Get(remoteConfig.Host, "HostName") + if remoteConfig.HostName == "" { + remoteConfig.HostName = remoteConfig.Host + } + port, _ := cfg.Get(remoteConfig.Host, "Port") + if port == "" { + port = "22" + } + privateKey, err := os.ReadFile(remoteConfig.PrivateKeyPath) + remoteConfig.HostName = remoteConfig.HostName + ":" + port + if err != nil { + panic(fmt.Errorf("read private key: %w", err)) + } + signer, err := ssh.ParsePrivateKey(privateKey) + if err != nil { + panic(fmt.Errorf("parse private key: %w", err)) + } + sshConfig := &ssh.ClientConfig{ + User: "root", + Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + sshc, connectErr = ssh.Dial("tcp", remoteConfig.HostName, sshConfig) + break + } + + } + return sshc, connectErr +}