From 1d61b0f09075fad2687a84dbfe2c2639043cb678 Mon Sep 17 00:00:00 2001 From: Novikov Roman Date: Fri, 18 Feb 2022 10:13:16 +0300 Subject: [PATCH 1/7] refact --- .gitignore | 1 + README.md | 2 +- client.go | 2 +- cmd/sup/main.go | 168 ++++++++++++------------- dist/brew/sup.rb | 14 +-- example/Supfile | 2 +- go.mod | 21 ++-- go.sum | 66 ++++++---- localhost.go | 48 ++++---- ssh.go | 311 +++++++++++++++++++++++++++++------------------ sup.go | 252 ++++++++++++-------------------------- supfile.go | 83 +++++++------ tar.go | 21 ++-- task.go | 265 ++++++++++++++++++++++++++++++++++------ 14 files changed, 734 insertions(+), 522 deletions(-) diff --git a/.gitignore b/.gitignore index 8269d0c..a7b5801 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ bin/ +.idea/ *.sw? diff --git a/README.md b/README.md index eb58b9d..32a9b26 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Stack Up is a simple deployment tool that performs given set of commands on mult # Installation - $ go get -u github.com/pressly/sup/cmd/sup + $ go get -u github.com/NovikovRoman/sup/cmd/sup # Usage diff --git a/client.go b/client.go index a260559..60b3c44 100644 --- a/client.go +++ b/client.go @@ -6,7 +6,7 @@ import ( ) type Client interface { - Connect(host string) error + Connect() error Run(task *Task) error Wait() error Close() error diff --git a/cmd/sup/main.go b/cmd/sup/main.go index e1f35ee..d1b30b2 100644 --- a/cmd/sup/main.go +++ b/cmd/sup/main.go @@ -12,9 +12,9 @@ import ( "text/tabwriter" "time" + "github.com/NovikovRoman/sup" "github.com/mikkeloscar/sshconfig" "github.com/pkg/errors" - "github.com/pressly/sup" ) var ( @@ -30,12 +30,10 @@ var ( showVersion bool showHelp bool - ErrUsage = errors.New("Usage: sup [OPTIONS] NETWORK COMMAND [...]\n sup [ --help | -v | --version ]") - ErrUnknownNetwork = errors.New("Unknown network") - ErrNetworkNoHosts = errors.New("No hosts defined for a given network") - ErrCmd = errors.New("Unknown command/target") - ErrTargetNoCommands = errors.New("No commands defined for a given target") - ErrConfigFile = errors.New("Unknown ssh_config file") + ErrUsage = errors.New("Usage: sup [OPTIONS] NETWORK COMMAND [...]\n sup [ --help | -v | --version ]") + ErrUnknownNetwork = errors.New("Unknown network") + ErrNetworkNoHosts = errors.New("No hosts defined for a given network") + ErrCmd = errors.New("Unknown command/target") ) type flagStringSlice []string @@ -70,44 +68,51 @@ func init() { func networkUsage(conf *sup.Supfile) { w := &tabwriter.Writer{} w.Init(os.Stderr, 4, 4, 2, ' ', 0) - defer w.Flush() + defer func(w *tabwriter.Writer) { + _ = w.Flush() + }(w) // Print available networks/hosts. - fmt.Fprintln(w, "Networks:\t") + _, _ = fmt.Fprintln(w, "Networks:\t") for _, name := range conf.Networks.Names { - fmt.Fprintf(w, "- %v\n", name) + _, _ = fmt.Fprintf(w, "- %v\n", name) network, _ := conf.Networks.Get(name) for _, host := range network.Hosts { - fmt.Fprintf(w, "\t- %v\n", host) + _, _ = fmt.Fprintf(w, "\t- %v\n", host) } } - fmt.Fprintln(w) + _, _ = fmt.Fprintln(w) } func cmdUsage(conf *sup.Supfile) { w := &tabwriter.Writer{} w.Init(os.Stderr, 4, 4, 2, ' ', 0) - defer w.Flush() + defer func(w *tabwriter.Writer) { + _ = w.Flush() + }(w) // Print available targets/commands. - fmt.Fprintln(w, "Targets:\t") + _, _ = fmt.Fprintln(w, "Targets:\t") for _, name := range conf.Targets.Names { cmds, _ := conf.Targets.Get(name) - fmt.Fprintf(w, "- %v\t%v\n", name, strings.Join(cmds, " ")) + _, _ = fmt.Fprintf(w, "- %v\t%v\n", name, strings.Join(cmds, " ")) } - fmt.Fprintln(w, "\t") - fmt.Fprintln(w, "Commands:\t") + _, _ = fmt.Fprintln(w, "\t") + _, _ = fmt.Fprintln(w, "Commands:\t") for _, name := range conf.Commands.Names { cmd, _ := conf.Commands.Get(name) - fmt.Fprintf(w, "- %v\t%v\n", name, cmd.Desc) + _, _ = fmt.Fprintf(w, "- %v\t%v\n", name, cmd.Desc) } - fmt.Fprintln(w) + _, _ = fmt.Fprintln(w) } // parseArgs parses args and returns network and commands to be run. // On error, it prints usage and exits. -func parseArgs(conf *sup.Supfile) (*sup.Network, []*sup.Command, error) { - var commands []*sup.Command +func parseArgs(conf *sup.Supfile) (network *sup.Network, commands []*sup.Command, err error) { + var ( + ok bool + nw sup.Network + ) args := flag.Args() if len(args) < 1 { @@ -116,12 +121,15 @@ func parseArgs(conf *sup.Supfile) (*sup.Network, []*sup.Command, error) { } // Does the exist? - network, ok := conf.Networks.Get(args[0]) + nw, ok = conf.Networks.Get(args[0]) if !ok { networkUsage(conf) - return nil, nil, ErrUnknownNetwork + err = ErrUnknownNetwork + return } + network = &nw + // Parse CLI --env flag env vars, override values defined in Network env. for _, env := range envVars { if len(env) == 0 { @@ -137,22 +145,25 @@ func parseArgs(conf *sup.Supfile) (*sup.Network, []*sup.Command, error) { network.Env.Set(env[:i], env[i+1:]) } - hosts, err := network.ParseInventory() - if err != nil { - return nil, nil, err + // Inventory + var hosts []string + if hosts, err = network.ParseInventory(); err != nil { + return } network.Hosts = append(network.Hosts, hosts...) // Does the have at least one host? if len(network.Hosts) == 0 { networkUsage(conf) - return nil, nil, ErrNetworkNoHosts + err = ErrNetworkNoHosts + return } // Check for the second argument if len(args) < 2 { cmdUsage(conf) - return nil, nil, ErrUsage + err = ErrUsage + return } // In case of the network.Env needs an initialization @@ -181,13 +192,14 @@ func parseArgs(conf *sup.Supfile) (*sup.Network, []*sup.Command, error) { target, isTarget := conf.Targets.Get(cmd) if isTarget { // Loop over target's commands. - for _, cmd := range target { - command, isCommand := conf.Commands.Get(cmd) + for _, cmdTarget := range target { + command, isCommand := conf.Commands.Get(cmdTarget) if !isCommand { cmdUsage(conf) - return nil, nil, fmt.Errorf("%v: %v", ErrCmd, cmd) + err = fmt.Errorf("%v: %v", ErrCmd, cmdTarget) + return } - command.Name = cmd + command.Name = cmdTarget commands = append(commands, &command) } } @@ -205,7 +217,7 @@ func parseArgs(conf *sup.Supfile) (*sup.Network, []*sup.Command, error) { } } - return &network, commands, nil + return } func resolvePath(path string) string { @@ -222,50 +234,59 @@ func resolvePath(path string) string { } func main() { + var ( + conf *sup.Supfile + commands []*sup.Command + network *sup.Network + data []byte + err error + ) flag.Parse() if showHelp { - fmt.Fprintln(os.Stderr, ErrUsage, "\n\nOptions:") + _, _ = fmt.Fprintln(os.Stderr, ErrUsage, "\n\nOptions:") flag.PrintDefaults() return } if showVersion { - fmt.Fprintln(os.Stderr, sup.VERSION) + _, _ = fmt.Fprintln(os.Stderr, sup.VERSION) return } if supfile == "" { supfile = "./Supfile" } - data, err := ioutil.ReadFile(resolvePath(supfile)) - if err != nil { + + if data, err = ioutil.ReadFile(resolvePath(supfile)); err != nil { firstErr := err data, err = ioutil.ReadFile("./Supfile.yml") // Alternative to ./Supfile. if err != nil { - fmt.Fprintln(os.Stderr, firstErr) - fmt.Fprintln(os.Stderr, err) + _, _ = fmt.Fprintln(os.Stderr, firstErr) + _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } } - conf, err := sup.NewSupfile(data) - if err != nil { - fmt.Fprintln(os.Stderr, err) + + if conf, err = sup.NewSupfile(data); err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } // Parse network and commands to be run from args. - network, commands, err := parseArgs(conf) - if err != nil { - fmt.Fprintln(os.Stderr, err) + if network, commands, err = parseArgs(conf); err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } + var ( + expr *regexp.Regexp + ) + // --only flag filters hosts if onlyHosts != "" { - expr, err := regexp.CompilePOSIX(onlyHosts) - if err != nil { - fmt.Fprintln(os.Stderr, err) + if expr, err = regexp.CompilePOSIX(onlyHosts); err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } @@ -276,7 +297,7 @@ func main() { } } if len(hosts) == 0 { - fmt.Fprintln(os.Stderr, fmt.Errorf("no hosts match --only '%v' regexp", onlyHosts)) + _, _ = fmt.Fprintln(os.Stderr, fmt.Errorf("no hosts match --only '%v' regexp", onlyHosts)) os.Exit(1) } network.Hosts = hosts @@ -284,9 +305,8 @@ func main() { // --except flag filters out hosts if exceptHosts != "" { - expr, err := regexp.CompilePOSIX(exceptHosts) - if err != nil { - fmt.Fprintln(os.Stderr, err) + if expr, err = regexp.CompilePOSIX(exceptHosts); err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } @@ -296,47 +316,31 @@ func main() { hosts = append(hosts, host) } } + if len(hosts) == 0 { - fmt.Fprintln(os.Stderr, fmt.Errorf("no hosts left after --except '%v' regexp", onlyHosts)) + _, _ = fmt.Fprintln(os.Stderr, fmt.Errorf("no hosts left after --except '%v' regexp", onlyHosts)) os.Exit(1) } network.Hosts = hosts } // --sshconfig flag location for ssh_config file - if sshConfig != "" { - confHosts, err := sshconfig.ParseSSHConfig(resolvePath(sshConfig)) - if err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) - } - - // flatten Host -> *SSHHost, not the prettiest - // but will do - confMap := map[string]*sshconfig.SSHHost{} - for _, conf := range confHosts { - for _, host := range conf.Host { - confMap[host] = conf - } - } + if sshConfig == "" { + sshConfig = filepath.Join(os.Getenv("HOME"), ".ssh", "config") + } - // check network.Hosts for match - for _, host := range network.Hosts { - conf, found := confMap[host] - if found { - network.User = conf.User - network.IdentityFile = resolvePath(conf.IdentityFile) - network.Hosts = []string{fmt.Sprintf("%s:%d", conf.HostName, conf.Port)} - } - } + var sshConfigHosts []*sshconfig.SSHHost + if sshConfigHosts, err = sshconfig.Parse(resolvePath(sshConfig)); err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) + os.Exit(1) } var vars sup.EnvList for _, val := range append(conf.Env, network.Env...) { vars.Set(val.Key, val.Value) } - if err := vars.ResolveValues(); err != nil { - fmt.Fprintln(os.Stderr, err) + if err = vars.ResolveValues(); err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } @@ -368,16 +372,16 @@ func main() { // Create new Stackup app. app, err := sup.New(conf) if err != nil { - fmt.Fprintln(os.Stderr, err) + _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } app.Debug(debug) app.Prefix(!disablePrefix) // Run all the commands in the given network. - err = app.Run(network, vars, commands...) + err = app.Run(sshConfigHosts, network, vars, commands...) if err != nil { - fmt.Fprintln(os.Stderr, err) + _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } } diff --git a/dist/brew/sup.rb b/dist/brew/sup.rb index ad3edc0..41280a3 100644 --- a/dist/brew/sup.rb +++ b/dist/brew/sup.rb @@ -2,10 +2,10 @@ class Sup < Formula desc "Stack Up. Super simple deployment tool - think of it like 'make' for a network of servers." - homepage "https://github.com/pressly/sup" - url "https://github.com/pressly/sup/archive/4ee5083c8321340bc2a6410f24d8a760f7ad3847.zip" - version "0.3.1" - sha256 "7fa17c20fdcd9e24d8c2fe98081e1300e936da02b3f2cf9c5a11fd699cbc487e" + homepage "https://github.com/NovikovRoman/sup" + url "https://github.com/NovikovRoman/sup/archive/refs/tags/v0.5.3.zip" + version "0.5.3" + sha256 "6e7922eb5371eec6a2d089811829f13049a67cfd485e2153f1a5a8e54702ff57" depends_on "go" => :build @@ -14,14 +14,14 @@ def install ENV["GOPATH"] = buildpath ENV["GOHOME"] = buildpath - mkdir_p buildpath/"src/github.com/pressly/" - ln_sf buildpath, buildpath/"src/github.com/pressly/sup" + mkdir_p buildpath/"src/github.com/NovikovRoman/" + ln_sf buildpath, buildpath/"src/github.com/NovikovRoman/sup" Language::Go.stage_deps resources, buildpath/"src" system "go", "build", "-o", bin/"sup", "./cmd/sup" end test do - assert_equal "0.3", shell_output("#{bin}/bin/sup") + assert_equal "0.5", shell_output("#{bin}/bin/sup") end end diff --git a/example/Supfile b/example/Supfile index 5140496..0c04258 100644 --- a/example/Supfile +++ b/example/Supfile @@ -5,7 +5,7 @@ version: 0.4 env: # Environment variables for all commands NAME: example - REPO: github.com/pressly/sup + REPO: github.com/NovikovRoman/sup BRANCH: master IMAGE: pressly/example HOST_PORT: 8000 diff --git a/go.mod b/go.mod index e1dce7e..118bcab 100644 --- a/go.mod +++ b/go.mod @@ -1,14 +1,19 @@ -module github.com/pressly/sup +module github.com/NovikovRoman/sup -go 1.13 +go 1.17 require ( github.com/goware/prefixer v0.0.0-20160118172347-395022866408 - github.com/kr/pretty v0.2.0 // indirect - github.com/mikkeloscar/sshconfig v0.0.0-20190102082740-ec0822bcc4f4 + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 + github.com/kr/pretty v0.3.0 // indirect + github.com/mikkeloscar/sshconfig v0.1.1 + github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/pkg/errors v0.9.1 - golang.org/x/crypto v0.0.0-20200208060501-ecb85df21340 - golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5 // indirect - gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect - gopkg.in/yaml.v2 v2.2.8 + github.com/rogpeppe/go-internal v1.8.1 // indirect + golang.org/x/crypto v0.0.0-20220214200702-86341886e292 + golang.org/x/sys v0.0.0-20220209214540-3681064d5158 // indirect + golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + gopkg.in/yaml.v2 v2.4.0 ) diff --git a/go.sum b/go.sum index eb3a23c..9e59af6 100644 --- a/go.sum +++ b/go.sum @@ -1,34 +1,46 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/goware/prefixer v0.0.0-20160118172347-395022866408 h1:Y9iQJfEqnN3/Nce9cOegemcy/9Ai5k3huT6E80F3zaw= github.com/goware/prefixer v0.0.0-20160118172347-395022866408/go.mod h1:PE1ycukgRPJ7bJ9a1fdfQ9j8i/cEcRAoLZzbxYpNB/s= -github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs= -github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/mikkeloscar/sshconfig v0.0.0-20161223095632-fc5e37b16b68 h1:Z1BVWGqEm0aveMz9ffiFnJthFjM5+YFdFqFklQ/hPBI= -github.com/mikkeloscar/sshconfig v0.0.0-20161223095632-fc5e37b16b68/go.mod h1:GvQCIGDpivPr+e8cuBt3c4+NTOJm66zpBrMjkit8jmw= -github.com/mikkeloscar/sshconfig v0.0.0-20190102082740-ec0822bcc4f4 h1:6mjPKnEtYKqYTqIXAraugfl5bkaW+A6wJAupYKAWMXM= -github.com/mikkeloscar/sshconfig v0.0.0-20190102082740-ec0822bcc4f4/go.mod h1:GvQCIGDpivPr+e8cuBt3c4+NTOJm66zpBrMjkit8jmw= -github.com/pkg/errors v0.7.1-0.20160627222352-a2d6902c6d2a h1:dKpZ0nc8i7prliB4AIfJulQxsX7whlVwi6j5HqaYUl4= -github.com/pkg/errors v0.7.1-0.20160627222352-a2d6902c6d2a/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mikkeloscar/sshconfig v0.1.1 h1:WJLz/y4M0jMkYHDJkydcbOb/S8UAJ1denM9fCpwKV5c= +github.com/mikkeloscar/sshconfig v0.1.1/go.mod h1:NavXZq+n9+iOgFT6fOobpl6nFBltLYOIjejTwNQTK7A= +github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -golang.org/x/crypto v0.0.0-20160804082612-7a1054f3ac58 h1:ytej7jB0ejb21kF+TjEWykw7n4sG85mxyjgYHgF/7ZQ= -golang.org/x/crypto v0.0.0-20160804082612-7a1054f3ac58/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200208060501-ecb85df21340 h1:KOcEaR10tFr7gdJV2GCKw8Os5yED1u1aOqHjOAb6d2Y= -golang.org/x/crypto v0.0.0-20200208060501-ecb85df21340/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5 h1:LfCXLvNmTYH9kEmVgqbnsWfruoXZIrh4YBgqVHtDvw0= -golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg= +github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220209214540-3681064d5158 h1:rm+CHSpPEEW2IsXUib1ThaHIjuBVZjxNgSKmBLFfD4c= +golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.0.0-20160301204022-a83829b6f129 h1:RBgb9aPUbZ9nu66ecQNIBNsA7j3mB5h8PNDIfhPjaJg= -gopkg.in/yaml.v2 v2.0.0-20160301204022-a83829b6f129/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= -gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/localhost.go b/localhost.go index ebdc495..5b9cd20 100644 --- a/localhost.go +++ b/localhost.go @@ -2,15 +2,14 @@ package sup import ( "fmt" + "github.com/pkg/errors" "io" "os" "os/exec" "os/user" - - "github.com/pkg/errors" ) -// Client is a wrapper over the SSH connection/sessions. +// LocalhostClient is a wrapper over the SSH connection/sessions. type LocalhostClient struct { cmd *exec.Cmd user string @@ -21,56 +20,51 @@ type LocalhostClient struct { env string //export FOO="bar"; export BAR="baz"; } -func (c *LocalhostClient) Connect(_ string) error { - u, err := user.Current() - if err != nil { - return err +func (c *LocalhostClient) Connect() (err error) { + var u *user.User + if u, err = user.Current(); err != nil { + return } c.user = u.Username - return nil + return } -func (c *LocalhostClient) Run(task *Task) error { - var err error - +func (c *LocalhostClient) Run(task *Task) (err error) { if c.running { - return fmt.Errorf("Command already running") + return fmt.Errorf("Command already running. ") } cmd := exec.Command("bash", "-c", c.env+task.Run) c.cmd = cmd - c.stdout, err = cmd.StdoutPipe() - if err != nil { - return err + if c.stdout, err = cmd.StdoutPipe(); err != nil { + return } - c.stderr, err = cmd.StderrPipe() - if err != nil { - return err + if c.stderr, err = cmd.StderrPipe(); err != nil { + return } - c.stdin, err = cmd.StdinPipe() - if err != nil { - return err + if c.stdin, err = cmd.StdinPipe(); err != nil { + return } - if err := c.cmd.Start(); err != nil { + if err = c.cmd.Start(); err != nil { return ErrTask{task, err.Error()} } c.running = true - return nil + return } -func (c *LocalhostClient) Wait() error { +func (c *LocalhostClient) Wait() (err error) { if !c.running { - return fmt.Errorf("Trying to wait on stopped command") + return fmt.Errorf("Trying to wait on stopped command. ") } - err := c.cmd.Wait() + err = c.cmd.Wait() c.running = false - return err + return } func (c *LocalhostClient) Close() error { diff --git a/ssh.go b/ssh.go index 8644fff..d535e4b 100644 --- a/ssh.go +++ b/ssh.go @@ -2,6 +2,11 @@ package sup import ( "fmt" + "github.com/hashicorp/go-multierror" + "github.com/mikkeloscar/sshconfig" + "github.com/pkg/errors" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" "io" "io/ioutil" "net" @@ -10,12 +15,9 @@ import ( "path/filepath" "strings" "sync" - - "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/agent" ) -// Client is a wrapper over the SSH connection/sessions. +// SSHClient is a wrapper over the SSH connection/sessions. type SSHClient struct { conn *ssh.Client sess *ssh.Session @@ -29,136 +31,128 @@ type SSHClient struct { running bool env string //export FOO="bar"; export BAR="baz"; color string + signer *ssh.Signer } -type ErrConnect struct { - User string - Host string - Reason string -} - -func (e ErrConnect) Error() string { - return fmt.Sprintf(`Connect("%v@%v"): %v`, e.User, e.Host, e.Reason) -} - -// parseHost parses and normalizes @ from a given string. -func (c *SSHClient) parseHost(host string) error { - c.host = host - - // Remove extra "ssh://" schema - if len(c.host) > 6 && c.host[:6] == "ssh://" { - c.host = c.host[6:] - } - - // Split by the last "@", since there may be an "@" in the username. - if at := strings.LastIndex(c.host, "@"); at != -1 { - c.user = c.host[:at] - c.host = c.host[at+1:] +func NewSSHClient(host string, env string, i int, sshConfigHosts []*sshconfig.SSHHost) (c *SSHClient, err error) { + c = &SSHClient{ + host: host, + color: Colors[i%len(Colors)], + signer: nil, } - // Add default user, if not set - if c.user == "" { - u, err := user.Current() - if err != nil { - return err + for _, sshHost := range sshConfigHosts { + for _, h := range sshHost.Host { + if host != h { + continue + } + + c.host = sshHost.HostName + c.user = sshHost.User + if sshHost.Port > 0 { + c.host = fmt.Sprintf("%s:%d", c.host, sshHost.Port) + } + + if sshHost.IdentityFile != "" { + if strings.HasPrefix(sshHost.IdentityFile, "~") { + sshHost.IdentityFile = strings.Replace(sshHost.IdentityFile, "~", os.Getenv("HOME"), 1) + } + + if c.signer, err = c.getPrivateKey(sshHost.IdentityFile); err != nil { + err = errors.Wrap(err, "get private key") + return + } + } + + c.env = env + `export SUP_HOST="` + c.host + `";` + return } - c.user = u.Username - } - - if strings.Index(c.host, "/") != -1 { - return ErrConnect{c.user, c.host, "unexpected slash in the host URL"} } - // Add default port, if not set - if strings.Index(c.host, ":") == -1 { - c.host += ":22" - } - - return nil + c.env = env + `export SUP_HOST="` + host + `";` + err = c.parseHost(host) + return } -var initAuthMethodOnce sync.Once -var authMethod ssh.AuthMethod - -// initAuthMethod initiates SSH authentication method. -func initAuthMethod() { - var signers []ssh.Signer - - // If there's a running SSH Agent, try to use its Private keys. - sock, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) - if err == nil { - agent := agent.NewClient(sock) - signers, _ = agent.Signers() - } - - // Try to read user's SSH private keys form the standard paths. - files, _ := filepath.Glob(os.Getenv("HOME") + "/.ssh/id_*") - for _, file := range files { - if strings.HasSuffix(file, ".pub") { - continue // Skip public keys. - } - data, err := ioutil.ReadFile(file) - if err != nil { - continue - } - signer, err := ssh.ParsePrivateKey(data) - if err != nil { - continue - } - signers = append(signers, signer) +type ErrConnect struct { + User string + Host string + Reason string +} - } - authMethod = ssh.PublicKeys(signers...) +func (e ErrConnect) Error() string { + return fmt.Sprintf(`Connect("%v@%v"): %v`, e.User, e.Host, e.Reason) } -// SSHDialFunc can dial an ssh server and return a client +// SSHDialFunc can dial a ssh server and return a client type SSHDialFunc func(net, addr string, config *ssh.ClientConfig) (*ssh.Client, error) // Connect creates SSH connection to a specified host. // It expects the host of the form "[ssh://]host[:port]". -func (c *SSHClient) Connect(host string) error { - return c.ConnectWith(host, ssh.Dial) +func (c *SSHClient) Connect() error { + return c.ConnectWith(ssh.Dial) } // ConnectWith creates a SSH connection to a specified host. It will use dialer to establish the // connection. // TODO: Split Signers to its own method. -func (c *SSHClient) ConnectWith(host string, dialer SSHDialFunc) error { +func (c *SSHClient) ConnectWith(dialer SSHDialFunc) (err error) { if c.connOpened { - return fmt.Errorf("Already connected") + return errors.New("Already connected") } initAuthMethodOnce.Do(initAuthMethod) - err := c.parseHost(host) - if err != nil { - return err + var auth []ssh.AuthMethod + if c.signer == nil { + auth = []ssh.AuthMethod{ssh.PublicKeys(signers...)} + + } else { + auth = []ssh.AuthMethod{ + ssh.PublicKeys(*c.signer), + } } config := &ssh.ClientConfig{ - User: c.user, - Auth: []ssh.AuthMethod{ - authMethod, - }, + User: c.user, + Auth: auth, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } - c.conn, err = dialer("tcp", c.host, config) - if err != nil { + if c.conn, err = dialer("tcp", c.host, config); err != nil { return ErrConnect{c.user, c.host, err.Error()} } + c.connOpened = true + return +} - return nil +func (c *SSHClient) getPrivateKey(file string) (*ssh.Signer, error) { + var ( + data []byte + signer ssh.Signer + err error + ) + + if strings.HasSuffix(file, ".pub") { + return nil, err + } + + if data, err = ioutil.ReadFile(file); err != nil { + return nil, err + } + + signer, err = ssh.ParsePrivateKey(data) + return &signer, err } // Run runs the task.Run command remotely on c.host. func (c *SSHClient) Run(task *Task) error { if c.running { - return fmt.Errorf("Session already running") + return errors.New("Session already running") } if c.sessOpened { - return fmt.Errorf("Session already connected") + return errors.New("Session already connected") } sess, err := c.conn.NewSession() @@ -185,17 +179,17 @@ func (c *SSHClient) Run(task *Task) error { // Set up terminal modes modes := ssh.TerminalModes{ ssh.ECHO: 0, // disable echoing - ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud - ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud + ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4k baud + ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4k baud } // Request pseudo terminal - if err := sess.RequestPty("xterm", 80, 40, modes); err != nil { + if err = sess.RequestPty("xterm", 80, 40, modes); err != nil { return ErrTask{task, fmt.Sprintf("request for pseudo terminal failed: %s", err)} } } // Start the remote command. - if err := sess.Start(c.env + task.Run); err != nil { + if err = sess.Start(c.env + task.Run); err != nil { return ErrTask{task, err.Error()} } @@ -207,47 +201,59 @@ func (c *SSHClient) Run(task *Task) error { // Wait waits until the remote command finishes and exits. // It closes the SSH session. -func (c *SSHClient) Wait() error { +func (c *SSHClient) Wait() (err error) { if !c.running { - return fmt.Errorf("Trying to wait on stopped session") + return errors.New("Trying to wait on stopped session") } - err := c.sess.Wait() - c.sess.Close() + err = c.sess.Wait() c.running = false c.sessOpened = false - return err + if e := c.sess.Close(); e != nil && e != io.EOF { + err = multierror.Append(err, e) + } + return } // DialThrough will create a new connection from the ssh server sc is connected to. DialThrough is an SSHDialer. -func (sc *SSHClient) DialThrough(net, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { - conn, err := sc.conn.Dial(net, addr) - if err != nil { - return nil, err +func (c *SSHClient) DialThrough(n, addr string, config *ssh.ClientConfig) (sc *ssh.Client, err error) { + var ( + cc ssh.Conn + nChan <-chan ssh.NewChannel + reqs <-chan *ssh.Request + conn net.Conn + ) + + if conn, err = c.conn.Dial(n, addr); err != nil { + return } - c, chans, reqs, err := ssh.NewClientConn(conn, addr, config) - if err != nil { - return nil, err + + if cc, nChan, reqs, err = ssh.NewClientConn(conn, addr, config); err != nil { + return } - return ssh.NewClient(c, chans, reqs), nil + + sc = ssh.NewClient(cc, nChan, reqs) + return } // Close closes the underlying SSH connection and session. -func (c *SSHClient) Close() error { +func (c *SSHClient) Close() (err error) { if c.sessOpened { - c.sess.Close() c.sessOpened = false + if err = c.sess.Close(); err != nil { + return + } } + if !c.connOpened { - return fmt.Errorf("Trying to close the already closed connection") + return errors.New("Trying to close the already closed connection") } - err := c.conn.Close() + err = c.conn.Close() c.connOpened = false c.running = false - return err } @@ -278,7 +284,7 @@ func (c *SSHClient) WriteClose() error { func (c *SSHClient) Signal(sig os.Signal) error { if !c.sessOpened { - return fmt.Errorf("session is not open") + return errors.New("session is not open") } switch sig { @@ -288,9 +294,84 @@ func (c *SSHClient) Signal(sig os.Signal) error { // which sounds like something that should be fixed/resolved // upstream in the golang.org/x/crypto/ssh pkg. // https://github.com/golang/go/issues/4115#issuecomment-66070418 - c.remoteStdin.Write([]byte("\x03")) + _, _ = c.remoteStdin.Write([]byte("\x03")) return c.sess.Signal(ssh.SIGINT) + default: return fmt.Errorf("%v not supported", sig) } } + +func (c *SSHClient) parseHost(host string) (err error) { + var ( + u *user.User + ) + + // Remove extra "ssh://" schema + if len(host) > 6 && host[:6] == "ssh://" { + host = host[6:] + } + + // Split by the last "@", since there may be an "@" in the username. + if at := strings.LastIndex(host, "@"); at != -1 { + c.user = host[:at] + host = host[at+1:] + } + + // Add default user, if not set + if c.user == "" { + if u, err = user.Current(); err != nil { + return + } + c.user = u.Username + } + + if strings.Index(host, "/") != -1 { + err = ErrConnect{User: c.user, Host: host, Reason: "unexpected slash in the host URL"} + return + } + + // Add default port, if not set + if at := strings.LastIndex(host, ":"); at != -1 { + c.host += ":22" + } + + return +} + +var ( + initAuthMethodOnce sync.Once + signers []ssh.Signer +) + +// initAuthMethod initiates SSH authentication method. +func initAuthMethod() { + var ( + data []byte + signer ssh.Signer + ) + + // If there's a running SSH Agent, try to use its Private keys. + sock, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) + if err == nil { + agentClient := agent.NewClient(sock) + signers, _ = agentClient.Signers() + } + + // Try to read user's SSH private keys form the standard paths. + files, _ := filepath.Glob(os.Getenv("HOME") + "/.ssh/id_*") + for _, file := range files { + if strings.HasSuffix(file, ".pub") { + continue // Skip public keys. + } + data, err = ioutil.ReadFile(file) + if err != nil { + continue + } + signer, err = ssh.ParsePrivateKey(data) + if err != nil { + continue + } + signers = append(signers, signer) + } +} diff --git a/sup.go b/sup.go index d815068..5e1528a 100644 --- a/sup.go +++ b/sup.go @@ -1,16 +1,11 @@ package sup import ( - "fmt" - "io" - "os" - "os/signal" - "strings" + "github.com/hashicorp/go-multierror" + "github.com/mikkeloscar/sshconfig" "sync" - "github.com/goware/prefixer" "github.com/pkg/errors" - "golang.org/x/crypto/ssh" ) const VERSION = "0.5" @@ -28,9 +23,7 @@ func New(conf *Supfile) (*Stackup, error) { } // Run runs set of commands on multiple hosts defined by network sequentially. -// TODO: This megamoth method needs a big refactor and should be split -// to multiple smaller methods. -func (sup *Stackup) Run(network *Network, envVars EnvList, commands ...*Command) error { +func (sup *Stackup) Run(sshConfigHosts []*sshconfig.SSHHost, network *Network, envVars EnvList, commands ...*Command) (err error) { if len(commands) == 0 { return errors.New("no commands to be run") } @@ -40,204 +33,78 @@ func (sup *Stackup) Run(network *Network, envVars EnvList, commands ...*Command) // Create clients for every host (either SSH or Localhost). var bastion *SSHClient if network.Bastion != "" { - bastion = &SSHClient{} - if err := bastion.Connect(network.Bastion); err != nil { + if bastion, err = NewSSHClient(network.Bastion, "bastion", 0, sshConfigHosts); err != nil { + return errors.Wrap(err, "create bastion") + } + + if err = bastion.Connect(); err != nil { return errors.Wrap(err, "connecting to bastion failed") } } - var wg sync.WaitGroup + wg := &sync.WaitGroup{} clientCh := make(chan Client, len(network.Hosts)) errCh := make(chan error, len(network.Hosts)) - for i, host := range network.Hosts { - wg.Add(1) - go func(i int, host string) { - defer wg.Done() - - // Localhost client. - if host == "localhost" { - local := &LocalhostClient{ - env: env + `export SUP_HOST="` + host + `";`, - } - if err := local.Connect(host); err != nil { - errCh <- errors.Wrap(err, "connecting to localhost failed") - return - } - clientCh <- local - return - } - - // SSH client. - remote := &SSHClient{ - env: env + `export SUP_HOST="` + host + `";`, - user: network.User, - color: Colors[i%len(Colors)], - } - - if bastion != nil { - if err := remote.ConnectWith(host, bastion.DialThrough); err != nil { - errCh <- errors.Wrap(err, "connecting to remote host through bastion failed") - return - } - } else { - if err := remote.Connect(host); err != nil { - errCh <- errors.Wrap(err, "connecting to remote host failed") - return - } - } - clientCh <- remote - }(i, host) + i := 0 + wg.Add(len(network.Hosts)) + for _, host := range network.Hosts { + i++ + go sup.networkHost(wg, clientCh, errCh, bastion, host, env, i, sshConfigHosts) } + wg.Wait() close(clientCh) close(errCh) maxLen := 0 - var clients []Client + var ( + clients []Client + deferRemoteClose []*SSHClient + ) + deferRemoteClose = []*SSHClient{} + for client := range clientCh { if remote, ok := client.(*SSHClient); ok { - defer remote.Close() + deferRemoteClose = append(deferRemoteClose, remote) } + _, prefixLen := client.Prefix() if prefixLen > maxLen { maxLen = prefixLen } clients = append(clients, client) } - for err := range errCh { + + defer func(deferRemoteClose []*SSHClient) { + for _, r := range deferRemoteClose { + if derr := r.Close(); derr != nil { + err = multierror.Append(err, derr) + } + } + }(deferRemoteClose) + + for err = range errCh { return errors.Wrap(err, "connecting to clients failed") } // Run command or run multiple commands defined by target sequentially. for _, cmd := range commands { + var tasks []*Task // Translate command into task(s). - tasks, err := sup.createTasks(cmd, clients, env) - if err != nil { + if tasks, err = sup.createTasks(cmd, clients, env); err != nil { return errors.Wrap(err, "creating task failed") } // Run tasks sequentially. for _, task := range tasks { - var writers []io.Writer - var wg sync.WaitGroup - - // Run tasks on the provided clients. - for _, c := range task.Clients { - var prefix string - var prefixLen int - if sup.prefix { - prefix, prefixLen = c.Prefix() - if len(prefix) < maxLen { // Left padding. - prefix = strings.Repeat(" ", maxLen-prefixLen) + prefix - } - } - - err := c.Run(task) - if err != nil { - return errors.Wrap(err, prefix+"task failed") - } - - // Copy over tasks's STDOUT. - wg.Add(1) - go func(c Client) { - defer wg.Done() - _, err := io.Copy(os.Stdout, prefixer.New(c.Stdout(), prefix)) - if err != nil && err != io.EOF { - // TODO: io.Copy() should not return io.EOF at all. - // Upstream bug? Or prefixer.WriteTo() bug? - fmt.Fprintf(os.Stderr, "%v", errors.Wrap(err, prefix+"reading STDOUT failed")) - } - }(c) - - // Copy over tasks's STDERR. - wg.Add(1) - go func(c Client) { - defer wg.Done() - _, err := io.Copy(os.Stderr, prefixer.New(c.Stderr(), prefix)) - if err != nil && err != io.EOF { - fmt.Fprintf(os.Stderr, "%v", errors.Wrap(err, prefix+"reading STDERR failed")) - } - }(c) - - writers = append(writers, c.Stdin()) - } - - // Copy over task's STDIN. - if task.Input != nil { - go func() { - writer := io.MultiWriter(writers...) - _, err := io.Copy(writer, task.Input) - if err != nil && err != io.EOF { - fmt.Fprintf(os.Stderr, "%v", errors.Wrap(err, "copying STDIN failed")) - } - // TODO: Use MultiWriteCloser (not in Stdlib), so we can writer.Close() instead? - for _, c := range clients { - c.WriteClose() - } - }() - } - - // Catch OS signals and pass them to all active clients. - trap := make(chan os.Signal, 1) - signal.Notify(trap, os.Interrupt) - go func() { - for { - select { - case sig, ok := <-trap: - if !ok { - return - } - for _, c := range task.Clients { - err := c.Signal(sig) - if err != nil { - fmt.Fprintf(os.Stderr, "%v", errors.Wrap(err, "sending signal failed")) - } - } - } - } - }() - - // Wait for all I/O operations first. - wg.Wait() - - // Make sure each client finishes the task, return on failure. - for _, c := range task.Clients { - wg.Add(1) - go func(c Client) { - defer wg.Done() - if err := c.Wait(); err != nil { - var prefix string - if sup.prefix { - var prefixLen int - prefix, prefixLen = c.Prefix() - if len(prefix) < maxLen { // Left padding. - prefix = strings.Repeat(" ", maxLen-prefixLen) + prefix - } - } - if e, ok := err.(*ssh.ExitError); ok && e.ExitStatus() != 15 { - // TODO: Store all the errors, and print them after Wait(). - fmt.Fprintf(os.Stderr, "%s%v\n", prefix, e) - os.Exit(e.ExitStatus()) - } - fmt.Fprintf(os.Stderr, "%s%v\n", prefix, err) - - // TODO: Shouldn't os.Exit(1) here. Instead, collect the exit statuses for later. - os.Exit(1) - } - }(c) + if err = task.do(sup.prefix, maxLen); err != nil { + return } - - // Wait for all commands to finish. - wg.Wait() - - // Stop catching signals for the currently active clients. - signal.Stop(trap) - close(trap) } } - return nil + return } func (sup *Stackup) Debug(value bool) { @@ -247,3 +114,46 @@ func (sup *Stackup) Debug(value bool) { func (sup *Stackup) Prefix(value bool) { sup.prefix = value } + +func (sup *Stackup) networkHost(wg *sync.WaitGroup, clientCh chan Client, errCh chan error, + bastion *SSHClient, host string, env string, i int, sshConfigHosts []*sshconfig.SSHHost) { + defer wg.Done() + + // Localhost client. + if host == "localhost" { + local := &LocalhostClient{ + env: env + `export SUP_HOST="` + host + `";`, + } + if err := local.Connect(); err != nil { + errCh <- errors.Wrap(err, "connecting to localhost failed") + return + } + clientCh <- local + return + } + + // SSH client. + var ( + remote *SSHClient + err error + ) + if remote, err = NewSSHClient(host, env, i, sshConfigHosts); err != nil { + errCh <- errors.Wrap(err, "create new ssh client") + return + } + + if bastion != nil { + if err = remote.ConnectWith(bastion.DialThrough); err != nil { + errCh <- errors.Wrap(err, "connecting to remote host through bastion failed") + return + } + + } else { + if err = remote.Connect(); err != nil { + errCh <- errors.Wrap(err, "connecting to remote host failed") + return + } + } + + clientCh <- remote +} diff --git a/supfile.go b/supfile.go index 2cf88b5..b219236 100644 --- a/supfile.go +++ b/supfile.go @@ -3,14 +3,12 @@ package sup import ( "bytes" "fmt" + "github.com/pkg/errors" + "gopkg.in/yaml.v2" "io" "os" "os/exec" "strings" - - "github.com/pkg/errors" - - "gopkg.in/yaml.v2" ) // Supfile represents the Stack Up configuration YAML file. @@ -22,16 +20,27 @@ type Supfile struct { Version string `yaml:"version"` } +/*func (s *Supfile) getPrivateKey(file string) (signer ssh.Signer, err error) { + var data []byte + + if strings.HasSuffix(file, ".pub") { + return // Skip public keys. + } + + if data, err = ioutil.ReadFile(file); err != nil { + return + } + + signer, err = ssh.ParsePrivateKey(data) + return +}*/ + // Network is group of hosts with extra custom env vars. type Network struct { Env EnvList `yaml:"env"` Inventory string `yaml:"inventory"` Hosts []string `yaml:"hosts"` - Bastion string `yaml:"bastion"` // Jump host for the environment - - // Should these live on Hosts too? We'd have to change []string to struct, even in Supfile. - User string // `yaml:"user"` - IdentityFile string // `yaml:"identity_file"` + Bastion string `yaml:"bastion"` // Jump to host for the environment } // Networks is a list of user-defined networks @@ -87,16 +96,14 @@ type Commands struct { cmds map[string]Command } -func (c *Commands) UnmarshalYAML(unmarshal func(interface{}) error) error { - err := unmarshal(&c.cmds) - if err != nil { - return err +func (c *Commands) UnmarshalYAML(unmarshal func(interface{}) error) (err error) { + if err = unmarshal(&c.cmds); err != nil { + return } var items yaml.MapSlice - err = unmarshal(&items) - if err != nil { - return err + if err = unmarshal(&items); err != nil { + return } c.Names = make([]string, len(items)) @@ -104,7 +111,7 @@ func (c *Commands) UnmarshalYAML(unmarshal func(interface{}) error) error { c.Names[i] = item.Key.(string) } - return nil + return } func (c *Commands) Get(name string) (Command, bool) { @@ -118,16 +125,14 @@ type Targets struct { targets map[string][]string } -func (t *Targets) UnmarshalYAML(unmarshal func(interface{}) error) error { - err := unmarshal(&t.targets) - if err != nil { +func (t *Targets) UnmarshalYAML(unmarshal func(interface{}) error) (err error) { + if err = unmarshal(&t.targets); err != nil { return err } var items yaml.MapSlice - err = unmarshal(&items) - if err != nil { - return err + if err = unmarshal(&items); err != nil { + return } t.Names = make([]string, len(items)) @@ -135,7 +140,7 @@ func (t *Targets) UnmarshalYAML(unmarshal func(interface{}) error) error { t.Names[i] = item.Key.(string) } - return nil + return } func (t *Targets) Get(name string) ([]string, bool) { @@ -179,7 +184,7 @@ func (e EnvList) Slice() []string { } func (e *EnvList) UnmarshalYAML(unmarshal func(interface{}) error) error { - items := []yaml.MapItem{} + var items []yaml.MapItem err := unmarshal(&items) if err != nil { @@ -255,7 +260,7 @@ type ErrUnsupportedSupfileVersion struct { } func (e ErrMustUpdate) Error() string { - return fmt.Sprintf("%v\n\nPlease update sup by `go get -u github.com/pressly/sup/cmd/sup`", e.Msg) + return fmt.Sprintf("%v\n\nPlease update sup by `go get -u github.com/NovikovRoman/sup/cmd/sup`", e.Msg) } func (e ErrUnsupportedSupfileVersion) Error() string { @@ -313,7 +318,7 @@ func NewSupfile(data []byte) (*Supfile, error) { } } if warning != "" { - fmt.Fprintf(os.Stderr, warning) + _, _ = fmt.Fprintf(os.Stderr, warning) } fallthrough @@ -329,29 +334,32 @@ func NewSupfile(data []byte) (*Supfile, error) { // ParseInventory runs the inventory command, if provided, and appends // the command's output lines to the manually defined list of hosts. -func (n Network) ParseInventory() ([]string, error) { +func (n Network) ParseInventory() (hosts []string, err error) { + var ( + host string + output []byte + ) if n.Inventory == "" { - return nil, nil + return } cmd := exec.Command("/bin/sh", "-c", n.Inventory) cmd.Env = os.Environ() cmd.Env = append(cmd.Env, n.Env.Slice()...) cmd.Stderr = os.Stderr - output, err := cmd.Output() - if err != nil { - return nil, err + + if output, err = cmd.Output(); err != nil { + return } - var hosts []string buf := bytes.NewBuffer(output) for { - host, err := buf.ReadString('\n') - if err != nil { + if host, err = buf.ReadString('\n'); err != nil { if err == io.EOF { + err = nil break } - return nil, err + return } host = strings.TrimSpace(host) @@ -362,5 +370,6 @@ func (n Network) ParseInventory() ([]string, error) { hosts = append(hosts, host) } - return hosts, nil + + return } diff --git a/tar.go b/tar.go index 10582f5..6310e05 100644 --- a/tar.go +++ b/tar.go @@ -20,12 +20,12 @@ func RemoteTarCommand(dir string) string { } func LocalTarCmdArgs(path, exclude string) []string { - args := []string{} + var args []string // Added pattens to exclude from tar compress excludes := strings.Split(exclude, ",") - for _, exclude := range excludes { - trimmed := strings.TrimSpace(exclude) + for _, exc := range excludes { + trimmed := strings.TrimSpace(exc) if trimmed != "" { args = append(args, `--exclude=`+trimmed) } @@ -37,17 +37,16 @@ func LocalTarCmdArgs(path, exclude string) []string { // NewTarStreamReader creates a tar stream reader from a local path. // TODO: Refactor. Use "archive/tar" instead. -func NewTarStreamReader(cwd, path, exclude string) (io.Reader, error) { +func NewTarStreamReader(cwd, path, exclude string) (stdout io.Reader, err error) { cmd := exec.Command("tar", LocalTarCmdArgs(path, exclude)...) cmd.Dir = cwd - stdout, err := cmd.StdoutPipe() - if err != nil { - return nil, errors.Wrap(err, "tar: stdout pipe failed") - } - if err := cmd.Start(); err != nil { - return nil, errors.Wrap(err, "tar: starting cmd failed") + if stdout, err = cmd.StdoutPipe(); err != nil { + err = errors.Wrap(err, "tar: stdout pipe failed") + + } else if err = cmd.Start(); err != nil { + err = errors.Wrap(err, "tar: starting cmd failed") } - return stdout, nil + return } diff --git a/task.go b/task.go index eebc3c7..044bb07 100644 --- a/task.go +++ b/task.go @@ -2,11 +2,18 @@ package sup import ( "fmt" + "github.com/goware/prefixer" + "github.com/hashicorp/go-multierror" + "github.com/pkg/errors" + "golang.org/x/crypto/ssh" "io" "io/ioutil" + "log" "os" - - "github.com/pkg/errors" + "os/signal" + "strconv" + "strings" + "sync" ) // Task represents a set of commands to be run. @@ -17,23 +24,31 @@ type Task struct { TTY bool } -func (sup *Stackup) createTasks(cmd *Command, clients []Client, env string) ([]*Task, error) { - var tasks []*Task +func (sup *Stackup) createTasks(cmd *Command, clients []Client, env string) (tasks []*Task, err error) { + var ( + cwd string + uploadFile string + uploadTarReader io.Reader + f *os.File + data []byte + ) - cwd, err := os.Getwd() - if err != nil { - return nil, errors.Wrap(err, "resolving CWD failed") + if cwd, err = os.Getwd(); err != nil { + err = errors.Wrap(err, "resolving CWD failed") + return } // Anything to upload? + tasks = []*Task{} for _, upload := range cmd.Upload { - uploadFile, err := ResolveLocalPath(cwd, upload.Src, env) - if err != nil { - return nil, errors.Wrap(err, "upload: "+upload.Src) + if uploadFile, err = ResolveLocalPath(cwd, upload.Src, env); err != nil { + err = errors.Wrap(err, "upload: "+upload.Src) + return } - uploadTarReader, err := NewTarStreamReader(cwd, uploadFile, upload.Exc) - if err != nil { - return nil, errors.Wrap(err, "upload: "+upload.Src) + + if uploadTarReader, err = NewTarStreamReader(cwd, uploadFile, upload.Exc); err != nil { + err = errors.Wrap(err, "upload: "+upload.Src) + return } task := Task{ @@ -45,47 +60,58 @@ func (sup *Stackup) createTasks(cmd *Command, clients []Client, env string) ([]* if cmd.Once { task.Clients = []Client{clients[0]} tasks = append(tasks, &task) - } else if cmd.Serial > 0 { + continue + } + + if cmd.Serial > 0 { // Each "serial" task client group is executed sequentially. for i := 0; i < len(clients); i += cmd.Serial { j := i + cmd.Serial if j > len(clients) { j = len(clients) } - copy := task - copy.Clients = clients[i:j] - tasks = append(tasks, ©) + + copyTask := task + copyTask.Clients = clients[i:j] + tasks = append(tasks, ©Task) } - } else { - task.Clients = clients - tasks = append(tasks, &task) + + continue } + + task.Clients = clients + tasks = append(tasks, &task) } // Script. Read the file as a multiline input command. if cmd.Script != "" { - f, err := os.Open(cmd.Script) - if err != nil { - return nil, errors.Wrap(err, "can't open script") + if f, err = os.Open(cmd.Script); err != nil { + err = errors.Wrap(err, "can't open script") + return } - data, err := ioutil.ReadAll(f) - if err != nil { - return nil, errors.Wrap(err, "can't read script") + + if data, err = ioutil.ReadAll(f); err != nil { + err = errors.Wrap(err, "can't read script") + return } task := Task{ Run: string(data), TTY: true, } + if sup.debug { task.Run = "set -x;" + task.Run } + if cmd.Stdin { task.Input = os.Stdin } + if cmd.Once { task.Clients = []Client{clients[0]} tasks = append(tasks, &task) + } else if cmd.Serial > 0 { // Each "serial" task client group is executed sequentially. for i := 0; i < len(clients); i += cmd.Serial { @@ -93,10 +119,11 @@ func (sup *Stackup) createTasks(cmd *Command, clients []Client, env string) ([]* if j > len(clients) { j = len(clients) } - copy := task - copy.Clients = clients[i:j] - tasks = append(tasks, ©) + copyTask := task + copyTask.Clients = clients[i:j] + tasks = append(tasks, ©Task) } + } else { task.Clients = clients tasks = append(tasks, &task) @@ -108,18 +135,22 @@ func (sup *Stackup) createTasks(cmd *Command, clients []Client, env string) ([]* local := &LocalhostClient{ env: env + `export SUP_HOST="localhost";`, } - local.Connect("localhost") + + _ = local.Connect() task := &Task{ Run: cmd.Local, Clients: []Client{local}, TTY: true, } + if sup.debug { task.Run = "set -x;" + task.Run } + if cmd.Stdin { task.Input = os.Stdin } + tasks = append(tasks, task) } @@ -129,15 +160,19 @@ func (sup *Stackup) createTasks(cmd *Command, clients []Client, env string) ([]* Run: cmd.Run, TTY: true, } + if sup.debug { task.Run = "set -x;" + task.Run } + if cmd.Stdin { task.Input = os.Stdin } + if cmd.Once { task.Clients = []Client{clients[0]} tasks = append(tasks, &task) + } else if cmd.Serial > 0 { // Each "serial" task client group is executed sequentially. for i := 0; i < len(clients); i += cmd.Serial { @@ -145,17 +180,179 @@ func (sup *Stackup) createTasks(cmd *Command, clients []Client, env string) ([]* if j > len(clients) { j = len(clients) } - copy := task - copy.Clients = clients[i:j] - tasks = append(tasks, ©) + copyTask := task + copyTask.Clients = clients[i:j] + tasks = append(tasks, ©Task) } + } else { task.Clients = clients tasks = append(tasks, &task) } } - return tasks, nil + return +} +func (t *Task) formatClientPrefix(c Client, len int) string { + p, _ := c.Prefix() + return fmt.Sprintf("%"+strconv.Itoa(len)+"s", p) +} + +func (t *Task) do(onPrefix bool, maxLen int) (err error) { + var writers []io.Writer + + // Run tasks on the provided clients. + wg := &sync.WaitGroup{} + for _, c := range t.Clients { + var prefix string + var prefixLen int + if onPrefix { + prefix, prefixLen = c.Prefix() + if len(prefix) < maxLen { // Left padding. + prefix = strings.Repeat(" ", maxLen-prefixLen) + prefix + } + } + + if err = c.Run(t); err != nil { + return errors.Wrap(err, prefix+"task failed") + } + + // Copy over task's STDOUT. + wg.Add(1) + go func(c Client) { + defer wg.Done() + _, derr := io.Copy(os.Stdout, prefixer.New(c.Stdout(), prefix)) + if derr != nil && derr != io.EOF { + // TODO: io.Copy() should not return io.EOF at all. + // Upstream bug? Or prefixer.WriteTo() bug? + _, _ = fmt.Fprintf(os.Stderr, "%v", errors.Wrap(derr, prefix+"reading STDOUT failed")) + } + }(c) + + // Copy over task's STDERR. + wg.Add(1) + go func(c Client) { + defer wg.Done() + _, derr := io.Copy(os.Stderr, prefixer.New(c.Stderr(), prefix)) + if derr != nil && derr != io.EOF { + _, _ = fmt.Fprintf(os.Stderr, "%v", errors.Wrap(derr, prefix+"reading STDERR failed")) + } + }(c) + + writers = append(writers, c.Stdin()) + } + + // Copy over task's STDIN. + if t.Input != nil { + go t.copyStdin(writers) + } + + // Catch OS signals and pass them to all active clients. + trap := make(chan os.Signal, 1) + signal.Notify(trap, os.Interrupt) + go t.catchSignals(trap) + + // Wait for all I/O operations first. + wg.Wait() + + // Make sure each client finishes the task, return on failure. + t.clientsFinish(onPrefix, maxLen) + + // Stop catching signals for the currently active clients. + signal.Stop(trap) + close(trap) + return +} + +func (t *Task) catchSignals(trap chan os.Signal) { + var err error + + for { + select { + case sig, ok := <-trap: + if !ok { + return + } + + for _, c := range t.Clients { + if err = c.Signal(sig); err != nil { + _, err = fmt.Fprintf(os.Stderr, "%v", errors.Wrap(err, "sending signal failed")) + if err != nil { + log.Println("catchSignals Fprintf:", err) + } + } + } + } + } +} + +func (t *Task) copyStdin(writers []io.Writer) { + var err error + + writer := io.MultiWriter(writers...) + _, err = io.Copy(writer, t.Input) + if err != nil && err != io.EOF { + _, err = fmt.Fprintf(os.Stderr, "%v", errors.Wrap(err, "copying STDIN failed")) + if err != nil { + log.Println("copyStdin Fprintf:", err) + } + } + + err = nil + for _, c := range t.Clients { + if e := c.WriteClose(); e != nil { + err = multierror.Append(err, e) + } + } + + if err != nil { + _, err = fmt.Fprintf(os.Stderr, "%v", errors.Wrap(err, "failed to close clients")) + if err != nil { + log.Println("copyStdin Fprintf:", err) + } + } +} + +func (t *Task) clientsFinish(onPrefix bool, len int) { + wg := &sync.WaitGroup{} + + for _, c := range t.Clients { + wg.Add(1) + go func(c Client) { + var err error + defer wg.Done() + + if err = c.Wait(); err == nil { + return + } + + prefix := "" + if onPrefix { + prefix = t.formatClientPrefix(c, len) + } + + if e, ok := err.(*ssh.ExitError); ok && e.ExitStatus() != 15 { + // TODO: Store all the errors, and print them after Wait(). + _, err = fmt.Fprintf(os.Stderr, "%s%v\n", prefix, e) + if err != nil { + log.Println("clientsFinish Fprintf:", err) + } + + os.Exit(e.ExitStatus()) + } + + _, err = fmt.Fprintf(os.Stderr, "%s%v\n", prefix, err) + if err != nil { + log.Println("clientsFinish Fprintf:", err) + } + + // TODO: Shouldn't os.Exit(1) here. Instead, collect the exit statuses for later. + os.Exit(1) + + }(c) + } + + wg.Wait() } type ErrTask struct { From 7acd302ee20cc6df8d0041beeaa1b8901de34e4e Mon Sep 17 00:00:00 2001 From: Novikov Roman Date: Mon, 21 Feb 2022 10:35:12 +0300 Subject: [PATCH 2/7] fix alternatives supfile --- cmd/sup/main.go | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/cmd/sup/main.go b/cmd/sup/main.go index d1b30b2..3d8d048 100644 --- a/cmd/sup/main.go +++ b/cmd/sup/main.go @@ -254,18 +254,33 @@ func main() { return } - if supfile == "" { - supfile = "./Supfile" + var ( + supfileVariants []string + firstErr error + ) + if supfile != "" { + supfileVariants = append(supfileVariants, supfile) } - if data, err = ioutil.ReadFile(resolvePath(supfile)); err != nil { - firstErr := err - data, err = ioutil.ReadFile("./Supfile.yml") // Alternative to ./Supfile. - if err != nil { + // alternatives + supfileVariants = append(supfileVariants, []string{"./Supfile", "./Supfile.yml", "./Supfile.yaml"}...) + + for i, fn := range supfileVariants { + if data, err = ioutil.ReadFile(resolvePath(fn)); err == nil { + break + } + + if i == 0 { + firstErr = err + } + } + + if err != nil { + if supfile != "" { _, _ = fmt.Fprintln(os.Stderr, firstErr) - _, _ = fmt.Fprintln(os.Stderr, err) - os.Exit(1) } + _, _ = fmt.Fprintln(os.Stderr, err) + os.Exit(1) } if conf, err = sup.NewSupfile(data); err != nil { From daa382cbc5b36e7e9971a3549b2a1694fad5f912 Mon Sep 17 00:00:00 2001 From: AlexMikhalev Date: Tue, 10 Sep 2024 14:26:50 +0100 Subject: [PATCH 3/7] made it work with current ssh Signed-off-by: AlexMikhalev --- ssh.go | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/ssh.go b/ssh.go index d535e4b..84d5101 100644 --- a/ssh.go +++ b/ssh.go @@ -2,11 +2,6 @@ package sup import ( "fmt" - "github.com/hashicorp/go-multierror" - "github.com/mikkeloscar/sshconfig" - "github.com/pkg/errors" - "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/agent" "io" "io/ioutil" "net" @@ -15,6 +10,12 @@ import ( "path/filepath" "strings" "sync" + + "github.com/hashicorp/go-multierror" + "github.com/mikkeloscar/sshconfig" + "github.com/pkg/errors" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" ) // SSHClient is a wrapper over the SSH connection/sessions. @@ -326,16 +327,17 @@ func (c *SSHClient) parseHost(host string) (err error) { c.user = u.Username } - if strings.Index(host, "/") != -1 { + if strings.Contains(host, "/") { err = ErrConnect{User: c.user, Host: host, Reason: "unexpected slash in the host URL"} return } // Add default port, if not set - if at := strings.LastIndex(host, ":"); at != -1 { - c.host += ":22" + if !strings.Contains(host, ":") { + host += ":22" } + c.host = host return } From 5c9217a963d61724ad0eaaf4a7efddf841ebe7cb Mon Sep 17 00:00:00 2001 From: AlexMikhalev Date: Tue, 10 Sep 2024 18:16:36 +0100 Subject: [PATCH 4/7] made it work with current ssh go.mod Signed-off-by: AlexMikhalev --- .gitignore | 2 ++ go.mod | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index a7b5801..74b0047 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ bin/ .idea/ *.sw? + +.vscode/ \ No newline at end of file diff --git a/go.mod b/go.mod index 118bcab..a45ff77 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/NovikovRoman/sup +module github.com/AlexMikhalev/sup go 1.17 From a5af5e71d4f6fa71bf6b9b55c79c613cab94fba5 Mon Sep 17 00:00:00 2001 From: Dr Alexander Mikhalev Date: Wed, 11 Dec 2024 13:46:49 +0000 Subject: [PATCH 5/7] Update main.go --- cmd/sup/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/sup/main.go b/cmd/sup/main.go index 3d8d048..504ab8c 100644 --- a/cmd/sup/main.go +++ b/cmd/sup/main.go @@ -12,7 +12,7 @@ import ( "text/tabwriter" "time" - "github.com/NovikovRoman/sup" + "github.com/AlexMikhalev/sup" "github.com/mikkeloscar/sshconfig" "github.com/pkg/errors" ) From 45189106e6568d7408a495d56982592ee5eb4350 Mon Sep 17 00:00:00 2001 From: Alex Mikhalev Date: Sun, 29 Dec 2024 22:18:05 +0000 Subject: [PATCH 6/7] Local commands work --- example_simple/Supfile.yml | 23 ++++++ localhost.go | 53 +++++++++---- localhost_darwin.go | 24 ++++++ localhost_test.go | 148 ++++++++++++++++++++++++++++++++++ ssh.go | 76 ++++++++++-------- ssh_test.go | 159 +++++++++++++++++++++++++++++++++++++ supfile.go | 21 ++--- task.go | 15 ++-- task_test.go | 135 +++++++++++++++++++++++++++++++ 9 files changed, 589 insertions(+), 65 deletions(-) create mode 100644 example_simple/Supfile.yml create mode 100644 localhost_darwin.go create mode 100644 localhost_test.go create mode 100644 ssh_test.go create mode 100644 task_test.go diff --git a/example_simple/Supfile.yml b/example_simple/Supfile.yml new file mode 100644 index 0000000..4fc1304 --- /dev/null +++ b/example_simple/Supfile.yml @@ -0,0 +1,23 @@ +networks: + dev: + hosts: + - alex@bigbox +commands: + bash: + desc: Interactive Bash on all hosts + stdin: true + run: bash + ping: + desc: Print uname and current date/time. + run: uname -a; date + upload: + desc: Upload dist files to all hosts + upload: + - src: ./dist + dst: /tmp/ + build: + desc: build + local: make build + test: + desc: test + local: make test diff --git a/localhost.go b/localhost.go index 5b9cd20..00e20aa 100644 --- a/localhost.go +++ b/localhost.go @@ -2,11 +2,13 @@ package sup import ( "fmt" - "github.com/pkg/errors" "io" "os" "os/exec" "os/user" + "strings" + + "github.com/pkg/errors" ) // LocalhostClient is a wrapper over the SSH connection/sessions. @@ -32,39 +34,60 @@ func (c *LocalhostClient) Connect() (err error) { func (c *LocalhostClient) Run(task *Task) (err error) { if c.running { - return fmt.Errorf("Command already running. ") + return fmt.Errorf("Command already running") } - cmd := exec.Command("bash", "-c", c.env+task.Run) - c.cmd = cmd + // Create command directly without shell + cmd := exec.Command("make", "build") + + // Set up environment variables + if c.env != "" { + cmd.Env = append(os.Environ(), strings.Split(strings.TrimSuffix(c.env, ";"), ";")...) + } + + // Set up pipes for stdin, stdout, stderr + if c.stdin, err = cmd.StdinPipe(); err != nil { + return errors.Wrap(err, "failed to create stdin pipe") + } if c.stdout, err = cmd.StdoutPipe(); err != nil { - return + return errors.Wrap(err, "failed to create stdout pipe") } if c.stderr, err = cmd.StderrPipe(); err != nil { - return + return errors.Wrap(err, "failed to create stderr pipe") } - if c.stdin, err = cmd.StdinPipe(); err != nil { - return - } + // Set working directory to current directory + cmd.Dir = "." - if err = c.cmd.Start(); err != nil { + // Start the command + if err = cmd.Start(); err != nil { return ErrTask{task, err.Error()} } + // Handle input if provided + if task.Input != nil { + if _, err = io.Copy(c.stdin, task.Input); err != nil { + return errors.Wrap(err, "copying input failed") + } + if err = c.stdin.Close(); err != nil { + return errors.Wrap(err, "closing input failed") + } + } + + c.cmd = cmd c.running = true - return + return nil } -func (c *LocalhostClient) Wait() (err error) { +func (c *LocalhostClient) Wait() error { if !c.running { - return fmt.Errorf("Trying to wait on stopped command. ") + return fmt.Errorf("Trying to wait on stopped command") } - err = c.cmd.Wait() + err := c.cmd.Wait() c.running = false - return + return err } func (c *LocalhostClient) Close() error { diff --git a/localhost_darwin.go b/localhost_darwin.go new file mode 100644 index 0000000..ec5a0e2 --- /dev/null +++ b/localhost_darwin.go @@ -0,0 +1,24 @@ +//go:build darwin +// +build darwin + +package sup + +import ( + "syscall" +) + +func getProcAttrs(tty bool) *syscall.SysProcAttr { + attrs := &syscall.SysProcAttr{ + Credential: &syscall.Credential{ + Uid: uint32(syscall.Getuid()), + Gid: uint32(syscall.Getgid()), + }, + } + + if tty { + attrs.Setpgid = true + attrs.Setsid = true + } + + return attrs +} diff --git a/localhost_test.go b/localhost_test.go new file mode 100644 index 0000000..bf9a5ab --- /dev/null +++ b/localhost_test.go @@ -0,0 +1,148 @@ +package sup + +import ( + "bytes" + "io" + "os" + "os/exec" + "runtime" + "testing" + "time" +) + +func TestLocalhostClient_Run(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping test on Windows") + } + + // Verify that basic shell commands are available + if _, err := exec.LookPath("sh"); err != nil { + t.Skip("sh not available in PATH") + } + + tests := []struct { + name string + task *Task + env string + wantErr bool + wantOutput string + interactive bool + }{ + { + name: "simple command", + task: &Task{ + Run: "printf 'hello\\n'", + TTY: false, + Input: nil, + }, + wantErr: false, + wantOutput: "hello\n", + interactive: false, + }, + { + name: "command with input", + task: &Task{ + Run: "cat", + TTY: false, + Input: bytes.NewBufferString("test input"), + }, + wantErr: false, + wantOutput: "test input", + interactive: false, + }, + { + name: "command with environment variables", + task: &Task{ + Run: "printf \"$TEST_VAR\\n\"", + TTY: false, + Input: nil, + }, + env: "export TEST_VAR=test_value;", + wantErr: false, + wantOutput: "test_value\n", + interactive: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &LocalhostClient{ + env: tt.env, + } + + if err := client.Connect(); err != nil { + t.Fatalf("Failed to connect: %v", err) + } + + if err := client.Run(tt.task); (err != nil) != tt.wantErr { + t.Errorf("LocalhostClient.Run() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.interactive && tt.wantOutput != "" { + var output []byte + var err error + + // Use a channel to handle timeout + done := make(chan bool) + go func() { + output, err = io.ReadAll(client.stdout) + done <- true + }() + + // Wait with timeout + select { + case <-done: + if err != nil { + t.Fatalf("Failed to read output: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for command output") + } + + if string(output) != tt.wantOutput { + t.Errorf("LocalhostClient.Run() output = %q, want %q", string(output), tt.wantOutput) + } + } + + if err := client.Wait(); err != nil && !tt.wantErr { + t.Errorf("LocalhostClient.Wait() error = %v", err) + } + }) + } +} + +func TestLocalhostClient_Signal(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping test on Windows") + } + + client := &LocalhostClient{} + if err := client.Connect(); err != nil { + t.Fatalf("Failed to connect: %v", err) + } + + task := &Task{ + Run: "sleep 10", + TTY: false, + Input: nil, + } + + if err := client.Run(task); err != nil { + t.Fatalf("Failed to run task: %v", err) + } + + // Give the process time to start + time.Sleep(100 * time.Millisecond) + + // Send interrupt signal + if err := client.Signal(os.Interrupt); err != nil { + t.Errorf("LocalhostClient.Signal() error = %v", err) + } + + // Wait should return with an error due to the interrupt + err := client.Wait() + if err == nil { + t.Error("Expected error from Wait() after interrupt") + } +} diff --git a/ssh.go b/ssh.go index 84d5101..351cdcd 100644 --- a/ssh.go +++ b/ssh.go @@ -148,56 +148,55 @@ func (c *SSHClient) getPrivateKey(file string) (*ssh.Signer, error) { } // Run runs the task.Run command remotely on c.host. -func (c *SSHClient) Run(task *Task) error { +func (c *SSHClient) Run(task *Task) (err error) { if c.running { return errors.New("Session already running") } - if c.sessOpened { - return errors.New("Session already connected") + if err = c.openSession(); err != nil { + return err } - sess, err := c.conn.NewSession() - if err != nil { - return err + // Handle interactive sessions + if task.TTY { + modes := ssh.TerminalModes{ + ssh.ECHO: 1, + ssh.TTY_OP_ISPEED: 14400, + ssh.TTY_OP_OSPEED: 14400, + } + + if err = c.sess.RequestPty("xterm", 40, 80, modes); err != nil { + return errors.Wrap(err, "request for pseudo terminal failed") + } } - c.remoteStdin, err = sess.StdinPipe() - if err != nil { + if c.remoteStdin, err = c.sess.StdinPipe(); err != nil { return err } - - c.remoteStdout, err = sess.StdoutPipe() - if err != nil { + if c.remoteStdout, err = c.sess.StdoutPipe(); err != nil { return err } - - c.remoteStderr, err = sess.StderrPipe() - if err != nil { + if c.remoteStderr, err = c.sess.StderrPipe(); err != nil { return err } - if task.TTY { - // Set up terminal modes - modes := ssh.TerminalModes{ - ssh.ECHO: 0, // disable echoing - ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4k baud - ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4k baud + if task.Input != nil { + if err = c.sess.Start(task.Run); err != nil { + return ErrTask{task, err.Error()} } - // Request pseudo terminal - if err = sess.RequestPty("xterm", 80, 40, modes); err != nil { - return ErrTask{task, fmt.Sprintf("request for pseudo terminal failed: %s", err)} + if _, err = io.Copy(c.remoteStdin, task.Input); err != nil { + return errors.Wrap(err, "copying input failed") + } + if err = c.remoteStdin.Close(); err != nil { + return errors.Wrap(err, "closing input failed") + } + } else { + if err = c.sess.Start(c.env + task.Run); err != nil { + return ErrTask{task, err.Error()} } } - // Start the remote command. - if err = sess.Start(c.env + task.Run); err != nil { - return ErrTask{task, err.Error()} - } - - c.sess = sess - c.sessOpened = true c.running = true - return nil + return } // Wait waits until the remote command finishes and exits. @@ -377,3 +376,18 @@ func initAuthMethod() { signers = append(signers, signer) } } + +func (c *SSHClient) openSession() error { + if c.sessOpened { + return errors.New("Session already connected") + } + + sess, err := c.conn.NewSession() + if err != nil { + return err + } + + c.sess = sess + c.sessOpened = true + return nil +} diff --git a/ssh_test.go b/ssh_test.go new file mode 100644 index 0000000..02e6859 --- /dev/null +++ b/ssh_test.go @@ -0,0 +1,159 @@ +package sup + +import ( + "bytes" + "io" + "os" + "testing" +) + +type mockClient struct { + stdin io.WriteCloser + stdout io.Reader + stderr io.Reader + pty bool +} + +func (c *mockClient) Connect() error { + return nil +} + +func (c *mockClient) Run(task *Task) error { + if task.TTY { + c.pty = true + } + if task.Input != nil { + c.stdin = &mockWriteCloser{Buffer: bytes.NewBuffer(nil)} + _, err := io.Copy(c.stdin, task.Input) + if err != nil { + return err + } + c.stdin.(*mockWriteCloser).Close() + } + return nil +} + +func (c *mockClient) Wait() error { + return nil +} + +func (c *mockClient) Close() error { + return nil +} + +func (c *mockClient) Prefix() (string, int) { + return "mock", 4 +} + +func (c *mockClient) Write(p []byte) (n int, err error) { + if c.stdin != nil { + return c.stdin.Write(p) + } + return len(p), nil +} + +func (c *mockClient) WriteClose() error { + if c.stdin != nil { + return c.stdin.Close() + } + return nil +} + +func (c *mockClient) Stdin() io.WriteCloser { + return c.stdin +} + +func (c *mockClient) Stderr() io.Reader { + return c.stderr +} + +func (c *mockClient) Stdout() io.Reader { + return c.stdout +} + +func (c *mockClient) Signal(sig os.Signal) error { + return nil +} + +type mockWriteCloser struct { + *bytes.Buffer + closed bool +} + +func (m *mockWriteCloser) Close() error { + m.closed = true + return nil +} + +func TestClient_Run(t *testing.T) { + tests := []struct { + name string + task *Task + wantErr bool + wantPTY bool + wantInput string + interactive bool + }{ + { + name: "non-interactive command", + task: &Task{ + Run: "echo 'hello'", + TTY: false, + Input: nil, + }, + wantErr: false, + wantPTY: false, + interactive: false, + }, + { + name: "interactive command", + task: &Task{ + Run: "bash", + TTY: true, + Input: nil, + }, + wantErr: false, + wantPTY: true, + interactive: true, + }, + { + name: "command with input", + task: &Task{ + Run: "cat", + TTY: false, + Input: bytes.NewBufferString("test input"), + }, + wantErr: false, + wantPTY: false, + wantInput: "test input", + interactive: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &mockClient{ + stdout: bytes.NewBuffer([]byte("mock output")), + stderr: bytes.NewBuffer(nil), + } + + if err := client.Run(tt.task); (err != nil) != tt.wantErr { + t.Errorf("Client.Run() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.interactive && !client.pty { + t.Error("Expected PTY to be requested for interactive command") + } + + if tt.wantInput != "" { + input := client.stdin.(*mockWriteCloser) + if input.String() != tt.wantInput { + t.Errorf("Expected input %q, got %q", tt.wantInput, input.String()) + } + if !input.closed { + t.Error("Expected input to be closed") + } + } + }) + } +} diff --git a/supfile.go b/supfile.go index b219236..e77f7aa 100644 --- a/supfile.go +++ b/supfile.go @@ -3,12 +3,13 @@ package sup import ( "bytes" "fmt" - "github.com/pkg/errors" - "gopkg.in/yaml.v2" "io" "os" "os/exec" "strings" + + "github.com/pkg/errors" + "gopkg.in/yaml.v2" ) // Supfile represents the Stack Up configuration YAML file. @@ -275,7 +276,7 @@ func NewSupfile(data []byte) (*Supfile, error) { return nil, err } - // API backward compatibility. Will be deprecated in v1.0. + // API backward compatibility switch conf.Version { case "": conf.Version = "0.1" @@ -286,17 +287,9 @@ func NewSupfile(data []byte) (*Supfile, error) { if cmd.RunOnce { return nil, ErrMustUpdate{"command.run_once is not supported in Supfile v" + conf.Version} } - } - fallthrough - - case "0.2": - for _, cmd := range conf.Commands.cmds { if cmd.Once { return nil, ErrMustUpdate{"command.once is not supported in Supfile v" + conf.Version} } - if cmd.Local != "" { - return nil, ErrMustUpdate{"command.local is not supported in Supfile v" + conf.Version} - } if cmd.Serial != 0 { return nil, ErrMustUpdate{"command.serial is not supported in Supfile v" + conf.Version} } @@ -308,7 +301,7 @@ func NewSupfile(data []byte) (*Supfile, error) { } fallthrough - case "0.3": + case "0.2": var warning string for key, cmd := range conf.Commands.cmds { if cmd.RunOnce { @@ -320,10 +313,10 @@ func NewSupfile(data []byte) (*Supfile, error) { if warning != "" { _, _ = fmt.Fprintf(os.Stderr, warning) } - fallthrough - case "0.4", "0.5": + case "0.3", "0.4", "0.5": + // All good default: return nil, ErrUnsupportedSupfileVersion{"unsupported Supfile version " + conf.Version} diff --git a/task.go b/task.go index 044bb07..7721064 100644 --- a/task.go +++ b/task.go @@ -2,10 +2,6 @@ package sup import ( "fmt" - "github.com/goware/prefixer" - "github.com/hashicorp/go-multierror" - "github.com/pkg/errors" - "golang.org/x/crypto/ssh" "io" "io/ioutil" "log" @@ -14,6 +10,11 @@ import ( "strconv" "strings" "sync" + + "github.com/goware/prefixer" + "github.com/hashicorp/go-multierror" + "github.com/pkg/errors" + "golang.org/x/crypto/ssh" ) // Task represents a set of commands to be run. @@ -136,7 +137,10 @@ func (sup *Stackup) createTasks(cmd *Command, clients []Client, env string) (tas env: env + `export SUP_HOST="localhost";`, } - _ = local.Connect() + if err = local.Connect(); err != nil { + return nil, errors.Wrap(err, "connecting to localhost failed") + } + task := &Task{ Run: cmd.Local, Clients: []Client{local}, @@ -193,6 +197,7 @@ func (sup *Stackup) createTasks(cmd *Command, clients []Client, env string) (tas return } + func (t *Task) formatClientPrefix(c Client, len int) string { p, _ := c.Prefix() return fmt.Sprintf("%"+strconv.Itoa(len)+"s", p) diff --git a/task_test.go b/task_test.go new file mode 100644 index 0000000..d4e99c0 --- /dev/null +++ b/task_test.go @@ -0,0 +1,135 @@ +package sup + +import ( + "bytes" + "testing" +) + +func TestStackup_createTasks(t *testing.T) { + tests := []struct { + name string + command *Command + wantErr bool + wantTask bool + local bool + }{ + { + name: "local command", + command: &Command{ + Name: "test", + Local: "echo 'hello'", + }, + wantErr: false, + wantTask: true, + local: true, + }, + { + name: "remote command", + command: &Command{ + Name: "test", + Run: "echo 'hello'", + }, + wantErr: false, + wantTask: true, + local: false, + }, + { + name: "local and remote command", + command: &Command{ + Name: "test", + Local: "echo 'local'", + Run: "echo 'remote'", + }, + wantErr: false, + wantTask: true, + local: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sup := &Stackup{ + conf: &Supfile{ + Version: "0.5", + }, + } + + clients := []Client{ + &mockClient{ + stdout: bytes.NewBuffer([]byte("mock output")), + stderr: bytes.NewBuffer(nil), + }, + } + + tasks, err := sup.createTasks(tt.command, clients, "") + if (err != nil) != tt.wantErr { + t.Errorf("Stackup.createTasks() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantTask && len(tasks) == 0 { + t.Error("Stackup.createTasks() expected tasks but got none") + return + } + + if tt.local { + found := false + for _, task := range tasks { + for _, client := range task.Clients { + if _, ok := client.(*LocalhostClient); ok { + found = true + break + } + } + } + if !found { + t.Error("Stackup.createTasks() expected LocalhostClient but found none") + } + } + }) + } +} + +func TestStackup_createTasks_WithEnv(t *testing.T) { + sup := &Stackup{ + conf: &Supfile{ + Version: "0.5", + }, + } + + command := &Command{ + Name: "test", + Local: "echo $TEST_VAR", + } + + clients := []Client{ + &mockClient{ + stdout: bytes.NewBuffer([]byte("mock output")), + stderr: bytes.NewBuffer(nil), + }, + } + + env := "export TEST_VAR=test_value;" + tasks, err := sup.createTasks(command, clients, env) + if err != nil { + t.Fatalf("Stackup.createTasks() error = %v", err) + } + + if len(tasks) == 0 { + t.Fatal("Stackup.createTasks() expected tasks but got none") + } + + task := tasks[0] + if len(task.Clients) != 1 { + t.Fatal("Expected exactly one client") + } + + client, ok := task.Clients[0].(*LocalhostClient) + if !ok { + t.Fatal("Expected LocalhostClient") + } + + if client.env != env+`export SUP_HOST="localhost";` { + t.Errorf("Expected env %q, got %q", env+`export SUP_HOST="localhost";`, client.env) + } +} From 1be36b8d8d0b569958cd65bab3736fddc3413671 Mon Sep 17 00:00:00 2001 From: Alex Mikhalev Date: Mon, 30 Dec 2024 11:03:50 +0000 Subject: [PATCH 7/7] fixes for local commands and tests --- .gitignore | 6 ++++- example_simple/Supfile.yml | 10 ++++++++ example_simple/dist/test.txt | 1 + go.mod | 14 ++++++---- go.sum | 2 ++ localhost.go | 50 +++++++++++++++++++++++++++++++----- 6 files changed, 71 insertions(+), 12 deletions(-) create mode 100644 example_simple/dist/test.txt diff --git a/.gitignore b/.gitignore index 74b0047..d074104 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,8 @@ bin/ .idea/ *.sw? -.vscode/ \ No newline at end of file +.vscode/ + +/target +**/*.rs.bk +Cargo.lock \ No newline at end of file diff --git a/example_simple/Supfile.yml b/example_simple/Supfile.yml index 4fc1304..45e05dd 100644 --- a/example_simple/Supfile.yml +++ b/example_simple/Supfile.yml @@ -1,7 +1,17 @@ +--- +version: 0.4 + networks: dev: hosts: - alex@bigbox + - alex@100.106.66.7 + staging: + hosts: + - alex@100.106.66.7 + prod: + hosts: + - alex@api.thepattern.digital commands: bash: desc: Interactive Bash on all hosts diff --git a/example_simple/dist/test.txt b/example_simple/dist/test.txt new file mode 100644 index 0000000..9daeafb --- /dev/null +++ b/example_simple/dist/test.txt @@ -0,0 +1 @@ +test diff --git a/go.mod b/go.mod index a45ff77..8c18a12 100644 --- a/go.mod +++ b/go.mod @@ -4,16 +4,20 @@ go 1.17 require ( github.com/goware/prefixer v0.0.0-20160118172347-395022866408 - github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 - github.com/kr/pretty v0.3.0 // indirect github.com/mikkeloscar/sshconfig v0.1.1 - github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/pkg/errors v0.9.1 - github.com/rogpeppe/go-internal v1.8.1 // indirect golang.org/x/crypto v0.0.0-20220214200702-86341886e292 + gopkg.in/yaml.v2 v2.4.0 +) + +require ( + github.com/creack/pty v1.1.24 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/kr/pretty v0.3.0 // indirect + github.com/mitchellh/go-homedir v1.1.0 // indirect + github.com/rogpeppe/go-internal v1.8.1 // indirect golang.org/x/sys v0.0.0-20220209214540-3681064d5158 // indirect golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect - gopkg.in/yaml.v2 v2.4.0 ) diff --git a/go.sum b/go.sum index 9e59af6..4078a80 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= +github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/goware/prefixer v0.0.0-20160118172347-395022866408 h1:Y9iQJfEqnN3/Nce9cOegemcy/9Ai5k3huT6E80F3zaw= github.com/goware/prefixer v0.0.0-20160118172347-395022866408/go.mod h1:PE1ycukgRPJ7bJ9a1fdfQ9j8i/cEcRAoLZzbxYpNB/s= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= diff --git a/localhost.go b/localhost.go index 00e20aa..a0816ec 100644 --- a/localhost.go +++ b/localhost.go @@ -7,6 +7,7 @@ import ( "os/exec" "os/user" "strings" + "syscall" "github.com/pkg/errors" ) @@ -37,15 +38,40 @@ func (c *LocalhostClient) Run(task *Task) (err error) { return fmt.Errorf("Command already running") } - // Create command directly without shell - cmd := exec.Command("make", "build") + // Parse the command and arguments + cmdArgs := strings.Fields(task.Run) + if len(cmdArgs) == 0 { + return fmt.Errorf("No command specified") + } + + // For interactive commands, use syscall.Exec + if task.TTY { + binary, err := exec.LookPath(cmdArgs[0]) + if err != nil { + return ErrTask{task, err.Error()} + } + + env := os.Environ() + if c.env != "" { + env = append(env, strings.Split(strings.TrimSuffix(c.env, ";"), ";")...) + } + + err = syscall.Exec(binary, cmdArgs, env) + if err != nil { + return ErrTask{task, err.Error()} + } + return nil + } + + // Create command with proper arguments for non-interactive commands + cmd := exec.Command(cmdArgs[0], cmdArgs[1:]...) // Set up environment variables if c.env != "" { cmd.Env = append(os.Environ(), strings.Split(strings.TrimSuffix(c.env, ";"), ";")...) } - // Set up pipes for stdin, stdout, stderr + // Set up pipes for non-interactive commands if c.stdin, err = cmd.StdinPipe(); err != nil { return errors.Wrap(err, "failed to create stdin pipe") } @@ -58,9 +84,6 @@ func (c *LocalhostClient) Run(task *Task) (err error) { return errors.Wrap(err, "failed to create stderr pipe") } - // Set working directory to current directory - cmd.Dir = "." - // Start the command if err = cmd.Start(); err != nil { return ErrTask{task, err.Error()} @@ -95,14 +118,29 @@ func (c *LocalhostClient) Close() error { } func (c *LocalhostClient) Stdin() io.WriteCloser { + if c.cmd != nil && c.cmd.Stdin != nil { + if writer, ok := c.cmd.Stdin.(io.WriteCloser); ok { + return writer + } + } return c.stdin } func (c *LocalhostClient) Stderr() io.Reader { + if c.cmd != nil && c.cmd.Stderr != nil { + if reader, ok := c.cmd.Stderr.(io.Reader); ok { + return reader + } + } return c.stderr } func (c *LocalhostClient) Stdout() io.Reader { + if c.cmd != nil && c.cmd.Stdout != nil { + if reader, ok := c.cmd.Stdout.(io.Reader); ok { + return reader + } + } return c.stdout }