package main import ( "bytes" "fmt" "io" "os" "path/filepath" "strings" "os/exec" "os/user" "git.andrewnw.xyz/CyberShell/backy/pkg/backy" "github.com/kevinburke/ssh_config" "golang.org/x/crypto/ssh" ) type backupCommand struct { // command to run cmd string // host on which to run cmd host string // Shell specifies which shell to run the command in, if any // Not applicable when host is not local shell string // path to the command path string // cmdArgs is an array that holds the arguments to cmd cmdArgs []string } var remoteHost *backy.Host func main() { } func (command *backupCommand) runCmd() { var cmdArgsStr string for _, v := range command.cmdArgs { cmdArgsStr += fmt.Sprintf(" %s", v) } fmt.Printf("\n\nRunning command: " + command.cmd + " " + cmdArgsStr + " on host " + command.host + "...\n\n") if command.host != "local" { remoteHost.Port = 22 remoteHost.Host = command.host sshc, err := connectToSSHHost(remoteHost) if err != nil { panic(fmt.Errorf("ssh dial: %w", err)) } defer sshc.Close() s, err := sshc.NewSession() if err != nil { panic(fmt.Errorf("new ssh session: %w", err)) } defer s.Close() cmd := command.cmd for _, a := range command.cmdArgs { cmd += " " + a } var stdoutBuf, stderrBuf bytes.Buffer s.Stdout = io.MultiWriter(os.Stdout, &stdoutBuf) s.Stderr = io.MultiWriter(os.Stderr, &stderrBuf) err = s.Run(cmd) if err != nil { panic(fmt.Errorf("error when running cmd " + cmd + "\n Error: " + err.Error())) } // fmt.Printf("Output: %s\n", string(output)) } else { // shell := "/bin/bash" var err error if command.shell != "" { cmdArgsStr = fmt.Sprintf("%s %s", command.cmd, cmdArgsStr) localCMD := exec.Command(command.shell, "-c", cmdArgsStr) var stdoutBuf, stderrBuf bytes.Buffer localCMD.Stdout = io.MultiWriter(os.Stdout, &stdoutBuf) localCMD.Stderr = io.MultiWriter(os.Stderr, &stderrBuf) err = localCMD.Run() if err != nil { panic(fmt.Errorf("error when running cmd: %s: %w", command.cmd, err)) } } // localCMD := exec.Command(command.cmd, command.args...) // var stdoutBuf, stderrBuf bytes.Buffer // localCMD.Stdout = io.MultiWriter(os.Stdout, &stdoutBuf) // localCMD.Stderr = io.MultiWriter(os.Stderr, &stderrBuf) // err = localCMD.Run() if err != nil { panic(fmt.Errorf("error when running cmd: %s: %w", command.cmd, err)) } // fmt.Printf("%s\n", string(output)) } } func connectToSSHHost(remoteConfig *backy.Host) (*ssh.Client, error) { var sshc *ssh.Client var connectErr error f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config")) cfg, _ := ssh_config.Decode(f) for _, host := range cfg.Hosts { if host.Matches(remoteConfig.Host) { var identityFile string if remoteConfig.PrivateKeyPath == "" { identityFile, _ = cfg.Get(remoteConfig.Host, "IdentityFile") usr, _ := user.Current() dir := usr.HomeDir if identityFile == "~" { // In case of "~", which won't be caught by the "else if" identityFile = dir } else if strings.HasPrefix(identityFile, "~/") { // Use strings.HasPrefix so we don't match paths like // "/something/~/something/" identityFile = filepath.Join(dir, identityFile[2:]) } remoteConfig.PrivateKeyPath = filepath.Join(identityFile) } remoteConfig.HostName, _ = cfg.Get(remoteConfig.Host, "HostName") if remoteConfig.HostName == "" { remoteConfig.HostName = remoteConfig.Host } port, _ := cfg.Get(remoteConfig.Host, "Port") if port == "" { port = "22" } privateKey, err := os.ReadFile(remoteConfig.PrivateKeyPath) remoteConfig.HostName = remoteConfig.HostName + ":" + port if err != nil { panic(fmt.Errorf("read private key: %w", err)) } signer, err := ssh.ParsePrivateKey(privateKey) if err != nil { panic(fmt.Errorf("parse private key: %w", err)) } sshConfig := &ssh.ClientConfig{ User: "root", Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } sshc, connectErr = ssh.Dial("tcp", remoteConfig.HostName, sshConfig) break } } return sshc, connectErr }