diff --git a/.travis.yml b/.travis.yml index ac72f57..db0764c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,5 @@ language: go go: - - 1.5 + - 1.8.x - tip sudo: false \ No newline at end of file diff --git a/cli.go b/cli.go new file mode 100644 index 0000000..dc46630 --- /dev/null +++ b/cli.go @@ -0,0 +1,180 @@ +package starter + +import ( + "context" + "fmt" + "os" + "reflect" + "strings" + "time" + + "github.com/pkg/errors" +) + +func NewCLI() *CLI { + return &CLI{} +} + +func makeOptionList(opts *options) []Option { + var list []Option + if len(opts.Args) > 0 { + list = append(list, WithArgs(opts.Args...)) + } + if opts.AutoRestartInterval.Valid { + list = append(list, WithAutoRestartInterval(time.Duration(opts.AutoRestartInterval.Value)*time.Second)) + } + if opts.Dir != "" { + list = append(list, WithDir(opts.Dir)) + } + if opts.EnableAutoRestart.Valid { + list = append(list, WithAutoRestart(opts.EnableAutoRestart.Value)) + } + if opts.Envdir.Valid { + list = append(list, WithEnvdir(opts.Envdir.Value)) + } + if opts.Interval > -1 { + list = append(list, WithInterval(time.Duration(opts.Interval)*time.Second)) + } + if opts.KillOldDelay.Valid { + list = append(list, WithKillOldDelay(time.Duration(opts.KillOldDelay.Value)*time.Second)) + } + if len(opts.Paths) > 0 { + list = append(list, WithPaths(opts.Paths)) + } + if opts.PidFile != "" { + list = append(list, WithPidFile(opts.PidFile)) + } + if len(opts.Ports) > 0 { + list = append(list, WithPorts(opts.Ports)) + } + if opts.SignalOnHUP != "" { + list = append(list, WithSignalOnHUP(sigFromName(opts.SignalOnHUP))) + } + if opts.SignalOnTERM != "" { + list = append(list, WithSignalOnTERM(sigFromName(opts.SignalOnTERM))) + } + if opts.StatusFile != "" { + list = append(list, WithStatusFile(opts.StatusFile)) + } + return list +} + +func (cli *CLI) ParseArgs(args ...string) (*options, error) { + var opts options + opts.Interval = -1 // allow 0 + if err := opts.Parse(args...); err != nil { + return nil, errors.Wrap(err, "failed to parse arguments") + } + + if opts.Interval < 0 { + opts.Interval = 1 + } + + if !opts.Help && !opts.Version && !opts.Restart { + if len(opts.Args) == 0 { + return nil, errors.New("server program not specified") + } + + opts.Command = opts.Args[0] + if len(opts.Args) > 1 { + opts.Args = opts.Args[1:] + } else { + opts.Args = []string(nil) + } + } + + return &opts, nil +} + +func (cli *CLI) Run(ctx context.Context) error { + opts, err := cli.ParseArgs(os.Args[1:]...) + if err != nil { + return err + } + + if opts.Help { + showHelp() + return nil + } + + if opts.Version { + fmt.Printf("%s\n", version) + return nil + } + + if opts.Restart { + if opts.PidFile == "" || opts.StatusFile == "" { + return errors.New("--restart option requires --pid-file and --status-file to be set as well") + } + + s := NewRestarter(opts.PidFile, opts.StatusFile) + return s.Run(ctx) + } + + s := New(opts.Command, makeOptionList(opts)...) + return s.Run(ctx) +} + +func showHelp() { + // The ONLY reason we're not using go-flags' help option is + // because I wanted to tweak the format just a bit... but + // there wasn't an easy way to do so + os.Stderr.WriteString(` +Usage: + start_server [options] -- server-prog server-arg1 server-arg2 ... + + # start Plack using Starlet listening at TCP port 8000 + start_server --port=8000 -- plackup -s Starlet --max-workers=100 index.psgi + +Options: +`) + + t := reflect.TypeOf(options{}) + + // This weird indexing stuff is done purely to keep ourselves + // compatible with the original start_server program + // (This is the order that the help is displayed in) + names := []string{ + "Ports", + "Paths", + "Dir", + "Interval", + "SignalOnHUP", + "SignalOnTERM", + "PidFile", + "StatusFile", + "Envdir", + "EnableAutoRestart", + "AutoRestartInterval", + "KillOldDelay", + "Restart", + "Help", + "Version", + } + + for _, name := range names { + f, ok := t.FieldByName(name) + if !ok { + continue + } + + tag := f.Tag + if tag == "" { + continue + } + if s := tag.Get("long"); s != "" { + fmt.Fprintf(os.Stderr, " --%s", s) + if a := tag.Get("arg"); a != "" { + fmt.Fprintf(os.Stderr, "=%s", a) + } + if tag.Get("note") == "unimplemented" { + fmt.Fprintf(os.Stderr, " (UNIMPLEMENTED)") + } + fmt.Fprintf(os.Stderr, ":\n") + } + for _, l := range strings.Split(tag.Get("description"), "\n") { + fmt.Fprintf(os.Stderr, " %s\n", l) + } + fmt.Fprintf(os.Stderr, "\n") + } +} diff --git a/cli_test.go b/cli_test.go new file mode 100644 index 0000000..1753f9a --- /dev/null +++ b/cli_test.go @@ -0,0 +1,350 @@ +package starter + +import ( + "errors" + "fmt" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func findInOptionList(t *testing.T, opts *options, name string, val interface{}) error { + for _, o := range makeOptionList(opts) { + switch o.Name() { + case name: + if !assert.Equal(t, val, o.Value(), "option value matches") { + return errors.New("option value does not match") + } + return nil + } + } + t.Errorf("failed to find option") + return errors.New("failed to find option") +} + +func TestCLIArgs(t *testing.T) { + c := NewCLI() + + t.Run("no parameters", func(t *testing.T) { + opts, err := c.ParseArgs("ls") + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + t.Logf("%s", err) + return + } + + expected := options{ + Command: "ls", + Interval: 1, + } + + if !assert.Equal(t, &expected, opts) { + return + } + }) + + t.Run("--auto-restart-interval=5", func(t *testing.T) { + opts, err := c.ParseArgs("ls", "--auto-restart-interval=5") + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + AutoRestartInterval: intOpt{Valid: true, Value: 5}, + } + + if !assert.Equal(t, &expected, opts) { + return + } + if err := findInOptionList(t, opts, "auto_restart_interval", 5*time.Second); err != nil { + return + } + }) + + t.Run("--dir=foo", func(t *testing.T) { + opts, err := c.ParseArgs("ls", "--dir=foo") + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + Dir: "foo", + } + + if !assert.Equal(t, &expected, opts) { + return + } + if err := findInOptionList(t, opts, "dir", "foo"); err != nil { + return + } + }) + + for _, val := range []bool{true, false} { + arg := fmt.Sprintf("--enable-auto-restart=%t", val) + t.Run(arg, func(t *testing.T) { + opts, err := c.ParseArgs("ls", arg) + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + EnableAutoRestart: boolOpt{Valid: true, Value: val}, + } + + if !assert.Equal(t, &expected, opts) { + return + } + if err := findInOptionList(t, opts, "enable_auto_restart", val); err != nil { + return + } + }) + } + + t.Run("--envdir=foo", func(t *testing.T) { + opts, err := c.ParseArgs("ls", "--envdir=foo") + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + Envdir: stringOpt{Valid: true, Value: "foo"}, + } + + if !assert.Equal(t, &expected, opts) { + return + } + if err := findInOptionList(t, opts, "envdir", "foo"); err != nil { + return + } + }) + + // 0 is a special case, so we must test + for i := 0; i < 2; i++ { + arg := fmt.Sprintf("--interval=%d", i) + t.Run(arg, func(t *testing.T) { + opts, err := c.ParseArgs("ls", arg) + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: i, + } + + if !assert.Equal(t, &expected, opts) { + return + } + if err := findInOptionList(t, opts, "interval", time.Duration(i)*time.Second); err != nil { + return + } + }) + } + + for name, sig := range niceNameToSigs { + hupArg := fmt.Sprintf("--signal-on-hup=%s", name) + t.Run(hupArg, func(t *testing.T) { + opts, err := c.ParseArgs("ls", hupArg) + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + SignalOnHUP: name, + } + + if !assert.Equal(t, &expected, opts) { + return + } + + if err := findInOptionList(t, opts, "signal_on_hup", sig); err != nil { + return + } + }) + termArg := fmt.Sprintf("--signal-on-term=%s", name) + t.Run(termArg, func(t *testing.T) { + opts, err := c.ParseArgs("ls", termArg) + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + SignalOnTERM: name, + } + + if !assert.Equal(t, &expected, opts) { + return + } + + if err := findInOptionList(t, opts, "signal_on_term", sig); err != nil { + return + } + }) + } + + for _, i := range []int{5, 10} { + arg := fmt.Sprintf("--kill-old-delay=%d", i) + t.Run(arg, func(t *testing.T) { + opts, err := c.ParseArgs("ls", arg) + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + KillOldDelay: intOpt{Valid: true, Value: i}, + } + + if !assert.Equal(t, &expected, opts) { + return + } + if err := findInOptionList(t, opts, "kill_old_delay", time.Duration(i)*time.Second); err != nil { + return + } + }) + } + + paths := []string{ + "/tmp/foo.sock", + "/tmp/bar.sock", + } + for i := 1; i <= 2; i++ { + args := make([]string, len(paths)) + for i, p := range paths { + args[i] = "--path=" + p + } + name := strings.Join(args, " ") + t.Run(name, func(t *testing.T) { + opts, err := c.ParseArgs(append([]string{"ls"}, args[:i]...)...) + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + Paths: paths[:i], + } + + if !assert.Equal(t, &expected, opts) { + return + } + if err := findInOptionList(t, opts, "paths", paths[:i]); err != nil { + return + } + }) + } + + t.Run("--pid-file=/path/to/foo", func(t *testing.T) { + opts, err := c.ParseArgs("ls", "--pid-file=/path/to/foo") + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + PidFile: "/path/to/foo", + } + + if !assert.Equal(t, &expected, opts) { + return + } + if err := findInOptionList(t, opts, "pid_file", "/path/to/foo"); err != nil { + return + } + }) + + ports := []string{ + "8080", + "0.0.0.0:9090", + } + for i := 1; i <= 2; i++ { + args := make([]string, len(ports)) + for i, p := range ports { + args[i] = "--port=" + p + } + name := strings.Join(args, " ") + t.Run(name, func(t *testing.T) { + opts, err := c.ParseArgs(append([]string{"ls"}, args[:i]...)...) + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + Ports: ports[:i], + } + + if !assert.Equal(t, &expected, opts) { + return + } + if err := findInOptionList(t, opts, "ports", ports[:i]); err != nil { + return + } + }) + } + + for _, val := range []bool{true, false} { + arg := fmt.Sprintf("--restart=%t", val) + t.Run(arg, func(t *testing.T) { + args := []string{arg} + if !val { + args = append([]string{"ls"}, args...) + } + opts, err := c.ParseArgs(args...) + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + t.Logf("%s", err) + return + } + + expected := options{ + Interval: 1, + Restart: val, + } + if !val { + expected.Command = "ls" + } + + if !assert.Equal(t, &expected, opts) { + return + } + }) + } + + t.Run("--status-file=/path/to/foo", func(t *testing.T) { + opts, err := c.ParseArgs("ls", "--status-file=/path/to/foo") + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + StatusFile: "/path/to/foo", + } + + if !assert.Equal(t, &expected, opts) { + return + } + if err := findInOptionList(t, opts, "status_file", "/path/to/foo"); err != nil { + return + } + }) + +} diff --git a/cmd/start_server/start_server.go b/cmd/start_server/start_server.go index 69dc013..008030a 100644 --- a/cmd/start_server/start_server.go +++ b/cmd/start_server/start_server.go @@ -1,161 +1,17 @@ package main import ( + "context" "fmt" "os" - "reflect" - "strings" - "time" - "github.com/jessevdk/go-flags" "github.com/lestrrat/go-server-starter" ) -const version = "0.0.2" - -type options struct { - OptArgs []string - OptCommand string - OptDir string `long:"dir" arg:"path" description:"working directory, start_server do chdir to before exec (optional)"` - OptInterval int `long:"interval" arg:"seconds" description:"minimum interval (in seconds) to respawn the server program (default: 1)"` - OptPorts []string `long:"port" arg:"(port|host:port)" description:"TCP port to listen to (if omitted, will not bind to any ports)"` - OptPaths []string `long:"path" arg:"path" description:"path at where to listen using unix socket (optional)"` - OptSignalOnHUP string `long:"signal-on-hup" arg:"Signal" description:"name of the signal to be sent to the server process when start_server\nreceives a SIGHUP (default: TERM). If you use this option, be sure to\nalso use '--signal-on-term' below."` - OptSignalOnTERM string `long:"signal-on-term" arg:"Signal" description:"name of the signal to be sent to the server process when start_server\nreceives a SIGTERM (default: TERM)"` - OptPidFile string `long:"pid-file" arg:"filename" description:"if set, writes the process id of the start_server process to the file"` - OptStatusFile string `long:"status-file" arg:"filename" description:"if set, writes the status of the server process(es) to the file"` - OptEnvdir string `long:"envdir" arg:"Envdir" description:"directory that contains environment variables to the server processes.\nIt is intended for use with \"envdir\" in \"daemontools\". This can be\noverwritten by environment variable \"ENVDIR\"."` - OptEnableAutoRestart bool `long:"enable-auto-restart" description:"enables automatic restart by time. This can be overwritten by\nenvironment variable \"ENABLE_AUTO_RESTART\"." note:"unimplemented"` - OptAutoRestartInterval int `long:"auto-restart-interval" arg:"seconds" description:"automatic restart interval (default 360). It is used with\n\"--enable-auto-restart\" option. This can be overwritten by environment\nvariable \"AUTO_RESTART_INTERVAL\"." note:"unimplemented"` - OptKillOldDelay int `long:"kill-old-delay" arg:"seconds" description:"time to suspend to send a signal to the old worker. The default value is\n5 when \"--enable-auto-restart\" is set, 0 otherwise. This can be\noverwritten by environment variable \"KILL_OLD_DELAY\"."` - OptRestart bool `long:"restart" description:"this is a wrapper command that reads the pid of the start_server process\nfrom --pid-file, sends SIGHUP to the process and waits until the\nserver(s) of the older generation(s) die by monitoring the contents of\nthe --status-file" note:"unimplemented"` - OptHelp bool `long:"help" description:"prints this help"` - OptVersion bool `long:"version" description:"prints the version number"` -} - -func (o options) Args() []string { return o.OptArgs } -func (o options) Command() string { return o.OptCommand } -func (o options) Dir() string { return o.OptDir } -func (o options) Interval() time.Duration { return time.Duration(o.OptInterval) * time.Second } -func (o options) PidFile() string { return o.OptPidFile } -func (o options) Ports() []string { return o.OptPorts } -func (o options) Paths() []string { return o.OptPaths } -func (o options) SignalOnHUP() os.Signal { return starter.SigFromName(o.OptSignalOnHUP) } -func (o options) SignalOnTERM() os.Signal { return starter.SigFromName(o.OptSignalOnTERM) } -func (o options) StatusFile() string { return o.OptStatusFile } - -func showHelp() { - // The ONLY reason we're not using go-flags' help option is - // because I wanted to tweak the format just a bit... but - // there wasn't an easy way to do so - os.Stderr.WriteString(` -Usage: - start_server [options] -- server-prog server-arg1 server-arg2 ... - - # start Plack using Starlet listening at TCP port 8000 - start_server --port=8000 -- plackup -s Starlet --max-workers=100 index.psgi - -Options: -`) - - t := reflect.TypeOf(options{}) - - // This weird indexing stuff is done purely to keep ourselves - // compatible with the original start_server program - // (This is the order that the help is displayed in) - names := []string{ - "OptPorts", - "OptPaths", - "OptDir", - "OptInterval", - "OptSignalOnHUP", - "OptSignalOnTERM", - "OptPidFile", - "OptStatusFile", - "OptEnvdir", - "OptEnableAutoRestart", - "OptAutoRestartInterval", - "OptKillOldDelay", - "OptRestart", - "OptHelp", - "OptVersion", - } - - for _, name := range names { - f, ok := t.FieldByName(name) - if !ok { - continue - } - - tag := f.Tag - if tag == "" { - continue - } - if s := tag.Get("long"); s != "" { - fmt.Fprintf(os.Stderr, " --%s", s) - if a := tag.Get("arg"); a != "" { - fmt.Fprintf(os.Stderr, "=%s", a) - } - if tag.Get("note") == "unimplemented" { - fmt.Fprintf(os.Stderr, " (UNIMPLEMENTED)") - } - fmt.Fprintf(os.Stderr, ":\n") - } - for _, l := range strings.Split(tag.Get("description"), "\n") { - fmt.Fprintf(os.Stderr, " %s\n", l) - } - fmt.Fprintf(os.Stderr, "\n") - } -} - func main() { - os.Exit(_main()) -} - -func _main() (st int) { - st = 1 - - opts := &options{OptInterval: -1} - p := flags.NewParser(opts, flags.PrintErrors|flags.PassDoubleDash) - args, err := p.Parse() - if err != nil || opts.OptHelp { - showHelp() - return - } - - if opts.OptVersion { - fmt.Printf("%s\n", version) - st = 0 - return - } - - if opts.OptInterval < 0 { - opts.OptInterval = 1 - } - - if len(args) == 0 { - fmt.Fprintf(os.Stderr, "server program not specified\n") - return - } - - opts.OptCommand = args[0] - if len(args) > 1 { - opts.OptArgs = args[1:] - } - - if opts.OptEnvdir != "" { - os.Setenv("ENVDIR", opts.OptEnvdir) - } - - s, err := starter.NewStarter(opts) - if err != nil { - fmt.Fprintf(os.Stderr, "error: %s\n", err) - return - } - if err := s.Run(); err != nil { - fmt.Fprintf(os.Stderr, "error: %s\n", err) - return + cli := starter.NewCLI() + if err := cli.Run(context.Background()); err != nil { + fmt.Fprintf(os.Stderr, "%s\n", err.Error()) + os.Exit(1) } - st = 0 - return } diff --git a/env.go b/env.go index debe20c..b073a9e 100644 --- a/env.go +++ b/env.go @@ -1,65 +1,21 @@ package starter import ( - "bufio" - "errors" "os" - "path/filepath" - "strings" + "strconv" + "time" ) -var errNoEnv = errors.New("no ENVDIR specified, or ENVDIR does not exist") - -func reloadEnv() (map[string]string, error) { - dn := os.Getenv("ENVDIR") - if dn == "" { - return nil, errNoEnv - } - - fi, err := os.Stat(dn) - if err != nil { - return nil, err - } - - if !fi.IsDir() { - return nil, err - } - - var m map[string]string - - filepath.Walk(dn, func(path string, fi os.FileInfo, err error) error { - // Ignore errors - if err != nil { - return nil - } - - // Don't go into directories - if fi.IsDir() && dn != path { - return filepath.SkipDir - } - - f, err := os.Open(path) - if err != nil { - return nil - } - defer f.Close() - - envName := filepath.Base(path) - scanner := bufio.NewScanner(f) - if scanner.Scan() { - if m == nil { - m = make(map[string]string) - } - l := scanner.Text() - m[envName] = strings.TrimSpace(l) - } - - return nil - }) +func envAsBool(name string) bool { + b, err := strconv.ParseBool(os.Getenv(name)) + return err == nil && b +} - if m == nil { - return nil, errNoEnv - } +func envAsInt(name string) int { + i, _ := strconv.ParseInt(os.Getenv(name), 10, 64) + return int(i) +} - return m, nil +func envAsDuration(name string) time.Duration { + return time.Duration(envAsInt(name)) * time.Second } diff --git a/env_test.go b/env_test.go deleted file mode 100644 index 8c4e525..0000000 --- a/env_test.go +++ /dev/null @@ -1,68 +0,0 @@ -package starter - -import ( - "io" - "io/ioutil" - "os" - "path/filepath" - "testing" -) - -func TestEnvdir(t *testing.T) { - dir, err := ioutil.TempDir("", "starter_test") - if err != nil { - t.Errorf("Failed to create tempdir: %s", err) - return - } - defer os.RemoveAll(dir) - - files := []string{"FOO", "BAR", "BAZ"} - for _, fn := range files { - longFn := filepath.Join(dir, fn) - - f, err := os.OpenFile(longFn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) - if err != nil { - t.Errorf("Failed to create file '%s': %s", fn, err) - return - } - closed := false - defer func() { - if !closed { - f.Close() - } - }() - - io.WriteString(f, fn) - f.Close() - closed = true - - // save old values and restore later, if any - if old := os.Getenv(fn); old != "" { - os.Setenv(fn, "") - defer os.Setenv(fn, old) - } - } - - if old := os.Getenv("ENVDIR"); old != "" { - defer os.Setenv("ENVDIR", old) - } - - os.Setenv("ENVDIR", dir) - m, err := reloadEnv() - if err != nil { - t.Errorf("reloadEnv failed: %s", err) - return - } - - for _, fn := range files { - v, ok := m[fn] - if !ok { - t.Errorf("Expected environment variable '%s' to exist") - return - } - if v != fn { - t.Errorf("Expected environment variable '%s' to be '%s'", fn, fn) - return - } - } -} diff --git a/glide.lock b/glide.lock new file mode 100644 index 0000000..8cdb324 --- /dev/null +++ b/glide.lock @@ -0,0 +1,26 @@ +hash: 834007827c1b49861ec10a1fc2fcb83425427181a3bfc0aec736d2ea54f39bbc +updated: 2017-02-27T09:46:02.670776913+09:00 +imports: +- name: github.com/lestrrat/go-envload + version: d7bc6278d4641441e356c80af00fd627f9b199f0 +- name: github.com/pkg/errors + version: 248dadf4e9068a0b3e79f02ed0a610d935de5302 +- name: golang.org/x/net + version: 10c134ea0df15f7e34d789338c7a2d76cc7a3ab9 + subpackages: + - context +testImports: +- name: github.com/davecgh/go-spew + version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9 + subpackages: + - spew +- name: github.com/lestrrat/go-tcputil + version: 3cae7c8432aeb817399f001e0db3f5bb04e8b8a6 +- name: github.com/pmezard/go-difflib + version: d8ed2627bdf02c080bf22230dbb337003b7aba2d + subpackages: + - difflib +- name: github.com/stretchr/testify + version: 4d4bfba8f1d1027c4fdbe371823030df51419987 + subpackages: + - assert diff --git a/glide.yaml b/glide.yaml new file mode 100644 index 0000000..08a7506 --- /dev/null +++ b/glide.yaml @@ -0,0 +1,9 @@ +package: github.com/lestrrat/go-server-starter +import: +- package: github.com/lestrrat/go-envload +- package: github.com/pkg/errors +testImport: +- package: github.com/lestrrat/go-tcputil +- package: github.com/stretchr/testify + subpackages: + - assert diff --git a/interface.go b/interface.go new file mode 100644 index 0000000..3078b03 --- /dev/null +++ b/interface.go @@ -0,0 +1,50 @@ +package starter + +import ( + "net" + "syscall" +) + +const version = `0.0.2` + +var failureStatus syscall.WaitStatus + +type listener struct { + listener net.Listener + spec string // path or port spec +} + +type Option interface { + Name() string + Value() interface{} +} + +type Restarter struct { + pidFile string + statusFile string +} + +type Starter struct { + options []Option + command string +} + +type processState interface { + Pid() int + Sys() interface{} +} + +type CLI struct{} +type boolOpt struct { + Valid bool + Value bool +} +type intOpt struct { + Valid bool + Value int +} + +type stringOpt struct { + Valid bool + Value string +} diff --git a/monitor.go b/monitor.go new file mode 100644 index 0000000..d273e44 --- /dev/null +++ b/monitor.go @@ -0,0 +1,102 @@ +package starter + +import ( + "context" + "fmt" + "os/exec" + "reflect" + "strconv" + "time" +) + +func ExampleMonitor() { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + ch := make(chan *exec.Cmd, 10) + done := make(chan *exec.Cmd) + go monitor(ctx, ch, done) + + for i := 0; i < 10; i++ { + cmd := exec.Command("sleep", strconv.Itoa(i)) + cmd.Start() + ch <- cmd + } + + for { + select { + case <-ctx.Done(): + fmt.Printf("timeout reached\n") + return + case cmd, ok := <-done: + if !ok { + fmt.Println("monitor exited") + return + } + fmt.Printf("notified: %d\n", cmd.ProcessState.Pid()) + } + } +} + +// monitor is a process (i.e. *exec.Cmd) monitor. it asynchronously +// listens for either (a) one of the monitored process exits, or (b) +// we get a request to watch for a new worker. +// +// users of this function can "wait" for the exit of a command +// by checking the done channel +func monitor(ctx context.Context, src chan *exec.Cmd, done chan *exec.Cmd) { + defer close(done) + var workers []struct { + Chan chan error + Cmd *exec.Cmd + } + for { + select { + case <-ctx.Done(): + return + default: + } + + cases := make([]reflect.SelectCase, len(workers)+1) + for i, worker := range workers { + cases[i].Chan = reflect.ValueOf(worker.Chan) + cases[i].Dir = reflect.SelectRecv + } + cases[len(cases)-1].Chan = reflect.ValueOf(src) + cases[len(cases)-1].Dir = reflect.SelectRecv + + chosen, recv, recvOK := reflect.Select(cases) + if !recvOK { + panic("should not get here") + } + + if chosen == len(workers) { + // 1 + max worker index, so must be our "new worker chan" + cmd := recv.Interface().(*exec.Cmd) + ch := make(chan error) + go func() { + ch <- cmd.Wait() + }() + workers = append(workers, struct { + Chan chan error + Cmd *exec.Cmd + }{ + Chan: ch, + Cmd: cmd, + }) + continue + } + + exited := workers[chosen].Cmd + // one of the workers must have finished + // remove the corresponding one + switch { + case len(workers) < 2: + workers = nil + default: + workers = append(workers[:chosen], workers[chosen+1:]...) + } + + done <- exited + } +} diff --git a/options.go b/options.go new file mode 100644 index 0000000..b2c55fb --- /dev/null +++ b/options.go @@ -0,0 +1,238 @@ +package starter + +import ( + "io" + "os" + "reflect" + "strconv" + "strings" + "time" + + "github.com/pkg/errors" +) + +type valueOption struct { + name string + value interface{} +} + +func (o *valueOption) Name() string { + return o.name +} + +func (o *valueOption) Value() interface{} { + return o.value +} + +func WithAutoRestart(b bool) Option { + return &valueOption{name: "enable_auto_restart", value: b} +} + +func WithAutoRestartInterval(t time.Duration) Option { + return &valueOption{name: "auto_restart_interval", value: t} +} + +func WithArgs(a ...string) Option { + return &valueOption{name: "args", value: a} +} + +func WithDir(dir string) Option { + return &valueOption{name: "dir", value: dir} +} + +func WithEnvdir(dir string) Option { + return &valueOption{name: "envdir", value: dir} +} + +func WithInterval(t time.Duration) Option { + return &valueOption{name: "interval", value: t} +} + +func WithKillOldDelay(t time.Duration) Option { + return &valueOption{name: "kill_old_delay", value: t} +} + +func WithPaths(l []string) Option { + return &valueOption{name: "paths", value: l} +} + +func WithPidFile(s string) Option { + return &valueOption{name: "pid_file", value: s} +} + +func WithPorts(l []string) Option { + return &valueOption{name: "ports", value: l} +} + +func WithSignalOnHUP(s os.Signal) Option { + return &valueOption{name: "signal_on_hup", value: s} +} + +func WithSignalOnTERM(s os.Signal) Option { + return &valueOption{name: "signal_on_term", value: s} +} + +func WithStatusFile(s string) Option { + return &valueOption{name: "status_file", value: s} +} + +func WithNoticeOutput(w io.Writer) Option { + return &valueOption{name: "notice_output", value: w} +} + +func WithLogStdout(w io.Writer) Option { + return &valueOption{name: "log_stdout", value: w} +} + +func WithLogStderr(w io.Writer) Option { + return &valueOption{name: "log_stderr", value: w} +} + +func (o *stringOpt) String() string { + return o.Value +} + +func (o *stringOpt) Set(s string) error { + o.Valid = true + o.Value = s + return nil +} + +func (o *intOpt) String() string { + return strconv.FormatInt(int64(o.Value), 10) +} + +func (o *intOpt) Set(s string) error { + i, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return err + } + o.Valid = true + o.Value = int(i) + return nil +} + +func (o *boolOpt) String() string { + return strconv.FormatBool(o.Value) +} + +func (o *boolOpt) Set(s string) error { + b, err := strconv.ParseBool(s) + if err != nil { + return err + } + o.Valid = true + o.Value = b + return nil +} + +type optsetter interface { + Set(string) error +} + +var osv = reflect.TypeOf((*optsetter)(nil)).Elem() + +func (o *options) Parse(args ...string) error { + rv := reflect.ValueOf(o).Elem() + tv := rv.Type() + names := map[string]reflect.Value{} + for i := 0; i < tv.NumField(); i++ { + f := tv.Field(i) + if f.PkgPath != "" || f.Anonymous { + continue + } + names[f.Tag.Get("long")] = rv.Field(i) + } + + var arguments []string + for len(args) > 0 { + arg := args[0] + args = args[1:] + l := len(arg) + if l == 2 && arg == "--" { + // stop processing, everything after this is an argument + if len(args) > 0 { + arguments = append(arguments, args...) + } + args = []string(nil) // force loop termination + break + } + + if !strings.HasPrefix(arg, "--") { + arguments = append(arguments, arg) + continue + } + end := l + var hasOpval bool + var opval string + if ei := strings.IndexByte(arg, '='); ei > -1 { + end = ei + if end < l-1 { + opval = arg[end+1:] + hasOpval = true + } else { + return errors.Errorf("invalid argument '%s'", arg) + } + } else { + // is the next argument the argument to this option + if len(args) > 0 && !strings.HasPrefix(args[0], "--") { + hasOpval = true + opval = args[0] + args = args[1:] + } + } + opname := arg[2:end] + f := names[opname] + opvalv := reflect.ValueOf(opval) + switch f.Kind() { + case reflect.Struct: + if reflect.PtrTo(f.Type()).Implements(osv) { + f.Addr().MethodByName("Set").Call([]reflect.Value{opvalv}) + } else if opvalv.Type().AssignableTo(f.Type()) { + f.Set(opvalv) + } + case reflect.String: + f.Set(opvalv) + case reflect.Bool: + if !hasOpval { + opval = "true" + } + b, err := strconv.ParseBool(opval) + if err != nil { + return err + } + f.Set(reflect.ValueOf(b)) + case reflect.Int: + i, err := strconv.ParseInt(opval, 10, 64) + if err != nil { + return err + } + f.Set(reflect.ValueOf(int(i))) + case reflect.Slice: + f.Set(reflect.Append(f, opvalv)) + } + } + + o.Args = arguments + return nil +} + +type options struct { + Args []string + AutoRestartInterval intOpt `long:"auto-restart-interval" arg:"seconds" description:"automatic restart interval (default 360). It is used with\n\"--enable-auto-restart\" option. This can be overwritten by environment\nvariable \"AUTO_RESTART_INTERVAL\"."` + Command string + Dir string `long:"dir" arg:"path" description:"working directory, start_server do chdir to before exec (optional)"` + EnableAutoRestart boolOpt `long:"enable-auto-restart" description:"enables automatic restart by time. This can be overwritten by\nenvironment variable \"ENABLE_AUTO_RESTART\"."` + Envdir stringOpt `long:"envdir" arg:"Envdir" description:"directory that contains environment variables to the server processes.\nIt is intended for use with \"envdir\" in \"daemontools\". This can be\noverwritten by environment variable \"ENVDIR\"."` + Interval int `long:"interval" arg:"seconds" description:"minimum interval (in seconds) to respawn the server program (default: 1)"` + KillOldDelay intOpt `long:"kill-old-delay" arg:"seconds" description:"time to suspend to send a signal to the old worker. The default value is\n5 when \"--enable-auto-restart\" is set, 0 otherwise. This can be\noverwritten by environment variable \"KILL_OLD_DELAY\"."` + Paths []string `long:"path" arg:"path" description:"path at where to listen using unix socket (optional)"` + PidFile string `long:"pid-file" arg:"filename" description:"if set, writes the process id of the start_server process to the file"` + Ports []string `long:"port" arg:"(port|host:port)" description:"TCP port to listen to (if omitted, will not bind to any ports)"` + Restart bool `long:"restart" description:"this is a wrapper command that reads the pid of the start_server process\nfrom --pid-file, sends SIGHUP to the process and waits until the\nserver(s) of the older generation(s) die by monitoring the contents of\nthe --status-file"` + SignalOnHUP string `long:"signal-on-hup" arg:"Signal" description:"name of the signal to be sent to the server process when start_server\nreceives a SIGHUP (default: TERM). If you use this option, be sure to\nalso use '--signal-on-term' below."` + SignalOnTERM string `long:"signal-on-term" arg:"Signal" description:"name of the signal to be sent to the server process when start_server\nreceives a SIGTERM (default: TERM)"` + StatusFile string `long:"status-file" arg:"filename" description:"if set, writes the status of the server process(es) to the file"` + Help bool `long:"help" description:"prints this help"` + Version bool `long:"version" description:"printes the version number"` +} diff --git a/restarter.go b/restarter.go new file mode 100644 index 0000000..942c422 --- /dev/null +++ b/restarter.go @@ -0,0 +1,102 @@ +package starter + +import ( + "bufio" + "bytes" + "context" + "io/ioutil" + "os" + "sort" + "strconv" + "strings" + "syscall" + "time" + + "github.com/pkg/errors" +) + +func NewRestarter(pidFile, statusFile string) *Restarter { + return &Restarter{ + pidFile: pidFile, + statusFile: statusFile, + } +} + +func (s *Restarter) Run(ctx context.Context) error { + pidbuf, err := ioutil.ReadFile(s.pidFile) + if err != nil { + return errors.Wrapf(err, "failed to open file:%s", s.pidFile) + } + pid, err := strconv.ParseInt(string(pidbuf), 10, 64) + if err != nil { + return errors.Wrap(err, "failed to parse pid") + } + + p, err := os.FindProcess(int(pid)) + if err != nil { + return errors.Wrapf(err, "failed to find process:%s", pidbuf) + } + + if err := p.Signal(syscall.SIGHUP); err != nil { + return errors.Wrap(err, "failed to send SIGHUP to the server process") + } + + getGenerations := func(file string) ([]int, error) { + genbuf, err := ioutil.ReadFile(file) + if err != nil { + return nil, errors.Wrapf(err, "failed to open file:%s", file) + } + scanner := bufio.NewScanner(bytes.NewReader(genbuf)) + genmap := make(map[int]struct{}) + for scanner.Scan() { + txt := scanner.Text() + i := strings.IndexByte(txt, ':') + if i <= 0 { + continue + } + gen, err := strconv.ParseInt(string(txt[:i]), 10, 64) + if err != nil { + continue + } + + genmap[int(gen)] = struct{}{} + } + + var generations []int + for k := range genmap { + generations = append(generations, k) + } + sort.Ints(generations) + return generations, nil + } + + var waitFor int + { + generations, err := getGenerations(s.statusFile) + if err != nil { + return errors.Wrap(err, "failed to find generations") + } + + if len(generations) == 0 { + return errors.New("no active process found in the status file") + } + + waitFor = generations[len(generations)-1] + 1 + } + + t := time.NewTicker(time.Second) + defer t.Stop() + + for { + select { + case <-t.C: + generations, err := getGenerations(s.statusFile) + if err != nil { + return errors.Wrap(err, "failed to find generations") + } + if len(generations) == 1 && generations[0] == waitFor { + return nil + } + } + } +} diff --git a/signals.go b/signals.go new file mode 100644 index 0000000..2224f83 --- /dev/null +++ b/signals.go @@ -0,0 +1,63 @@ +package starter + +import ( + "fmt" + "os" + "strings" + "syscall" +) + +var niceSigNames map[syscall.Signal]string +var niceNameToSigs map[string]syscall.Signal + +func makeNiceSigNames() map[syscall.Signal]string { + m := map[syscall.Signal]string{ + syscall.SIGABRT: "ABRT", + syscall.SIGALRM: "ALRM", + syscall.SIGBUS: "BUS", + // syscall.SIGEMT: "EMT", + syscall.SIGFPE: "FPE", + syscall.SIGHUP: "HUP", + syscall.SIGILL: "ILL", + // syscall.SIGINFO: "INFO", + syscall.SIGINT: "INT", + // syscall.SIGIOT: "IOT", + syscall.SIGKILL: "KILL", + syscall.SIGPIPE: "PIPE", + syscall.SIGQUIT: "QUIT", + syscall.SIGSEGV: "SEGV", + syscall.SIGTERM: "TERM", + syscall.SIGTRAP: "TRAP", + } + + // addPlatformDepdentNiceSigNames() is defined in the files + // containing build tags + return addPlatformDependentNiceSigNames(m) +} + +func init() { + niceSigNames = makeNiceSigNames() + niceNameToSigs = make(map[string]syscall.Signal) + for sig, name := range niceSigNames { + niceNameToSigs[name] = sig + } +} + +func signame(s os.Signal) string { + if ss, ok := s.(syscall.Signal); ok { + return niceSigNames[ss] + } + return fmt.Sprintf("UNKNOWN (%s)", s) +} + +func sigFromName(n string) os.Signal { + n = strings.ToUpper(n) + if strings.HasPrefix(n, "SIG") { + n = n[3:] // remove SIG prefix + } + + if sig, ok := niceNameToSigs[n]; ok { + return sig + } + return nil +} diff --git a/starter_any.go b/signals_any.go similarity index 86% rename from starter_any.go rename to signals_any.go index c66bdc2..4bdfda8 100644 --- a/starter_any.go +++ b/signals_any.go @@ -4,11 +4,6 @@ package starter import "syscall" -func init() { - failureStatus = syscall.WaitStatus(255) - successStatus = syscall.WaitStatus(0) -} - func addPlatformDependentNiceSigNames(v map[syscall.Signal]string) map[syscall.Signal]string { v[syscall.SIGCHLD] = "CHLD" v[syscall.SIGCONT] = "CONT" diff --git a/starter_windows.go b/signals_windows.go similarity index 54% rename from starter_windows.go rename to signals_windows.go index 49c6ef7..c4e94ae 100644 --- a/starter_windows.go +++ b/signals_windows.go @@ -2,11 +2,6 @@ package starter import "syscall" -func init() { - failureStatus = syscall.WaitStatus{ExitCode: 255} - successStatus = syscall.WaitStatus{ExitCode: 0} -} - func addPlatformDependentNiceSigNames(v map[syscall.Signal]string) map[syscall.Signal]string { return v } diff --git a/starter.go b/starter.go index c65ce00..0af3d29 100644 --- a/starter.go +++ b/starter.go @@ -1,133 +1,29 @@ package starter import ( + "bytes" + "context" "fmt" + "io" "net" "os" "os/exec" "os/signal" + "sort" "strconv" "strings" "syscall" "time" -) - -var niceSigNames map[syscall.Signal]string -var niceNameToSigs map[string]syscall.Signal -var successStatus syscall.WaitStatus -var failureStatus syscall.WaitStatus - -func makeNiceSigNamesCommon() map[syscall.Signal]string { - return map[syscall.Signal]string{ - syscall.SIGABRT: "ABRT", - syscall.SIGALRM: "ALRM", - syscall.SIGBUS: "BUS", - // syscall.SIGEMT: "EMT", - syscall.SIGFPE: "FPE", - syscall.SIGHUP: "HUP", - syscall.SIGILL: "ILL", - // syscall.SIGINFO: "INFO", - syscall.SIGINT: "INT", - // syscall.SIGIOT: "IOT", - syscall.SIGKILL: "KILL", - syscall.SIGPIPE: "PIPE", - syscall.SIGQUIT: "QUIT", - syscall.SIGSEGV: "SEGV", - syscall.SIGTERM: "TERM", - syscall.SIGTRAP: "TRAP", - } -} - -func makeNiceSigNames() map[syscall.Signal]string { - return addPlatformDependentNiceSigNames(makeNiceSigNamesCommon()) -} - -func init() { - niceSigNames = makeNiceSigNames() - niceNameToSigs = make(map[string]syscall.Signal) - for sig, name := range niceSigNames { - niceNameToSigs[name] = sig - } -} - -type listener struct { - listener net.Listener - spec string // path or port spec -} - -type Config interface { - Args() []string - Command() string - Dir() string // Directory to chdir to before executing the command - Interval() time.Duration // Time between checks for liveness - PidFile() string - Ports() []string // Ports to bind to (addr:port or port, so it's a string) - Paths() []string // Paths (UNIX domain socket) to bind to - SignalOnHUP() os.Signal // Signal to send when HUP is received - SignalOnTERM() os.Signal // Signal to send when TERM is received - StatusFile() string -} -type Starter struct { - interval time.Duration - signalOnHUP os.Signal - signalOnTERM os.Signal - // you can't set this in go: backlog - statusFile string - pidFile string - dir string - ports []string - paths []string - listeners []listener - generation int - command string - args []string -} - -// NewStarter creates a new Starter object. Config parameter may NOT be -// nil, as `Ports` and/or `Paths`, and `Command` are required -func NewStarter(c Config) (*Starter, error) { - if c == nil { - return nil, fmt.Errorf("config argument must be non-nil") - } - - var signalOnHUP os.Signal = syscall.SIGTERM - var signalOnTERM os.Signal = syscall.SIGTERM - if s := c.SignalOnHUP(); s != nil { - signalOnHUP = s - } - if s := c.SignalOnTERM(); s != nil { - signalOnTERM = s - } - - if c.Command() == "" { - return nil, fmt.Errorf("argument Command must be specified") - } - if _, err := exec.LookPath(c.Command()); err != nil { - return nil, err - } + envload "github.com/lestrrat/go-envload" + "github.com/pkg/errors" +) - s := &Starter{ - args: c.Args(), - command: c.Command(), - dir: c.Dir(), - interval: c.Interval(), - listeners: make([]listener, 0, len(c.Ports())+len(c.Paths())), - pidFile: c.PidFile(), - ports: c.Ports(), - paths: c.Paths(), - signalOnHUP: signalOnHUP, - signalOnTERM: signalOnTERM, - statusFile: c.StatusFile(), +func New(command string, options ...Option) *Starter { + return &Starter{ + command: command, + options: options, // This is stored as-is on purpose. } - - return s, nil - -} - -func (s Starter) Stop() { - p, _ := os.FindProcess(os.Getpid()) - p.Signal(syscall.SIGTERM) } func grabExitStatus(st processState) syscall.WaitStatus { @@ -141,111 +37,202 @@ func grabExitStatus(st processState) syscall.WaitStatus { return exitSt } -type processState interface { - Pid() int - Sys() interface{} -} -type dummyProcessState struct { - pid int - status syscall.WaitStatus -} +func parsePortSpec(addr string) (string, int, error) { + i := strings.IndexByte(addr, ':') + portPart := "" + if i < 0 { + portPart = addr + addr = "" + } else { + portPart = addr[i+1:] + addr = addr[:i] + } -func (d dummyProcessState) Pid() int { - return d.pid -} + port, err := strconv.ParseInt(portPart, 10, 64) + if err != nil { + return "", -1, err + } -func (d dummyProcessState) Sys() interface{} { - return d.status + return addr, int(port), nil } -func signame(s os.Signal) string { - if ss, ok := s.(syscall.Signal); ok { - return niceSigNames[ss] +// This keeps listening to INT,TERM,HUP, and ALRM signals, +// and queues them up into a destination channel +func acceptSignals(ctx context.Context, dst chan os.Signal) { + src := make(chan os.Signal, 32) // up to 32 signals + signal.Notify(src, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGALRM) + signal.Ignore(syscall.SIGPIPE) + defer close(dst) + for { + select { + case <-ctx.Done(): + return + case sig, ok := <-src: + if !ok { + return + } + dst <- sig + } } - return "UNKNOWN" } -// SigFromName returns the signal corresponding to the given signal name string. -// If the given name string is not defined, it returns nil. -func SigFromName(n string) os.Signal { - n = strings.ToUpper(n) - if strings.HasPrefix(n, "SIG") { - n = n[3:] // remove SIG prefix - } - - if sig, ok := niceNameToSigs[n]; ok { - return sig +func wait(ctx context.Context, sigCh chan os.Signal, workerDone chan *exec.Cmd) *exec.Cmd { + // Original code in lib/Server/Starter.pm (_wait3) looks... interesting + // currently going to punt it in light of just "wait for a process to + // finish or a signal is received" + t := time.NewTicker(time.Second) + defer t.Stop() + for { + select { + case <-ctx.Done(): + p, _ := os.FindProcess(os.Getpid()) + p.Signal(os.Signal(syscall.SIGTERM)) + return nil + case <-t.C: + if len(sigCh) > 0 { + return nil + } + case cmd := <-workerDone: + return cmd + } } return nil } -func setEnv() { - if os.Getenv("ENVDIR") == "" { +var registerCleanupKey struct{} + +func registerCleanup(ctx context.Context, f func()) { + register, ok := ctx.Value(registerCleanupKey).(func(func())) + if !ok { return } + register(f) +} - m, err := reloadEnv() - if err != nil && err != errNoEnv { - // do something - fmt.Fprintf(os.Stderr, "failed to load from envdir: %s\n", err) +func cleanup(ctx context.Context, ch chan func()) { + var finalizers []func() + for loop := true; loop; { + select { + case <-ctx.Done(): + loop = false + continue + case f, ok := <-ch: + if ok { + finalizers = append(finalizers, f) + } + } } - - for k, v := range m { - os.Setenv(k, v) + for _, f := range finalizers { + f() } } -func parsePortSpec(addr string) (string, int, error) { - i := strings.IndexByte(addr, ':') - portPart := "" - if i < 0 { - portPart = addr - addr = "" - } else { - portPart = addr[i+1:] - addr = addr[:i] +func (s *Starter) Run(ctx context.Context) error { + var cancel func() + ctx, cancel = context.WithCancel(ctx) + defer cancel() + + var cmdArgs []string + var dir string + var interval time.Duration = time.Second + var paths []string + var pidFile string + var listeners []listener + var ports []string + var sigonhup os.Signal = os.Signal(syscall.SIGTERM) + var sigonterm os.Signal = os.Signal(syscall.SIGTERM) + var statusFile string + var noticeOutput io.Writer = os.Stderr + var logStdout io.Writer = os.Stdout + var logStderr io.Writer = os.Stderr + var killOldDelay time.Duration = 5 * time.Second + + for _, opt := range s.options { + switch opt.Name() { + case "args": + cmdArgs = opt.Value().([]string) + case "auto_restart_interval": + v := opt.Value().(int) + os.Setenv(`AUTO_RESTART_INTERVAL`, strconv.Itoa(v)) + case "dir": + dir = opt.Value().(string) + case "enable_auto_restart": + b := opt.Value().(bool) + if b { + os.Setenv(`ENABLE_AUTO_RESTART`, `1`) + } else { + os.Setenv(`ENABLE_AUTO_RESTART`, `0`) + } + case "envdir": + os.Setenv("ENVDIR", opt.Value().(string)) + case "interval": + interval = opt.Value().(time.Duration) + case "kill_old_delay": + killOldDelay = opt.Value().(time.Duration) + case "paths": + paths = opt.Value().([]string) + case "pid_file": + pidFile = opt.Value().(string) + case "ports": + ports = opt.Value().([]string) + case "signal_on_hup": + sigonhup = opt.Value().(os.Signal) + case "signal_on_term": + sigonterm = opt.Value().(os.Signal) + case "status_file": + statusFile = opt.Value().(string) + case "notice_output": + noticeOutput = opt.Value().(io.Writer) + case "log_stdout": + logStdout = opt.Value().(io.Writer) + case "log_stderr": + logStderr = opt.Value().(io.Writer) + } } - port, err := strconv.ParseInt(portPart, 10, 64) - if err != nil { - return "", -1, err + if envAsBool(`ENABLE_AUTO_RESTART`) { + os.Setenv(`KILL_OLD_DELAY`, strconv.Itoa(int(killOldDelay/time.Second))) } - return addr, int(port), nil -} - -func (s *Starter) Run() error { - defer s.Teardown() - - if s.pidFile != "" { - f, err := os.OpenFile(s.pidFile, os.O_EXCL|os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) - if err != nil { - return err + notice := func(f string, args ...interface{}) { + var buf bytes.Buffer + fmt.Fprintf(&buf, f, args...) + if buf.Len() == 0 { + return } - if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX); err != nil { - return err + b := buf.Bytes() + if b[len(b)-1] != '\n' { + buf.WriteByte('\n') } - fmt.Fprintf(f, "%d", os.Getpid()) - defer func() { - os.Remove(f.Name()) - f.Close() - }() + buf.WriteTo(noticeOutput) } - for _, addr := range s.ports { + generation := 0 // This is SERVER_STARTER_GENERATION + os.Setenv(`SERVER_STARTER_GENERATION`, `0`) + + cleanupCh := make(chan func()) + ctx = context.WithValue(ctx, registerCleanupKey, func(f func()) { + cleanupCh <- f + }) + go cleanup(ctx, cleanupCh) + + // start listening + extraFiles := make([]*os.File, 0, len(ports)+len(paths)) + portSpecs := make([]string, 0, len(ports)+len(paths)) + for _, addr := range ports { var l net.Listener host, port, err := parsePortSpec(addr) if err != nil { - fmt.Fprintf(os.Stderr, "failed to parse addr spec '%s': %s", addr, err) + notice("failed to parse addr spec '%s': %s", addr, err) return err } hostport := fmt.Sprintf("%s:%d", host, port) l, err = net.Listen("tcp4", hostport) if err != nil { - fmt.Fprintf(os.Stderr, "failed to listen to %s:%s\n", hostport, err) + notice("failed to listen to %s:%s\n", hostport, err) return err } @@ -255,324 +242,363 @@ func (s *Starter) Run() error { } else { spec = fmt.Sprintf("%s:%d", host, port) } - s.listeners = append(s.listeners, listener{listener: l, spec: spec}) + f, err := l.(*net.TCPListener).File() + if err != nil { + return errors.Wrap(err, "failed to get fd from listener") + } + registerCleanup(ctx, func() { f.Close() }) + extraFiles = append(extraFiles, f) + portSpecs = append(portSpecs, fmt.Sprintf("%s=%d", spec, len(portSpecs)+3)) + listeners = append(listeners, listener{listener: l, spec: spec}) } - for _, path := range s.paths { + for _, path := range paths { var l net.Listener if fl, err := os.Lstat(path); err == nil && fl.Mode()&os.ModeSocket == os.ModeSocket { - fmt.Fprintf(os.Stderr, "removing existing socket file:%s\n", path) + notice("removing existing socket file:%s\n", path) err = os.Remove(path) if err != nil { - fmt.Fprintf(os.Stderr, "failed to remove existing socket file:%s:%s\n", path, err) + notice("failed to remove existing socket file:%s:%s\n", path, err) return err } } _ = os.Remove(path) l, err := net.Listen("unix", path) if err != nil { - fmt.Fprintf(os.Stderr, "failed to listen file:%s:%s\n", path, err) + notice("failed to listen file:%s:%s\n", path, err) return err } - s.listeners = append(s.listeners, listener{listener: l, spec: path}) + f, err := l.(*net.UnixListener).File() + if err != nil { + return errors.Wrap(err, "failed to get fd from listener") + } + registerCleanup(ctx, func() { f.Close() }) + extraFiles = append(extraFiles, f) + portSpecs = append(portSpecs, fmt.Sprintf("%s=%d", path, len(portSpecs)+3)) + listeners = append(listeners, listener{listener: l, spec: path}) } - s.generation = 0 - os.Setenv("SERVER_STARTER_GENERATION", fmt.Sprintf("%d", s.generation)) - - // XXX Not portable - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, - syscall.SIGHUP, - syscall.SIGINT, - syscall.SIGTERM, - syscall.SIGQUIT, - ) - - // Okay, ready to launch the program now... - setEnv() - workerCh := make(chan processState) - p := s.StartWorker(sigCh, workerCh) - oldWorkers := make(map[int]int) - var sigReceived os.Signal - var sigToSend os.Signal - - statusCh := make(chan map[int]int) - go func(fn string, ch chan map[int]int) { - for wmap := range ch { - if fn == "" { - continue - } + os.Setenv("SERVER_STARTER_PORT", strings.Join(portSpecs, ";")) - f, err := os.OpenFile(fn, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) + // Note: environment variables that are set after this + // will NOT be re-populated + envLoader := envload.New() + + var statusFileCreated bool + defer func() { + if statusFileCreated { + os.Remove(statusFile) + } + }() + var currentWorker int // pid + var lastRestartTime time.Time + oldWorkers := map[int]int{} // pid to generation + + var updateStatus func() error + switch fn := statusFile; fn { + case "": + updateStatus = func() error { return nil } + default: + updateStatus = func() error { + tmpfn := fn + "." + strconv.Itoa(os.Getpid()) + f, err := os.OpenFile(tmpfn, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) if err != nil { - continue + return errors.Wrapf(err, "failed to create temporary file:%s", fn) } - - for gen, pid := range wmap { - fmt.Fprintf(f, "%d:%d\n", gen, pid) + statusFileCreated = true + m := map[int]int{} + for k, v := range oldWorkers { + m[k] = v + } + if currentWorker > 0 { + m[generation] = currentWorker } + keys := make([]int, 0, len(oldWorkers)+1) + for k := range oldWorkers { + keys = append(keys, k) + } + sort.Ints(keys) + for _, k := range keys { + fmt.Fprintf(f, "%d:%d\n", k, m[k]) + } f.Close() + return errors.Wrapf(os.Rename(tmpfn, fn), "failed to rename %s to %s", fn, tmpfn) } - }(s.statusFile, statusCh) + } - defer func() { - if p != nil { - oldWorkers[p.Pid] = s.generation - } - - fmt.Fprintf(os.Stderr, "received %s, sending %s to all workers:", - signame(sigReceived), - signame(sigToSend), - ) - size := len(oldWorkers) - i := 0 - for pid := range oldWorkers { - i++ - fmt.Fprintf(os.Stderr, "%d", pid) - if i < size { - fmt.Fprintf(os.Stderr, ",") - } + // This watcher receives commands to watch for. + workerSrc := make(chan *exec.Cmd) + workerDone := make(chan *exec.Cmd) + go monitor(ctx, workerSrc, workerDone) + + // signal handler here queues up signals to the other + // channel, so that we can keep accepting signals while we + // only really handle them once per loop + sigCh := make(chan os.Signal, 32) + go acceptSignals(ctx, sigCh) + + errTryExec := errors.New("keep trying") + startCmd := func(cmd *exec.Cmd) error { + if err := cmd.Start(); err != nil { + notice("%s", err.Error()) + // We would LOVE to continue immediately, but we need to do the + // same check-for-signals and etc here, so we go on.. + } else { + notice("starting new worker %d", cmd.Process.Pid) } - fmt.Fprintf(os.Stderr, "\n") - for pid := range oldWorkers { - worker, err := os.FindProcess(pid) - if err != nil { - continue + // Wait for up to `interval` seconds before + // checking if this command (process) is alive + time.Sleep(interval) + + // Check if we have received any signals while we were + // waiting. this is a very dirty trick in that we are + // mucking with a channel that is potentially being written + // to concurrently :/ + nonhup := 0 + var bufferedSigs []os.Signal + l := len(sigCh) + for i := 0; i < l; i++ { + s := <-sigCh + bufferedSigs = append(bufferedSigs, s) + if s != os.Signal(syscall.SIGHUP) { + // do not immediately stop... read all + nonhup++ } - worker.Signal(sigToSend) } - - for len(oldWorkers) > 0 { - st := <-workerCh - fmt.Fprintf(os.Stderr, "worker %d died, status:%d\n", st.Pid(), grabExitStatus(st)) - delete(oldWorkers, st.Pid()) + if len(bufferedSigs) > 0 { + go func() { + for _, s := range bufferedSigs { + sigCh <- s + } + }() + if nonhup > 0 { // bailout + return errors.New("received signal while waiting") + } } - fmt.Fprintf(os.Stderr, "exiting\n") - }() - - // var lastRestartTime time.Time - for { // outer loop - setEnv() - // Just wait for the worker to exit, or for us to receive a signal - for { - // restart = 2: force restart - // restart = 1 and no workers: force restart - // restart = 0: no restart - restart := 0 - - select { - case st := <-workerCh: - // oops, the worker exited? check for its pid - if p.Pid == st.Pid() { // current worker - exitSt := grabExitStatus(st) - fmt.Fprintf(os.Stderr, "worker %d died unexpectedly with status %d, restarting\n", p.Pid, exitSt) - p = s.StartWorker(sigCh, workerCh) - // lastRestartTime = time.Now() - } else { - exitSt := grabExitStatus(st) - fmt.Fprintf(os.Stderr, "old worker %d died, status:%d\n", st.Pid(), exitSt) - delete(oldWorkers, st.Pid()) - } - case sigReceived = <-sigCh: - // Temporary fix - switch sigReceived { - case syscall.SIGHUP: - // When we receive a HUP signal, we need to spawn a new worker - fmt.Fprintf(os.Stderr, "received HUP (num_old_workers=TODO)\n") - restart = 1 - sigToSend = s.signalOnHUP - case syscall.SIGTERM: - sigToSend = s.signalOnTERM - return nil - default: - sigToSend = syscall.SIGTERM + // Want to check if the given PID is still alive. + // This is not a great way to do it b/c we're not + // even sure the Pid we're looking for is the same + // process as the one we spawned, but... this is + // so far the best we can do + // Note: Does this work on windows? + if cmd.Process != nil { + p, err := os.FindProcess(cmd.Process.Pid) + if err == nil { + if err := p.Signal(os.Signal(syscall.Signal(0))); err == nil { return nil } } + } - if restart > 1 || restart > 0 && len(oldWorkers) == 0 { - fmt.Fprintf(os.Stderr, "spawning a new worker (num_old_workers=TODO)\n") - oldWorkers[p.Pid] = s.generation - p = s.StartWorker(sigCh, workerCh) - fmt.Fprintf(os.Stderr, "new worker is now running, sending %s to old workers:", signame(sigToSend)) - size := len(oldWorkers) - if size == 0 { - fmt.Fprintf(os.Stderr, "none\n") - } else { - i := 0 - for pid := range oldWorkers { - i++ - fmt.Fprintf(os.Stderr, "%d", pid) - if i < size { - fmt.Fprintf(os.Stderr, ",") - } - } - fmt.Fprintf(os.Stderr, "\n") + switch { + case cmd.ProcessState != nil: + notice("new worker %d seems to have failed to start, exit status:%d", cmd.ProcessState.Pid(), grabExitStatus(cmd.ProcessState)) + case cmd.Process != nil: + notice("new worker %d seems to have failed to start", cmd.Process.Pid) + default: + notice("new worker seems to have failed to start") + } + return errTryExec + } - killOldDelay := getKillOldDelay() - fmt.Fprintf(os.Stderr, "sleep %d secs\n", int(killOldDelay/time.Second)) - if killOldDelay > 0 { - time.Sleep(killOldDelay) - } + if pidFile != "" { + f, err := os.OpenFile(pidFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) + if err != nil { + return errors.Wrapf(err, "failed to open file:%s", pidFile) + } + defer f.Close() + defer os.Remove(f.Name()) - fmt.Fprintf(os.Stderr, "killing old workers\n") + if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX); err != nil { + return errors.Wrapf(err, "flock failed(%s)", pidFile) + } + fmt.Fprintf(f, "%d\n", os.Getpid()) + if err := f.Sync(); err != nil { + return errors.Wrapf(err, "failed to sync file(%s)", pidFile) + } + } - for pid := range oldWorkers { - worker, err := os.FindProcess(pid) - if err != nil { - continue - } - worker.Signal(s.signalOnHUP) - } - } - } + newCommand := func() *exec.Cmd { + cmd := exec.Command(s.command, cmdArgs...) + if dir != "" { + cmd.Dir = dir } + cmd.Stdout = logStdout + cmd.Stderr = logStderr + cmd.ExtraFiles = extraFiles + return cmd } - return nil -} + startWorker := func() error { + for loop := true; loop; { + generation++ + os.Setenv(`SERVER_STARTER_GENERATION`, strconv.Itoa(generation)) + + cmd := newCommand() + switch err := startCmd(cmd); err { + case nil: + loop = false + currentWorker = cmd.Process.Pid + lastRestartTime = time.Now() + updateStatus() + workerSrc <- cmd + case errTryExec: + // keep trying + default: + return errors.Wrap(err, "failed to start command") + } + } -func getKillOldDelay() time.Duration { - // Ignore errors. - delay, _ := strconv.ParseInt(os.Getenv("KILL_OLD_DELAY"), 10, 0) - autoRestart, _ := strconv.ParseBool(os.Getenv("ENABLE_AUTO_RESTART")) - if autoRestart && delay == 0 { - delay = 5 + return nil } - return time.Duration(delay) * time.Second -} - -type WorkerState int + var cleanupWorkers = func(sig os.Signal) { + termSig := os.Signal(syscall.SIGTERM) + if sig == termSig { + termSig = sigonterm + } -const ( - WorkerStarted WorkerState = iota - ErrFailedToStart -) + if currentWorker > 0 { + oldWorkers[currentWorker] = envAsInt(`SERVER_STARTER_GENERATION`) + currentWorker = 0 + } + var buf bytes.Buffer + fmt.Fprintf(&buf, "received %s, sending %s to all workers:", signame(sig), signame(termSig)) + keys := make([]int, 0, len(oldWorkers)) + for k := range oldWorkers { + keys = append(keys, k) + } + sort.Ints(keys) + for i, k := range keys { + fmt.Fprintf(&buf, "%d", k) + if i < len(keys)-1 { + buf.WriteByte(',') + } + } + notice(buf.String()) -// StartWorker starts the actual command. -func (s *Starter) StartWorker(sigCh chan os.Signal, ch chan processState) *os.Process { - // Don't give up until we're running. - for { - pid := -1 - cmd := exec.Command(s.command, s.args...) - if s.dir != "" { - cmd.Dir = s.dir - } - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - - // This whole section here basically sets up the env - // var and the file descriptors that are inherited by the - // external process - files := make([]*os.File, len(s.ports)+len(s.paths)) - ports := make([]string, len(s.ports)+len(s.paths)) - for i, l := range s.listeners { - // file descriptor numbers in ExtraFiles turn out to be - // index + 3, so we can just hard code it - var f *os.File - var err error - switch l.listener.(type) { - case *net.TCPListener: - f, err = l.listener.(*net.TCPListener).File() - case *net.UnixListener: - f, err = l.listener.(*net.UnixListener).File() - default: - panic("Unknown listener type") + for _, pid := range keys { + p, err := os.FindProcess(pid) + if err != nil { // XXX to be safe, let's delete this pid + delete(oldWorkers, pid) } - if err != nil { - panic(err) + p.Signal(termSig) + } + + for len(oldWorkers) > 0 { + cmd, ok := <-workerDone + if !ok { + panic("workerDone channel closed while still waiting for children to be reaped") } - defer f.Close() - ports[i] = fmt.Sprintf("%s=%d", l.spec, i+3) - files[i] = f + notice("worker %d died, status:%d", cmd.ProcessState.Pid(), grabExitStatus(cmd.ProcessState)) + delete(oldWorkers, cmd.ProcessState.Pid()) + updateStatus() } - cmd.ExtraFiles = files + } - s.generation++ - os.Setenv("SERVER_STARTER_PORT", strings.Join(ports, ";")) - os.Setenv("SERVER_STARTER_GENERATION", fmt.Sprintf("%d", s.generation)) + if err := startWorker(); err != nil { + return errors.Wrap(err, "failed to start worker") + } - // Now start! - if err := cmd.Start(); err != nil { - fmt.Fprintf(os.Stderr, "failed to exec %s: %s\n", cmd.Path, err) - } else { - // Save pid... - pid = cmd.Process.Pid - fmt.Fprintf(os.Stderr, "starting new worker %d\n", pid) - - // Wait for interval before checking if the process is alive - tch := time.After(s.interval) - sigs := []os.Signal{} - for loop := true; loop; { - select { - case <-tch: - // bail out - loop = false - case sig := <-sigCh: - sigs = append(sigs, sig) - } + for { + // wait for next signal (or when auto-restart becomes necessary) + exited := wait(ctx, sigCh, workerDone) + + // reload env if necessary + envLoader.Restore(envload.WithLoadEnvdir(true)) + + if envAsBool(`ENABLE_AUTO_RESTART`) { + if os.Getenv("AUTO_RESTART_INTERVAL") == "" { + os.Setenv("AUTO_RESTART_INTERVAL", "360") } + } - // if received any signals, during the wait, we bail out - gotSig := false - if len(sigs) > 0 { - for _, sig := range sigs { - // we need to resend these signals so it can be caught in the - // main routine... - go func() { sigCh <- sig }() - if sysSig, ok := sig.(syscall.Signal); ok { - if sysSig != syscall.SIGHUP { - gotSig = true - } - } + if exited != nil { // got some command exit + pid := exited.ProcessState.Pid() + if pid == currentWorker { + notice("worker %d died unexpectedly with status: %d, restarting\n", pid, grabExitStatus(exited.ProcessState)) + if err := startWorker(); err != nil { + return errors.Wrap(err, "failed to start worker") } + } else { + notice("old worker %d died, status:%d", pid, grabExitStatus(exited.ProcessState)) + delete(oldWorkers, pid) + updateStatus() } + } - // Check if we can find a process by its pid - p, err := os.FindProcess(pid) - if gotSig || err == nil { - // No error? We were successful! Make sure we capture - // the program exiting - go func() { - err := cmd.Wait() - if err != nil { - ch <- err.(*exec.ExitError).ProcessState - } else { - ch <- &dummyProcessState{pid: pid, status: successStatus} - } - }() - // Bail out - return p + var restart bool + for loop := true; loop; { + select { + case sig := <-sigCh: + switch sig { + case syscall.SIGHUP: + notice("received HUP, spawning a new worker") + restart = true + loop = false + case syscall.SIGALRM: + loop = false + default: + cleanupWorkers(sig) + return nil + } + default: + loop = false } - } - // If we fall through here, we prematurely exited :/ - // Make sure to wait to release resources - cmd.Wait() - for _, f := range cmd.ExtraFiles { - f.Close() + + if !restart && envAsBool("ENABLE_AUTO_RESTART") { + autoRestartInterval := envAsDuration("AUTO_RESTART_INTERVAL") + elapsedSinceRestart := time.Since(lastRestartTime) + if elapsedSinceRestart >= autoRestartInterval && len(oldWorkers) == 0 { + notice("autorestart triggered (interval=%s)", autoRestartInterval) + restart = true + } else if elapsedSinceRestart >= autoRestartInterval*2 { + notice("autorestart triggered (forced, interval=%s)", autoRestartInterval) + } } - fmt.Fprintf(os.Stderr, "new worker %d seems to have failed to start\n", pid) - } + if restart { + oldWorkers[currentWorker] = generation + if err := startWorker(); err != nil { + return errors.Wrap(err, "failed to restart worker") + } - // never reached - return nil -} + var buf bytes.Buffer + l := len(oldWorkers) + if l == 0 { + buf.WriteString("none") + } else { + i := 0 + for pid := range oldWorkers { + buf.WriteString(strconv.Itoa(pid)) + if i < l-1 { + buf.WriteByte(',') + } + } + } + notice("new worker is now running, sending %s to old workers: %s", signame(sigonhup), buf.String()) -func (s *Starter) Teardown() error { - if s.statusFile != "" { - os.Remove(s.statusFile) - } + killOldDelay := envAsDuration(`KILL_OLD_DELAY`) + if killOldDelay == 0 { + if envAsBool(`ENABLE_AUTO_RESTART`) { + killOldDelay = 5 * time.Second + } else { + killOldDelay = 1 + } + } - for _, l := range s.listeners { - l.listener.Close() + time.Sleep(killOldDelay) + for pid := range oldWorkers { + worker, err := os.FindProcess(pid) + if err != nil { + continue + } + worker.Signal(sigonhup) + } + } } - - return nil } diff --git a/starter_test.go b/starter_test.go index 83d467d..5be6e42 100644 --- a/starter_test.go +++ b/starter_test.go @@ -1,162 +1,296 @@ package starter import ( + "bytes" + "context" "fmt" "io" "io/ioutil" - "log" "net" "os" "os/exec" "path/filepath" - "regexp" - "strings" + "strconv" "syscall" "testing" "time" + + envload "github.com/lestrrat/go-envload" + tcputil "github.com/lestrrat/go-tcputil" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" ) -var echoServerTxt = `package main +var echoServerSrc = `package main import ( + "flag" "fmt" "io" "net/http" "os" "os/signal" "syscall" - "time" "github.com/lestrrat/go-server-starter/listener" ) func main() { + fmt.Fprintf(os.Stderr, "Starting echod (%d)\n", os.Getpid()) + var maxSigusr1 int // number of times we "withstand" a sigusr1 + flag.IntVar(&maxSigusr1, "sigusr1", 0, "") + flag.Parse() + listeners, err := listener.ListenAll() if err != nil { fmt.Fprintf(os.Stderr, "Failed to listen: %s\n", err) os.Exit(1) } + defer func() { + for _, l := range listeners { + l.Close() + } + }() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { io.Copy(w, r.Body) }) for _, l := range listeners { - http.Serve(l, handler) + go http.Serve(l, handler) } - loop := false - sigCh := make(chan os.Signal) - signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGHUP) - for loop { + fmt.Fprintf(os.Stderr, "echod: Waiting for signal (max USR1 = %d)\n", maxSigusr1) + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGUSR1) + + sigusr1 := 0 + for loop := true; loop; { select { - case <-sigCh: - loop = false - default: - time.Sleep(time.Second) + case s := <-sigCh: + fmt.Fprintf(os.Stderr, "echod: received %s\n", s) + switch s { + case syscall.SIGUSR1: + sigusr1++ + if maxSigusr1 > sigusr1 { + fmt.Fprintf(os.Stderr, "echod: got USR1, ignoring (max = %d, count = %d)\n", maxSigusr1, sigusr1) + } else { + fmt.Fprintf(os.Stderr, "echod: reached max USR1 limit (%d)\n", maxSigusr1) + loop = false + } + case syscall.SIGTERM: + loop = false + default: + // do nothing + } } } } ` -type config struct { - args []string - command string - dir string - interval int - pidfile string - ports []string - paths []string - sigonhup string - sigonterm string - statusfile string -} - -func (c config) Args() []string { return c.args } -func (c config) Command() string { return c.command } -func (c config) Dir() string { return c.dir } -func (c config) Interval() time.Duration { return time.Duration(c.interval) * time.Second } -func (c config) PidFile() string { return c.pidfile } -func (c config) Ports() []string { return c.ports } -func (c config) Paths() []string { return c.paths } -func (c config) SignalOnHUP() os.Signal { return SigFromName(c.sigonhup) } -func (c config) SignalOnTERM() os.Signal { return SigFromName(c.sigonterm) } -func (c config) StatusFile() string { return c.statusfile } - -func TestRun(t *testing.T) { - dir, err := ioutil.TempDir("", fmt.Sprintf("server-starter-test-%d", os.Getpid())) +func build(name string, src string) (string, func(), error) { + dir, err := ioutil.TempDir("", fmt.Sprintf("server-starter-test-%s-%d", name, os.Getpid())) if err != nil { - t.Errorf("Failed to create temp directory: %s", err) - return + return "", nil, errors.Wrapf(err, "failed to create tempdir %s", dir) } - defer os.RemoveAll(dir) - - srcFile := filepath.Join(dir, "echod.go") + cleanup := func() { + os.RemoveAll(dir) + } + srcFile := filepath.Join(dir, fmt.Sprintf("%s.go", name)) f, err := os.OpenFile(srcFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) if err != nil { - t.Errorf("Failed to create %s: %s", srcFile, err) - return + return "", cleanup, errors.Wrapf(err, "failed to create source file %s", f) } - io.WriteString(f, echoServerTxt) + io.WriteString(f, src) f.Close() - cmd := exec.Command("go", "build", "-o", filepath.Join(dir, "echod"), ".") + result := filepath.Join(dir, name) + cmd := exec.Command("go", "build", "-o", result, ".") + cmd.Env = nil cmd.Dir = dir - if output, err := cmd.CombinedOutput(); err != nil { - t.Errorf("Failed to compile %s: %s\n%s", dir, err, output) - return + output, err := cmd.CombinedOutput() + if err != nil { + return "", cleanup, errors.Wrapf(err, "failed to compile %s: %s", name, output) } + return result, cleanup, nil +} - ports := []string{"9090", "8080"} - sd, err := NewStarter(&config{ - ports: ports, - command: filepath.Join(dir, "echod"), - }) - if err != nil { - t.Errorf("Failed to create starter: %s", err) +func TestRun(t *testing.T) { + l := envload.New() + defer l.Restore() + + cmdname, cleanup, err := build("echod", echoServerSrc) + if cleanup != nil { + defer cleanup() + } + if !assert.NoError(t, err, "build failed") { return } - doneCh := make(chan struct{}) - readyCh := make(chan struct{}) - go func() { - defer func() { doneCh <- struct{}{} }() - time.AfterFunc(500*time.Millisecond, func() { - readyCh <- struct{}{} - }) - if err := sd.Run(); err != nil { - t.Errorf("sd.Run() failed: %s", err) - } - t.Logf("Exiting...") - }() + t.Run("normal execution", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + portCount := 2 + ports := make([]string, portCount) + for i := 0; i < 2; i++ { + p, err := tcputil.EmptyPort() + if !assert.NoError(t, err, "failed to find an empty port") { + return + } + ports[i] = strconv.Itoa(p) + } + var output bytes.Buffer + defer func() { + t.Logf("%s", output.String()) + }() + sd := New(cmdname, WithPorts(ports), WithNoticeOutput(&output)) + + done := make(chan struct{}) + go func() { + defer close(done) + if !assert.NoError(t, sd.Run(ctx), "Run should exit with no errors") { + return + } + }() - <-readyCh + tick := time.NewTicker(500 * time.Millisecond) + defer tick.Stop() - for _, port := range ports { - _, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%s", port)) - if err != nil { - t.Errorf("Error connecing to port '%s': %s", port, err) + ctx, cancel2 := context.WithTimeout(ctx, 5*time.Second) + defer cancel2() + for loop := true; loop; { + select { + case <-ctx.Done(): + t.Errorf("Error connecing: %s", ctx.Err()) + return + case <-tick.C: + ok := 0 + for _, port := range ports { + _, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%s", port)) + if err == nil { + t.Logf("Successfully connected to port %s", port) + ok++ + } + } + if ok == len(ports) { + loop = false + } + } } - } - time.AfterFunc(time.Second, sd.Stop) - <-doneCh + time.Sleep(time.Second) - log.Printf("Checking ports...") + var closed bool + select { + case <-done: + // grr, if we got here, done is closed + closed = true + default: + } - patterns := make([]string, len(ports)) - for i, port := range ports { - patterns[i] = fmt.Sprintf(`%s=\d+`, port) - } - pattern := regexp.MustCompile(strings.Join(patterns, ";")) + if !closed { + p, _ := os.FindProcess(os.Getpid()) + p.Signal(os.Signal(syscall.SIGTERM)) + } - if envPort := os.Getenv("SERVER_STARTER_PORT"); !pattern.MatchString(envPort) { - t.Errorf("SERVER_STARTER_PORT: Expected '%s', but got '%s'", pattern, envPort) - } + <-done + }) + l.Restore() + + t.Run("send multiple signals", func(t *testing.T) { + // Note: this test does NOT test that the same echod server has received + // signals, because doing that intelligently would require rpc between + // this test code and the echod, and I really am in no mood to do it + // for now. However, visually it looks like it's doing the right job + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + port, err := tcputil.EmptyPort() + if !assert.NoError(t, err, "failed to find an empty port") { + return + } + + var output bytes.Buffer + defer func() { + t.Logf("%s", output.String()) + }() + sd := New(cmdname, + WithArgs("--sigusr1=2"), + WithPorts([]string{strconv.Itoa(port)}), + WithNoticeOutput(&output), + WithLogStdout(&output), + WithLogStderr(&output), + WithSignalOnHUP(syscall.SIGUSR1), + ) + + done := make(chan struct{}) + go func() { + defer close(done) + if !assert.NoError(t, sd.Run(ctx), "Run should exit with no errors") { + return + } + }() + + time.Sleep(2 * time.Second) + + var closed bool + select { + case <-done: + closed = true + t.Errorf("unexpected exit") + default: + } + + if !closed { + p, _ := os.FindProcess(os.Getpid()) + p.Signal(syscall.SIGHUP) + } + + time.Sleep(2 * time.Second) + + closed = false + select { + case <-done: + closed = true + t.Errorf("unexpected exit") + default: + } + + if !closed { + p, _ := os.FindProcess(os.Getpid()) + p.Signal(syscall.SIGHUP) + } + + time.Sleep(2 * time.Second) + + closed = false + select { + case <-done: + closed = true + t.Errorf("unexpected exit") + default: + } + + if !closed { + p, _ := os.FindProcess(os.Getpid()) + p.Signal(syscall.SIGTERM) + } + time.Sleep(time.Second) + + select { + case <-ctx.Done(): + t.Errorf("context prematurely ended: %s", ctx.Err()) + p, _ := os.FindProcess(os.Getpid()) + p.Signal(syscall.SIGTERM) + case <-done: + } + }) + l.Restore() } func TestSigFromName(t *testing.T) { for sig, name := range niceSigNames { - if got := SigFromName(name); sig != got { + if got := sigFromName(name); sig != got { t.Errorf("%v: wants '%v' but got '%v'", name, sig, got) } } @@ -167,7 +301,7 @@ func TestSigFromName(t *testing.T) { "Hup": syscall.SIGHUP, } for name, sig := range variants { - if got := SigFromName(name); sig != got { + if got := sigFromName(name); sig != got { t.Errorf("%v: wants '%v' but got '%v'", name, sig, got) } } diff --git a/status_unix.go b/status_unix.go new file mode 100644 index 0000000..bdfb2c3 --- /dev/null +++ b/status_unix.go @@ -0,0 +1,9 @@ +// +build !windows + +package starter + +import "syscall" + +func init() { + failureStatus = syscall.WaitStatus(0) +} diff --git a/status_windows.go b/status_windows.go new file mode 100644 index 0000000..d96f6a9 --- /dev/null +++ b/status_windows.go @@ -0,0 +1,7 @@ +package starter + +import "syscall" + +func init() { + failureStatus = syscall.WaitStatus{ExitCode: 255} +}