From c83e2326b9023bd46fee6dcd5777bd4e4f15d496 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Sun, 27 Nov 2016 23:26:47 +0900 Subject: [PATCH 01/26] start rewriting to match server::starter 0.32 --- cli.go | 156 +++++++ cmd/start_server/start_server.go | 154 +------ interface.go | 103 +++++ internal/env/env.go | 169 +++++++ internal/env/env_test.go | 54 +++ internal/env/interface.go | 34 ++ internal/env/options.go | 18 + monitor.go | 100 +++++ options.go | 130 ++++++ signals.go | 64 +++ starter_any.go => signals_any.go | 5 - signals_windows.go | 7 + starter.go | 123 +----- starter_test.go | 72 ++- status_any.go | 10 + starter_windows.go => status_windows.go | 6 +- worker.go | 557 ++++++++++++++++++++++++ 17 files changed, 1444 insertions(+), 318 deletions(-) create mode 100644 cli.go create mode 100644 interface.go create mode 100644 internal/env/env.go create mode 100644 internal/env/env_test.go create mode 100644 internal/env/interface.go create mode 100644 internal/env/options.go create mode 100644 monitor.go create mode 100644 options.go create mode 100644 signals.go rename starter_any.go => signals_any.go (86%) create mode 100644 signals_windows.go create mode 100644 status_any.go rename starter_windows.go => status_windows.go (58%) create mode 100644 worker.go diff --git a/cli.go b/cli.go new file mode 100644 index 0000000..ac39f4c --- /dev/null +++ b/cli.go @@ -0,0 +1,156 @@ +package starter + +import ( + "context" + "errors" + "fmt" + "os" + "reflect" + "strings" + "time" + + flags "github.com/jessevdk/go-flags" +) + +func NewCLI() *CLI { + return &CLI{} +} + +func makeOptionList(opts *options) []Option { + var list []Option + if len(opts.Args) > 0 { + list = append(list, WithArgs(opts.Args)) + } + if opts.AutoRestartInterval.Valid { + list = append(list, WithAutoRestartInterval(time.Duration(opts.AutoRestartInterval.Value)*time.Second)) + } + if opts.Dir != "" { + list = append(list, WithDir(opts.Dir)) + } + if opts.EnableAutoRestart.Valid { + list = append(list, WithAutoRestart(opts.EnableAutoRestart.Value)) + } + if opts.Envdir.Valid { + list = append(list, WithEnvdir(opts.Envdir.Value)) + } + if opts.Interval > -1 { + list = append(list, WithInterval(time.Duration(opts.Interval) * time.Second)) + } + if opts.KillOldDelay.Valid { + list = append(list, WithKillOldDelay(time.Duration(opts.KillOldDelay.Value)*time.Second)) + } + if len(opts.Paths) > 0 { + list = append(list, WithPaths(opts.Paths)) + } + if opts.PidFile != "" { + list = append(list, WithPidFile(opts.PidFile)) + } + if len(opts.Ports) > 0 { + list = append(list, WithPorts(opts.Ports)) + } + if opts.SignalOnHUP != "" { + list = append(list, WithSignalOnHUP(SigFromName(opts.SignalOnHUP))) + } + if opts.SignalOnTERM != "" { + list = append(list, WithSignalOnTERM(SigFromName(opts.SignalOnTERM))) + } + if opts.StatusFile != "" { + list = append(list, WithStatusFile(opts.StatusFile)) + } + return list +} +func (cli *CLI) Run(ctx context.Context) error { + var opts options + opts.Interval = -1 // allow 0 + p := flags.NewParser(&opts, flags.PrintErrors|flags.PassDoubleDash) + args, err := p.Parse() + if err != nil || opts.Help { + showHelp() + return nil + } + + if opts.Version { + fmt.Printf("%s\n", version) + return nil + } + + if opts.Interval <= 0 { + opts.Interval = 1 + } + + if len(args) == 0 { + return errors.New("server program not specified") + } + + opts.Command = args[0] + if len(args) > 1 { + opts.Args = args[1:] + } + + s := New(opts.Command, makeOptionList(&opts)...) + return s.Run(ctx) +} + +func showHelp() { + // The ONLY reason we're not using go-flags' help option is + // because I wanted to tweak the format just a bit... but + // there wasn't an easy way to do so + os.Stderr.WriteString(` +Usage: + start_server [options] -- server-prog server-arg1 server-arg2 ... + + # start Plack using Starlet listening at TCP port 8000 + start_server --port=8000 -- plackup -s Starlet --max-workers=100 index.psgi + +Options: +`) + + t := reflect.TypeOf(options{}) + + // This weird indexing stuff is done purely to keep ourselves + // compatible with the original start_server program + // (This is the order that the help is displayed in) + names := []string{ + "Ports", + "Paths", + "Dir", + "Interval", + "SignalOnHUP", + "SignalOnTERM", + "PidFile", + "StatusFile", + "Envdir", + "EnableAutoRestart", + "AutoRestartInterval", + "KillOldDelay", + "Restart", + "Help", + "Version", + } + + for _, name := range names { + f, ok := t.FieldByName(name) + if !ok { + continue + } + + tag := f.Tag + if tag == "" { + continue + } + if s := tag.Get("long"); s != "" { + fmt.Fprintf(os.Stderr, " --%s", s) + if a := tag.Get("arg"); a != "" { + fmt.Fprintf(os.Stderr, "=%s", a) + } + if tag.Get("note") == "unimplemented" { + fmt.Fprintf(os.Stderr, " (UNIMPLEMENTED)") + } + fmt.Fprintf(os.Stderr, ":\n") + } + for _, l := range strings.Split(tag.Get("description"), "\n") { + fmt.Fprintf(os.Stderr, " %s\n", l) + } + fmt.Fprintf(os.Stderr, "\n") + } +} diff --git a/cmd/start_server/start_server.go b/cmd/start_server/start_server.go index 69dc013..008030a 100644 --- a/cmd/start_server/start_server.go +++ b/cmd/start_server/start_server.go @@ -1,161 +1,17 @@ package main import ( + "context" "fmt" "os" - "reflect" - "strings" - "time" - "github.com/jessevdk/go-flags" "github.com/lestrrat/go-server-starter" ) -const version = "0.0.2" - -type options struct { - OptArgs []string - OptCommand string - OptDir string `long:"dir" arg:"path" description:"working directory, start_server do chdir to before exec (optional)"` - OptInterval int `long:"interval" arg:"seconds" description:"minimum interval (in seconds) to respawn the server program (default: 1)"` - OptPorts []string `long:"port" arg:"(port|host:port)" description:"TCP port to listen to (if omitted, will not bind to any ports)"` - OptPaths []string `long:"path" arg:"path" description:"path at where to listen using unix socket (optional)"` - OptSignalOnHUP string `long:"signal-on-hup" arg:"Signal" description:"name of the signal to be sent to the server process when start_server\nreceives a SIGHUP (default: TERM). If you use this option, be sure to\nalso use '--signal-on-term' below."` - OptSignalOnTERM string `long:"signal-on-term" arg:"Signal" description:"name of the signal to be sent to the server process when start_server\nreceives a SIGTERM (default: TERM)"` - OptPidFile string `long:"pid-file" arg:"filename" description:"if set, writes the process id of the start_server process to the file"` - OptStatusFile string `long:"status-file" arg:"filename" description:"if set, writes the status of the server process(es) to the file"` - OptEnvdir string `long:"envdir" arg:"Envdir" description:"directory that contains environment variables to the server processes.\nIt is intended for use with \"envdir\" in \"daemontools\". This can be\noverwritten by environment variable \"ENVDIR\"."` - OptEnableAutoRestart bool `long:"enable-auto-restart" description:"enables automatic restart by time. This can be overwritten by\nenvironment variable \"ENABLE_AUTO_RESTART\"." note:"unimplemented"` - OptAutoRestartInterval int `long:"auto-restart-interval" arg:"seconds" description:"automatic restart interval (default 360). It is used with\n\"--enable-auto-restart\" option. This can be overwritten by environment\nvariable \"AUTO_RESTART_INTERVAL\"." note:"unimplemented"` - OptKillOldDelay int `long:"kill-old-delay" arg:"seconds" description:"time to suspend to send a signal to the old worker. The default value is\n5 when \"--enable-auto-restart\" is set, 0 otherwise. This can be\noverwritten by environment variable \"KILL_OLD_DELAY\"."` - OptRestart bool `long:"restart" description:"this is a wrapper command that reads the pid of the start_server process\nfrom --pid-file, sends SIGHUP to the process and waits until the\nserver(s) of the older generation(s) die by monitoring the contents of\nthe --status-file" note:"unimplemented"` - OptHelp bool `long:"help" description:"prints this help"` - OptVersion bool `long:"version" description:"prints the version number"` -} - -func (o options) Args() []string { return o.OptArgs } -func (o options) Command() string { return o.OptCommand } -func (o options) Dir() string { return o.OptDir } -func (o options) Interval() time.Duration { return time.Duration(o.OptInterval) * time.Second } -func (o options) PidFile() string { return o.OptPidFile } -func (o options) Ports() []string { return o.OptPorts } -func (o options) Paths() []string { return o.OptPaths } -func (o options) SignalOnHUP() os.Signal { return starter.SigFromName(o.OptSignalOnHUP) } -func (o options) SignalOnTERM() os.Signal { return starter.SigFromName(o.OptSignalOnTERM) } -func (o options) StatusFile() string { return o.OptStatusFile } - -func showHelp() { - // The ONLY reason we're not using go-flags' help option is - // because I wanted to tweak the format just a bit... but - // there wasn't an easy way to do so - os.Stderr.WriteString(` -Usage: - start_server [options] -- server-prog server-arg1 server-arg2 ... - - # start Plack using Starlet listening at TCP port 8000 - start_server --port=8000 -- plackup -s Starlet --max-workers=100 index.psgi - -Options: -`) - - t := reflect.TypeOf(options{}) - - // This weird indexing stuff is done purely to keep ourselves - // compatible with the original start_server program - // (This is the order that the help is displayed in) - names := []string{ - "OptPorts", - "OptPaths", - "OptDir", - "OptInterval", - "OptSignalOnHUP", - "OptSignalOnTERM", - "OptPidFile", - "OptStatusFile", - "OptEnvdir", - "OptEnableAutoRestart", - "OptAutoRestartInterval", - "OptKillOldDelay", - "OptRestart", - "OptHelp", - "OptVersion", - } - - for _, name := range names { - f, ok := t.FieldByName(name) - if !ok { - continue - } - - tag := f.Tag - if tag == "" { - continue - } - if s := tag.Get("long"); s != "" { - fmt.Fprintf(os.Stderr, " --%s", s) - if a := tag.Get("arg"); a != "" { - fmt.Fprintf(os.Stderr, "=%s", a) - } - if tag.Get("note") == "unimplemented" { - fmt.Fprintf(os.Stderr, " (UNIMPLEMENTED)") - } - fmt.Fprintf(os.Stderr, ":\n") - } - for _, l := range strings.Split(tag.Get("description"), "\n") { - fmt.Fprintf(os.Stderr, " %s\n", l) - } - fmt.Fprintf(os.Stderr, "\n") - } -} - func main() { - os.Exit(_main()) -} - -func _main() (st int) { - st = 1 - - opts := &options{OptInterval: -1} - p := flags.NewParser(opts, flags.PrintErrors|flags.PassDoubleDash) - args, err := p.Parse() - if err != nil || opts.OptHelp { - showHelp() - return - } - - if opts.OptVersion { - fmt.Printf("%s\n", version) - st = 0 - return - } - - if opts.OptInterval < 0 { - opts.OptInterval = 1 - } - - if len(args) == 0 { - fmt.Fprintf(os.Stderr, "server program not specified\n") - return - } - - opts.OptCommand = args[0] - if len(args) > 1 { - opts.OptArgs = args[1:] - } - - if opts.OptEnvdir != "" { - os.Setenv("ENVDIR", opts.OptEnvdir) - } - - s, err := starter.NewStarter(opts) - if err != nil { - fmt.Fprintf(os.Stderr, "error: %s\n", err) - return - } - if err := s.Run(); err != nil { - fmt.Fprintf(os.Stderr, "error: %s\n", err) - return + cli := starter.NewCLI() + if err := cli.Run(context.Background()); err != nil { + fmt.Fprintf(os.Stderr, "%s\n", err.Error()) + os.Exit(1) } - st = 0 - return } diff --git a/interface.go b/interface.go new file mode 100644 index 0000000..cf0eebe --- /dev/null +++ b/interface.go @@ -0,0 +1,103 @@ +package starter + +import ( + "io" + "net" + "os" + "syscall" + "time" + + "github.com/lestrrat/go-server-starter/internal/env" +) + +const version = `0.0.2` + +var successStatus syscall.WaitStatus +var failureStatus syscall.WaitStatus + +type listener struct { + listener net.Listener + spec string // path or port spec +} + +type Option interface { + Name() string + Value() interface{} +} + +type Config interface { + Args() []string + Command() string + Dir() string // Dirctory to chdir to before executing the command + Interval() time.Duration // Time between checks for liveness + PidFile() string + Ports() []string // Ports to bind to (addr:port or port, so it's a string) + Paths() []string // Paths (UNIX domain socket) to bind to + SignalOnHUP() os.Signal // Signal to send when HUP is received + SignalOnTERM() os.Signal // Signal to send when TERM is received + StatusFile() string +} + +type Starter struct { + options []Option + interval time.Duration + envLoader *env.Loader + noticeWriter io.Writer + extraFiles []*os.File + portSpecs []string + + signalOnHUP os.Signal + signalOnTERM os.Signal + // you can't set this in go: backlog + statusFile string + pidFile string + dir string + ports []string + paths []string + listeners []listener + generation int + command string + args []string +} + +type processState interface { + Pid() int + Sys() interface{} +} + +type dummyProcessState struct { + pid int + status syscall.WaitStatus +} + +func (d dummyProcessState) Pid() int { + return d.pid +} + +func (d dummyProcessState) Sys() interface{} { + return d.status +} + +type WorkerState int + +const ( + WorkerStarted WorkerState = iota + ErrFailedToStart +) + +type CLI struct{} +type boolOpt struct { + Valid bool + Value bool +} +type intOpt struct { + Valid bool + Value int +} + +type stringOpt struct { + Valid bool + Value string +} + + diff --git a/internal/env/env.go b/internal/env/env.go new file mode 100644 index 0000000..3bf7c17 --- /dev/null +++ b/internal/env/env.go @@ -0,0 +1,169 @@ +package env + +import ( + "bytes" + "context" + "io/ioutil" + "os" + "path/filepath" + "strings" +) + +func (e *sysenv) Clearenv() { + os.Clearenv() +} + +func (e *sysenv) Setenv(k, v string) { + os.Setenv(k, v) +} + +func SystemEnvironment() Environment { + return &sysenv{} +} + +func NewLoader(environ ...string) *Loader { + if len(environ) == 0 { + environ = os.Environ() + } + + var envdir string + original := make([]iterItem, 0, len(environ)) + for _, v := range environ { + i := strings.IndexByte(v, '=') + if i <= 0 || i >= len(v)-1 { + continue + } + original = append(original, iterItem{ + key: v[:i], + value: v[i+1:], + }) + if v[:i] == "ENVDIR" { + envdir = v[i+1:] + } + } + + return &Loader{ + original: original, + envdir: envdir, + } +} + +func (l *Loader) Restore(octx context.Context, e Environment) error { + return l.Apply(octx, e, WithLoadEnvdir(false)) +} + +func (l *Loader) Apply(octx context.Context, e Environment, options ...Option) error { + ctx, cancel := context.WithCancel(octx) + defer cancel() + + e.Clearenv() + iter := l.Iterator(ctx, options...) + for iter.Next() { + k, v := iter.KV() + e.Setenv(k, v) + } + + return nil +} + +func (l *Loader) Environ(octx context.Context, options ...Option) []string { + ctx, cancel := context.WithCancel(octx) + defer cancel() + + var environ []string + it := l.Iterator(ctx, options...) + for it.Next() { + k, v := it.KV() + environ = append(environ, k+`=`+v) + } + return environ +} + +func (l *Loader) Iterator(ctx context.Context, options ...Option) *Iterator { + loadEnvdir := true + for _, o := range options { + switch o.Name() { + case LoadEnvdirKey: + loadEnvdir = o.Value().(bool) + } + } + + ch := make(chan *iterItem) + ex := make(chan *iterItem) + defer close(ex) + + go func(m []iterItem, ch, ex chan *iterItem) { + defer close(ch) + for _, it := range m { + select { + case <-ctx.Done(): + return + case ch <- &iterItem{key: it.key, value: it.value}: + } + } + + for { + select { + case <-ctx.Done(): + return + case it, ok := <-ex: + if !ok { + return + } + select { + case <-ctx.Done(): + return + case ch <- it: + } + } + } + }(l.original, ch, ex) + + // meanwhile, load from envdir, if available + if loadEnvdir && l.envdir != "" { + if fi, err := os.Stat(l.envdir); err == nil && fi.IsDir() { + filepath.Walk(l.envdir, func(path string, fi os.FileInfo, err error) error { + // Ignore errors + if err != nil { + return nil + } + + // Do not recurse into directories + if fi.IsDir() && l.envdir != path { + return filepath.SkipDir + } + + buf, err := ioutil.ReadFile(path) + if err != nil { + return nil + } + + ex <- &iterItem{ + key: filepath.Base(path), + value: string(bytes.TrimSpace(buf)), + } + return nil + }) + } + } + + return &Iterator{ + ch: ch, + } +} + +func (iter *Iterator) Next() bool { + iter.nextK = "" + iter.nextV = "" + pair, ok := <-iter.ch + if !ok { + return false + } + iter.nextK = pair.key + iter.nextV = pair.value + return true +} + +func (iter *Iterator) KV() (string, string) { + return iter.nextK, iter.nextV +} diff --git a/internal/env/env_test.go b/internal/env/env_test.go new file mode 100644 index 0000000..f0c99d5 --- /dev/null +++ b/internal/env/env_test.go @@ -0,0 +1,54 @@ +package env_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/lestrrat/go-server-starter/internal/env" + "github.com/stretchr/testify/assert" +) + +func TestIter(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + src := []string{`FOO=foo`, `BAR=bar`, `BAZ=baz`} + l := env.NewLoader(src...) + i := l.Iterator(ctx) + if !assert.NotNil(t, i, "Iterator is ok") { + return + } + + os.Setenv(`QUUX`, `quux`) // This should have no effect + var list []string + for i.Next() { + k, v := i.KV() + t.Logf("%s=%v", k, v) + list = append(list, k+"="+v) + } + + if !assert.Equal(t, src, list) { + return + } +} + +func TestEnviron(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + src := []string{`FOO=foo`, `BAR=bar`, `BAZ=baz`} + l := env.NewLoader(src...) + i := l.Iterator(ctx) + if !assert.NotNil(t, i, "Iterator is ok") { + return + } + + os.Setenv(`QUUX`, `quux`) // This should have no effect + list := l.Environ(ctx) + if !assert.Equal(t, src, list) { + return + } +} + diff --git a/internal/env/interface.go b/internal/env/interface.go new file mode 100644 index 0000000..4189eb4 --- /dev/null +++ b/internal/env/interface.go @@ -0,0 +1,34 @@ +package env + +type Loader struct { + original []iterItem + envdir string +} + +type Iterator struct { + ch chan *iterItem + nextK string + nextV string +} + +type iterItem struct { + key string + value string +} + +type Environment interface { + Clearenv() + Setenv(string, string) +} +type sysenv struct{} + +type Option interface { + Name() string + Value() interface{} +} + +const LoadEnvdirKey = "LoadEnvdirKey" +type option struct { + name string + value interface{} +} diff --git a/internal/env/options.go b/internal/env/options.go new file mode 100644 index 0000000..a344a94 --- /dev/null +++ b/internal/env/options.go @@ -0,0 +1,18 @@ +package env + +func (o *option) Name() string { + return o.name +} + +func (o *option) Value() interface {} { + return o.value +} + +// WithLoadEnvdir specifies if Loader should load the original +// environment variables AND the contents of envdir +func WithLoadEnvdir(b bool) Option { + return &option{ + name: LoadEnvdirKey, + value: b, + } +} diff --git a/monitor.go b/monitor.go new file mode 100644 index 0000000..b3f36bf --- /dev/null +++ b/monitor.go @@ -0,0 +1,100 @@ +package starter + +import ( + "context" + "fmt" + "os/exec" + "reflect" + "strconv" + "time" +) + +func ExampleMonitor() { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + ch := make(chan *exec.Cmd, 10) + done := make(chan *exec.Cmd) + go monitor(ctx, ch, done) + + for i := 0; i < 10; i++ { + cmd := exec.Command("sleep", strconv.Itoa(i)) + cmd.Start() + ch <- cmd + } + + for { + select { + case <-ctx.Done(): + fmt.Printf("timeout reached\n") + return + case cmd, ok := <-done: + if !ok { + fmt.Println("monitor exited") + return + } + fmt.Printf("notified: %d\n", cmd.ProcessState.Pid()) + } + } +} + +// monitor is a process (i.e. *exec.Cmd) monitor. it asynchronously +// listens for either (a) one of the monitored process exits, or (b) +// we get a request to watch for a new worker. +// +// users of this function can "wait" for the exit of a command +// by checking the done channel +func monitor(ctx context.Context, src chan *exec.Cmd, done chan *exec.Cmd) { + defer close(done) + var workers []struct { + Chan chan error + Cmd *exec.Cmd + } + for { + cases := make([]reflect.SelectCase, len(workers)+1) + for i, worker := range workers { + cases[i].Chan = reflect.ValueOf(worker.Chan) + cases[i].Dir = reflect.SelectRecv + } + cases[len(cases)-1].Chan = reflect.ValueOf(src) + cases[len(cases)-1].Dir = reflect.SelectRecv + + chosen, recv, recvOK := reflect.Select(cases) + if !recvOK { + panic("should not get here") + } + + if chosen == len(workers) { + // 1 + max worker index, so must be our "new worker chan" + cmd := recv.Interface().(*exec.Cmd) + ch := make(chan error) + go func() { + ch <- cmd.Wait() + }() + workers = append(workers, struct { + Chan chan error + Cmd *exec.Cmd + }{ + Chan: ch, + Cmd: cmd, + }) + continue + } + + exited := workers[chosen].Cmd + // one of the workers must have finished + // remove the corresponding one + switch { + case len(workers) < 2: + workers = nil + case chosen == 0: + workers = workers[1:] + case chosen == len(workers)-1: + workers = workers[:chosen] + default: + workers = append(workers[:chosen], workers[chosen+1:]...) + } + + done <- exited + } +} diff --git a/options.go b/options.go new file mode 100644 index 0000000..ca7cd73 --- /dev/null +++ b/options.go @@ -0,0 +1,130 @@ +package starter + +import ( + "os" + "strconv" + "time" +) + +type valueOption struct { + name string + value interface{} +} + +func (o *valueOption) Name() string { + return o.name +} + +func (o *valueOption) Value() interface{} { + return o.value +} + +func WithAutoRestart(b bool) Option { + return &valueOption{name: "enable_auto_restart", value: b} +} + +func WithAutoRestartInterval(t time.Duration) Option { + return &valueOption{name: "auto_restart_interval", value: t} +} + +func WithArgs(a []string) Option { + return &valueOption{name: "args", value: a} +} + +func WithDir(dir string) Option { + return &valueOption{name: "dir", value: dir} +} + +func WithEnvdir(dir string) Option { + return &valueOption{name: "envdir", value: dir} +} + +func WithInterval(t time.Duration) Option { + return &valueOption{name: "interval", value: t} +} + +func WithKillOldDelay(t time.Duration) Option { + return &valueOption{name: "kill_old_interval", value: t} +} + +func WithPaths(l []string) Option { + return &valueOption{name: "paths", value: l} +} + +func WithPidFile(s string) Option { + return &valueOption{name: "pid_file", value: s} +} + +func WithPorts(l []string) Option { + return &valueOption{name: "ports", value: l} +} + +func WithSignalOnHUP(s os.Signal) Option { + return &valueOption{name: "signal_on_hup", value: s} +} + +func WithSignalOnTERM(s os.Signal) Option { + return &valueOption{name: "signal_on_term", value: s} +} + +func WithStatusFile(s string) Option { + return &valueOption{name: "status_file", value: s} +} + +func (o *stringOpt) String() string { + return o.Value +} + +func (o *stringOpt) Set(s string) error { + o.Valid = true + o.Value = s + return nil +} + +func (o *intOpt) String() string { + return strconv.FormatInt(int64(o.Value), 10) +} + +func (o *intOpt) Set(s string) error { + i, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return err + } + o.Valid = true + o.Value = int(i) + return nil +} + +func (o *boolOpt) String() string { + return strconv.FormatBool(o.Value) +} + +func (o *boolOpt) Set(s string) error { + b, err := strconv.ParseBool(s) + if err != nil { + return err + } + o.Valid = true + o.Value = b + return nil +} + +type options struct { + Args []string + AutoRestartInterval intOpt `long:"auto-restart-interval" arg:"seconds" description:"automatic restart interval (default 360). It is used with\n\"--enable-auto-restart\" option. This can be overwritten by environment\nvariable \"AUTO_RESTART_INTERVAL\"." note:"unimplemented"` + Command string + Dir string `long:"dir" arg:"path" description:"working directory, start_server do chdir to before exec (optional)"` + EnableAutoRestart boolOpt `long:"enable-auto-restart" description:"enables automatic restart by time. This can be overwritten by\nenvironment variable \"ENABLE_AUTO_RESTART\"." note:"unimplemented"` + Envdir stringOpt `long:"envdir" arg:"Envdir" description:"directory that contains environment variables to the server processes.\nIt is intended for use with \"envdir\" in \"daemontools\". This can be\noverwritten by environment variable \"ENVDIR\"."` + Interval int `long:"interval" arg:"seconds" description:"minimum interval (in seconds) to respawn the server program (default: 1)"` + KillOldDelay intOpt `long:"kill-old-delay" arg:"seconds" description:"time to suspend to send a signal to the old worker. The default value is\n5 when \"--enable-auto-restart\" is set, 0 otherwise. This can be\noverwritten by environment variable \"KILL_OLD_DELAY\"."` + Paths []string `long:"path" arg:"path" description:"path at where to listen using unix socket (optional)"` + PidFile string `long:"pid-file" arg:"filename" description:"if set, writes the process id of the start_server process to the file"` + Ports []string `long:"port" arg:"(port|host:port)" description:"TCP port to listen to (if omitted, will not bind to any ports)"` + Restart bool `long:"restart" description:"this is a wrapper command that reads the pid of the start_server process\nfrom --pid-file, sends SIGHUP to the process and waits until the\nserver(s) of the older generation(s) die by monitoring the contents of\nthe --status-file" note:"unimplemented"` + SignalOnHUP string `long:"signal-on-hup" arg:"Signal" description:"name of the signal to be sent to the server process when start_server\nreceives a SIGHUP (default: TERM). If you use this option, be sure to\nalso use '--signal-on-term' below."` + SignalOnTERM string `long:"signal-on-term" arg:"Signal" description:"name of the signal to be sent to the server process when start_server\nreceives a SIGTERM (default: TERM)"` + StatusFile string `long:"status-file" arg:"filename" description:"if set, writes the status of the server process(es) to the file"` + Help bool `long:"help" description:"prints this help"` + Version bool `long:"version" description:"printes the version number"` +} diff --git a/signals.go b/signals.go new file mode 100644 index 0000000..d4c747a --- /dev/null +++ b/signals.go @@ -0,0 +1,64 @@ +package starter + +import ( + "os" + "strings" + "syscall" +) + +var niceSigNames map[syscall.Signal]string +var niceNameToSigs map[string]syscall.Signal + +func makeNiceSigNames() map[syscall.Signal]string { + m := map[syscall.Signal]string{ + syscall.SIGABRT: "ABRT", + syscall.SIGALRM: "ALRM", + syscall.SIGBUS: "BUS", + // syscall.SIGEMT: "EMT", + syscall.SIGFPE: "FPE", + syscall.SIGHUP: "HUP", + syscall.SIGILL: "ILL", + // syscall.SIGINFO: "INFO", + syscall.SIGINT: "INT", + // syscall.SIGIOT: "IOT", + syscall.SIGKILL: "KILL", + syscall.SIGPIPE: "PIPE", + syscall.SIGQUIT: "QUIT", + syscall.SIGSEGV: "SEGV", + syscall.SIGTERM: "TERM", + syscall.SIGTRAP: "TRAP", + } + + // addPlatformDepdentNiceSigNames() is defined in the files + // containing build tags + return addPlatformDependentNiceSigNames(m) +} + +func init() { + niceSigNames = makeNiceSigNames() + niceNameToSigs = make(map[string]syscall.Signal) + for sig, name := range niceSigNames { + niceNameToSigs[name] = sig + } +} + +func signame(s os.Signal) string { + if ss, ok := s.(syscall.Signal); ok { + return niceSigNames[ss] + } + return "UNKNOWN" +} + +// SigFromName returns the signal corresponding to the given signal name string. +// If the given name string is not defined, it returns nil. +func SigFromName(n string) os.Signal { + n = strings.ToUpper(n) + if strings.HasPrefix(n, "SIG") { + n = n[3:] // remove SIG prefix + } + + if sig, ok := niceNameToSigs[n]; ok { + return sig + } + return nil +} diff --git a/starter_any.go b/signals_any.go similarity index 86% rename from starter_any.go rename to signals_any.go index c66bdc2..4bdfda8 100644 --- a/starter_any.go +++ b/signals_any.go @@ -4,11 +4,6 @@ package starter import "syscall" -func init() { - failureStatus = syscall.WaitStatus(255) - successStatus = syscall.WaitStatus(0) -} - func addPlatformDependentNiceSigNames(v map[syscall.Signal]string) map[syscall.Signal]string { v[syscall.SIGCHLD] = "CHLD" v[syscall.SIGCONT] = "CONT" diff --git a/signals_windows.go b/signals_windows.go new file mode 100644 index 0000000..c4e94ae --- /dev/null +++ b/signals_windows.go @@ -0,0 +1,7 @@ +package starter + +import "syscall" + +func addPlatformDependentNiceSigNames(v map[syscall.Signal]string) map[syscall.Signal]string { + return v +} diff --git a/starter.go b/starter.go index c65ce00..3fc6ea8 100644 --- a/starter.go +++ b/starter.go @@ -5,85 +5,20 @@ import ( "net" "os" "os/exec" - "os/signal" "strconv" "strings" "syscall" "time" ) -var niceSigNames map[syscall.Signal]string -var niceNameToSigs map[string]syscall.Signal -var successStatus syscall.WaitStatus -var failureStatus syscall.WaitStatus - -func makeNiceSigNamesCommon() map[syscall.Signal]string { - return map[syscall.Signal]string{ - syscall.SIGABRT: "ABRT", - syscall.SIGALRM: "ALRM", - syscall.SIGBUS: "BUS", - // syscall.SIGEMT: "EMT", - syscall.SIGFPE: "FPE", - syscall.SIGHUP: "HUP", - syscall.SIGILL: "ILL", - // syscall.SIGINFO: "INFO", - syscall.SIGINT: "INT", - // syscall.SIGIOT: "IOT", - syscall.SIGKILL: "KILL", - syscall.SIGPIPE: "PIPE", - syscall.SIGQUIT: "QUIT", - syscall.SIGSEGV: "SEGV", - syscall.SIGTERM: "TERM", - syscall.SIGTRAP: "TRAP", +func New(command string, options ...Option) *Starter { + return &Starter{ + command: command, + options: options, // This is stored as-is on purpose. + noticeWriter: os.Stderr, } } -func makeNiceSigNames() map[syscall.Signal]string { - return addPlatformDependentNiceSigNames(makeNiceSigNamesCommon()) -} - -func init() { - niceSigNames = makeNiceSigNames() - niceNameToSigs = make(map[string]syscall.Signal) - for sig, name := range niceSigNames { - niceNameToSigs[name] = sig - } -} - -type listener struct { - listener net.Listener - spec string // path or port spec -} - -type Config interface { - Args() []string - Command() string - Dir() string // Directory to chdir to before executing the command - Interval() time.Duration // Time between checks for liveness - PidFile() string - Ports() []string // Ports to bind to (addr:port or port, so it's a string) - Paths() []string // Paths (UNIX domain socket) to bind to - SignalOnHUP() os.Signal // Signal to send when HUP is received - SignalOnTERM() os.Signal // Signal to send when TERM is received - StatusFile() string -} - -type Starter struct { - interval time.Duration - signalOnHUP os.Signal - signalOnTERM os.Signal - // you can't set this in go: backlog - statusFile string - pidFile string - dir string - ports []string - paths []string - listeners []listener - generation int - command string - args []string -} - // NewStarter creates a new Starter object. Config parameter may NOT be // nil, as `Ports` and/or `Paths`, and `Command` are required func NewStarter(c Config) (*Starter, error) { @@ -141,44 +76,6 @@ func grabExitStatus(st processState) syscall.WaitStatus { return exitSt } -type processState interface { - Pid() int - Sys() interface{} -} -type dummyProcessState struct { - pid int - status syscall.WaitStatus -} - -func (d dummyProcessState) Pid() int { - return d.pid -} - -func (d dummyProcessState) Sys() interface{} { - return d.status -} - -func signame(s os.Signal) string { - if ss, ok := s.(syscall.Signal); ok { - return niceSigNames[ss] - } - return "UNKNOWN" -} - -// SigFromName returns the signal corresponding to the given signal name string. -// If the given name string is not defined, it returns nil. -func SigFromName(n string) os.Signal { - n = strings.ToUpper(n) - if strings.HasPrefix(n, "SIG") { - n = n[3:] // remove SIG prefix - } - - if sig, ok := niceNameToSigs[n]; ok { - return sig - } - return nil -} - func setEnv() { if os.Getenv("ENVDIR") == "" { return @@ -214,6 +111,7 @@ func parsePortSpec(addr string) (string, int, error) { return addr, int(port), nil } +/* func (s *Starter) Run() error { defer s.Teardown() @@ -424,6 +322,7 @@ func (s *Starter) Run() error { for pid := range oldWorkers { worker, err := os.FindProcess(pid) if err != nil { + fmt.Fprintf(os.Stderr, "failed to find process %d\n", pid) continue } worker.Signal(s.signalOnHUP) @@ -435,6 +334,7 @@ func (s *Starter) Run() error { return nil } +*/ func getKillOldDelay() time.Duration { // Ignore errors. @@ -447,13 +347,6 @@ func getKillOldDelay() time.Duration { return time.Duration(delay) * time.Second } -type WorkerState int - -const ( - WorkerStarted WorkerState = iota - ErrFailedToStart -) - // StartWorker starts the actual command. func (s *Starter) StartWorker(sigCh chan os.Signal, ch chan processState) *os.Process { // Don't give up until we're running. diff --git a/starter_test.go b/starter_test.go index 83d467d..5545d3e 100644 --- a/starter_test.go +++ b/starter_test.go @@ -1,16 +1,14 @@ package starter import ( + "context" "fmt" "io" "io/ioutil" - "log" "net" "os" "os/exec" "path/filepath" - "regexp" - "strings" "syscall" "testing" "time" @@ -82,6 +80,9 @@ func (c config) SignalOnTERM() os.Signal { return SigFromName(c.sigonterm) } func (c config) StatusFile() string { return c.statusfile } func TestRun(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + dir, err := ioutil.TempDir("", fmt.Sprintf("server-starter-test-%d", os.Getpid())) if err != nil { t.Errorf("Failed to create temp directory: %s", err) @@ -106,52 +107,33 @@ func TestRun(t *testing.T) { } ports := []string{"9090", "8080"} - sd, err := NewStarter(&config{ - ports: ports, - command: filepath.Join(dir, "echod"), - }) - if err != nil { - t.Errorf("Failed to create starter: %s", err) - return - } - - doneCh := make(chan struct{}) - readyCh := make(chan struct{}) - go func() { - defer func() { doneCh <- struct{}{} }() - time.AfterFunc(500*time.Millisecond, func() { - readyCh <- struct{}{} - }) - if err := sd.Run(); err != nil { - t.Errorf("sd.Run() failed: %s", err) - } - t.Logf("Exiting...") - }() + sd := New(filepath.Join(dir, "echod"), WithPorts(ports)) + go sd.Run(ctx) - <-readyCh + tick := time.NewTicker(500 * time.Millisecond) + defer tick.Stop() - for _, port := range ports { - _, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%s", port)) - if err != nil { - t.Errorf("Error connecing to port '%s': %s", port, err) + ctx, cancel2 := context.WithTimeout(ctx, 5*time.Second) + defer cancel2() + for loop := true; loop; { + select { + case <-ctx.Done(): + t.Errorf("Error connecing: %s", ctx.Err()) + return + case <-tick.C: + ok := 0 + for _, port := range ports { + _, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%s", port)) + if err == nil { + t.Logf("Successfully connected to port %s", port) + ok++ + } + } + if ok == len(ports) { + loop = false + } } } - - time.AfterFunc(time.Second, sd.Stop) - <-doneCh - - log.Printf("Checking ports...") - - patterns := make([]string, len(ports)) - for i, port := range ports { - patterns[i] = fmt.Sprintf(`%s=\d+`, port) - } - pattern := regexp.MustCompile(strings.Join(patterns, ";")) - - if envPort := os.Getenv("SERVER_STARTER_PORT"); !pattern.MatchString(envPort) { - t.Errorf("SERVER_STARTER_PORT: Expected '%s', but got '%s'", pattern, envPort) - } - } func TestSigFromName(t *testing.T) { diff --git a/status_any.go b/status_any.go new file mode 100644 index 0000000..403ef84 --- /dev/null +++ b/status_any.go @@ -0,0 +1,10 @@ +// +build !windows + +package starter + +import "syscall" + +func init() { + failureStatus = syscall.WaitStatus(255) + successStatus = syscall.WaitStatus(0) +} diff --git a/starter_windows.go b/status_windows.go similarity index 58% rename from starter_windows.go rename to status_windows.go index 49c6ef7..577880b 100644 --- a/starter_windows.go +++ b/status_windows.go @@ -1,3 +1,5 @@ +// +build windows + package starter import "syscall" @@ -6,7 +8,3 @@ func init() { failureStatus = syscall.WaitStatus{ExitCode: 255} successStatus = syscall.WaitStatus{ExitCode: 0} } - -func addPlatformDependentNiceSigNames(v map[syscall.Signal]string) map[syscall.Signal]string { - return v -} diff --git a/worker.go b/worker.go new file mode 100644 index 0000000..779d6ce --- /dev/null +++ b/worker.go @@ -0,0 +1,557 @@ +package starter + +import ( + "bytes" + "context" + "fmt" + "net" + "os" + "os/exec" + "os/signal" + "sort" + "strconv" + "strings" + "syscall" + "time" + + "github.com/lestrrat/go-server-starter/internal/env" + "github.com/pkg/errors" +) + +func (s *Starter) Notice(f string, args ...interface{}) { + var buf bytes.Buffer + fmt.Fprintf(&buf, f, args...) + if buf.Len() == 0 { + return + } + + b := buf.Bytes() + if b[len(b)-1] != '\n' { + buf.WriteByte('\n') + } + buf.WriteTo(s.noticeWriter) +} + +func envAsBool(name string) bool { + b, err := strconv.ParseBool(os.Getenv(name)) + return err == nil && b +} + +func envAsInt(name string) int { + i, _ := strconv.ParseInt(os.Getenv(name), 10, 64) + return int(i) +} + +func envAsDuration(name string) time.Duration { + return time.Duration(envAsInt(name)) * time.Second +} + +// This keeps listening to INT,TERM,HUP, and ALRM signals, +// and queues them up into a destination channel +func acceptSignals(ctx context.Context, dst chan os.Signal) { + src := make(chan os.Signal, 32) // up to 32 signals + signal.Notify(src, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGALRM) + signal.Ignore(syscall.SIGPIPE) + defer close(dst) + for { + select { + case <-ctx.Done(): + return + case sig, ok := <-src: + if !ok { + return + } + dst <- sig + } + } +} + +func wait(sigCh chan os.Signal, workerDone chan *exec.Cmd) *exec.Cmd { + // Original code in lib/Server/Starter.pm (_wait3) looks... interesting + // currently going to punt it in light of just "wait for a process to + // finish or a signal is received" + t := time.NewTicker(time.Second) + defer t.Stop() + for { + select { + case <-t.C: + if len(sigCh) > 0 { + return nil + } + case cmd := <-workerDone: + return cmd + } + } + return nil +} + +var registerCleanupKey struct{} + +func registerCleanup(ctx context.Context, f func()) { + register, ok := ctx.Value(registerCleanupKey).(func(func())) + if !ok { + return + } + register(f) +} + +func cleanup(ctx context.Context, ch chan func()) { + var finalizers []func() + for loop := true; loop; { + select { + case <-ctx.Done(): + loop = false + continue + case f, ok := <-ch: + if ok { + finalizers = append(finalizers, f) + } + } + } + for _, f := range finalizers { + f() + } +} + +func (s *Starter) Run(ctx context.Context) error { + var cancel func() + ctx, cancel = context.WithCancel(ctx) + defer cancel() + + var dir string + var interval time.Duration = time.Second + var paths []string + var pidFile string + var listeners []listener + var ports []string + var sigonhup os.Signal = os.Signal(syscall.SIGTERM) + var sigonterm os.Signal = os.Signal(syscall.SIGTERM) + var statusFile string + + for _, opt := range s.options { + switch opt.Name() { + case "auto_restart_interval": + v := opt.Value().(int) + os.Setenv(`AUTO_RESTART_INTERVAL`, strconv.Itoa(v)) + case "dir": + dir = opt.Value().(string) + case "enable_auto_restart": + b := opt.Value().(bool) + if b { + os.Setenv(`ENABLE_AUTO_RESTART`, `1`) + } else { + os.Setenv(`ENABLE_AUTO_RESTART`, `0`) + } + case "envdir": + os.Setenv("ENVDIR", opt.Value().(string)) + case "interval": + interval = opt.Value().(time.Duration) + case "kill_old_delay": + os.Setenv(`KILL_OLD_DELAY`, strconv.Itoa(int(opt.Value().(time.Duration)/time.Second))) + case "paths": + paths = opt.Value().([]string) + case "pid_file": + pidFile = opt.Value().(string) + case "ports": + ports = opt.Value().([]string) + case "signal_on_hup": + sigonhup = opt.Value().(os.Signal) + case "signal_on_term": + sigonterm = opt.Value().(os.Signal) + case "status_file": + statusFile = opt.Value().(string) + } + } + + generation := 0 // This is SERVER_STARTER_GENERATION + os.Setenv(`SERVER_STARTER_GENERATION`, `0`) + + cleanupCh := make(chan func()) + ctx = context.WithValue(ctx, registerCleanupKey, func(f func()) { + cleanupCh <- f + }) + go cleanup(ctx, cleanupCh) + + // start listening + extraFiles := make([]*os.File, 0, len(ports)+len(paths)) + portSpecs := make([]string, 0, len(ports)+len(paths)) + for _, addr := range ports { + var l net.Listener + + host, port, err := parsePortSpec(addr) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to parse addr spec '%s': %s", addr, err) + return err + } + + hostport := fmt.Sprintf("%s:%d", host, port) + l, err = net.Listen("tcp4", hostport) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to listen to %s:%s\n", hostport, err) + return err + } + + spec := "" + if host == "" { + spec = fmt.Sprintf("%d", port) + } else { + spec = fmt.Sprintf("%s:%d", host, port) + } + f, err := l.(*net.TCPListener).File() + if err != nil { + return errors.Wrap(err, "failed to get fd from listener") + } + registerCleanup(ctx, func() { f.Close() }) + extraFiles = append(extraFiles, f) + portSpecs = append(portSpecs, fmt.Sprintf("%s=%d", spec, len(portSpecs)+3)) + listeners = append(listeners, listener{listener: l, spec: spec}) + } + + for _, path := range paths { + var l net.Listener + if fl, err := os.Lstat(path); err == nil && fl.Mode()&os.ModeSocket == os.ModeSocket { + fmt.Fprintf(os.Stderr, "removing existing socket file:%s\n", path) + err = os.Remove(path) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to remove existing socket file:%s:%s\n", path, err) + return err + } + } + _ = os.Remove(path) + l, err := net.Listen("unix", path) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to listen file:%s:%s\n", path, err) + return err + } + f, err := l.(*net.UnixListener).File() + if err != nil { + return errors.Wrap(err, "failed to get fd from listener") + } + registerCleanup(ctx, func() { f.Close() }) + extraFiles = append(extraFiles, f) + portSpecs = append(portSpecs, fmt.Sprintf("%s=%d", path, len(portSpecs)+3)) + listeners = append(listeners, listener{listener: l, spec: path}) + } + + os.Setenv("SERVER_STARTER_PORT", strings.Join(portSpecs, ";")) + + // Note: environment variables that are set after this + // will NOT be re-populated + sysenv := env.SystemEnvironment() + envLoader := env.NewLoader() + + var statusFileCreated bool + defer func() { + if statusFileCreated { + os.Remove(statusFile) + } + }() + var currentWorker int // pid + var lastRestartTime time.Time + oldWorkers := map[int]int{} // pid to generation + + var updateStatus func() error + switch fn := statusFile; fn { + case "": + updateStatus = func() error { return nil } + default: + updateStatus = func() error { + tmpfn := fn + "." + strconv.Itoa(os.Getpid()) + f, err := os.OpenFile(tmpfn, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) + if err != nil { + return errors.Wrapf(err, "failed to create temporary file:%s", fn) + } + statusFileCreated = true + m := map[int]int{} + for k, v := range oldWorkers { + m[k] = v + } + if currentWorker > 0 { + m[generation] = currentWorker + } + + keys := make([]int, 0, len(oldWorkers)+1) + for k := range oldWorkers { + keys = append(keys, k) + } + sort.Ints(keys) + for _, k := range keys { + fmt.Fprintf(f, "%d:%d\n", k, m[k]) + } + f.Close() + return errors.Wrapf(os.Rename(tmpfn, fn), "failed to rename %s to %s", fn, tmpfn) + } + } + + // This watcher receives commands to watch for. + workerSrc := make(chan *exec.Cmd) + workerDone := make(chan *exec.Cmd) + go monitor(ctx, workerSrc, workerDone) + + // signal handler here queues up signals to the other + // channel, so that we can keep accepting signals while we + // only really handle them once per loop + sigCh := make(chan os.Signal, 32) + go acceptSignals(ctx, sigCh) + + errTryExec := errors.New("keep trying") + startCmd := func(cmd *exec.Cmd) error { + if err := cmd.Start(); err != nil { + s.Notice("%s", err.Error()) + // We would LOVE to continue immediately, but we need to do the + // same check-for-signals and etc here, so we go on.. + } else { + s.Notice("starting new worker %d", cmd.Process.Pid) + } + + // Wait for up to `interval` seconds before + // checking if this command (process) is alive + time.Sleep(interval) + + // Check if we have received any signals while we were + // waiting. this is a very dirty trick in that we are + // mucking with a channel that is potentially being written + // to concurrently :/ + nonhup := 0 + var bufferedSigs []os.Signal + l := len(sigCh) + for i := 0; i < l; i++ { + s := <-sigCh + bufferedSigs = append(bufferedSigs, s) + if s != os.Signal(syscall.SIGHUP) { + // do not immediately stop... read all + nonhup++ + } + } + if len(bufferedSigs) > 0 { + fmt.Printf("%#v\n", bufferedSigs) + go func() { + for _, s := range bufferedSigs { + sigCh <- s + } + }() + if nonhup > 0 { // bailout + return errors.New("received signal while waiting") + } + } + + // Want to check if the given PID is still alive. + // This is not a great way to do it b/c we're not + // even sure the Pid we're looking for is the same + // process as the one we spawned, but... this is + // so far the best we can do + // Note: Does this work on windows? + if cmd.Process != nil { + p, err := os.FindProcess(cmd.Process.Pid) + if err == nil { + if err := p.Signal(os.Signal(syscall.Signal(0))); err == nil { + return nil + } + } + } + + switch { + case cmd.ProcessState != nil: + s.Notice("new worker %d seems to have failed to start, exit status:%d", cmd.ProcessState.Pid(), grabExitStatus(cmd.ProcessState)) + case cmd.Process != nil: + s.Notice("new worker %d seems to have failed to start", cmd.Process.Pid) + default: + s.Notice("new worker seems to have failed to start") + } + return errTryExec + } + + if pidFile != "" { + f, err := os.OpenFile(pidFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) + if err != nil { + return errors.Wrapf(err, "failed to open file:%s", pidFile) + } + defer f.Close() + defer os.Remove(f.Name()) + + if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX); err != nil { + return errors.Wrapf(err, "flock failed(%s)", pidFile) + } + fmt.Fprintf(f, "%d\n", os.Getpid()) + if err := f.Sync(); err != nil { + return errors.Wrapf(err, "failed to sync file(%s)", pidFile) + } + } + + newCommand := func() *exec.Cmd { + cmd := exec.Command(s.command, s.args...) + if dir != "" { + cmd.Dir = dir + } + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.ExtraFiles = extraFiles + return cmd + } + + startWorker := func() error { + for loop := true; loop; { + generation++ + os.Setenv(`SERVER_STARTER_GENERATION`, strconv.Itoa(generation)) + + cmd := newCommand() + switch err := startCmd(cmd); err { + case nil: + loop = false + currentWorker = cmd.Process.Pid + lastRestartTime = time.Now() + updateStatus() + workerSrc <- cmd + case errTryExec: + // keep trying + default: + return errors.Wrap(err, "failed to start command") + } + } + + return nil + } + + var cleanupWorkers = func(sig os.Signal) { + termSig := os.Signal(syscall.SIGTERM) + if sig == termSig { + termSig = sigonterm + } + + if currentWorker > 0 { + oldWorkers[currentWorker] = envAsInt(`SERVER_STARTER_GENERATION`) + currentWorker = 0 + } + var buf bytes.Buffer + fmt.Fprintf(&buf, "received %s, sending %s to all workers:", signame(sig), signame(termSig)) + keys := make([]int, 0, len(oldWorkers)) + for k := range oldWorkers { + keys = append(keys, k) + } + sort.Ints(keys) + for i, k := range keys { + fmt.Fprintf(&buf, "%d", k) + if i < len(keys)-1 { + buf.WriteByte(',') + } + } + s.Notice(buf.String()) + + for _, pid := range keys { + p, err := os.FindProcess(pid) + if err != nil { // XXX to be safe, let's delete this pid + delete(oldWorkers, pid) + } + p.Signal(termSig) + } + + for len(oldWorkers) > 0 { + cmd, ok := <-workerDone + if !ok { + panic("workerDone channel closed while still waiting for children to be reaped") + } + s.Notice("worker %d died, status:%d", cmd.ProcessState.Pid(), grabExitStatus(cmd.ProcessState)) + delete(oldWorkers, cmd.ProcessState.Pid()) + updateStatus() + } + s.Notice("exiting") + } + + if err := startWorker(); err != nil { + return errors.Wrap(err, "failed to start worker") + } + + for { + // wait for next signal (or when auto-restart becomes necessary) + exited := wait(sigCh, workerDone) + + // reload env if necessary + envLoader.Apply(ctx, sysenv) + + if envAsBool(`ENABLE_AUTO_RESTART`) { + if os.Getenv("AUTO_RESTART_INTERVAL") == "" { + os.Setenv("AUTO_RESTART_INTERVAL", "360") + } + } + + if exited != nil { // got some command exit + pid := exited.ProcessState.Pid() + if pid == currentWorker { + s.Notice("worker %d died unexpectedly with status: %d, restarting\n", pid, grabExitStatus(exited.ProcessState)) + if err := startWorker(); err != nil { + return errors.Wrap(err, "failed to start worker") + } + } else { + s.Notice("old worker %d died, status:%d", pid, grabExitStatus(exited.ProcessState)) + delete(oldWorkers, pid) + updateStatus() + } + } + + var restart bool + for loop := true; loop; { + select { + case sig := <-sigCh: + switch sig { + case syscall.SIGHUP: + restart = true + loop = false + case syscall.SIGALRM: + loop = false + default: + cleanupWorkers(sig) + return nil + } + default: + loop = false + } + } + + if !restart && envAsBool("ENABLE_AUTO_RESTART") { + autoRestartInterval := envAsDuration("AUTO_RESTART_INTERVAL") + elapsedSinceRestart := time.Since(lastRestartTime) + if elapsedSinceRestart >= autoRestartInterval && len(oldWorkers) == 0 { + s.Notice("autorestart triggered (interval=%s)", autoRestartInterval) + restart = true + } else if elapsedSinceRestart >= autoRestartInterval*2 { + s.Notice("autorestart triggered (forced, interval=%s)", autoRestartInterval) + } + } + + if restart { + oldWorkers[currentWorker] = generation + if err := startWorker(); err != nil { + return errors.Wrap(err, "failed to restart worker") + } + + var buf bytes.Buffer + l := len(oldWorkers) + if l == 0 { + buf.WriteString("none") + } else { + i := 0 + for pid := range oldWorkers { + buf.WriteString(strconv.Itoa(pid)) + if i < l-1 { + buf.WriteByte(',') + } + } + } + s.Notice("new worker is now running, sending %s to old workers: %s", signame(sigonhup), buf.String()) + + killOldDelay := envAsDuration(`KILL_OLD_DELAY`) + if killOldDelay == 0 && envAsBool(`ENABLE_AUTO_RESTART`) { + killOldDelay = 5 * time.Second + } + + time.Sleep(killOldDelay) + for pid := range oldWorkers { + worker, err := os.FindProcess(pid) + if err != nil { + continue + } + worker.Signal(sigonhup) + } + } + } +} From 09a9c4b1ef998fdff0e17ac0b09a9fc6b7ef1676 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Sun, 27 Nov 2016 23:37:18 +0900 Subject: [PATCH 02/26] use 1.7, duh --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index ac72f57..1795ab6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,5 @@ language: go go: - - 1.5 + - 1.7 - tip sudo: false \ No newline at end of file From bd67b5b25e4709545011c193ef4564a531f7e0b3 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Tue, 29 Nov 2016 16:12:30 +0900 Subject: [PATCH 03/26] tweak signal timing --- cli.go | 4 +-- options.go | 5 +++ signals.go | 4 +-- starter_test.go | 90 +++++++++++++++++++++++++------------------------ worker.go | 61 +++++++++++++++++++++------------ 5 files changed, 94 insertions(+), 70 deletions(-) diff --git a/cli.go b/cli.go index ac39f4c..ca6a5fd 100644 --- a/cli.go +++ b/cli.go @@ -49,10 +49,10 @@ func makeOptionList(opts *options) []Option { list = append(list, WithPorts(opts.Ports)) } if opts.SignalOnHUP != "" { - list = append(list, WithSignalOnHUP(SigFromName(opts.SignalOnHUP))) + list = append(list, WithSignalOnHUP(sigFromName(opts.SignalOnHUP))) } if opts.SignalOnTERM != "" { - list = append(list, WithSignalOnTERM(SigFromName(opts.SignalOnTERM))) + list = append(list, WithSignalOnTERM(sigFromName(opts.SignalOnTERM))) } if opts.StatusFile != "" { list = append(list, WithStatusFile(opts.StatusFile)) diff --git a/options.go b/options.go index ca7cd73..604f73a 100644 --- a/options.go +++ b/options.go @@ -1,6 +1,7 @@ package starter import ( + "io" "os" "strconv" "time" @@ -71,6 +72,10 @@ func WithStatusFile(s string) Option { return &valueOption{name: "status_file", value: s} } +func WithNoticeOutput(w io.Writer) Option { + return &valueOption{name: "notice_output", value: w} +} + func (o *stringOpt) String() string { return o.Value } diff --git a/signals.go b/signals.go index d4c747a..3b141ff 100644 --- a/signals.go +++ b/signals.go @@ -49,9 +49,7 @@ func signame(s os.Signal) string { return "UNKNOWN" } -// SigFromName returns the signal corresponding to the given signal name string. -// If the given name string is not defined, it returns nil. -func SigFromName(n string) os.Signal { +func sigFromName(n string) os.Signal { n = strings.ToUpper(n) if strings.HasPrefix(n, "SIG") { n = n[3:] // remove SIG prefix diff --git a/starter_test.go b/starter_test.go index 5545d3e..0c34846 100644 --- a/starter_test.go +++ b/starter_test.go @@ -1,6 +1,7 @@ package starter import ( + "bytes" "context" "fmt" "io" @@ -12,6 +13,8 @@ import ( "syscall" "testing" "time" + + "github.com/stretchr/testify/assert" ) var echoServerTxt = `package main @@ -23,7 +26,6 @@ import ( "os" "os/signal" "syscall" - "time" "github.com/lestrrat/go-server-starter/listener" ) @@ -33,6 +35,11 @@ func main() { fmt.Fprintf(os.Stderr, "Failed to listen: %s\n", err) os.Exit(1) } + defer func() { + for _, l := range listeners { + l.Close() + } + }() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { io.Copy(w, r.Body) @@ -41,59 +48,27 @@ func main() { http.Serve(l, handler) } - loop := false sigCh := make(chan os.Signal) signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGHUP) - for loop { - select { - case <-sigCh: - loop = false - default: - time.Sleep(time.Second) - } + select { + case <-sigCh: } } ` -type config struct { - args []string - command string - dir string - interval int - pidfile string - ports []string - paths []string - sigonhup string - sigonterm string - statusfile string -} - -func (c config) Args() []string { return c.args } -func (c config) Command() string { return c.command } -func (c config) Dir() string { return c.dir } -func (c config) Interval() time.Duration { return time.Duration(c.interval) * time.Second } -func (c config) PidFile() string { return c.pidfile } -func (c config) Ports() []string { return c.ports } -func (c config) Paths() []string { return c.paths } -func (c config) SignalOnHUP() os.Signal { return SigFromName(c.sigonhup) } -func (c config) SignalOnTERM() os.Signal { return SigFromName(c.sigonterm) } -func (c config) StatusFile() string { return c.statusfile } - func TestRun(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() dir, err := ioutil.TempDir("", fmt.Sprintf("server-starter-test-%d", os.Getpid())) - if err != nil { - t.Errorf("Failed to create temp directory: %s", err) + if !assert.NoError(t, err, "failed to create tempdir %s", dir) { return } defer os.RemoveAll(dir) srcFile := filepath.Join(dir, "echod.go") f, err := os.OpenFile(srcFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) - if err != nil { - t.Errorf("Failed to create %s: %s", srcFile, err) + if !assert.NoError(t, err, "failed to create source file %s", f) { return } io.WriteString(f, echoServerTxt) @@ -101,14 +76,22 @@ func TestRun(t *testing.T) { cmd := exec.Command("go", "build", "-o", filepath.Join(dir, "echod"), ".") cmd.Dir = dir - if output, err := cmd.CombinedOutput(); err != nil { - t.Errorf("Failed to compile %s: %s\n%s", dir, err, output) + if output, err := cmd.CombinedOutput(); !assert.NoError(t, err, "failed to compile echod") { + t.Logf("%s", output) return } ports := []string{"9090", "8080"} - sd := New(filepath.Join(dir, "echod"), WithPorts(ports)) - go sd.Run(ctx) + var output bytes.Buffer + sd := New(filepath.Join(dir, "echod"), WithPorts(ports), WithNoticeOutput(&output)) + + done := make(chan struct{}) + go func() { + defer close(done) + if !assert.NoError(t, sd.Run(ctx), "Run should exit with no errors") { + return + } + }() tick := time.NewTicker(500 * time.Millisecond) defer tick.Stop() @@ -134,11 +117,30 @@ func TestRun(t *testing.T) { } } } + + time.Sleep(time.Second) + + var closed bool + select { + case <-done: + // grr, if we got here, done is closed + closed = true + default: + } + + if !closed { + p, _ := os.FindProcess(os.Getpid()) + p.Signal(os.Signal(syscall.SIGTERM)) + } + + <-done + + t.Logf("%s", output.String()) } func TestSigFromName(t *testing.T) { for sig, name := range niceSigNames { - if got := SigFromName(name); sig != got { + if got := sigFromName(name); sig != got { t.Errorf("%v: wants '%v' but got '%v'", name, sig, got) } } @@ -149,7 +151,7 @@ func TestSigFromName(t *testing.T) { "Hup": syscall.SIGHUP, } for name, sig := range variants { - if got := SigFromName(name); sig != got { + if got := sigFromName(name); sig != got { t.Errorf("%v: wants '%v' but got '%v'", name, sig, got) } } diff --git a/worker.go b/worker.go index 779d6ce..fae28ce 100644 --- a/worker.go +++ b/worker.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "io" "net" "os" "os/exec" @@ -66,7 +67,7 @@ func acceptSignals(ctx context.Context, dst chan os.Signal) { } } -func wait(sigCh chan os.Signal, workerDone chan *exec.Cmd) *exec.Cmd { +func wait(ctx context.Context, sigCh chan os.Signal, workerDone chan *exec.Cmd) *exec.Cmd { // Original code in lib/Server/Starter.pm (_wait3) looks... interesting // currently going to punt it in light of just "wait for a process to // finish or a signal is received" @@ -74,6 +75,10 @@ func wait(sigCh chan os.Signal, workerDone chan *exec.Cmd) *exec.Cmd { defer t.Stop() for { select { + case <-ctx.Done(): + p, _ := os.FindProcess(os.Getpid()) + p.Signal(os.Signal(syscall.SIGTERM)) + return nil case <-t.C: if len(sigCh) > 0 { return nil @@ -127,6 +132,7 @@ func (s *Starter) Run(ctx context.Context) error { var sigonhup os.Signal = os.Signal(syscall.SIGTERM) var sigonterm os.Signal = os.Signal(syscall.SIGTERM) var statusFile string + var noticeOutput io.Writer = os.Stderr for _, opt := range s.options { switch opt.Name() { @@ -160,7 +166,22 @@ func (s *Starter) Run(ctx context.Context) error { sigonterm = opt.Value().(os.Signal) case "status_file": statusFile = opt.Value().(string) + case "notice_output": + noticeOutput = opt.Value().(io.Writer) + } + } + notice := func(f string, args ...interface{}) { + var buf bytes.Buffer + fmt.Fprintf(&buf, f, args...) + if buf.Len() == 0 { + return + } + + b := buf.Bytes() + if b[len(b)-1] != '\n' { + buf.WriteByte('\n') } + buf.WriteTo(noticeOutput) } generation := 0 // This is SERVER_STARTER_GENERATION @@ -180,14 +201,14 @@ func (s *Starter) Run(ctx context.Context) error { host, port, err := parsePortSpec(addr) if err != nil { - fmt.Fprintf(os.Stderr, "failed to parse addr spec '%s': %s", addr, err) + notice("failed to parse addr spec '%s': %s", addr, err) return err } hostport := fmt.Sprintf("%s:%d", host, port) l, err = net.Listen("tcp4", hostport) if err != nil { - fmt.Fprintf(os.Stderr, "failed to listen to %s:%s\n", hostport, err) + notice("failed to listen to %s:%s\n", hostport, err) return err } @@ -210,17 +231,17 @@ func (s *Starter) Run(ctx context.Context) error { for _, path := range paths { var l net.Listener if fl, err := os.Lstat(path); err == nil && fl.Mode()&os.ModeSocket == os.ModeSocket { - fmt.Fprintf(os.Stderr, "removing existing socket file:%s\n", path) + notice("removing existing socket file:%s\n", path) err = os.Remove(path) if err != nil { - fmt.Fprintf(os.Stderr, "failed to remove existing socket file:%s:%s\n", path, err) + notice("failed to remove existing socket file:%s:%s\n", path, err) return err } } _ = os.Remove(path) l, err := net.Listen("unix", path) if err != nil { - fmt.Fprintf(os.Stderr, "failed to listen file:%s:%s\n", path, err) + notice("failed to listen file:%s:%s\n", path, err) return err } f, err := l.(*net.UnixListener).File() @@ -297,11 +318,11 @@ func (s *Starter) Run(ctx context.Context) error { errTryExec := errors.New("keep trying") startCmd := func(cmd *exec.Cmd) error { if err := cmd.Start(); err != nil { - s.Notice("%s", err.Error()) + notice("%s", err.Error()) // We would LOVE to continue immediately, but we need to do the // same check-for-signals and etc here, so we go on.. } else { - s.Notice("starting new worker %d", cmd.Process.Pid) + notice("starting new worker %d", cmd.Process.Pid) } // Wait for up to `interval` seconds before @@ -324,7 +345,6 @@ func (s *Starter) Run(ctx context.Context) error { } } if len(bufferedSigs) > 0 { - fmt.Printf("%#v\n", bufferedSigs) go func() { for _, s := range bufferedSigs { sigCh <- s @@ -352,11 +372,11 @@ func (s *Starter) Run(ctx context.Context) error { switch { case cmd.ProcessState != nil: - s.Notice("new worker %d seems to have failed to start, exit status:%d", cmd.ProcessState.Pid(), grabExitStatus(cmd.ProcessState)) + notice("new worker %d seems to have failed to start, exit status:%d", cmd.ProcessState.Pid(), grabExitStatus(cmd.ProcessState)) case cmd.Process != nil: - s.Notice("new worker %d seems to have failed to start", cmd.Process.Pid) + notice("new worker %d seems to have failed to start", cmd.Process.Pid) default: - s.Notice("new worker seems to have failed to start") + notice("new worker seems to have failed to start") } return errTryExec } @@ -435,7 +455,7 @@ func (s *Starter) Run(ctx context.Context) error { buf.WriteByte(',') } } - s.Notice(buf.String()) + notice(buf.String()) for _, pid := range keys { p, err := os.FindProcess(pid) @@ -450,11 +470,10 @@ func (s *Starter) Run(ctx context.Context) error { if !ok { panic("workerDone channel closed while still waiting for children to be reaped") } - s.Notice("worker %d died, status:%d", cmd.ProcessState.Pid(), grabExitStatus(cmd.ProcessState)) + notice("worker %d died, status:%d", cmd.ProcessState.Pid(), grabExitStatus(cmd.ProcessState)) delete(oldWorkers, cmd.ProcessState.Pid()) updateStatus() } - s.Notice("exiting") } if err := startWorker(); err != nil { @@ -463,7 +482,7 @@ func (s *Starter) Run(ctx context.Context) error { for { // wait for next signal (or when auto-restart becomes necessary) - exited := wait(sigCh, workerDone) + exited := wait(ctx, sigCh, workerDone) // reload env if necessary envLoader.Apply(ctx, sysenv) @@ -477,12 +496,12 @@ func (s *Starter) Run(ctx context.Context) error { if exited != nil { // got some command exit pid := exited.ProcessState.Pid() if pid == currentWorker { - s.Notice("worker %d died unexpectedly with status: %d, restarting\n", pid, grabExitStatus(exited.ProcessState)) + notice("worker %d died unexpectedly with status: %d, restarting\n", pid, grabExitStatus(exited.ProcessState)) if err := startWorker(); err != nil { return errors.Wrap(err, "failed to start worker") } } else { - s.Notice("old worker %d died, status:%d", pid, grabExitStatus(exited.ProcessState)) + notice("old worker %d died, status:%d", pid, grabExitStatus(exited.ProcessState)) delete(oldWorkers, pid) updateStatus() } @@ -511,10 +530,10 @@ func (s *Starter) Run(ctx context.Context) error { autoRestartInterval := envAsDuration("AUTO_RESTART_INTERVAL") elapsedSinceRestart := time.Since(lastRestartTime) if elapsedSinceRestart >= autoRestartInterval && len(oldWorkers) == 0 { - s.Notice("autorestart triggered (interval=%s)", autoRestartInterval) + notice("autorestart triggered (interval=%s)", autoRestartInterval) restart = true } else if elapsedSinceRestart >= autoRestartInterval*2 { - s.Notice("autorestart triggered (forced, interval=%s)", autoRestartInterval) + notice("autorestart triggered (forced, interval=%s)", autoRestartInterval) } } @@ -537,7 +556,7 @@ func (s *Starter) Run(ctx context.Context) error { } } } - s.Notice("new worker is now running, sending %s to old workers: %s", signame(sigonhup), buf.String()) + notice("new worker is now running, sending %s to old workers: %s", signame(sigonhup), buf.String()) killOldDelay := envAsDuration(`KILL_OLD_DELAY`) if killOldDelay == 0 && envAsBool(`ENABLE_AUTO_RESTART`) { From 3e28acc40529be5f8d42054bc4e59d0b3f6e5539 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Tue, 29 Nov 2016 16:25:37 +0900 Subject: [PATCH 04/26] refactor building sub-command --- starter_test.go | 50 +++++++++++++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/starter_test.go b/starter_test.go index 0c34846..2f699e7 100644 --- a/starter_test.go +++ b/starter_test.go @@ -14,10 +14,11 @@ import ( "testing" "time" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" ) -var echoServerTxt = `package main +var echoServerSrc = `package main import ( "fmt" @@ -56,34 +57,47 @@ func main() { } ` -func TestRun(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - dir, err := ioutil.TempDir("", fmt.Sprintf("server-starter-test-%d", os.Getpid())) - if !assert.NoError(t, err, "failed to create tempdir %s", dir) { - return +func build(name string, src string) (string, func(), error) { + dir, err := ioutil.TempDir("", fmt.Sprintf("server-starter-test-%s-%d", name, os.Getpid())) + if err != nil { + return "", nil, errors.Wrapf(err, "failed to create tempdir %s", dir) } - defer os.RemoveAll(dir) - - srcFile := filepath.Join(dir, "echod.go") + cleanup := func() { + os.RemoveAll(dir) + } + srcFile := filepath.Join(dir, fmt.Sprintf("%s.go", name)) f, err := os.OpenFile(srcFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) - if !assert.NoError(t, err, "failed to create source file %s", f) { - return + if err != nil { + return "", cleanup, errors.Wrapf(err, "failed to create source file %s", f) } - io.WriteString(f, echoServerTxt) + io.WriteString(f, src) f.Close() - cmd := exec.Command("go", "build", "-o", filepath.Join(dir, "echod"), ".") + result := filepath.Join(dir, name) + cmd := exec.Command("go", "build", "-o", result, ".") cmd.Dir = dir - if output, err := cmd.CombinedOutput(); !assert.NoError(t, err, "failed to compile echod") { - t.Logf("%s", output) + output, err := cmd.CombinedOutput() + if err != nil { + return "", cleanup, errors.Wrapf(err, "failed to compile %s: %s", name, output) + } + return result, cleanup, nil +} + +func TestRun(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + cmdname, cleanup, err := build("echod", echoServerSrc) + if cleanup != nil { + defer cleanup() + } + if !assert.NoError(t, err, "build failed") { return } ports := []string{"9090", "8080"} var output bytes.Buffer - sd := New(filepath.Join(dir, "echod"), WithPorts(ports), WithNoticeOutput(&output)) + sd := New(cmdname, WithPorts(ports), WithNoticeOutput(&output)) done := make(chan struct{}) go func() { From 5079ee303c9e3487cbcd16e6680197066aa52e2f Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Tue, 29 Nov 2016 18:22:04 +0900 Subject: [PATCH 05/26] test for multiple signals being sent --- cli.go | 2 +- options.go | 2 +- starter_test.go | 193 ++++++++++++++++++++++++++++++++++++------------ 3 files changed, 148 insertions(+), 49 deletions(-) diff --git a/cli.go b/cli.go index ca6a5fd..52aee11 100644 --- a/cli.go +++ b/cli.go @@ -19,7 +19,7 @@ func NewCLI() *CLI { func makeOptionList(opts *options) []Option { var list []Option if len(opts.Args) > 0 { - list = append(list, WithArgs(opts.Args)) + list = append(list, WithArgs(opts.Args...)) } if opts.AutoRestartInterval.Valid { list = append(list, WithAutoRestartInterval(time.Duration(opts.AutoRestartInterval.Value)*time.Second)) diff --git a/options.go b/options.go index 604f73a..efa82f7 100644 --- a/options.go +++ b/options.go @@ -28,7 +28,7 @@ func WithAutoRestartInterval(t time.Duration) Option { return &valueOption{name: "auto_restart_interval", value: t} } -func WithArgs(a []string) Option { +func WithArgs(a ...string) Option { return &valueOption{name: "args", value: a} } diff --git a/starter_test.go b/starter_test.go index 2f699e7..6d5a5e3 100644 --- a/starter_test.go +++ b/starter_test.go @@ -10,10 +10,13 @@ import ( "os" "os/exec" "path/filepath" + "strconv" "syscall" "testing" "time" + "github.com/lestrrat/go-server-starter/internal/env" + tcputil "github.com/lestrrat/go-tcputil" "github.com/pkg/errors" "github.com/stretchr/testify/assert" ) @@ -21,6 +24,7 @@ import ( var echoServerSrc = `package main import ( + "flag" "fmt" "io" "net/http" @@ -31,6 +35,10 @@ import ( ) func main() { + var maxSigterm int // number of times we "withstand" a sigterm + flag.IntVar(&maxSigterm, "sigterm", 0, "") + flag.Parse() + listeners, err := listener.ListenAll() if err != nil { fmt.Fprintf(os.Stderr, "Failed to listen: %s\n", err) @@ -49,10 +57,23 @@ func main() { http.Serve(l, handler) } - sigCh := make(chan os.Signal) + sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGHUP) - select { - case <-sigCh: + + sigterm := 0 + for loop := true; loop; { + select { + case s := <-sigCh: + switch s { + case syscall.SIGTERM: + sigterm++ + if maxSigterm <= sigterm { + loop = false + } + default: + // do nothing + } + } } } ` @@ -75,6 +96,7 @@ func build(name string, src string) (string, func(), error) { result := filepath.Join(dir, name) cmd := exec.Command("go", "build", "-o", result, ".") + cmd.Env = nil cmd.Dir = dir output, err := cmd.CombinedOutput() if err != nil { @@ -84,8 +106,9 @@ func build(name string, src string) (string, func(), error) { } func TestRun(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + l := env.NewLoader() + sysenv := env.SystemEnvironment() + defer l.Restore(context.Background(), sysenv) cmdname, cleanup, err := build("echod", echoServerSrc) if cleanup != nil { @@ -95,61 +118,137 @@ func TestRun(t *testing.T) { return } - ports := []string{"9090", "8080"} - var output bytes.Buffer - sd := New(cmdname, WithPorts(ports), WithNoticeOutput(&output)) + t.Run("normal execution", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() - done := make(chan struct{}) - go func() { - defer close(done) - if !assert.NoError(t, sd.Run(ctx), "Run should exit with no errors") { - return + portCount := 2 + ports := make([]string, portCount) + for i := 0; i < 2; i++ { + p, err := tcputil.EmptyPort() + if !assert.NoError(t, err, "failed to find an empty port") { + return + } + ports[i] = strconv.Itoa(p) } - }() + var output bytes.Buffer + defer func() { + t.Logf("%s", output.String()) + }() + sd := New(cmdname, WithPorts(ports), WithNoticeOutput(&output)) - tick := time.NewTicker(500 * time.Millisecond) - defer tick.Stop() + done := make(chan struct{}) + go func() { + defer close(done) + if !assert.NoError(t, sd.Run(ctx), "Run should exit with no errors") { + return + } + }() - ctx, cancel2 := context.WithTimeout(ctx, 5*time.Second) - defer cancel2() - for loop := true; loop; { - select { - case <-ctx.Done(): - t.Errorf("Error connecing: %s", ctx.Err()) - return - case <-tick.C: - ok := 0 - for _, port := range ports { - _, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%s", port)) - if err == nil { - t.Logf("Successfully connected to port %s", port) - ok++ + tick := time.NewTicker(500 * time.Millisecond) + defer tick.Stop() + + ctx, cancel2 := context.WithTimeout(ctx, 5*time.Second) + defer cancel2() + for loop := true; loop; { + select { + case <-ctx.Done(): + t.Errorf("Error connecing: %s", ctx.Err()) + return + case <-tick.C: + ok := 0 + for _, port := range ports { + _, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%s", port)) + if err == nil { + t.Logf("Successfully connected to port %s", port) + ok++ + } + } + if ok == len(ports) { + loop = false } } - if ok == len(ports) { - loop = false + } + + time.Sleep(time.Second) + + var closed bool + select { + case <-done: + // grr, if we got here, done is closed + closed = true + default: + } + + if !closed { + p, _ := os.FindProcess(os.Getpid()) + p.Signal(os.Signal(syscall.SIGTERM)) + } + + <-done + }) + l.Restore(context.Background(), sysenv) + + t.Run("send multiple signals", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + port, err := tcputil.EmptyPort() + if !assert.NoError(t, err, "failed to find an empty port") { + return + } + + var output bytes.Buffer + defer func() { + t.Logf("%s", output.String()) + }() + sd := New(cmdname, WithArgs("-sigterm", "2"), WithPorts([]string{strconv.Itoa(port)}), WithNoticeOutput(&output)) + + done := make(chan struct{}) + go func() { + defer close(done) + if !assert.NoError(t, sd.Run(ctx), "Run should exit with no errors") { + return } + }() + + time.Sleep(time.Second) + + var closed bool + select { + case <-done: + closed = true + t.Errorf("unexpected exit") + default: } - } - time.Sleep(time.Second) + if !closed { + p, _ := os.FindProcess(os.Getpid()) + p.Signal(syscall.SIGTERM) + } - var closed bool - select { - case <-done: - // grr, if we got here, done is closed - closed = true - default: - } + time.Sleep(time.Second) - if !closed { - p, _ := os.FindProcess(os.Getpid()) - p.Signal(os.Signal(syscall.SIGTERM)) - } + closed = false + select { + case <-done: + closed = true + t.Errorf("unexpected exit") + default: + } - <-done + if !closed { + p, _ := os.FindProcess(os.Getpid()) + p.Signal(syscall.SIGTERM) + } - t.Logf("%s", output.String()) + select { + case <-ctx.Done(): + t.Errorf("context prematurely ended: %s", ctx.Err()) + case <-done: + } + }) + l.Restore(context.Background(), sysenv) } func TestSigFromName(t *testing.T) { From b62ffd7210cd21b6e0bcde31abd71acda268949a Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Tue, 29 Nov 2016 18:27:35 +0900 Subject: [PATCH 06/26] Remove unused code --- env.go | 68 +---- env_test.go | 68 ----- starter.go | 778 ++++++++++++++++++++++++++++++---------------------- worker.go | 576 -------------------------------------- 4 files changed, 458 insertions(+), 1032 deletions(-) delete mode 100644 env_test.go delete mode 100644 worker.go diff --git a/env.go b/env.go index debe20c..b073a9e 100644 --- a/env.go +++ b/env.go @@ -1,65 +1,21 @@ package starter import ( - "bufio" - "errors" "os" - "path/filepath" - "strings" + "strconv" + "time" ) -var errNoEnv = errors.New("no ENVDIR specified, or ENVDIR does not exist") - -func reloadEnv() (map[string]string, error) { - dn := os.Getenv("ENVDIR") - if dn == "" { - return nil, errNoEnv - } - - fi, err := os.Stat(dn) - if err != nil { - return nil, err - } - - if !fi.IsDir() { - return nil, err - } - - var m map[string]string - - filepath.Walk(dn, func(path string, fi os.FileInfo, err error) error { - // Ignore errors - if err != nil { - return nil - } - - // Don't go into directories - if fi.IsDir() && dn != path { - return filepath.SkipDir - } - - f, err := os.Open(path) - if err != nil { - return nil - } - defer f.Close() - - envName := filepath.Base(path) - scanner := bufio.NewScanner(f) - if scanner.Scan() { - if m == nil { - m = make(map[string]string) - } - l := scanner.Text() - m[envName] = strings.TrimSpace(l) - } - - return nil - }) +func envAsBool(name string) bool { + b, err := strconv.ParseBool(os.Getenv(name)) + return err == nil && b +} - if m == nil { - return nil, errNoEnv - } +func envAsInt(name string) int { + i, _ := strconv.ParseInt(os.Getenv(name), 10, 64) + return int(i) +} - return m, nil +func envAsDuration(name string) time.Duration { + return time.Duration(envAsInt(name)) * time.Second } diff --git a/env_test.go b/env_test.go deleted file mode 100644 index 8c4e525..0000000 --- a/env_test.go +++ /dev/null @@ -1,68 +0,0 @@ -package starter - -import ( - "io" - "io/ioutil" - "os" - "path/filepath" - "testing" -) - -func TestEnvdir(t *testing.T) { - dir, err := ioutil.TempDir("", "starter_test") - if err != nil { - t.Errorf("Failed to create tempdir: %s", err) - return - } - defer os.RemoveAll(dir) - - files := []string{"FOO", "BAR", "BAZ"} - for _, fn := range files { - longFn := filepath.Join(dir, fn) - - f, err := os.OpenFile(longFn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) - if err != nil { - t.Errorf("Failed to create file '%s': %s", fn, err) - return - } - closed := false - defer func() { - if !closed { - f.Close() - } - }() - - io.WriteString(f, fn) - f.Close() - closed = true - - // save old values and restore later, if any - if old := os.Getenv(fn); old != "" { - os.Setenv(fn, "") - defer os.Setenv(fn, old) - } - } - - if old := os.Getenv("ENVDIR"); old != "" { - defer os.Setenv("ENVDIR", old) - } - - os.Setenv("ENVDIR", dir) - m, err := reloadEnv() - if err != nil { - t.Errorf("reloadEnv failed: %s", err) - return - } - - for _, fn := range files { - v, ok := m[fn] - if !ok { - t.Errorf("Expected environment variable '%s' to exist") - return - } - if v != fn { - t.Errorf("Expected environment variable '%s' to be '%s'", fn, fn) - return - } - } -} diff --git a/starter.go b/starter.go index 3fc6ea8..f3dffc7 100644 --- a/starter.go +++ b/starter.go @@ -1,68 +1,29 @@ package starter import ( + "bytes" + "context" "fmt" + "io" "net" "os" "os/exec" + "os/signal" + "sort" "strconv" "strings" "syscall" "time" + + "github.com/lestrrat/go-server-starter/internal/env" + "github.com/pkg/errors" ) func New(command string, options ...Option) *Starter { return &Starter{ command: command, options: options, // This is stored as-is on purpose. - noticeWriter: os.Stderr, - } -} - -// NewStarter creates a new Starter object. Config parameter may NOT be -// nil, as `Ports` and/or `Paths`, and `Command` are required -func NewStarter(c Config) (*Starter, error) { - if c == nil { - return nil, fmt.Errorf("config argument must be non-nil") - } - - var signalOnHUP os.Signal = syscall.SIGTERM - var signalOnTERM os.Signal = syscall.SIGTERM - if s := c.SignalOnHUP(); s != nil { - signalOnHUP = s - } - if s := c.SignalOnTERM(); s != nil { - signalOnTERM = s - } - - if c.Command() == "" { - return nil, fmt.Errorf("argument Command must be specified") - } - if _, err := exec.LookPath(c.Command()); err != nil { - return nil, err - } - - s := &Starter{ - args: c.Args(), - command: c.Command(), - dir: c.Dir(), - interval: c.Interval(), - listeners: make([]listener, 0, len(c.Ports())+len(c.Paths())), - pidFile: c.PidFile(), - ports: c.Ports(), - paths: c.Paths(), - signalOnHUP: signalOnHUP, - signalOnTERM: signalOnTERM, - statusFile: c.StatusFile(), } - - return s, nil - -} - -func (s Starter) Stop() { - p, _ := os.FindProcess(os.Getpid()) - p.Signal(syscall.SIGTERM) } func grabExitStatus(st processState) syscall.WaitStatus { @@ -76,22 +37,6 @@ func grabExitStatus(st processState) syscall.WaitStatus { return exitSt } -func setEnv() { - if os.Getenv("ENVDIR") == "" { - return - } - - m, err := reloadEnv() - if err != nil && err != errNoEnv { - // do something - fmt.Fprintf(os.Stderr, "failed to load from envdir: %s\n", err) - } - - for k, v := range m { - os.Setenv(k, v) - } -} - func parsePortSpec(addr string) (string, int, error) { i := strings.IndexByte(addr, ':') portPart := "" @@ -111,39 +56,168 @@ func parsePortSpec(addr string) (string, int, error) { return addr, int(port), nil } -/* -func (s *Starter) Run() error { - defer s.Teardown() +// This keeps listening to INT,TERM,HUP, and ALRM signals, +// and queues them up into a destination channel +func acceptSignals(ctx context.Context, dst chan os.Signal) { + src := make(chan os.Signal, 32) // up to 32 signals + signal.Notify(src, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGALRM) + signal.Ignore(syscall.SIGPIPE) + defer close(dst) + for { + select { + case <-ctx.Done(): + return + case sig, ok := <-src: + if !ok { + return + } + dst <- sig + } + } +} - if s.pidFile != "" { - f, err := os.OpenFile(s.pidFile, os.O_EXCL|os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) - if err != nil { - return err +func wait(ctx context.Context, sigCh chan os.Signal, workerDone chan *exec.Cmd) *exec.Cmd { + // Original code in lib/Server/Starter.pm (_wait3) looks... interesting + // currently going to punt it in light of just "wait for a process to + // finish or a signal is received" + t := time.NewTicker(time.Second) + defer t.Stop() + for { + select { + case <-ctx.Done(): + p, _ := os.FindProcess(os.Getpid()) + p.Signal(os.Signal(syscall.SIGTERM)) + return nil + case <-t.C: + if len(sigCh) > 0 { + return nil + } + case cmd := <-workerDone: + return cmd } + } + return nil +} - if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX); err != nil { - return err +var registerCleanupKey struct{} + +func registerCleanup(ctx context.Context, f func()) { + register, ok := ctx.Value(registerCleanupKey).(func(func())) + if !ok { + return + } + register(f) +} + +func cleanup(ctx context.Context, ch chan func()) { + var finalizers []func() + for loop := true; loop; { + select { + case <-ctx.Done(): + loop = false + continue + case f, ok := <-ch: + if ok { + finalizers = append(finalizers, f) + } } - fmt.Fprintf(f, "%d", os.Getpid()) - defer func() { - os.Remove(f.Name()) - f.Close() - }() } + for _, f := range finalizers { + f() + } +} - for _, addr := range s.ports { +func (s *Starter) Run(ctx context.Context) error { + var cancel func() + ctx, cancel = context.WithCancel(ctx) + defer cancel() + + var dir string + var interval time.Duration = time.Second + var paths []string + var pidFile string + var listeners []listener + var ports []string + var sigonhup os.Signal = os.Signal(syscall.SIGTERM) + var sigonterm os.Signal = os.Signal(syscall.SIGTERM) + var statusFile string + var noticeOutput io.Writer = os.Stderr + + for _, opt := range s.options { + switch opt.Name() { + case "auto_restart_interval": + v := opt.Value().(int) + os.Setenv(`AUTO_RESTART_INTERVAL`, strconv.Itoa(v)) + case "dir": + dir = opt.Value().(string) + case "enable_auto_restart": + b := opt.Value().(bool) + if b { + os.Setenv(`ENABLE_AUTO_RESTART`, `1`) + } else { + os.Setenv(`ENABLE_AUTO_RESTART`, `0`) + } + case "envdir": + os.Setenv("ENVDIR", opt.Value().(string)) + case "interval": + interval = opt.Value().(time.Duration) + case "kill_old_delay": + os.Setenv(`KILL_OLD_DELAY`, strconv.Itoa(int(opt.Value().(time.Duration)/time.Second))) + case "paths": + paths = opt.Value().([]string) + case "pid_file": + pidFile = opt.Value().(string) + case "ports": + ports = opt.Value().([]string) + case "signal_on_hup": + sigonhup = opt.Value().(os.Signal) + case "signal_on_term": + sigonterm = opt.Value().(os.Signal) + case "status_file": + statusFile = opt.Value().(string) + case "notice_output": + noticeOutput = opt.Value().(io.Writer) + } + } + notice := func(f string, args ...interface{}) { + var buf bytes.Buffer + fmt.Fprintf(&buf, f, args...) + if buf.Len() == 0 { + return + } + + b := buf.Bytes() + if b[len(b)-1] != '\n' { + buf.WriteByte('\n') + } + buf.WriteTo(noticeOutput) + } + + generation := 0 // This is SERVER_STARTER_GENERATION + os.Setenv(`SERVER_STARTER_GENERATION`, `0`) + + cleanupCh := make(chan func()) + ctx = context.WithValue(ctx, registerCleanupKey, func(f func()) { + cleanupCh <- f + }) + go cleanup(ctx, cleanupCh) + + // start listening + extraFiles := make([]*os.File, 0, len(ports)+len(paths)) + portSpecs := make([]string, 0, len(ports)+len(paths)) + for _, addr := range ports { var l net.Listener host, port, err := parsePortSpec(addr) if err != nil { - fmt.Fprintf(os.Stderr, "failed to parse addr spec '%s': %s", addr, err) + notice("failed to parse addr spec '%s': %s", addr, err) return err } hostport := fmt.Sprintf("%s:%d", host, port) l, err = net.Listen("tcp4", hostport) if err != nil { - fmt.Fprintf(os.Stderr, "failed to listen to %s:%s\n", hostport, err) + notice("failed to listen to %s:%s\n", hostport, err) return err } @@ -153,319 +227,359 @@ func (s *Starter) Run() error { } else { spec = fmt.Sprintf("%s:%d", host, port) } - s.listeners = append(s.listeners, listener{listener: l, spec: spec}) + f, err := l.(*net.TCPListener).File() + if err != nil { + return errors.Wrap(err, "failed to get fd from listener") + } + registerCleanup(ctx, func() { f.Close() }) + extraFiles = append(extraFiles, f) + portSpecs = append(portSpecs, fmt.Sprintf("%s=%d", spec, len(portSpecs)+3)) + listeners = append(listeners, listener{listener: l, spec: spec}) } - for _, path := range s.paths { + for _, path := range paths { var l net.Listener if fl, err := os.Lstat(path); err == nil && fl.Mode()&os.ModeSocket == os.ModeSocket { - fmt.Fprintf(os.Stderr, "removing existing socket file:%s\n", path) + notice("removing existing socket file:%s\n", path) err = os.Remove(path) if err != nil { - fmt.Fprintf(os.Stderr, "failed to remove existing socket file:%s:%s\n", path, err) + notice("failed to remove existing socket file:%s:%s\n", path, err) return err } } _ = os.Remove(path) l, err := net.Listen("unix", path) if err != nil { - fmt.Fprintf(os.Stderr, "failed to listen file:%s:%s\n", path, err) + notice("failed to listen file:%s:%s\n", path, err) return err } - s.listeners = append(s.listeners, listener{listener: l, spec: path}) + f, err := l.(*net.UnixListener).File() + if err != nil { + return errors.Wrap(err, "failed to get fd from listener") + } + registerCleanup(ctx, func() { f.Close() }) + extraFiles = append(extraFiles, f) + portSpecs = append(portSpecs, fmt.Sprintf("%s=%d", path, len(portSpecs)+3)) + listeners = append(listeners, listener{listener: l, spec: path}) } - s.generation = 0 - os.Setenv("SERVER_STARTER_GENERATION", fmt.Sprintf("%d", s.generation)) - - // XXX Not portable - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, - syscall.SIGHUP, - syscall.SIGINT, - syscall.SIGTERM, - syscall.SIGQUIT, - ) - - // Okay, ready to launch the program now... - setEnv() - workerCh := make(chan processState) - p := s.StartWorker(sigCh, workerCh) - oldWorkers := make(map[int]int) - var sigReceived os.Signal - var sigToSend os.Signal - - statusCh := make(chan map[int]int) - go func(fn string, ch chan map[int]int) { - for wmap := range ch { - if fn == "" { - continue - } + os.Setenv("SERVER_STARTER_PORT", strings.Join(portSpecs, ";")) + + // Note: environment variables that are set after this + // will NOT be re-populated + sysenv := env.SystemEnvironment() + envLoader := env.NewLoader() - f, err := os.OpenFile(fn, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) + var statusFileCreated bool + defer func() { + if statusFileCreated { + os.Remove(statusFile) + } + }() + var currentWorker int // pid + var lastRestartTime time.Time + oldWorkers := map[int]int{} // pid to generation + + var updateStatus func() error + switch fn := statusFile; fn { + case "": + updateStatus = func() error { return nil } + default: + updateStatus = func() error { + tmpfn := fn + "." + strconv.Itoa(os.Getpid()) + f, err := os.OpenFile(tmpfn, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) if err != nil { - continue + return errors.Wrapf(err, "failed to create temporary file:%s", fn) } - - for gen, pid := range wmap { - fmt.Fprintf(f, "%d:%d\n", gen, pid) + statusFileCreated = true + m := map[int]int{} + for k, v := range oldWorkers { + m[k] = v + } + if currentWorker > 0 { + m[generation] = currentWorker } + keys := make([]int, 0, len(oldWorkers)+1) + for k := range oldWorkers { + keys = append(keys, k) + } + sort.Ints(keys) + for _, k := range keys { + fmt.Fprintf(f, "%d:%d\n", k, m[k]) + } f.Close() + return errors.Wrapf(os.Rename(tmpfn, fn), "failed to rename %s to %s", fn, tmpfn) } - }(s.statusFile, statusCh) + } - defer func() { - if p != nil { - oldWorkers[p.Pid] = s.generation - } - - fmt.Fprintf(os.Stderr, "received %s, sending %s to all workers:", - signame(sigReceived), - signame(sigToSend), - ) - size := len(oldWorkers) - i := 0 - for pid := range oldWorkers { - i++ - fmt.Fprintf(os.Stderr, "%d", pid) - if i < size { - fmt.Fprintf(os.Stderr, ",") - } + // This watcher receives commands to watch for. + workerSrc := make(chan *exec.Cmd) + workerDone := make(chan *exec.Cmd) + go monitor(ctx, workerSrc, workerDone) + + // signal handler here queues up signals to the other + // channel, so that we can keep accepting signals while we + // only really handle them once per loop + sigCh := make(chan os.Signal, 32) + go acceptSignals(ctx, sigCh) + + errTryExec := errors.New("keep trying") + startCmd := func(cmd *exec.Cmd) error { + if err := cmd.Start(); err != nil { + notice("%s", err.Error()) + // We would LOVE to continue immediately, but we need to do the + // same check-for-signals and etc here, so we go on.. + } else { + notice("starting new worker %d", cmd.Process.Pid) } - fmt.Fprintf(os.Stderr, "\n") - for pid := range oldWorkers { - worker, err := os.FindProcess(pid) - if err != nil { - continue + // Wait for up to `interval` seconds before + // checking if this command (process) is alive + time.Sleep(interval) + + // Check if we have received any signals while we were + // waiting. this is a very dirty trick in that we are + // mucking with a channel that is potentially being written + // to concurrently :/ + nonhup := 0 + var bufferedSigs []os.Signal + l := len(sigCh) + for i := 0; i < l; i++ { + s := <-sigCh + bufferedSigs = append(bufferedSigs, s) + if s != os.Signal(syscall.SIGHUP) { + // do not immediately stop... read all + nonhup++ } - worker.Signal(sigToSend) } - - for len(oldWorkers) > 0 { - st := <-workerCh - fmt.Fprintf(os.Stderr, "worker %d died, status:%d\n", st.Pid(), grabExitStatus(st)) - delete(oldWorkers, st.Pid()) + if len(bufferedSigs) > 0 { + go func() { + for _, s := range bufferedSigs { + sigCh <- s + } + }() + if nonhup > 0 { // bailout + return errors.New("received signal while waiting") + } } - fmt.Fprintf(os.Stderr, "exiting\n") - }() - - // var lastRestartTime time.Time - for { // outer loop - setEnv() - // Just wait for the worker to exit, or for us to receive a signal - for { - // restart = 2: force restart - // restart = 1 and no workers: force restart - // restart = 0: no restart - restart := 0 - - select { - case st := <-workerCh: - // oops, the worker exited? check for its pid - if p.Pid == st.Pid() { // current worker - exitSt := grabExitStatus(st) - fmt.Fprintf(os.Stderr, "worker %d died unexpectedly with status %d, restarting\n", p.Pid, exitSt) - p = s.StartWorker(sigCh, workerCh) - // lastRestartTime = time.Now() - } else { - exitSt := grabExitStatus(st) - fmt.Fprintf(os.Stderr, "old worker %d died, status:%d\n", st.Pid(), exitSt) - delete(oldWorkers, st.Pid()) - } - case sigReceived = <-sigCh: - // Temporary fix - switch sigReceived { - case syscall.SIGHUP: - // When we receive a HUP signal, we need to spawn a new worker - fmt.Fprintf(os.Stderr, "received HUP (num_old_workers=TODO)\n") - restart = 1 - sigToSend = s.signalOnHUP - case syscall.SIGTERM: - sigToSend = s.signalOnTERM - return nil - default: - sigToSend = syscall.SIGTERM + // Want to check if the given PID is still alive. + // This is not a great way to do it b/c we're not + // even sure the Pid we're looking for is the same + // process as the one we spawned, but... this is + // so far the best we can do + // Note: Does this work on windows? + if cmd.Process != nil { + p, err := os.FindProcess(cmd.Process.Pid) + if err == nil { + if err := p.Signal(os.Signal(syscall.Signal(0))); err == nil { return nil } } + } - if restart > 1 || restart > 0 && len(oldWorkers) == 0 { - fmt.Fprintf(os.Stderr, "spawning a new worker (num_old_workers=TODO)\n") - oldWorkers[p.Pid] = s.generation - p = s.StartWorker(sigCh, workerCh) - fmt.Fprintf(os.Stderr, "new worker is now running, sending %s to old workers:", signame(sigToSend)) - size := len(oldWorkers) - if size == 0 { - fmt.Fprintf(os.Stderr, "none\n") - } else { - i := 0 - for pid := range oldWorkers { - i++ - fmt.Fprintf(os.Stderr, "%d", pid) - if i < size { - fmt.Fprintf(os.Stderr, ",") - } - } - fmt.Fprintf(os.Stderr, "\n") - - killOldDelay := getKillOldDelay() - fmt.Fprintf(os.Stderr, "sleep %d secs\n", int(killOldDelay/time.Second)) - if killOldDelay > 0 { - time.Sleep(killOldDelay) - } - - fmt.Fprintf(os.Stderr, "killing old workers\n") - - for pid := range oldWorkers { - worker, err := os.FindProcess(pid) - if err != nil { - fmt.Fprintf(os.Stderr, "failed to find process %d\n", pid) - continue - } - worker.Signal(s.signalOnHUP) - } - } - } + switch { + case cmd.ProcessState != nil: + notice("new worker %d seems to have failed to start, exit status:%d", cmd.ProcessState.Pid(), grabExitStatus(cmd.ProcessState)) + case cmd.Process != nil: + notice("new worker %d seems to have failed to start", cmd.Process.Pid) + default: + notice("new worker seems to have failed to start") } + return errTryExec } - return nil -} -*/ - -func getKillOldDelay() time.Duration { - // Ignore errors. - delay, _ := strconv.ParseInt(os.Getenv("KILL_OLD_DELAY"), 10, 0) - autoRestart, _ := strconv.ParseBool(os.Getenv("ENABLE_AUTO_RESTART")) - if autoRestart && delay == 0 { - delay = 5 - } + if pidFile != "" { + f, err := os.OpenFile(pidFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) + if err != nil { + return errors.Wrapf(err, "failed to open file:%s", pidFile) + } + defer f.Close() + defer os.Remove(f.Name()) - return time.Duration(delay) * time.Second -} + if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX); err != nil { + return errors.Wrapf(err, "flock failed(%s)", pidFile) + } + fmt.Fprintf(f, "%d\n", os.Getpid()) + if err := f.Sync(); err != nil { + return errors.Wrapf(err, "failed to sync file(%s)", pidFile) + } + } -// StartWorker starts the actual command. -func (s *Starter) StartWorker(sigCh chan os.Signal, ch chan processState) *os.Process { - // Don't give up until we're running. - for { - pid := -1 + newCommand := func() *exec.Cmd { cmd := exec.Command(s.command, s.args...) - if s.dir != "" { - cmd.Dir = s.dir + if dir != "" { + cmd.Dir = dir } cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr + cmd.ExtraFiles = extraFiles + return cmd + } - // This whole section here basically sets up the env - // var and the file descriptors that are inherited by the - // external process - files := make([]*os.File, len(s.ports)+len(s.paths)) - ports := make([]string, len(s.ports)+len(s.paths)) - for i, l := range s.listeners { - // file descriptor numbers in ExtraFiles turn out to be - // index + 3, so we can just hard code it - var f *os.File - var err error - switch l.listener.(type) { - case *net.TCPListener: - f, err = l.listener.(*net.TCPListener).File() - case *net.UnixListener: - f, err = l.listener.(*net.UnixListener).File() + startWorker := func() error { + for loop := true; loop; { + generation++ + os.Setenv(`SERVER_STARTER_GENERATION`, strconv.Itoa(generation)) + + cmd := newCommand() + switch err := startCmd(cmd); err { + case nil: + loop = false + currentWorker = cmd.Process.Pid + lastRestartTime = time.Now() + updateStatus() + workerSrc <- cmd + case errTryExec: + // keep trying default: - panic("Unknown listener type") + return errors.Wrap(err, "failed to start command") } - if err != nil { - panic(err) + } + + return nil + } + + var cleanupWorkers = func(sig os.Signal) { + termSig := os.Signal(syscall.SIGTERM) + if sig == termSig { + termSig = sigonterm + } + + if currentWorker > 0 { + oldWorkers[currentWorker] = envAsInt(`SERVER_STARTER_GENERATION`) + currentWorker = 0 + } + var buf bytes.Buffer + fmt.Fprintf(&buf, "received %s, sending %s to all workers:", signame(sig), signame(termSig)) + keys := make([]int, 0, len(oldWorkers)) + for k := range oldWorkers { + keys = append(keys, k) + } + sort.Ints(keys) + for i, k := range keys { + fmt.Fprintf(&buf, "%d", k) + if i < len(keys)-1 { + buf.WriteByte(',') } - defer f.Close() - ports[i] = fmt.Sprintf("%s=%d", l.spec, i+3) - files[i] = f } - cmd.ExtraFiles = files + notice(buf.String()) - s.generation++ - os.Setenv("SERVER_STARTER_PORT", strings.Join(ports, ";")) - os.Setenv("SERVER_STARTER_GENERATION", fmt.Sprintf("%d", s.generation)) + for _, pid := range keys { + p, err := os.FindProcess(pid) + if err != nil { // XXX to be safe, let's delete this pid + delete(oldWorkers, pid) + } + p.Signal(termSig) + } - // Now start! - if err := cmd.Start(); err != nil { - fmt.Fprintf(os.Stderr, "failed to exec %s: %s\n", cmd.Path, err) - } else { - // Save pid... - pid = cmd.Process.Pid - fmt.Fprintf(os.Stderr, "starting new worker %d\n", pid) - - // Wait for interval before checking if the process is alive - tch := time.After(s.interval) - sigs := []os.Signal{} - for loop := true; loop; { - select { - case <-tch: - // bail out - loop = false - case sig := <-sigCh: - sigs = append(sigs, sig) - } + for len(oldWorkers) > 0 { + cmd, ok := <-workerDone + if !ok { + panic("workerDone channel closed while still waiting for children to be reaped") } + notice("worker %d died, status:%d", cmd.ProcessState.Pid(), grabExitStatus(cmd.ProcessState)) + delete(oldWorkers, cmd.ProcessState.Pid()) + updateStatus() + } + } - // if received any signals, during the wait, we bail out - gotSig := false - if len(sigs) > 0 { - for _, sig := range sigs { - // we need to resend these signals so it can be caught in the - // main routine... - go func() { sigCh <- sig }() - if sysSig, ok := sig.(syscall.Signal); ok { - if sysSig != syscall.SIGHUP { - gotSig = true - } - } - } + if err := startWorker(); err != nil { + return errors.Wrap(err, "failed to start worker") + } + + for { + // wait for next signal (or when auto-restart becomes necessary) + exited := wait(ctx, sigCh, workerDone) + + // reload env if necessary + envLoader.Apply(ctx, sysenv) + + if envAsBool(`ENABLE_AUTO_RESTART`) { + if os.Getenv("AUTO_RESTART_INTERVAL") == "" { + os.Setenv("AUTO_RESTART_INTERVAL", "360") } + } - // Check if we can find a process by its pid - p, err := os.FindProcess(pid) - if gotSig || err == nil { - // No error? We were successful! Make sure we capture - // the program exiting - go func() { - err := cmd.Wait() - if err != nil { - ch <- err.(*exec.ExitError).ProcessState - } else { - ch <- &dummyProcessState{pid: pid, status: successStatus} - } - }() - // Bail out - return p + if exited != nil { // got some command exit + pid := exited.ProcessState.Pid() + if pid == currentWorker { + notice("worker %d died unexpectedly with status: %d, restarting\n", pid, grabExitStatus(exited.ProcessState)) + if err := startWorker(); err != nil { + return errors.Wrap(err, "failed to start worker") + } + } else { + notice("old worker %d died, status:%d", pid, grabExitStatus(exited.ProcessState)) + delete(oldWorkers, pid) + updateStatus() } + } + var restart bool + for loop := true; loop; { + select { + case sig := <-sigCh: + switch sig { + case syscall.SIGHUP: + restart = true + loop = false + case syscall.SIGALRM: + loop = false + default: + cleanupWorkers(sig) + return nil + } + default: + loop = false + } } - // If we fall through here, we prematurely exited :/ - // Make sure to wait to release resources - cmd.Wait() - for _, f := range cmd.ExtraFiles { - f.Close() + + if !restart && envAsBool("ENABLE_AUTO_RESTART") { + autoRestartInterval := envAsDuration("AUTO_RESTART_INTERVAL") + elapsedSinceRestart := time.Since(lastRestartTime) + if elapsedSinceRestart >= autoRestartInterval && len(oldWorkers) == 0 { + notice("autorestart triggered (interval=%s)", autoRestartInterval) + restart = true + } else if elapsedSinceRestart >= autoRestartInterval*2 { + notice("autorestart triggered (forced, interval=%s)", autoRestartInterval) + } } - fmt.Fprintf(os.Stderr, "new worker %d seems to have failed to start\n", pid) - } + if restart { + oldWorkers[currentWorker] = generation + if err := startWorker(); err != nil { + return errors.Wrap(err, "failed to restart worker") + } - // never reached - return nil -} + var buf bytes.Buffer + l := len(oldWorkers) + if l == 0 { + buf.WriteString("none") + } else { + i := 0 + for pid := range oldWorkers { + buf.WriteString(strconv.Itoa(pid)) + if i < l-1 { + buf.WriteByte(',') + } + } + } + notice("new worker is now running, sending %s to old workers: %s", signame(sigonhup), buf.String()) -func (s *Starter) Teardown() error { - if s.statusFile != "" { - os.Remove(s.statusFile) - } + killOldDelay := envAsDuration(`KILL_OLD_DELAY`) + if killOldDelay == 0 && envAsBool(`ENABLE_AUTO_RESTART`) { + killOldDelay = 5 * time.Second + } - for _, l := range s.listeners { - l.listener.Close() + time.Sleep(killOldDelay) + for pid := range oldWorkers { + worker, err := os.FindProcess(pid) + if err != nil { + continue + } + worker.Signal(sigonhup) + } + } } - - return nil } diff --git a/worker.go b/worker.go deleted file mode 100644 index fae28ce..0000000 --- a/worker.go +++ /dev/null @@ -1,576 +0,0 @@ -package starter - -import ( - "bytes" - "context" - "fmt" - "io" - "net" - "os" - "os/exec" - "os/signal" - "sort" - "strconv" - "strings" - "syscall" - "time" - - "github.com/lestrrat/go-server-starter/internal/env" - "github.com/pkg/errors" -) - -func (s *Starter) Notice(f string, args ...interface{}) { - var buf bytes.Buffer - fmt.Fprintf(&buf, f, args...) - if buf.Len() == 0 { - return - } - - b := buf.Bytes() - if b[len(b)-1] != '\n' { - buf.WriteByte('\n') - } - buf.WriteTo(s.noticeWriter) -} - -func envAsBool(name string) bool { - b, err := strconv.ParseBool(os.Getenv(name)) - return err == nil && b -} - -func envAsInt(name string) int { - i, _ := strconv.ParseInt(os.Getenv(name), 10, 64) - return int(i) -} - -func envAsDuration(name string) time.Duration { - return time.Duration(envAsInt(name)) * time.Second -} - -// This keeps listening to INT,TERM,HUP, and ALRM signals, -// and queues them up into a destination channel -func acceptSignals(ctx context.Context, dst chan os.Signal) { - src := make(chan os.Signal, 32) // up to 32 signals - signal.Notify(src, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGALRM) - signal.Ignore(syscall.SIGPIPE) - defer close(dst) - for { - select { - case <-ctx.Done(): - return - case sig, ok := <-src: - if !ok { - return - } - dst <- sig - } - } -} - -func wait(ctx context.Context, sigCh chan os.Signal, workerDone chan *exec.Cmd) *exec.Cmd { - // Original code in lib/Server/Starter.pm (_wait3) looks... interesting - // currently going to punt it in light of just "wait for a process to - // finish or a signal is received" - t := time.NewTicker(time.Second) - defer t.Stop() - for { - select { - case <-ctx.Done(): - p, _ := os.FindProcess(os.Getpid()) - p.Signal(os.Signal(syscall.SIGTERM)) - return nil - case <-t.C: - if len(sigCh) > 0 { - return nil - } - case cmd := <-workerDone: - return cmd - } - } - return nil -} - -var registerCleanupKey struct{} - -func registerCleanup(ctx context.Context, f func()) { - register, ok := ctx.Value(registerCleanupKey).(func(func())) - if !ok { - return - } - register(f) -} - -func cleanup(ctx context.Context, ch chan func()) { - var finalizers []func() - for loop := true; loop; { - select { - case <-ctx.Done(): - loop = false - continue - case f, ok := <-ch: - if ok { - finalizers = append(finalizers, f) - } - } - } - for _, f := range finalizers { - f() - } -} - -func (s *Starter) Run(ctx context.Context) error { - var cancel func() - ctx, cancel = context.WithCancel(ctx) - defer cancel() - - var dir string - var interval time.Duration = time.Second - var paths []string - var pidFile string - var listeners []listener - var ports []string - var sigonhup os.Signal = os.Signal(syscall.SIGTERM) - var sigonterm os.Signal = os.Signal(syscall.SIGTERM) - var statusFile string - var noticeOutput io.Writer = os.Stderr - - for _, opt := range s.options { - switch opt.Name() { - case "auto_restart_interval": - v := opt.Value().(int) - os.Setenv(`AUTO_RESTART_INTERVAL`, strconv.Itoa(v)) - case "dir": - dir = opt.Value().(string) - case "enable_auto_restart": - b := opt.Value().(bool) - if b { - os.Setenv(`ENABLE_AUTO_RESTART`, `1`) - } else { - os.Setenv(`ENABLE_AUTO_RESTART`, `0`) - } - case "envdir": - os.Setenv("ENVDIR", opt.Value().(string)) - case "interval": - interval = opt.Value().(time.Duration) - case "kill_old_delay": - os.Setenv(`KILL_OLD_DELAY`, strconv.Itoa(int(opt.Value().(time.Duration)/time.Second))) - case "paths": - paths = opt.Value().([]string) - case "pid_file": - pidFile = opt.Value().(string) - case "ports": - ports = opt.Value().([]string) - case "signal_on_hup": - sigonhup = opt.Value().(os.Signal) - case "signal_on_term": - sigonterm = opt.Value().(os.Signal) - case "status_file": - statusFile = opt.Value().(string) - case "notice_output": - noticeOutput = opt.Value().(io.Writer) - } - } - notice := func(f string, args ...interface{}) { - var buf bytes.Buffer - fmt.Fprintf(&buf, f, args...) - if buf.Len() == 0 { - return - } - - b := buf.Bytes() - if b[len(b)-1] != '\n' { - buf.WriteByte('\n') - } - buf.WriteTo(noticeOutput) - } - - generation := 0 // This is SERVER_STARTER_GENERATION - os.Setenv(`SERVER_STARTER_GENERATION`, `0`) - - cleanupCh := make(chan func()) - ctx = context.WithValue(ctx, registerCleanupKey, func(f func()) { - cleanupCh <- f - }) - go cleanup(ctx, cleanupCh) - - // start listening - extraFiles := make([]*os.File, 0, len(ports)+len(paths)) - portSpecs := make([]string, 0, len(ports)+len(paths)) - for _, addr := range ports { - var l net.Listener - - host, port, err := parsePortSpec(addr) - if err != nil { - notice("failed to parse addr spec '%s': %s", addr, err) - return err - } - - hostport := fmt.Sprintf("%s:%d", host, port) - l, err = net.Listen("tcp4", hostport) - if err != nil { - notice("failed to listen to %s:%s\n", hostport, err) - return err - } - - spec := "" - if host == "" { - spec = fmt.Sprintf("%d", port) - } else { - spec = fmt.Sprintf("%s:%d", host, port) - } - f, err := l.(*net.TCPListener).File() - if err != nil { - return errors.Wrap(err, "failed to get fd from listener") - } - registerCleanup(ctx, func() { f.Close() }) - extraFiles = append(extraFiles, f) - portSpecs = append(portSpecs, fmt.Sprintf("%s=%d", spec, len(portSpecs)+3)) - listeners = append(listeners, listener{listener: l, spec: spec}) - } - - for _, path := range paths { - var l net.Listener - if fl, err := os.Lstat(path); err == nil && fl.Mode()&os.ModeSocket == os.ModeSocket { - notice("removing existing socket file:%s\n", path) - err = os.Remove(path) - if err != nil { - notice("failed to remove existing socket file:%s:%s\n", path, err) - return err - } - } - _ = os.Remove(path) - l, err := net.Listen("unix", path) - if err != nil { - notice("failed to listen file:%s:%s\n", path, err) - return err - } - f, err := l.(*net.UnixListener).File() - if err != nil { - return errors.Wrap(err, "failed to get fd from listener") - } - registerCleanup(ctx, func() { f.Close() }) - extraFiles = append(extraFiles, f) - portSpecs = append(portSpecs, fmt.Sprintf("%s=%d", path, len(portSpecs)+3)) - listeners = append(listeners, listener{listener: l, spec: path}) - } - - os.Setenv("SERVER_STARTER_PORT", strings.Join(portSpecs, ";")) - - // Note: environment variables that are set after this - // will NOT be re-populated - sysenv := env.SystemEnvironment() - envLoader := env.NewLoader() - - var statusFileCreated bool - defer func() { - if statusFileCreated { - os.Remove(statusFile) - } - }() - var currentWorker int // pid - var lastRestartTime time.Time - oldWorkers := map[int]int{} // pid to generation - - var updateStatus func() error - switch fn := statusFile; fn { - case "": - updateStatus = func() error { return nil } - default: - updateStatus = func() error { - tmpfn := fn + "." + strconv.Itoa(os.Getpid()) - f, err := os.OpenFile(tmpfn, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) - if err != nil { - return errors.Wrapf(err, "failed to create temporary file:%s", fn) - } - statusFileCreated = true - m := map[int]int{} - for k, v := range oldWorkers { - m[k] = v - } - if currentWorker > 0 { - m[generation] = currentWorker - } - - keys := make([]int, 0, len(oldWorkers)+1) - for k := range oldWorkers { - keys = append(keys, k) - } - sort.Ints(keys) - for _, k := range keys { - fmt.Fprintf(f, "%d:%d\n", k, m[k]) - } - f.Close() - return errors.Wrapf(os.Rename(tmpfn, fn), "failed to rename %s to %s", fn, tmpfn) - } - } - - // This watcher receives commands to watch for. - workerSrc := make(chan *exec.Cmd) - workerDone := make(chan *exec.Cmd) - go monitor(ctx, workerSrc, workerDone) - - // signal handler here queues up signals to the other - // channel, so that we can keep accepting signals while we - // only really handle them once per loop - sigCh := make(chan os.Signal, 32) - go acceptSignals(ctx, sigCh) - - errTryExec := errors.New("keep trying") - startCmd := func(cmd *exec.Cmd) error { - if err := cmd.Start(); err != nil { - notice("%s", err.Error()) - // We would LOVE to continue immediately, but we need to do the - // same check-for-signals and etc here, so we go on.. - } else { - notice("starting new worker %d", cmd.Process.Pid) - } - - // Wait for up to `interval` seconds before - // checking if this command (process) is alive - time.Sleep(interval) - - // Check if we have received any signals while we were - // waiting. this is a very dirty trick in that we are - // mucking with a channel that is potentially being written - // to concurrently :/ - nonhup := 0 - var bufferedSigs []os.Signal - l := len(sigCh) - for i := 0; i < l; i++ { - s := <-sigCh - bufferedSigs = append(bufferedSigs, s) - if s != os.Signal(syscall.SIGHUP) { - // do not immediately stop... read all - nonhup++ - } - } - if len(bufferedSigs) > 0 { - go func() { - for _, s := range bufferedSigs { - sigCh <- s - } - }() - if nonhup > 0 { // bailout - return errors.New("received signal while waiting") - } - } - - // Want to check if the given PID is still alive. - // This is not a great way to do it b/c we're not - // even sure the Pid we're looking for is the same - // process as the one we spawned, but... this is - // so far the best we can do - // Note: Does this work on windows? - if cmd.Process != nil { - p, err := os.FindProcess(cmd.Process.Pid) - if err == nil { - if err := p.Signal(os.Signal(syscall.Signal(0))); err == nil { - return nil - } - } - } - - switch { - case cmd.ProcessState != nil: - notice("new worker %d seems to have failed to start, exit status:%d", cmd.ProcessState.Pid(), grabExitStatus(cmd.ProcessState)) - case cmd.Process != nil: - notice("new worker %d seems to have failed to start", cmd.Process.Pid) - default: - notice("new worker seems to have failed to start") - } - return errTryExec - } - - if pidFile != "" { - f, err := os.OpenFile(pidFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) - if err != nil { - return errors.Wrapf(err, "failed to open file:%s", pidFile) - } - defer f.Close() - defer os.Remove(f.Name()) - - if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX); err != nil { - return errors.Wrapf(err, "flock failed(%s)", pidFile) - } - fmt.Fprintf(f, "%d\n", os.Getpid()) - if err := f.Sync(); err != nil { - return errors.Wrapf(err, "failed to sync file(%s)", pidFile) - } - } - - newCommand := func() *exec.Cmd { - cmd := exec.Command(s.command, s.args...) - if dir != "" { - cmd.Dir = dir - } - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.ExtraFiles = extraFiles - return cmd - } - - startWorker := func() error { - for loop := true; loop; { - generation++ - os.Setenv(`SERVER_STARTER_GENERATION`, strconv.Itoa(generation)) - - cmd := newCommand() - switch err := startCmd(cmd); err { - case nil: - loop = false - currentWorker = cmd.Process.Pid - lastRestartTime = time.Now() - updateStatus() - workerSrc <- cmd - case errTryExec: - // keep trying - default: - return errors.Wrap(err, "failed to start command") - } - } - - return nil - } - - var cleanupWorkers = func(sig os.Signal) { - termSig := os.Signal(syscall.SIGTERM) - if sig == termSig { - termSig = sigonterm - } - - if currentWorker > 0 { - oldWorkers[currentWorker] = envAsInt(`SERVER_STARTER_GENERATION`) - currentWorker = 0 - } - var buf bytes.Buffer - fmt.Fprintf(&buf, "received %s, sending %s to all workers:", signame(sig), signame(termSig)) - keys := make([]int, 0, len(oldWorkers)) - for k := range oldWorkers { - keys = append(keys, k) - } - sort.Ints(keys) - for i, k := range keys { - fmt.Fprintf(&buf, "%d", k) - if i < len(keys)-1 { - buf.WriteByte(',') - } - } - notice(buf.String()) - - for _, pid := range keys { - p, err := os.FindProcess(pid) - if err != nil { // XXX to be safe, let's delete this pid - delete(oldWorkers, pid) - } - p.Signal(termSig) - } - - for len(oldWorkers) > 0 { - cmd, ok := <-workerDone - if !ok { - panic("workerDone channel closed while still waiting for children to be reaped") - } - notice("worker %d died, status:%d", cmd.ProcessState.Pid(), grabExitStatus(cmd.ProcessState)) - delete(oldWorkers, cmd.ProcessState.Pid()) - updateStatus() - } - } - - if err := startWorker(); err != nil { - return errors.Wrap(err, "failed to start worker") - } - - for { - // wait for next signal (or when auto-restart becomes necessary) - exited := wait(ctx, sigCh, workerDone) - - // reload env if necessary - envLoader.Apply(ctx, sysenv) - - if envAsBool(`ENABLE_AUTO_RESTART`) { - if os.Getenv("AUTO_RESTART_INTERVAL") == "" { - os.Setenv("AUTO_RESTART_INTERVAL", "360") - } - } - - if exited != nil { // got some command exit - pid := exited.ProcessState.Pid() - if pid == currentWorker { - notice("worker %d died unexpectedly with status: %d, restarting\n", pid, grabExitStatus(exited.ProcessState)) - if err := startWorker(); err != nil { - return errors.Wrap(err, "failed to start worker") - } - } else { - notice("old worker %d died, status:%d", pid, grabExitStatus(exited.ProcessState)) - delete(oldWorkers, pid) - updateStatus() - } - } - - var restart bool - for loop := true; loop; { - select { - case sig := <-sigCh: - switch sig { - case syscall.SIGHUP: - restart = true - loop = false - case syscall.SIGALRM: - loop = false - default: - cleanupWorkers(sig) - return nil - } - default: - loop = false - } - } - - if !restart && envAsBool("ENABLE_AUTO_RESTART") { - autoRestartInterval := envAsDuration("AUTO_RESTART_INTERVAL") - elapsedSinceRestart := time.Since(lastRestartTime) - if elapsedSinceRestart >= autoRestartInterval && len(oldWorkers) == 0 { - notice("autorestart triggered (interval=%s)", autoRestartInterval) - restart = true - } else if elapsedSinceRestart >= autoRestartInterval*2 { - notice("autorestart triggered (forced, interval=%s)", autoRestartInterval) - } - } - - if restart { - oldWorkers[currentWorker] = generation - if err := startWorker(); err != nil { - return errors.Wrap(err, "failed to restart worker") - } - - var buf bytes.Buffer - l := len(oldWorkers) - if l == 0 { - buf.WriteString("none") - } else { - i := 0 - for pid := range oldWorkers { - buf.WriteString(strconv.Itoa(pid)) - if i < l-1 { - buf.WriteByte(',') - } - } - } - notice("new worker is now running, sending %s to old workers: %s", signame(sigonhup), buf.String()) - - killOldDelay := envAsDuration(`KILL_OLD_DELAY`) - if killOldDelay == 0 && envAsBool(`ENABLE_AUTO_RESTART`) { - killOldDelay = 5 * time.Second - } - - time.Sleep(killOldDelay) - for pid := range oldWorkers { - worker, err := os.FindProcess(pid) - if err != nil { - continue - } - worker.Signal(sigonhup) - } - } - } -} From 1fc837f17838b0ac7da38a76460bf5d113098f95 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Tue, 29 Nov 2016 18:34:20 +0900 Subject: [PATCH 07/26] increase the wait --- starter_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starter_test.go b/starter_test.go index 6d5a5e3..521fa2e 100644 --- a/starter_test.go +++ b/starter_test.go @@ -212,7 +212,7 @@ func TestRun(t *testing.T) { } }() - time.Sleep(time.Second) + time.Sleep(2 * time.Second) var closed bool select { From a5728ffc391d56ecff970df7ba26a5c0401c82ac Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Tue, 29 Nov 2016 20:01:11 +0900 Subject: [PATCH 08/26] forgot to pass args, various tweaks --- options.go | 8 +++++++ signals.go | 3 ++- starter.go | 16 ++++++++++--- starter_test.go | 62 +++++++++++++++++++++++++++++++++++++++---------- 4 files changed, 73 insertions(+), 16 deletions(-) diff --git a/options.go b/options.go index efa82f7..619d5c3 100644 --- a/options.go +++ b/options.go @@ -76,6 +76,14 @@ func WithNoticeOutput(w io.Writer) Option { return &valueOption{name: "notice_output", value: w} } +func WithLogStdout(w io.Writer) Option { + return &valueOption{name: "log_stdout", value: w} +} + +func WithLogStderr(w io.Writer) Option { + return &valueOption{name: "log_stderr", value: w} +} + func (o *stringOpt) String() string { return o.Value } diff --git a/signals.go b/signals.go index 3b141ff..2224f83 100644 --- a/signals.go +++ b/signals.go @@ -1,6 +1,7 @@ package starter import ( + "fmt" "os" "strings" "syscall" @@ -46,7 +47,7 @@ func signame(s os.Signal) string { if ss, ok := s.(syscall.Signal); ok { return niceSigNames[ss] } - return "UNKNOWN" + return fmt.Sprintf("UNKNOWN (%s)", s) } func sigFromName(n string) os.Signal { diff --git a/starter.go b/starter.go index f3dffc7..e41cb30 100644 --- a/starter.go +++ b/starter.go @@ -132,6 +132,7 @@ func (s *Starter) Run(ctx context.Context) error { ctx, cancel = context.WithCancel(ctx) defer cancel() + var cmdArgs []string var dir string var interval time.Duration = time.Second var paths []string @@ -142,9 +143,13 @@ func (s *Starter) Run(ctx context.Context) error { var sigonterm os.Signal = os.Signal(syscall.SIGTERM) var statusFile string var noticeOutput io.Writer = os.Stderr + var logStdout io.Writer = os.Stdout + var logStderr io.Writer = os.Stderr for _, opt := range s.options { switch opt.Name() { + case "args": + cmdArgs = opt.Value().([]string) case "auto_restart_interval": v := opt.Value().(int) os.Setenv(`AUTO_RESTART_INTERVAL`, strconv.Itoa(v)) @@ -177,6 +182,10 @@ func (s *Starter) Run(ctx context.Context) error { statusFile = opt.Value().(string) case "notice_output": noticeOutput = opt.Value().(io.Writer) + case "log_stdout": + logStdout = opt.Value().(io.Writer) + case "log_stderr": + logStderr = opt.Value().(io.Writer) } } notice := func(f string, args ...interface{}) { @@ -408,12 +417,12 @@ func (s *Starter) Run(ctx context.Context) error { } newCommand := func() *exec.Cmd { - cmd := exec.Command(s.command, s.args...) + cmd := exec.Command(s.command, cmdArgs...) if dir != "" { cmd.Dir = dir } - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr + cmd.Stdout = logStdout + cmd.Stderr = logStderr cmd.ExtraFiles = extraFiles return cmd } @@ -522,6 +531,7 @@ func (s *Starter) Run(ctx context.Context) error { case sig := <-sigCh: switch sig { case syscall.SIGHUP: + notice("received HUP, spawning a new worker") restart = true loop = false case syscall.SIGALRM: diff --git a/starter_test.go b/starter_test.go index 521fa2e..53b2e33 100644 --- a/starter_test.go +++ b/starter_test.go @@ -35,8 +35,9 @@ import ( ) func main() { - var maxSigterm int // number of times we "withstand" a sigterm - flag.IntVar(&maxSigterm, "sigterm", 0, "") + fmt.Fprintf(os.Stderr, "Starting echod (%d)\n", os.Getpid()) + var maxSigusr1 int // number of times we "withstand" a sigusr1 + flag.IntVar(&maxSigusr1, "sigusr1", 0, "") flag.Parse() listeners, err := listener.ListenAll() @@ -54,22 +55,29 @@ func main() { io.Copy(w, r.Body) }) for _, l := range listeners { - http.Serve(l, handler) + go http.Serve(l, handler) } + fmt.Fprintf(os.Stderr, "echod: Waiting for signal (max USR1 = %d)\n", maxSigusr1) sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGHUP) + signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGUSR1) - sigterm := 0 + sigusr1 := 0 for loop := true; loop; { select { case s := <-sigCh: + fmt.Fprintf(os.Stderr, "echod: received %s\n", s) switch s { - case syscall.SIGTERM: - sigterm++ - if maxSigterm <= sigterm { + case syscall.SIGUSR1: + sigusr1++ + if maxSigusr1 > sigusr1 { + fmt.Fprintf(os.Stderr, "echod: got USR1, ignoring (max = %d, count = %d)\n", maxSigusr1, sigusr1) + } else { + fmt.Fprintf(os.Stderr, "echod: reached max USR1 limit (%d)\n", maxSigusr1) loop = false } + case syscall.SIGTERM: + loop = false default: // do nothing } @@ -190,7 +198,11 @@ func TestRun(t *testing.T) { l.Restore(context.Background(), sysenv) t.Run("send multiple signals", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + // Note: this test does NOT test that the same echod server has received + // signals, because doing that intelligently would require rpc between + // this test code and the echod, and I really am in no mood to do it + // for now. However, visually it looks like it's doing the right job + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() port, err := tcputil.EmptyPort() @@ -202,7 +214,14 @@ func TestRun(t *testing.T) { defer func() { t.Logf("%s", output.String()) }() - sd := New(cmdname, WithArgs("-sigterm", "2"), WithPorts([]string{strconv.Itoa(port)}), WithNoticeOutput(&output)) + sd := New(cmdname, + WithArgs("--sigusr1=2"), + WithPorts([]string{strconv.Itoa(port)}), + WithNoticeOutput(&output), + WithLogStdout(&output), + WithLogStderr(&output), + WithSignalOnHUP(syscall.SIGUSR1), + ) done := make(chan struct{}) go func() { @@ -224,10 +243,25 @@ func TestRun(t *testing.T) { if !closed { p, _ := os.FindProcess(os.Getpid()) - p.Signal(syscall.SIGTERM) + p.Signal(syscall.SIGHUP) } - time.Sleep(time.Second) + time.Sleep(2 * time.Second) + + closed = false + select { + case <-done: + closed = true + t.Errorf("unexpected exit") + default: + } + + if !closed { + p, _ := os.FindProcess(os.Getpid()) + p.Signal(syscall.SIGHUP) + } + + time.Sleep(2 * time.Second) closed = false select { @@ -242,9 +276,13 @@ func TestRun(t *testing.T) { p.Signal(syscall.SIGTERM) } + time.Sleep(time.Second) + select { case <-ctx.Done(): t.Errorf("context prematurely ended: %s", ctx.Err()) + p, _ := os.FindProcess(os.Getpid()) + p.Signal(syscall.SIGTERM) case <-done: } }) From c0dfba657b14b8276bc64d946c17e79156cfcb06 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Tue, 29 Nov 2016 20:08:56 +0900 Subject: [PATCH 09/26] Remove more unnecessary code --- interface.go | 62 ++--------------------------------------------- status_any.go | 10 -------- status_unix.go | 9 +++++++ status_windows.go | 3 --- 4 files changed, 11 insertions(+), 73 deletions(-) delete mode 100644 status_any.go create mode 100644 status_unix.go diff --git a/interface.go b/interface.go index cf0eebe..14a1de4 100644 --- a/interface.go +++ b/interface.go @@ -1,18 +1,12 @@ package starter import ( - "io" "net" - "os" "syscall" - "time" - - "github.com/lestrrat/go-server-starter/internal/env" ) const version = `0.0.2` -var successStatus syscall.WaitStatus var failureStatus syscall.WaitStatus type listener struct { @@ -25,39 +19,9 @@ type Option interface { Value() interface{} } -type Config interface { - Args() []string - Command() string - Dir() string // Dirctory to chdir to before executing the command - Interval() time.Duration // Time between checks for liveness - PidFile() string - Ports() []string // Ports to bind to (addr:port or port, so it's a string) - Paths() []string // Paths (UNIX domain socket) to bind to - SignalOnHUP() os.Signal // Signal to send when HUP is received - SignalOnTERM() os.Signal // Signal to send when TERM is received - StatusFile() string -} - type Starter struct { - options []Option - interval time.Duration - envLoader *env.Loader - noticeWriter io.Writer - extraFiles []*os.File - portSpecs []string - - signalOnHUP os.Signal - signalOnTERM os.Signal - // you can't set this in go: backlog - statusFile string - pidFile string - dir string - ports []string - paths []string - listeners []listener - generation int - command string - args []string + options []Option + command string } type processState interface { @@ -65,26 +29,6 @@ type processState interface { Sys() interface{} } -type dummyProcessState struct { - pid int - status syscall.WaitStatus -} - -func (d dummyProcessState) Pid() int { - return d.pid -} - -func (d dummyProcessState) Sys() interface{} { - return d.status -} - -type WorkerState int - -const ( - WorkerStarted WorkerState = iota - ErrFailedToStart -) - type CLI struct{} type boolOpt struct { Valid bool @@ -99,5 +43,3 @@ type stringOpt struct { Valid bool Value string } - - diff --git a/status_any.go b/status_any.go deleted file mode 100644 index 403ef84..0000000 --- a/status_any.go +++ /dev/null @@ -1,10 +0,0 @@ -// +build !windows - -package starter - -import "syscall" - -func init() { - failureStatus = syscall.WaitStatus(255) - successStatus = syscall.WaitStatus(0) -} diff --git a/status_unix.go b/status_unix.go new file mode 100644 index 0000000..bdfb2c3 --- /dev/null +++ b/status_unix.go @@ -0,0 +1,9 @@ +// +build !windows + +package starter + +import "syscall" + +func init() { + failureStatus = syscall.WaitStatus(0) +} diff --git a/status_windows.go b/status_windows.go index 577880b..d96f6a9 100644 --- a/status_windows.go +++ b/status_windows.go @@ -1,10 +1,7 @@ -// +build windows - package starter import "syscall" func init() { failureStatus = syscall.WaitStatus{ExitCode: 255} - successStatus = syscall.WaitStatus{ExitCode: 0} } From 424616b18484f98f0593c9766f04eae0d61669c6 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Thu, 1 Dec 2016 08:59:35 +0900 Subject: [PATCH 10/26] bye bye flags pkg, using an adhoc parser to achieve my goals --- cli.go | 53 +++++++++++------- cli_test.go | 152 ++++++++++++++++++++++++++++++++++++++++++++++++++++ options.go | 83 ++++++++++++++++++++++++++++ 3 files changed, 268 insertions(+), 20 deletions(-) create mode 100644 cli_test.go diff --git a/cli.go b/cli.go index 52aee11..740e35f 100644 --- a/cli.go +++ b/cli.go @@ -2,14 +2,13 @@ package starter import ( "context" - "errors" "fmt" "os" "reflect" "strings" "time" - flags "github.com/jessevdk/go-flags" + "github.com/pkg/errors" ) func NewCLI() *CLI { @@ -34,7 +33,7 @@ func makeOptionList(opts *options) []Option { list = append(list, WithEnvdir(opts.Envdir.Value)) } if opts.Interval > -1 { - list = append(list, WithInterval(time.Duration(opts.Interval) * time.Second)) + list = append(list, WithInterval(time.Duration(opts.Interval)*time.Second)) } if opts.KillOldDelay.Valid { list = append(list, WithKillOldDelay(time.Duration(opts.KillOldDelay.Value)*time.Second)) @@ -59,35 +58,49 @@ func makeOptionList(opts *options) []Option { } return list } -func (cli *CLI) Run(ctx context.Context) error { + +func (cli *CLI) ParseArgs(args ...string) (*options, error) { var opts options opts.Interval = -1 // allow 0 - p := flags.NewParser(&opts, flags.PrintErrors|flags.PassDoubleDash) - args, err := p.Parse() - if err != nil || opts.Help { - showHelp() - return nil + if err := opts.Parse(args...); err != nil { + return nil, errors.Wrap(err, "failed to parse arguments") } - if opts.Version { - fmt.Printf("%s\n", version) - return nil + if opts.Interval < 0 { + opts.Interval = 1 } - if opts.Interval <= 0 { - opts.Interval = 1 + if len(opts.Args) == 0 { + return nil, errors.New("server program not specified") } - if len(args) == 0 { - return errors.New("server program not specified") + opts.Command = opts.Args[0] + if len(opts.Args) > 1 { + opts.Args = opts.Args[1:] + } else { + opts.Args = []string(nil) } - opts.Command = args[0] - if len(args) > 1 { - opts.Args = args[1:] + return &opts, nil +} + +func (cli *CLI) Run(ctx context.Context) error { + opts, err := cli.ParseArgs(os.Args...) + if err != nil { + return err + } + + if opts.Help { + showHelp() + return nil + } + + if opts.Version { + fmt.Printf("%s\n", version) + return nil } - s := New(opts.Command, makeOptionList(&opts)...) + s := New(opts.Command, makeOptionList(opts)...) return s.Run(ctx) } diff --git a/cli_test.go b/cli_test.go new file mode 100644 index 0000000..256e91a --- /dev/null +++ b/cli_test.go @@ -0,0 +1,152 @@ +package starter + +import ( + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func findInOptionList(t *testing.T, opts *options, name string, val interface{}) error { + for _, o := range makeOptionList(opts) { + switch o.Name() { + case name: + if !assert.Equal(t, val, o.Value(), "option value matches") { + return errors.New("option value does not match") + } + return nil + } + } + t.Errorf("failed to find option") + return errors.New("failed to find option") +} + +func TestCLIArgs(t *testing.T) { + c := NewCLI() + + t.Run("no parameters", func(t *testing.T) { + opts, err := c.ParseArgs("ls") + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + t.Logf("%s", err) + return + } + + expected := options{ + Command: "ls", + Interval: 1, + } + + if !assert.Equal(t, &expected, opts) { + return + } + }) + + t.Run("--interval=0", func(t *testing.T) { + opts, err := c.ParseArgs("ls", "--interval=0") + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 0, + } + + if !assert.Equal(t, &expected, opts) { + return + } + if err := findInOptionList(t, opts, "interval", time.Duration(0)); err != nil { + return + } + }) + + t.Run("--dir=foo", func(t *testing.T) { + opts, err := c.ParseArgs("ls", "--dir=foo") + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + Dir: "foo", + } + + if !assert.Equal(t, &expected, opts) { + return + } + if err := findInOptionList(t, opts, "dir", "foo"); err != nil { + return + } + }) + + for _, val := range []bool{true, false} { + arg := fmt.Sprintf("--enable-auto-restart=%t", val) + t.Run(arg, func(t *testing.T) { + opts, err := c.ParseArgs("ls", arg) + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + EnableAutoRestart: boolOpt{ Valid: true, Value: val }, + } + + if !assert.Equal(t, &expected, opts) { + return + } + if err := findInOptionList(t, opts, "enable_auto_restart", val); err != nil { + return + } + }) + } + + for name, sig := range niceNameToSigs { + hupArg := fmt.Sprintf("--signal-on-hup=%s", name) + t.Run(hupArg, func(t *testing.T) { + opts, err := c.ParseArgs("ls", hupArg) + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + SignalOnHUP: name, + } + + if !assert.Equal(t, &expected, opts) { + return + } + + if err := findInOptionList(t, opts, "signal_on_hup", sig); err != nil { + return + } + }) + termArg := fmt.Sprintf("--signal-on-term=%s", name) + t.Run(termArg, func(t *testing.T) { + opts, err := c.ParseArgs("ls", termArg) + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + SignalOnTERM: name, + } + + if !assert.Equal(t, &expected, opts) { + return + } + + if err := findInOptionList(t, opts, "signal_on_term", sig); err != nil { + return + } + }) + } +} diff --git a/options.go b/options.go index 619d5c3..0637155 100644 --- a/options.go +++ b/options.go @@ -3,8 +3,12 @@ package starter import ( "io" "os" + "reflect" "strconv" + "strings" "time" + + "github.com/pkg/errors" ) type valueOption struct { @@ -122,6 +126,85 @@ func (o *boolOpt) Set(s string) error { return nil } +type optsetter interface { + Set(string) error +} + +var osv = reflect.TypeOf((*optsetter)(nil)).Elem() + +func (o *options) Parse(args ...string) error { + rv := reflect.ValueOf(o).Elem() + tv := rv.Type() + names := map[string]reflect.Value{} + for i := 0; i < tv.NumField(); i++ { + f := tv.Field(i) + if f.PkgPath != "" || f.Anonymous { + continue + } + names[f.Tag.Get("long")] = rv.Field(i) + } + + var arguments []string + for len(args) > 0 { + arg := args[0] + args = args[1:] + l := len(arg) + if l == 2 && arg == "--" { + // stop processing, everything after this is an argument + if len(args) > 0 { + arguments = append(arguments, args[1:]...) + } + args = []string(nil) // force loop termination + break + } + + if !strings.HasPrefix(arg, "--") { + arguments = append(arguments, arg) + continue + } + end := l + var opval string + if ei := strings.IndexByte(arg, '='); ei > -1 { + end = ei + if end < l-1 { + opval = arg[end+1:] + } else { + return errors.Errorf("invalid argument '%s'", arg) + } + } else { + // is the next argument the argument to this option + if len(args) > 0 && !strings.HasPrefix(args[0], "--") { + opval = args[0] + args = args[1:] + } + } + opname := arg[2:end] + f := names[opname] + opvalv := reflect.ValueOf(opval) + switch f.Kind() { + case reflect.Struct: + if reflect.PtrTo(f.Type()).Implements(osv) { + f.Addr().MethodByName("Set").Call([]reflect.Value{opvalv}) + } else if opvalv.Type().AssignableTo(f.Type()) { + f.Set(opvalv) + } + case reflect.String: + f.Set(opvalv) + case reflect.Int: + i, err := strconv.ParseInt(opval, 10, 64) + if err != nil { + return err + } + f.Set(reflect.ValueOf(int(i))) + case reflect.Slice: + f.Set(reflect.Append(f, opvalv)) + } + } + + o.Args = arguments + return nil +} + type options struct { Args []string AutoRestartInterval intOpt `long:"auto-restart-interval" arg:"seconds" description:"automatic restart interval (default 360). It is used with\n\"--enable-auto-restart\" option. This can be overwritten by environment\nvariable \"AUTO_RESTART_INTERVAL\"." note:"unimplemented"` From d94e387ba03c7bbdbe2e303078008c476130e03e Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Thu, 1 Dec 2016 18:00:46 +0900 Subject: [PATCH 11/26] more tests for cli args --- cli_test.go | 56 +++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 50 insertions(+), 6 deletions(-) diff --git a/cli_test.go b/cli_test.go index 256e91a..087f4a8 100644 --- a/cli_test.go +++ b/cli_test.go @@ -43,21 +43,22 @@ func TestCLIArgs(t *testing.T) { } }) - t.Run("--interval=0", func(t *testing.T) { - opts, err := c.ParseArgs("ls", "--interval=0") + t.Run("--auto-restart-interval=5", func(t *testing.T) { + opts, err := c.ParseArgs("ls", "--auto-restart-interval=5") if !assert.NoError(t, err, "cli.ParseArgs should succeed") { return } expected := options{ - Command: "ls", - Interval: 0, + Command: "ls", + Interval: 1, + AutoRestartInterval: intOpt{Valid: true, Value: 5}, } if !assert.Equal(t, &expected, opts) { return } - if err := findInOptionList(t, opts, "interval", time.Duration(0)); err != nil { + if err := findInOptionList(t, opts, "auto_restart_interval", 5*time.Second); err != nil { return } }) @@ -93,7 +94,7 @@ func TestCLIArgs(t *testing.T) { expected := options{ Command: "ls", Interval: 1, - EnableAutoRestart: boolOpt{ Valid: true, Value: val }, + EnableAutoRestart: boolOpt{Valid: true, Value: val}, } if !assert.Equal(t, &expected, opts) { @@ -105,6 +106,49 @@ func TestCLIArgs(t *testing.T) { }) } + t.Run("--envdir=foo", func(t *testing.T) { + opts, err := c.ParseArgs("ls", "--envdir=foo") + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + Envdir: stringOpt{Valid: true, Value:"foo"}, + } + + if !assert.Equal(t, &expected, opts) { + return + } + if err := findInOptionList(t, opts, "envdir", "foo"); err != nil { + return + } + }) + + // 0 is a special case, so we must test + for i := 0; i < 2; i++ { + arg := fmt.Sprintf("--interval=%d", i) + t.Run(arg, func(t *testing.T) { + opts, err := c.ParseArgs("ls", arg) + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: i, + } + + if !assert.Equal(t, &expected, opts) { + return + } + if err := findInOptionList(t, opts, "interval", time.Duration(i)*time.Second); err != nil { + return + } + }) + } + for name, sig := range niceNameToSigs { hupArg := fmt.Sprintf("--signal-on-hup=%s", name) t.Run(hupArg, func(t *testing.T) { From 6f7a8465095ad05983d3aab5b4e048def3453b00 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Fri, 2 Dec 2016 09:23:33 +0900 Subject: [PATCH 12/26] test and fix kill-old-delay --- cli_test.go | 26 +++++++++++++++++++++++++- options.go | 2 +- starter.go | 8 +++++++- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/cli_test.go b/cli_test.go index 087f4a8..de5e2c8 100644 --- a/cli_test.go +++ b/cli_test.go @@ -115,7 +115,7 @@ func TestCLIArgs(t *testing.T) { expected := options{ Command: "ls", Interval: 1, - Envdir: stringOpt{Valid: true, Value:"foo"}, + Envdir: stringOpt{Valid: true, Value: "foo"}, } if !assert.Equal(t, &expected, opts) { @@ -193,4 +193,28 @@ func TestCLIArgs(t *testing.T) { } }) } + + for _, i := range []int{5, 10} { + arg := fmt.Sprintf("--kill-old-delay=%d", i) + t.Run(arg, func(t *testing.T) { + opts, err := c.ParseArgs("ls", arg) + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + KillOldDelay: intOpt{Valid: true, Value: i}, + } + + if !assert.Equal(t, &expected, opts) { + return + } + if err := findInOptionList(t, opts, "kill_old_delay", time.Duration(i)*time.Second); err != nil { + return + } + }) + } + } diff --git a/options.go b/options.go index 0637155..431526a 100644 --- a/options.go +++ b/options.go @@ -49,7 +49,7 @@ func WithInterval(t time.Duration) Option { } func WithKillOldDelay(t time.Duration) Option { - return &valueOption{name: "kill_old_interval", value: t} + return &valueOption{name: "kill_old_delay", value: t} } func WithPaths(l []string) Option { diff --git a/starter.go b/starter.go index e41cb30..e913fd8 100644 --- a/starter.go +++ b/starter.go @@ -145,6 +145,7 @@ func (s *Starter) Run(ctx context.Context) error { var noticeOutput io.Writer = os.Stderr var logStdout io.Writer = os.Stdout var logStderr io.Writer = os.Stderr + var killOldDelay time.Duration = 5 * time.Second for _, opt := range s.options { switch opt.Name() { @@ -167,7 +168,7 @@ func (s *Starter) Run(ctx context.Context) error { case "interval": interval = opt.Value().(time.Duration) case "kill_old_delay": - os.Setenv(`KILL_OLD_DELAY`, strconv.Itoa(int(opt.Value().(time.Duration)/time.Second))) + killOldDelay = opt.Value().(time.Duration) case "paths": paths = opt.Value().([]string) case "pid_file": @@ -188,6 +189,11 @@ func (s *Starter) Run(ctx context.Context) error { logStderr = opt.Value().(io.Writer) } } + + if envAsBool(`ENABLE_AUTO_RESTART`) { + os.Setenv(`KILL_OLD_DELAY`, strconv.Itoa(int(killOldDelay/time.Second))) + } + notice := func(f string, args ...interface{}) { var buf bytes.Buffer fmt.Fprintf(&buf, f, args...) From 931c188219dc14ee1e377adfb67302a59df830aa Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Fri, 2 Dec 2016 09:39:22 +0900 Subject: [PATCH 13/26] more options --- cli_test.go | 83 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/cli_test.go b/cli_test.go index de5e2c8..9861464 100644 --- a/cli_test.go +++ b/cli_test.go @@ -3,6 +3,7 @@ package starter import ( "errors" "fmt" + "strings" "testing" "time" @@ -217,4 +218,86 @@ func TestCLIArgs(t *testing.T) { }) } + paths := []string{ + "/tmp/foo.sock", + "/tmp/bar.sock", + } + for i := 1; i <= 2; i++ { + args := make([]string, len(paths)) + for i, p := range paths { + args[i] = "--path=" + p + } + name := strings.Join(args, " ") + t.Run(name, func(t *testing.T) { + opts, err := c.ParseArgs(append([]string{"ls"}, args[:i]...)...) + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + Paths: paths[:i], + } + + if !assert.Equal(t, &expected, opts) { + return + } + if err := findInOptionList(t, opts, "paths", paths[:i]); err != nil { + return + } + }) + } + + t.Run("--pid-file=/path/to/foo", func(t *testing.T) { + opts, err := c.ParseArgs("ls", "--pid-file=/path/to/foo") + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + PidFile: "/path/to/foo", + } + + if !assert.Equal(t, &expected, opts) { + return + } + if err := findInOptionList(t, opts, "pid_file", "/path/to/foo"); err != nil { + return + } + }) + + ports := []string{ + "8080", + "0.0.0.0:9090", + } + for i := 1; i <= 2; i++ { + args := make([]string, len(ports)) + for i, p := range ports { + args[i] = "--port=" + p + } + name := strings.Join(args, " ") + t.Run(name, func(t *testing.T) { + opts, err := c.ParseArgs(append([]string{"ls"}, args[:i]...)...) + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + Ports: ports[:i], + } + + if !assert.Equal(t, &expected, opts) { + return + } + if err := findInOptionList(t, opts, "ports", ports[:i]); err != nil { + return + } + }) + } + } From 7d5339fd4b71a7248440f56a113d10eb83b6edcb Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Fri, 2 Dec 2016 10:18:06 +0900 Subject: [PATCH 14/26] Add restarter --- cli.go | 9 +++++ interface.go | 5 +++ restarter.go | 102 +++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 116 insertions(+) create mode 100644 restarter.go diff --git a/cli.go b/cli.go index 740e35f..c2572df 100644 --- a/cli.go +++ b/cli.go @@ -100,6 +100,15 @@ func (cli *CLI) Run(ctx context.Context) error { return nil } + if opts.Restart { + if opts.PidFile == "" || opts.StatusFile == "" { + return errors.New("--restart option requires --pid-file and --status-file to be set as well") + } + + s := NewRestarter(opts.PidFile, opts.StatusFile) + return s.Run(ctx) + } + s := New(opts.Command, makeOptionList(opts)...) return s.Run(ctx) } diff --git a/interface.go b/interface.go index 14a1de4..3078b03 100644 --- a/interface.go +++ b/interface.go @@ -19,6 +19,11 @@ type Option interface { Value() interface{} } +type Restarter struct { + pidFile string + statusFile string +} + type Starter struct { options []Option command string diff --git a/restarter.go b/restarter.go new file mode 100644 index 0000000..942c422 --- /dev/null +++ b/restarter.go @@ -0,0 +1,102 @@ +package starter + +import ( + "bufio" + "bytes" + "context" + "io/ioutil" + "os" + "sort" + "strconv" + "strings" + "syscall" + "time" + + "github.com/pkg/errors" +) + +func NewRestarter(pidFile, statusFile string) *Restarter { + return &Restarter{ + pidFile: pidFile, + statusFile: statusFile, + } +} + +func (s *Restarter) Run(ctx context.Context) error { + pidbuf, err := ioutil.ReadFile(s.pidFile) + if err != nil { + return errors.Wrapf(err, "failed to open file:%s", s.pidFile) + } + pid, err := strconv.ParseInt(string(pidbuf), 10, 64) + if err != nil { + return errors.Wrap(err, "failed to parse pid") + } + + p, err := os.FindProcess(int(pid)) + if err != nil { + return errors.Wrapf(err, "failed to find process:%s", pidbuf) + } + + if err := p.Signal(syscall.SIGHUP); err != nil { + return errors.Wrap(err, "failed to send SIGHUP to the server process") + } + + getGenerations := func(file string) ([]int, error) { + genbuf, err := ioutil.ReadFile(file) + if err != nil { + return nil, errors.Wrapf(err, "failed to open file:%s", file) + } + scanner := bufio.NewScanner(bytes.NewReader(genbuf)) + genmap := make(map[int]struct{}) + for scanner.Scan() { + txt := scanner.Text() + i := strings.IndexByte(txt, ':') + if i <= 0 { + continue + } + gen, err := strconv.ParseInt(string(txt[:i]), 10, 64) + if err != nil { + continue + } + + genmap[int(gen)] = struct{}{} + } + + var generations []int + for k := range genmap { + generations = append(generations, k) + } + sort.Ints(generations) + return generations, nil + } + + var waitFor int + { + generations, err := getGenerations(s.statusFile) + if err != nil { + return errors.Wrap(err, "failed to find generations") + } + + if len(generations) == 0 { + return errors.New("no active process found in the status file") + } + + waitFor = generations[len(generations)-1] + 1 + } + + t := time.NewTicker(time.Second) + defer t.Stop() + + for { + select { + case <-t.C: + generations, err := getGenerations(s.statusFile) + if err != nil { + return errors.Wrap(err, "failed to find generations") + } + if len(generations) == 1 && generations[0] == waitFor { + return nil + } + } + } +} From d3e6463c63186de71e28b540ab7169a9340af4b1 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Fri, 2 Dec 2016 10:18:54 +0900 Subject: [PATCH 15/26] fix options --- options.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/options.go b/options.go index 431526a..84c10fc 100644 --- a/options.go +++ b/options.go @@ -207,17 +207,17 @@ func (o *options) Parse(args ...string) error { type options struct { Args []string - AutoRestartInterval intOpt `long:"auto-restart-interval" arg:"seconds" description:"automatic restart interval (default 360). It is used with\n\"--enable-auto-restart\" option. This can be overwritten by environment\nvariable \"AUTO_RESTART_INTERVAL\"." note:"unimplemented"` + AutoRestartInterval intOpt `long:"auto-restart-interval" arg:"seconds" description:"automatic restart interval (default 360). It is used with\n\"--enable-auto-restart\" option. This can be overwritten by environment\nvariable \"AUTO_RESTART_INTERVAL\"."` Command string Dir string `long:"dir" arg:"path" description:"working directory, start_server do chdir to before exec (optional)"` - EnableAutoRestart boolOpt `long:"enable-auto-restart" description:"enables automatic restart by time. This can be overwritten by\nenvironment variable \"ENABLE_AUTO_RESTART\"." note:"unimplemented"` + EnableAutoRestart boolOpt `long:"enable-auto-restart" description:"enables automatic restart by time. This can be overwritten by\nenvironment variable \"ENABLE_AUTO_RESTART\"."` Envdir stringOpt `long:"envdir" arg:"Envdir" description:"directory that contains environment variables to the server processes.\nIt is intended for use with \"envdir\" in \"daemontools\". This can be\noverwritten by environment variable \"ENVDIR\"."` Interval int `long:"interval" arg:"seconds" description:"minimum interval (in seconds) to respawn the server program (default: 1)"` KillOldDelay intOpt `long:"kill-old-delay" arg:"seconds" description:"time to suspend to send a signal to the old worker. The default value is\n5 when \"--enable-auto-restart\" is set, 0 otherwise. This can be\noverwritten by environment variable \"KILL_OLD_DELAY\"."` Paths []string `long:"path" arg:"path" description:"path at where to listen using unix socket (optional)"` PidFile string `long:"pid-file" arg:"filename" description:"if set, writes the process id of the start_server process to the file"` Ports []string `long:"port" arg:"(port|host:port)" description:"TCP port to listen to (if omitted, will not bind to any ports)"` - Restart bool `long:"restart" description:"this is a wrapper command that reads the pid of the start_server process\nfrom --pid-file, sends SIGHUP to the process and waits until the\nserver(s) of the older generation(s) die by monitoring the contents of\nthe --status-file" note:"unimplemented"` + Restart bool `long:"restart" description:"this is a wrapper command that reads the pid of the start_server process\nfrom --pid-file, sends SIGHUP to the process and waits until the\nserver(s) of the older generation(s) die by monitoring the contents of\nthe --status-file"` SignalOnHUP string `long:"signal-on-hup" arg:"Signal" description:"name of the signal to be sent to the server process when start_server\nreceives a SIGHUP (default: TERM). If you use this option, be sure to\nalso use '--signal-on-term' below."` SignalOnTERM string `long:"signal-on-term" arg:"Signal" description:"name of the signal to be sent to the server process when start_server\nreceives a SIGTERM (default: TERM)"` StatusFile string `long:"status-file" arg:"filename" description:"if set, writes the status of the server process(es) to the file"` From 1ad9ab088dfdd5cf3be2f462a8268b66dd595924 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Fri, 2 Dec 2016 10:27:30 +0900 Subject: [PATCH 16/26] test --restart --- cli_test.go | 20 ++++++++++++++++++++ options.go | 6 ++++++ 2 files changed, 26 insertions(+) diff --git a/cli_test.go b/cli_test.go index 9861464..adb1e72 100644 --- a/cli_test.go +++ b/cli_test.go @@ -300,4 +300,24 @@ func TestCLIArgs(t *testing.T) { }) } + for _, val := range []bool{true, false} { + arg := fmt.Sprintf("--restart=%t", val) + t.Run(arg, func(t *testing.T) { + opts, err := c.ParseArgs("ls", arg) + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + Restart: val, + } + + if !assert.Equal(t, &expected, opts) { + return + } + }) + } + } diff --git a/options.go b/options.go index 84c10fc..bb4b3fc 100644 --- a/options.go +++ b/options.go @@ -190,6 +190,12 @@ func (o *options) Parse(args ...string) error { } case reflect.String: f.Set(opvalv) + case reflect.Bool: + b, err := strconv.ParseBool(opval) + if err != nil { + return err + } + f.Set(reflect.ValueOf(b)) case reflect.Int: i, err := strconv.ParseInt(opval, 10, 64) if err != nil { From d211fcc0f29c8bc868e3a3f55d31f5cfa1925cf0 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Fri, 2 Dec 2016 10:29:09 +0900 Subject: [PATCH 17/26] test status-file --- cli_test.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/cli_test.go b/cli_test.go index adb1e72..1d27e11 100644 --- a/cli_test.go +++ b/cli_test.go @@ -320,4 +320,24 @@ func TestCLIArgs(t *testing.T) { }) } + t.Run("--status-file=/path/to/foo", func(t *testing.T) { + opts, err := c.ParseArgs("ls", "--status-file=/path/to/foo") + if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + return + } + + expected := options{ + Command: "ls", + Interval: 1, + StatusFile: "/path/to/foo", + } + + if !assert.Equal(t, &expected, opts) { + return + } + if err := findInOptionList(t, opts, "status_file", "/path/to/foo"); err != nil { + return + } + }) + } From 384947dc73d6fb39b2727601a859328087942133 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Fri, 2 Dec 2016 10:46:36 +0900 Subject: [PATCH 18/26] fix extra argument handling --- cli.go | 2 +- options.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cli.go b/cli.go index c2572df..005714d 100644 --- a/cli.go +++ b/cli.go @@ -85,7 +85,7 @@ func (cli *CLI) ParseArgs(args ...string) (*options, error) { } func (cli *CLI) Run(ctx context.Context) error { - opts, err := cli.ParseArgs(os.Args...) + opts, err := cli.ParseArgs(os.Args[1:]...) if err != nil { return err } diff --git a/options.go b/options.go index bb4b3fc..1dc5b29 100644 --- a/options.go +++ b/options.go @@ -152,7 +152,7 @@ func (o *options) Parse(args ...string) error { if l == 2 && arg == "--" { // stop processing, everything after this is an argument if len(args) > 0 { - arguments = append(arguments, args[1:]...) + arguments = append(arguments, args...) } args = []string(nil) // force loop termination break From 20b8ca7d7091ee0b1b108c34a5a518a6039cb4ad Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Fri, 2 Dec 2016 10:53:21 +0900 Subject: [PATCH 19/26] Fix boolean/help handling --- cli.go | 18 ++++++++++-------- options.go | 6 ++++++ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/cli.go b/cli.go index 005714d..dc46630 100644 --- a/cli.go +++ b/cli.go @@ -70,15 +70,17 @@ func (cli *CLI) ParseArgs(args ...string) (*options, error) { opts.Interval = 1 } - if len(opts.Args) == 0 { - return nil, errors.New("server program not specified") - } + if !opts.Help && !opts.Version && !opts.Restart { + if len(opts.Args) == 0 { + return nil, errors.New("server program not specified") + } - opts.Command = opts.Args[0] - if len(opts.Args) > 1 { - opts.Args = opts.Args[1:] - } else { - opts.Args = []string(nil) + opts.Command = opts.Args[0] + if len(opts.Args) > 1 { + opts.Args = opts.Args[1:] + } else { + opts.Args = []string(nil) + } } return &opts, nil diff --git a/options.go b/options.go index 1dc5b29..b2c55fb 100644 --- a/options.go +++ b/options.go @@ -163,17 +163,20 @@ func (o *options) Parse(args ...string) error { continue } end := l + var hasOpval bool var opval string if ei := strings.IndexByte(arg, '='); ei > -1 { end = ei if end < l-1 { opval = arg[end+1:] + hasOpval = true } else { return errors.Errorf("invalid argument '%s'", arg) } } else { // is the next argument the argument to this option if len(args) > 0 && !strings.HasPrefix(args[0], "--") { + hasOpval = true opval = args[0] args = args[1:] } @@ -191,6 +194,9 @@ func (o *options) Parse(args ...string) error { case reflect.String: f.Set(opvalv) case reflect.Bool: + if !hasOpval { + opval = "true" + } b, err := strconv.ParseBool(opval) if err != nil { return err From cc5a17b90dc18e8f84e6f9cef06500740837a4d1 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Fri, 2 Dec 2016 10:58:56 +0900 Subject: [PATCH 20/26] fix test --- cli_test.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/cli_test.go b/cli_test.go index 1d27e11..1753f9a 100644 --- a/cli_test.go +++ b/cli_test.go @@ -303,16 +303,23 @@ func TestCLIArgs(t *testing.T) { for _, val := range []bool{true, false} { arg := fmt.Sprintf("--restart=%t", val) t.Run(arg, func(t *testing.T) { - opts, err := c.ParseArgs("ls", arg) + args := []string{arg} + if !val { + args = append([]string{"ls"}, args...) + } + opts, err := c.ParseArgs(args...) if !assert.NoError(t, err, "cli.ParseArgs should succeed") { + t.Logf("%s", err) return } expected := options{ - Command: "ls", Interval: 1, Restart: val, } + if !val { + expected.Command = "ls" + } if !assert.Equal(t, &expected, opts) { return From ba7fc93291e07e5293a17691856e5562cb63c67b Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Tue, 13 Dec 2016 09:58:12 +0900 Subject: [PATCH 21/26] align with server::starter 0.33 --- starter.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/starter.go b/starter.go index e913fd8..f5f7f71 100644 --- a/starter.go +++ b/starter.go @@ -584,8 +584,12 @@ func (s *Starter) Run(ctx context.Context) error { notice("new worker is now running, sending %s to old workers: %s", signame(sigonhup), buf.String()) killOldDelay := envAsDuration(`KILL_OLD_DELAY`) - if killOldDelay == 0 && envAsBool(`ENABLE_AUTO_RESTART`) { - killOldDelay = 5 * time.Second + if killOldDelay == 0 { + if envAsBool(`ENABLE_AUTO_RESTART`) { + killOldDelay = 5 * time.Second + } else { + killOldDelay = 1 + } } time.Sleep(killOldDelay) From a7eea5f1371be882740f6de4302951a4ea87ca72 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Sun, 26 Feb 2017 16:45:20 +0900 Subject: [PATCH 22/26] Use 1.8 --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 1795ab6..db0764c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,5 @@ language: go go: - - 1.7 + - 1.8.x - tip sudo: false \ No newline at end of file From 4b20913637b5298c1918ccac659c27f5ffeb615d Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Mon, 27 Feb 2017 09:27:13 +0900 Subject: [PATCH 23/26] Use go-envload --- internal/env/env.go | 169 -------------------------------------- internal/env/env_test.go | 54 ------------ internal/env/interface.go | 34 -------- internal/env/options.go | 18 ---- starter.go | 7 +- starter_test.go | 11 ++- 6 files changed, 8 insertions(+), 285 deletions(-) delete mode 100644 internal/env/env.go delete mode 100644 internal/env/env_test.go delete mode 100644 internal/env/interface.go delete mode 100644 internal/env/options.go diff --git a/internal/env/env.go b/internal/env/env.go deleted file mode 100644 index 3bf7c17..0000000 --- a/internal/env/env.go +++ /dev/null @@ -1,169 +0,0 @@ -package env - -import ( - "bytes" - "context" - "io/ioutil" - "os" - "path/filepath" - "strings" -) - -func (e *sysenv) Clearenv() { - os.Clearenv() -} - -func (e *sysenv) Setenv(k, v string) { - os.Setenv(k, v) -} - -func SystemEnvironment() Environment { - return &sysenv{} -} - -func NewLoader(environ ...string) *Loader { - if len(environ) == 0 { - environ = os.Environ() - } - - var envdir string - original := make([]iterItem, 0, len(environ)) - for _, v := range environ { - i := strings.IndexByte(v, '=') - if i <= 0 || i >= len(v)-1 { - continue - } - original = append(original, iterItem{ - key: v[:i], - value: v[i+1:], - }) - if v[:i] == "ENVDIR" { - envdir = v[i+1:] - } - } - - return &Loader{ - original: original, - envdir: envdir, - } -} - -func (l *Loader) Restore(octx context.Context, e Environment) error { - return l.Apply(octx, e, WithLoadEnvdir(false)) -} - -func (l *Loader) Apply(octx context.Context, e Environment, options ...Option) error { - ctx, cancel := context.WithCancel(octx) - defer cancel() - - e.Clearenv() - iter := l.Iterator(ctx, options...) - for iter.Next() { - k, v := iter.KV() - e.Setenv(k, v) - } - - return nil -} - -func (l *Loader) Environ(octx context.Context, options ...Option) []string { - ctx, cancel := context.WithCancel(octx) - defer cancel() - - var environ []string - it := l.Iterator(ctx, options...) - for it.Next() { - k, v := it.KV() - environ = append(environ, k+`=`+v) - } - return environ -} - -func (l *Loader) Iterator(ctx context.Context, options ...Option) *Iterator { - loadEnvdir := true - for _, o := range options { - switch o.Name() { - case LoadEnvdirKey: - loadEnvdir = o.Value().(bool) - } - } - - ch := make(chan *iterItem) - ex := make(chan *iterItem) - defer close(ex) - - go func(m []iterItem, ch, ex chan *iterItem) { - defer close(ch) - for _, it := range m { - select { - case <-ctx.Done(): - return - case ch <- &iterItem{key: it.key, value: it.value}: - } - } - - for { - select { - case <-ctx.Done(): - return - case it, ok := <-ex: - if !ok { - return - } - select { - case <-ctx.Done(): - return - case ch <- it: - } - } - } - }(l.original, ch, ex) - - // meanwhile, load from envdir, if available - if loadEnvdir && l.envdir != "" { - if fi, err := os.Stat(l.envdir); err == nil && fi.IsDir() { - filepath.Walk(l.envdir, func(path string, fi os.FileInfo, err error) error { - // Ignore errors - if err != nil { - return nil - } - - // Do not recurse into directories - if fi.IsDir() && l.envdir != path { - return filepath.SkipDir - } - - buf, err := ioutil.ReadFile(path) - if err != nil { - return nil - } - - ex <- &iterItem{ - key: filepath.Base(path), - value: string(bytes.TrimSpace(buf)), - } - return nil - }) - } - } - - return &Iterator{ - ch: ch, - } -} - -func (iter *Iterator) Next() bool { - iter.nextK = "" - iter.nextV = "" - pair, ok := <-iter.ch - if !ok { - return false - } - iter.nextK = pair.key - iter.nextV = pair.value - return true -} - -func (iter *Iterator) KV() (string, string) { - return iter.nextK, iter.nextV -} diff --git a/internal/env/env_test.go b/internal/env/env_test.go deleted file mode 100644 index f0c99d5..0000000 --- a/internal/env/env_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package env_test - -import ( - "context" - "os" - "testing" - "time" - - "github.com/lestrrat/go-server-starter/internal/env" - "github.com/stretchr/testify/assert" -) - -func TestIter(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - src := []string{`FOO=foo`, `BAR=bar`, `BAZ=baz`} - l := env.NewLoader(src...) - i := l.Iterator(ctx) - if !assert.NotNil(t, i, "Iterator is ok") { - return - } - - os.Setenv(`QUUX`, `quux`) // This should have no effect - var list []string - for i.Next() { - k, v := i.KV() - t.Logf("%s=%v", k, v) - list = append(list, k+"="+v) - } - - if !assert.Equal(t, src, list) { - return - } -} - -func TestEnviron(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - src := []string{`FOO=foo`, `BAR=bar`, `BAZ=baz`} - l := env.NewLoader(src...) - i := l.Iterator(ctx) - if !assert.NotNil(t, i, "Iterator is ok") { - return - } - - os.Setenv(`QUUX`, `quux`) // This should have no effect - list := l.Environ(ctx) - if !assert.Equal(t, src, list) { - return - } -} - diff --git a/internal/env/interface.go b/internal/env/interface.go deleted file mode 100644 index 4189eb4..0000000 --- a/internal/env/interface.go +++ /dev/null @@ -1,34 +0,0 @@ -package env - -type Loader struct { - original []iterItem - envdir string -} - -type Iterator struct { - ch chan *iterItem - nextK string - nextV string -} - -type iterItem struct { - key string - value string -} - -type Environment interface { - Clearenv() - Setenv(string, string) -} -type sysenv struct{} - -type Option interface { - Name() string - Value() interface{} -} - -const LoadEnvdirKey = "LoadEnvdirKey" -type option struct { - name string - value interface{} -} diff --git a/internal/env/options.go b/internal/env/options.go deleted file mode 100644 index a344a94..0000000 --- a/internal/env/options.go +++ /dev/null @@ -1,18 +0,0 @@ -package env - -func (o *option) Name() string { - return o.name -} - -func (o *option) Value() interface {} { - return o.value -} - -// WithLoadEnvdir specifies if Loader should load the original -// environment variables AND the contents of envdir -func WithLoadEnvdir(b bool) Option { - return &option{ - name: LoadEnvdirKey, - value: b, - } -} diff --git a/starter.go b/starter.go index f5f7f71..0af3d29 100644 --- a/starter.go +++ b/starter.go @@ -15,7 +15,7 @@ import ( "syscall" "time" - "github.com/lestrrat/go-server-starter/internal/env" + envload "github.com/lestrrat/go-envload" "github.com/pkg/errors" ) @@ -282,8 +282,7 @@ func (s *Starter) Run(ctx context.Context) error { // Note: environment variables that are set after this // will NOT be re-populated - sysenv := env.SystemEnvironment() - envLoader := env.NewLoader() + envLoader := envload.New() var statusFileCreated bool defer func() { @@ -509,7 +508,7 @@ func (s *Starter) Run(ctx context.Context) error { exited := wait(ctx, sigCh, workerDone) // reload env if necessary - envLoader.Apply(ctx, sysenv) + envLoader.Restore(envload.WithLoadEnvdir(true)) if envAsBool(`ENABLE_AUTO_RESTART`) { if os.Getenv("AUTO_RESTART_INTERVAL") == "" { diff --git a/starter_test.go b/starter_test.go index 53b2e33..5be6e42 100644 --- a/starter_test.go +++ b/starter_test.go @@ -15,7 +15,7 @@ import ( "testing" "time" - "github.com/lestrrat/go-server-starter/internal/env" + envload "github.com/lestrrat/go-envload" tcputil "github.com/lestrrat/go-tcputil" "github.com/pkg/errors" "github.com/stretchr/testify/assert" @@ -114,9 +114,8 @@ func build(name string, src string) (string, func(), error) { } func TestRun(t *testing.T) { - l := env.NewLoader() - sysenv := env.SystemEnvironment() - defer l.Restore(context.Background(), sysenv) + l := envload.New() + defer l.Restore() cmdname, cleanup, err := build("echod", echoServerSrc) if cleanup != nil { @@ -195,7 +194,7 @@ func TestRun(t *testing.T) { <-done }) - l.Restore(context.Background(), sysenv) + l.Restore() t.Run("send multiple signals", func(t *testing.T) { // Note: this test does NOT test that the same echod server has received @@ -286,7 +285,7 @@ func TestRun(t *testing.T) { case <-done: } }) - l.Restore(context.Background(), sysenv) + l.Restore() } func TestSigFromName(t *testing.T) { From a7173e032b2ee5435b38b0b021cbfe95e07a6d5d Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Mon, 27 Feb 2017 09:46:54 +0900 Subject: [PATCH 24/26] Use glide --- glide.lock | 26 ++++++++++++++++++++++++++ glide.yaml | 9 +++++++++ 2 files changed, 35 insertions(+) create mode 100644 glide.lock create mode 100644 glide.yaml diff --git a/glide.lock b/glide.lock new file mode 100644 index 0000000..8cdb324 --- /dev/null +++ b/glide.lock @@ -0,0 +1,26 @@ +hash: 834007827c1b49861ec10a1fc2fcb83425427181a3bfc0aec736d2ea54f39bbc +updated: 2017-02-27T09:46:02.670776913+09:00 +imports: +- name: github.com/lestrrat/go-envload + version: d7bc6278d4641441e356c80af00fd627f9b199f0 +- name: github.com/pkg/errors + version: 248dadf4e9068a0b3e79f02ed0a610d935de5302 +- name: golang.org/x/net + version: 10c134ea0df15f7e34d789338c7a2d76cc7a3ab9 + subpackages: + - context +testImports: +- name: github.com/davecgh/go-spew + version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9 + subpackages: + - spew +- name: github.com/lestrrat/go-tcputil + version: 3cae7c8432aeb817399f001e0db3f5bb04e8b8a6 +- name: github.com/pmezard/go-difflib + version: d8ed2627bdf02c080bf22230dbb337003b7aba2d + subpackages: + - difflib +- name: github.com/stretchr/testify + version: 4d4bfba8f1d1027c4fdbe371823030df51419987 + subpackages: + - assert diff --git a/glide.yaml b/glide.yaml new file mode 100644 index 0000000..08a7506 --- /dev/null +++ b/glide.yaml @@ -0,0 +1,9 @@ +package: github.com/lestrrat/go-server-starter +import: +- package: github.com/lestrrat/go-envload +- package: github.com/pkg/errors +testImport: +- package: github.com/lestrrat/go-tcputil +- package: github.com/stretchr/testify + subpackages: + - assert From bb58920e7ea6f6c311dcb082c0814fa5c01a5a61 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Mon, 27 Feb 2017 09:51:33 +0900 Subject: [PATCH 25/26] Add missing <-ctx.Done() --- monitor.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/monitor.go b/monitor.go index b3f36bf..7d45493 100644 --- a/monitor.go +++ b/monitor.go @@ -51,6 +51,12 @@ func monitor(ctx context.Context, src chan *exec.Cmd, done chan *exec.Cmd) { Cmd *exec.Cmd } for { + select { + case <-ctx.Done(): + return + default: + } + cases := make([]reflect.SelectCase, len(workers)+1) for i, worker := range workers { cases[i].Chan = reflect.ValueOf(worker.Chan) From 67abe7ca71c645267d86c999b9ab0cea21d2ac2c Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Mon, 27 Feb 2017 10:27:34 +0900 Subject: [PATCH 26/26] Righto. remove redandunt cases --- monitor.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/monitor.go b/monitor.go index 7d45493..d273e44 100644 --- a/monitor.go +++ b/monitor.go @@ -93,10 +93,6 @@ func monitor(ctx context.Context, src chan *exec.Cmd, done chan *exec.Cmd) { switch { case len(workers) < 2: workers = nil - case chosen == 0: - workers = workers[1:] - case chosen == len(workers)-1: - workers = workers[:chosen] default: workers = append(workers[:chosen], workers[chosen+1:]...) }