diff --git a/Makefile b/Makefile index 9cb39ee..9664a77 100644 --- a/Makefile +++ b/Makefile @@ -2,10 +2,11 @@ version=1.4.3 serverExec=room_server taskExec=room_task collectEventExec=room_collect_event +saveEventExec=room_save_event testTxExec=room_trans clean: - rm -f ${serverExec}* ${taskExec}* ${collectEventExec}* ${testTxExec}* + rm -f ${serverExec}* ${taskExec}* ${collectEventExec}* ${saveEventExec}* ${testTxExec}* build-prod-server: CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags '-w -s -X "main.version=${version}"' -v -o $(serverExec)_${version} cmd/server/main.go @@ -13,10 +14,13 @@ build-prod-server: build-prod-collect-event: CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags '-w -s -X "main.version=${version}"' -v -o $(collectEventExec)_${version} cmd/collect_event/main.go +build-prod-save-event: + CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags '-w -s -X "main.version=${version}"' -v -o $(saveEventExec)_${version} cmd/save_event/main.go + build-prod-task: CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags '-w -s -X "main.version=${version}"' -v -o $(taskExec)_${version} cmd/task/main.go -build-prod-all: build-prod-server build-prod-collect-event build-prod-task +build-prod-all: build-prod-server build-prod-collect-event build-prod-save-event build-prod-task build-local-server: go build -ldflags '-X "main.version=${version}"' -v -o $(serverExec)_${version} cmd/server/main.go @@ -24,13 +28,16 @@ build-local-server: build-local-collect-event: go build -ldflags '-X "main.version=${version}"' -v -o $(collectEventExec)_${version} cmd/collect_event/main.go +build-local-save-event: + go build -ldflags '-X "main.version=${version}"' -v -o $(saveEventExec)_${version} cmd/save_event/main.go + build-local-task: go build -ldflags '-X "main.version=${version}"' -v -o $(taskExec)_${version} cmd/task/main.go build-local-tx: go build -v -o $(testTxExec) cmd/tools/transaction.go -build-local-all: build-local-server build-local-collect-event build-local-task +build-local-all: build-local-server build-local-collect-event build-local-save-event build-local-task build-linux-server: CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags '-X "main.version=${version}"' -v -o $(serverExec)_${version} cmd/server/main.go @@ -38,10 +45,13 @@ build-linux-server: build-linux-collect-event: CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags '-X "main.version=${version}"' -v -o $(collectEventExec)_${version} cmd/collect_event/main.go +build-linux-save-event: + CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags '-X "main.version=${version}"' -v -o $(saveEventExec)_${version} cmd/save_event/main.go + build-linux-task: CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags '-X "main.version=${version}"' -v -o $(taskExec)_${version} cmd/task/main.go build-linux-tx: CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -v -o $(testTxExec) cmd/tools/transaction.go -build-linux-all: build-linux-server build-linux-collect-event build-linux-task +build-linux-all: build-linux-server build-linux-collect-event build-linux-save-event build-linux-task diff --git a/base/base.go b/base/base.go index c53de2b..a9c1c68 100644 --- a/base/base.go +++ b/base/base.go @@ -16,6 +16,7 @@ import ( var serverDependency Dependency var taskDependency Dependency var collectEventDependency CollectEventDependency +var saveEventDependency SaveEventDependency var hashTagEventService *HashTagEventService var hashTagLoadedCache *cache.Cache @@ -23,6 +24,7 @@ var hashTagLoadedCache *cache.Cache var serverConfig *RoomServerConfig var taskConfig *RoomTaskConfig var collectEventConfig *RoomCollectEventConfig +var saveEventConfig *RoomSaveEventConfig var json = jsoniter.ConfigCompatibleWithStandardLibrary @@ -39,6 +41,9 @@ func initService( if err != nil { return nil, nil, nil, fmt.Errorf("init_metric.%w", err) } + if dbClusterConfig.IsEmpty() { + return logger, metric, nil, nil + } databaseCluster, err := NewDBClusterFromConfig(dbClusterConfig, logger, metric) if err != nil { return nil, nil, nil, fmt.Errorf("init_db.%w", err) @@ -147,13 +152,12 @@ func InitCollectEvent(configPath string) error { logger, metric, dbCluster, err := initService( "room.collect_event", collectEventConfig.Log, - collectEventConfig.Metric, collectEventConfig.DB) + collectEventConfig.Metric, DBClusterConfig{}) if err != nil { return err } collectEventDependency = CollectEventDependency{ - DB: dbCluster, Logger: logger, Metric: metric, } @@ -167,6 +171,39 @@ func InitCollectEvent(configPath string) error { return nil } +func InitSaveEvent(configPath string) error { + config, err := newConfigFromFile(configPath) + if err != nil { + return err + } + + saveEventConfig = &config.SaveEvent + if err = saveEventConfig.init(); err != nil { + return err + } + + logger, metric, dbCluster, err := initService( + "room.save_event", saveEventConfig.Log, + saveEventConfig.Metric, saveEventConfig.DB) + if err != nil { + return err + } + + saveEventDependency = SaveEventDependency{ + DB: dbCluster, + Logger: logger, + Metric: metric, + } + + logger.Info( + "init room save event service", + log.String("config", fmt.Sprintf("%+v", saveEventConfig)), + log.String("database", dbCluster.String()), + ) + + return nil +} + func GetHashTagEventService() *HashTagEventService { return hashTagEventService } @@ -187,6 +224,10 @@ func GetCollectEventConfig() *RoomCollectEventConfig { return collectEventConfig } +func GetSaveEventConfig() *RoomSaveEventConfig { + return saveEventConfig +} + func StartServices() { hashTagEventService.Run() } @@ -291,12 +332,31 @@ func GetTaskDependency() Dependency { } type CollectEventDependency struct { - DB *DBCluster Logger *log.Logger Metric *MetricClient } func (dep CollectEventDependency) Check() error { + if dep.Logger == nil { + return ErrDepLoggerNull + } + if dep.Metric == nil { + return ErrDepMetricNull + } + return nil +} + +func GetCollectEventDependency() CollectEventDependency { + return collectEventDependency +} + +type SaveEventDependency struct { + DB *DBCluster + Logger *log.Logger + Metric *MetricClient +} + +func (dep SaveEventDependency) Check() error { if dep.DB == nil { return ErrDepDBNull } @@ -309,6 +369,6 @@ func (dep CollectEventDependency) Check() error { return nil } -func GetCollectEventDependency() CollectEventDependency { - return collectEventDependency +func GetSaveEventDependency() SaveEventDependency { + return saveEventDependency } diff --git a/base/config.go b/base/config.go index 462bcc5..f6e863f 100644 --- a/base/config.go +++ b/base/config.go @@ -12,6 +12,11 @@ import ( "gopkg.in/yaml.v2" ) +const ( + defaultGracefulShutdownWaitDuration = 5 * time.Second + defaultMonitorConnectionInterval = 1 * time.Second +) + func newConfigFromFile(filePath string) (Config, error) { config := Config{} bs, err := readFileFromPath(filePath) @@ -49,6 +54,7 @@ func readBytes(fp io.Reader) ([]byte, error) { type Config struct { Server RoomServerConfig `yaml:"server"` CollectEvent RoomCollectEventConfig `yaml:"collect_event"` + SaveEvent RoomSaveEventConfig `yaml:"save_event"` Task RoomTaskConfig `yaml:"task"` } @@ -59,6 +65,9 @@ func (config Config) check() error { if err := config.CollectEvent.check(); err != nil { return fmt.Errorf("room_collect_event.%w", err) } + if err := config.SaveEvent.check(); err != nil { + return fmt.Errorf("room_save_event.%w", err) + } if err := config.Task.check(); err != nil { return fmt.Errorf("room_task.%w", err) } @@ -66,8 +75,15 @@ func (config Config) check() error { } type RoomServerConfig struct { - EnablePProf bool `yaml:"enable_pprof"` - IsDebug bool `yaml:"is_debug"` + EnablePProf bool `yaml:"enable_pprof"` + IsDebug bool `yaml:"is_debug"` + + RawGracefulShutdownWaitDuration string `yaml:"graceful_shutdown_wait_duration"` + GracefulShutdownWaitDuration time.Duration `yaml:"-"` + + RawMonitorConnectionInterval string `yaml:"monitor_connection_interval"` + MonitorConnectionInterval time.Duration `yaml:"-"` + Log map[string]interface{} `yaml:"log"` Metric MetricConfig `yaml:"metric"` LoadKey LoadKeyConfig `yaml:"load_key"` @@ -107,6 +123,28 @@ func (config *RoomServerConfig) init() error { return fmt.Errorf("room_server.%w", err) } + rawGracefulShutdownWaitDuration := config.RawGracefulShutdownWaitDuration + if rawGracefulShutdownWaitDuration == "" { + config.GracefulShutdownWaitDuration = defaultGracefulShutdownWaitDuration + } else { + d, err := time.ParseDuration(rawGracefulShutdownWaitDuration) + if err != nil { + return fmt.Errorf("graceful_shutdown_wait_duration=%s is invalid", rawGracefulShutdownWaitDuration) + } + config.GracefulShutdownWaitDuration = d + } + + rawMonitorConnectionInterval := config.RawMonitorConnectionInterval + if rawMonitorConnectionInterval == "" { + config.MonitorConnectionInterval = defaultMonitorConnectionInterval + } else { + d, err := time.ParseDuration(rawMonitorConnectionInterval) + if err != nil { + return fmt.Errorf("monitor_connection_interval=%s is invalid", rawMonitorConnectionInterval) + } + config.MonitorConnectionInterval = d + } + d, err := time.ParseDuration(config.LoadKey.RawRetryInterval) if err != nil { return fmt.Errorf("load_key.retry_interval=%s is invalid %w", config.LoadKey.RawRetryInterval, err) @@ -243,8 +281,6 @@ type RoomCollectEventConfig struct { Server CollectEventServiceServerConfig `yaml:"server"` - SaveDB CollectEventServiceSaveDBConfig `yaml:"save_db"` - SaveFile CollectEventServiceSaveFileConfig `yaml:"save_file"` BufferLimit int `yaml:"buffer_limit"` @@ -256,8 +292,6 @@ type RoomCollectEventConfig struct { RawMonitorInterval string `yaml:"monitor_interval"` MonitorInterval time.Duration - - DB DBClusterConfig `yaml:"db_cluster"` } func (config RoomCollectEventConfig) check() error { @@ -270,9 +304,6 @@ func (config RoomCollectEventConfig) check() error { if err := config.Server.check(); err != nil { return fmt.Errorf("server.%w", err) } - if err := config.SaveDB.check(); err != nil { - return fmt.Errorf("save_db.%w", err) - } if err := config.SaveFile.check(); err != nil { return fmt.Errorf("save_file.%w", err) } @@ -288,9 +319,6 @@ func (config RoomCollectEventConfig) check() error { if config.RawMonitorInterval == "" { return errors.New("monitor_interval should not be empty") } - if err := config.DB.check(); err != nil { - return fmt.Errorf("db_cluster.%w", err) - } return nil } @@ -299,13 +327,7 @@ func (config *RoomCollectEventConfig) init() error { return fmt.Errorf("room_collect_event.%w", err) } - duration, err := time.ParseDuration(config.SaveDB.RawFileAge) - if err != nil { - return fmt.Errorf("save_db.file_age.%w", err) - } - config.SaveDB.FileAge = duration - - duration, err = time.ParseDuration(config.SaveFile.RawMaxFileAge) + duration, err := time.ParseDuration(config.SaveFile.RawMaxFileAge) if err != nil { return fmt.Errorf("save_file.max_file_age.%w", err) } @@ -348,7 +370,67 @@ func (config CollectEventServiceServerConfig) check() error { return nil } -type CollectEventServiceSaveDBConfig struct { +type CollectEventServiceSaveFileConfig struct { + MaxEventCount int `yaml:"max_event_count"` + + RawMaxFileAge string `yaml:"max_file_age"` + MaxFileAge time.Duration + + FileDirectory string `yaml:"file_directory"` +} + +func (config CollectEventServiceSaveFileConfig) check() error { + if config.MaxEventCount <= 0 { + return fmt.Errorf("max_event_count=%d, it should be greater than 0", config.MaxEventCount) + } + if config.RawMaxFileAge == "" { + return errors.New("max_file_age should not be empty") + } + if config.FileDirectory == "" { + return errors.New("file_directory should not be empty") + } + return nil +} + +type RoomSaveEventConfig struct { + Log map[string]interface{} `yaml:"log"` + Metric MetricConfig `yaml:"metric"` + + SaveDB SaveEventServiceSaveDBConfig `yaml:"save_db"` + + DB DBClusterConfig `yaml:"db_cluster"` +} + +func (config RoomSaveEventConfig) check() error { + if len(config.Log) == 0 { + return errors.New("log should not be empty") + } + if err := config.Metric.check(); err != nil { + return fmt.Errorf("metric.%w", err) + } + if err := config.SaveDB.check(); err != nil { + return fmt.Errorf("save_db.%w", err) + } + if err := config.DB.check(); err != nil { + return fmt.Errorf("db_cluster.%w", err) + } + return nil +} + +func (config *RoomSaveEventConfig) init() error { + if err := config.check(); err != nil { + return fmt.Errorf("room_save_event.%w", err) + } + duration, err := time.ParseDuration(config.SaveDB.RawFileAge) + if err != nil { + return fmt.Errorf("save_db.file_age.%w", err) + } + config.SaveDB.FileAge = duration + + return nil +} + +type SaveEventServiceSaveDBConfig struct { RetryTimes int `yaml:"retry_times"` RetryIntervalMS int `yaml:"retry_interval_ms"` TimeoutMS int `yaml:"timeout_ms"` @@ -356,10 +438,12 @@ type CollectEventServiceSaveDBConfig struct { RawFileAge string `yaml:"file_age"` FileAge time.Duration + FileDirectory string `yaml:"file_directory"` + RateLimitPerSecond int `yaml:"rate_limit_per_second"` } -func (config CollectEventServiceSaveDBConfig) check() error { +func (config SaveEventServiceSaveDBConfig) check() error { if config.RetryTimes <= 0 { return fmt.Errorf("retry_times is %d, it should be greater than 0", config.RetryTimes) } @@ -372,31 +456,12 @@ func (config CollectEventServiceSaveDBConfig) check() error { if config.RawFileAge == "" { return errors.New("file_age should not be empty") } - if config.RateLimitPerSecond <= 0 { - return fmt.Errorf("rate_limit_per_second is %d, it should be greater than 0", config.RateLimitPerSecond) - } - return nil -} - -type CollectEventServiceSaveFileConfig struct { - MaxEventCount int `yaml:"max_event_count"` - - RawMaxFileAge string `yaml:"max_file_age"` - MaxFileAge time.Duration - - FileDirectory string `yaml:"file_directory"` -} - -func (config CollectEventServiceSaveFileConfig) check() error { - if config.MaxEventCount <= 0 { - return fmt.Errorf("max_event_count=%d, it should be greater than 0", config.MaxEventCount) - } - if config.RawMaxFileAge == "" { - return errors.New("max_file_age should not be empty") - } if config.FileDirectory == "" { return errors.New("file_directory should not be empty") } + if config.RateLimitPerSecond <= 0 { + return fmt.Errorf("rate_limit_per_second is %d, it should be greater than 0", config.RateLimitPerSecond) + } return nil } diff --git a/base/db.go b/base/db.go index 5b38bd8..cd4fbdf 100644 --- a/base/db.go +++ b/base/db.go @@ -18,6 +18,10 @@ type DBClusterConfig struct { Shardings []DBConfig `yaml:"shardings"` } +func (config DBClusterConfig) IsEmpty() bool { + return config.ShardingCount == 0 && len(config.Shardings) == 0 +} + func (config DBClusterConfig) check() error { if config.ShardingCount <= 0 { return errors.New("sharding_count should be greater than 0") diff --git a/base/hash_tag_event.go b/base/hash_tag_event.go index 8b34de9..2925f4e 100644 --- a/base/hash_tag_event.go +++ b/base/hash_tag_event.go @@ -171,6 +171,10 @@ type HashTagEventServiceEventReportConfig struct { RequestIdleConnTimeout time.Duration RequestMaxConn int `yaml:"request_max_conn"` + + RequestMaxRetry int `yaml:"request_max_retry"` + RequestMinRetryBackoffMS int `yaml:"request_min_retry_backoff_ms"` + RequestMaxRetryBackoffMS int `yaml:"request_max_retry_backoff_ms"` } func (config HashTagEventServiceEventReportConfig) check() error { @@ -198,6 +202,21 @@ func (config HashTagEventServiceEventReportConfig) check() error { if config.RequestMaxConn <= 0 { return fmt.Errorf("request_max_conn=%d, it should be greater than 0", config.RequestMaxConn) } + if config.RequestMaxRetry <= 0 { + return fmt.Errorf("request_max_retry=%d, it should be greater than 0", config.RequestMaxRetry) + } + + if v := config.RequestMinRetryBackoffMS; v <= 0 { + return fmt.Errorf("request_min_retry_backoff_ms=%d, it should be > 0", v) + } + if v := config.RequestMaxRetryBackoffMS; v <= 0 { + return fmt.Errorf("request_max_retry_backoff_ms=%d, it should be > 0", v) + } + if config.RequestMinRetryBackoffMS > config.RequestMaxRetryBackoffMS { + return fmt.Errorf( + "request_min_retry_backoff_ms=%d, request_max_retry_backoff_ms=%d, request_min_retry_backoff_ms shoule be less than or equal to request_max_retry_backoff_ms", + config.RequestMinRetryBackoffMS, config.RequestMaxRetryBackoffMS) + } return nil } @@ -411,8 +430,27 @@ func (service *HashTagEventService) _reportEvents(events []HashTagEvent) error { if err != nil { return err } - requestBody := bytes.NewReader(bs) - resp, err := service.client.Post(service.config.EventReport.URL, HTTPContentTypeJSON, requestBody) + + for attempt := 0; attempt < service.config.EventReport.RequestMaxRetry; attempt++ { + if attempt > 0 { + waitBeforeRetryDuration := utility.RetryBackoff( + uint(attempt), + time.Duration(service.config.EventReport.RequestMinRetryBackoffMS), + time.Duration(service.config.EventReport.RequestMaxRetryBackoffMS)) + + time.Sleep(waitBeforeRetryDuration) + } + err = reportRequest(service.client, service.config.EventReport.URL, bs) + if err != nil { + continue + } + } + return err +} + +func reportRequest(client *http.Client, url string, body []byte) error { + requestBody := bytes.NewReader(body) + resp, err := client.Post(url, HTTPContentTypeJSON, requestBody) if err != nil { return err } diff --git a/cmd/collect_event/main.go b/cmd/collect_event/main.go index 9f302f0..4c68fa0 100644 --- a/cmd/collect_event/main.go +++ b/cmd/collect_event/main.go @@ -36,7 +36,7 @@ func main() { config := base.GetCollectEventConfig() serviceName := "collect_event_service" - collectEventService, err := service.NewCollectEventService(config, dep.Logger, dep.Metric, dep.DB) + collectEventService, err := service.NewCollectEventService(config, dep.Logger, dep.Metric) if err != nil { panic(err) } diff --git a/cmd/config.template.yaml b/cmd/config.template.yaml index 0f64ee2..53408b8 100644 --- a/cmd/config.template.yaml +++ b/cmd/config.template.yaml @@ -1,6 +1,8 @@ server: enable_pprof: true is_debug: true + graceful_shutdown_wait_duration: "5s" + monitor_connection_interval: "1s" log: console: @@ -27,6 +29,9 @@ server: request_conn_keep_alive_interval: "30s" request_idle_conn_timeout: "90s" request_max_conn: 100 + request_max_retry: 5 + request_min_retry_backoff_ms: 1 + request_max_retry_backoff_ms: 10 agg_interval : "1m" buffer_limit: 10240000 monitor_interval: "15s" @@ -107,17 +112,27 @@ collect_event: write_timeout_ms: 1000 idle_timeout_ms: 1000 + save_file: + max_event_count: 1000 + max_file_age: "10m" + file_directory: "/data/room" + +save_event: + metric: + prefix: "bytepower_room.save_event" + host: "127.0.0.1:8125" + + log: + console: + level: debug + save_db: retry_times: 3 retry_interval_ms: 20 timeout_ms: 2000 - file_age: "5m" - rate_limit_per_second: 100 - - save_file: - max_event_count: 1000 - max_file_age: "10m" + file_age: "20m" file_directory: "/data/room" + rate_limit_per_second: 100 db_cluster: sharding_count: 5 @@ -244,4 +259,4 @@ task: # Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". inactive_duration: 2h rate_limit_per_second: 100 - off: false \ No newline at end of file + off: false diff --git a/cmd/save_event/main.go b/cmd/save_event/main.go new file mode 100644 index 0000000..f951100 --- /dev/null +++ b/cmd/save_event/main.go @@ -0,0 +1,58 @@ +package main + +import ( + "bytepower_room/base" + "bytepower_room/base/log" + "bytepower_room/service" + "fmt" + "os" + "os/signal" + "syscall" + + "github.com/spf13/pflag" +) + +var configPath = pflag.StringP("config", "c", "config.yaml", "config file path") +var versionFlag = pflag.BoolP("version", "v", false, "service version") +var version string + +func main() { + pflag.Parse() + if *versionFlag { + fmt.Println(version) + return + } + + if configPath == nil { + panic("config not found") + } + if err := base.InitSaveEvent(*configPath); err != nil { + panic(err) + } + dep := base.GetSaveEventDependency() + if err := dep.Check(); err != nil { + panic(err) + } + + config := base.GetSaveEventConfig() + serviceName := "save_event_service" + saveEventService, err := service.NewSaveEventService(config, dep.Logger, dep.Metric, dep.DB) + if err != nil { + panic(err) + } + dep.Logger.Info("init_save_event_service", log.String("config", fmt.Sprintf("%+v", *saveEventService.Config()))) + + saveEventService.Run() + + signalCh := make(chan os.Signal, 1) + signal.Notify(signalCh, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + sig := <-signalCh + + dep.Logger.Info( + fmt.Sprintf("signal received, closing %s ...", serviceName), + log.String("signal", sig.String())) + + saveEventService.Stop() + dep.Logger.Info(fmt.Sprintf("close %s success", serviceName)) +} diff --git a/cmd/server/main.go b/cmd/server/main.go index b24c20c..ce84b69 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -54,7 +54,7 @@ func main() { sig := <-signalCh logger.Info("signal received, closing service...", log.String("signal", sig.String())) - roomService.Stop() + roomService.Stop(config.GracefulShutdownWaitDuration) logger.Info("room server is stopped, try to stop other related services...") base.StopServices() logger.Info("room server and related service are all closed") diff --git a/cmd/tools/fix_collect_event/main.go b/cmd/tools/fix_collect_event/main.go index a8c0b61..089bdb8 100644 --- a/cmd/tools/fix_collect_event/main.go +++ b/cmd/tools/fix_collect_event/main.go @@ -42,7 +42,6 @@ func main() { if err = base.InitCollectEvent(*configPath); err != nil { panic(err) } - db := base.GetCollectEventDependency().DB successCount := 0 failedCount := 0 @@ -52,7 +51,7 @@ func main() { "start_save_event:%s, keys=%v, at_is_zero=%t, wt_is_zero=%t\n", event.String(), event.Keys, event.AccessTime.IsZero(), event.WriteTime.IsZero()) if !*dryRun { - err = service.SaveEvent(context.TODO(), db, event, time.Now()) + err = service.SaveEvent(context.TODO(), event, time.Now()) if err != nil { failedCount += 1 logger.Printf("save_event_error:%s, event:%s\n", err.Error(), event.String()) diff --git a/commands/transaction_command.go b/commands/transaction_command.go index 7c5be66..74fe098 100644 --- a/commands/transaction_command.go +++ b/commands/transaction_command.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/go-redis/redis/v8" + "github.com/tidwall/redcon" ) type TransactionCloseReason string @@ -35,6 +36,7 @@ const ( type Transaction struct { tx *redis.Tx + conn redcon.Conn watchedKeys []string keys []string status TransactionStatus @@ -42,8 +44,8 @@ type Transaction struct { dep base.Dependency } -func NewTransaction(dep base.Dependency) *Transaction { - return &Transaction{status: TransactionStatusInited, dep: dep} +func NewTransaction(dep base.Dependency, conn redcon.Conn) *Transaction { + return &Transaction{status: TransactionStatusInited, dep: dep, conn: conn} } var errTxKeysNotInSameSlot = errors.New("ERR keys in transaction should be in the same slot") @@ -63,10 +65,12 @@ func (transaction *Transaction) multi() RESPData { return RESPData{DataType: ErrorRespType, Value: errors.New("ERR MULTI calls can not be nested")} } transaction.status = TransactionStatusStarted + //TODO: set conn in tx return RESPData{DataType: SimpleStringRespType, Value: "OK"} } func (transaction *Transaction) reset(reason TransactionCloseReason, status TransactionStatus) error { + //TODO: set conn not in tx in defer if transaction.tx != nil { if err := transaction.tx.Close(contextTODO); err != nil { recordTransactionCloseError(transaction.dep.Logger, transaction.dep.Metric, err, reason) @@ -113,6 +117,7 @@ func (transaction *Transaction) watch(keys ...string) RESPData { "execute transaction command: %s %s", "watch", strings.Join(keys, " "), )) + // TODO: set conn in tx if _, err := transaction.tx.Watch(contextTODO, keys...).Result(); err != nil { return ConvertErrorToRESPData(err) } diff --git a/go.mod b/go.mod index 8aac9de..22b39f4 100644 --- a/go.mod +++ b/go.mod @@ -26,4 +26,4 @@ replace github.com/go-redis/redis/v8 v8.4.3 => github.com/byte-power/redis/v8 v8 replace github.com/vmihailenco/bufpool v0.1.11 => github.com/byte-power/bufpool v0.1.13 -replace github.com/tidwall/redcon v1.4.4 => github.com/byte-power/redcon v1.4.4 +replace github.com/tidwall/redcon v1.4.4 => ../redcon diff --git a/service/service_collect_event.go b/service/service_collect_event.go index 257b8e3..3cd7172 100644 --- a/service/service_collect_event.go +++ b/service/service_collect_event.go @@ -1,7 +1,6 @@ package service import ( - "bufio" "bytepower_room/base" "bytepower_room/base/log" "bytepower_room/utility" @@ -10,7 +9,6 @@ import ( "net" "net/http" "os" - "path" "path/filepath" "reflect" "strings" @@ -20,8 +18,6 @@ import ( "fmt" "sync" "time" - - "go.uber.org/ratelimit" ) const ( @@ -55,7 +51,6 @@ type CollectEventService struct { logger *log.Logger metric *base.MetricClient - db *base.DBCluster wg sync.WaitGroup stopCh chan bool @@ -70,7 +65,6 @@ type CollectEventService struct { func NewCollectEventService( config *base.RoomCollectEventConfig, logger *log.Logger, metric *base.MetricClient, - db *base.DBCluster, ) (*CollectEventService, error) { if logger == nil { @@ -79,9 +73,6 @@ func NewCollectEventService( if metric == nil { return nil, errors.New("metric should not be nil") } - if db == nil { - return nil, errors.New("db should not be nil") - } file, err := NewEventFile( logger, metric, config.SaveFile.FileDirectory, config.SaveFile.MaxEventCount, config.SaveFile.MaxFileAge) @@ -103,7 +94,6 @@ func NewCollectEventService( logger: logger, metric: metric, - db: db, wg: sync.WaitGroup{}, stopCh: make(chan bool), @@ -148,9 +138,6 @@ func (service *CollectEventService) Run() { service.wg.Add(1) go service.saveEventsToFile() - service.wg.Add(1) - go service.saveEventsToDB() - service.wg.Add(1) go service.mointor(service.config.MonitorInterval) } @@ -293,227 +280,6 @@ func (service *CollectEventService) saveEventsToFile() { } } -func (service *CollectEventService) saveEventsToDB() { - jobName := "save events to db" - metricMsg := "save_events_to_db" - - defer func() { - service.logger.Info( - fmt.Sprintf("stop %s", jobName), - log.String("time", time.Now().String()), - ) - service.wg.Done() - }() - service.logger.Info( - fmt.Sprintf("start %s", jobName), - log.String("time", time.Now().String()), - ) - - directory := service.config.SaveFile.FileDirectory - interval := 5 * time.Second - for { - files, err := listEventFilesInDirectory(directory) - if err != nil { - service.recordError(metricMsg, err, map[string]string{"dir": directory}) - time.Sleep(interval) - continue - } - for _, file := range files { - quit := service.saveEventsFromFileToDB(file, time.Now(), metricMsg) - if quit { - service.logger.Info(fmt.Sprintf("quit signal received, stop %s", jobName)) - return - } - } - if atomic.LoadInt32(&service.stop) == 1 { - service.logger.Info(fmt.Sprintf("service is stopped, stop %s", jobName)) - return - } - time.Sleep(interval) - } -} - -func (service *CollectEventService) saveEventsFromFileToDB(file os.DirEntry, processStartTime time.Time, metricMsg string) bool { - directory := service.config.SaveFile.FileDirectory - needProcess, err := isEventFileNeededToProcess(file, service.config.SaveDB.FileAge, processStartTime) - if err != nil { - service.recordError( - fmt.Sprintf("%s.check_need_process", metricMsg), - err, map[string]string{"name": file.Name()}, - ) - return false - } - if !needProcess { - return false - } - - name := file.Name() - service.logger.Info( - "start to save events from file to database", - log.String("name", name), - log.String("start_time", processStartTime.String()), - ) - fullName := path.Join(directory, name) - count, quit, errs := service._saveEventsFromFileToDB(fullName, metricMsg) - if len(errs) != 0 { - service.recordError( - fmt.Sprintf("%s.error_count", metricMsg), - fmt.Errorf("%d errors", len(errs)), - map[string]string{ - "name": name, - "count": fmt.Sprint(count), - }, - ) - } else { - service.logger.Info( - "end to save events from file to database", - log.String("name", name), - log.Int("count", count), - log.String("duration", time.Since(processStartTime).String()), - ) - service.recordSuccessWithCount(metricMsg, count) - service.recordSuccessWithDuration(metricMsg, time.Since(processStartTime)) - } - if quit { - return quit - } - // rename file if has errors - if len(errs) != 0 { - backupName := path.Join(directory, fmt.Sprintf("%s.bak", name)) - if err := os.Rename(fullName, backupName); err != nil { - service.recordError( - fmt.Sprintf("%s.backup_file", metricMsg), - err, - map[string]string{"name": fullName, "backup": backupName}, - ) - } else { - service.logger.Info( - "backup file success", - log.String("name", fullName), - log.String("backup", backupName), - ) - } - return quit - } - - // remove file if has errors - if err := os.Remove(fullName); err != nil { - service.recordError( - fmt.Sprintf("%s.remove_file", metricMsg), - err, - map[string]string{"name": fullName}, - ) - } else { - service.logger.Info( - "remove file success", - log.String("name", fullName), - ) - } - return quit -} - -func isEventFileNeededToProcess(file os.DirEntry, fileAge time.Duration, t time.Time) (bool, error) { - info, err := file.Info() - if err != nil { - return false, err - } - return info.ModTime().Add(fileAge).Before(t), nil -} - -func (service *CollectEventService) _saveEventsFromFileToDB(name, metricMsg string) (int, bool, []error) { - var errors []error - var successCount int - var quit bool - file, err := os.Open(name) - if err != nil { - errors = append(errors, err) - service.recordError(fmt.Sprintf("%s.open_file", metricMsg), err, map[string]string{"name": name}) - return successCount, quit, errors - } - defer func() { - if err := file.Close(); err != nil { - service.recordError( - fmt.Sprintf("%s.close_file", metricMsg), - err, - map[string]string{"name": name}, - ) - } - }() - scanner := bufio.NewScanner(file) - ratelimitBucket := ratelimit.New(service.config.SaveDB.RateLimitPerSecond) -loop: - for scanner.Scan() { - var event base.HashTagEvent - err := json.Unmarshal(scanner.Bytes(), &event) - if err != nil { - errors = append(errors, err) - service.recordError( - fmt.Sprintf("%s.unmarshal_event", metricMsg), - err, - map[string]string{ - "name": name, - "event": scanner.Text(), - }, - ) - continue - } - select { - case <-service.stopCh: - quit = true - break loop - default: - ratelimitBucket.Take() - if err := service.saveEvent(event); err != nil { - errors = append(errors, err) - service.recordError( - fmt.Sprintf("%s.save_event", metricMsg), - err, - map[string]string{ - "name": name, - "event": scanner.Text(), - }) - continue - } - successCount += 1 - } - } - if err := scanner.Err(); err != nil { - service.recordError(fmt.Sprintf("%s.scan", metricMsg), err, map[string]string{"name": name}) - errors = append(errors, err) - } - return successCount, quit, errors -} - -func (service *CollectEventService) saveEvent(event base.HashTagEvent) error { - var err error - if err = event.Check(); err != nil { - return err - } - config := service.config.SaveDB - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(config.TimeoutMS)*time.Millisecond) - defer cancel() - retryInterval := time.Duration(config.RetryIntervalMS) * time.Millisecond - for i := 0; i < config.RetryTimes; i++ { - err = upsertHashTagKeysRecordByEvent(ctx, service.db, event, time.Now()) - if err != nil { - if isRetryErrorForUpdateInTx(err) { - service.logger.Warn( - "save_event_to_db_retry", - log.Error(err), - log.String("event", event.String()), - log.Int("retry_times", i), - ) - service.recordSuccessWithCount("save_event_to_db_retry", 1) - time.Sleep(retryInterval) - continue - } - return err - } - break - } - return err -} - func (service *CollectEventService) mointor(interval time.Duration) { jobName := "mointor" @@ -690,6 +456,10 @@ func (service *CollectEventService) recordGaugeMetric(metricName string, count i } func (service *CollectEventService) recordError(reason string, err error, info map[string]string) { + recordError(service.logger, service.metric, reason, err, info) +} + +func recordError(logger *log.Logger, metric *base.MetricClient, reason string, err error, info map[string]string) { logPairs := make([]log.LogPair, 0) for key, value := range info { logPairs = append(logPairs, log.String(key, value)) @@ -697,12 +467,12 @@ func (service *CollectEventService) recordError(reason string, err error, info m if err != nil { logPairs = append(logPairs, log.Error(err)) } - service.logger.Error(reason, logPairs...) + logger.Error(reason, logPairs...) errorMetricName := "error" - service.metric.MetricIncrease(errorMetricName) + metric.MetricIncrease(errorMetricName) specificErrorMetricName := fmt.Sprintf("%s.%s", errorMetricName, reason) - service.metric.MetricIncrease(specificErrorMetricName) + metric.MetricIncrease(specificErrorMetricName) } func (service *CollectEventService) recordWriteResponseError(err error, body []byte) { @@ -711,10 +481,14 @@ func (service *CollectEventService) recordWriteResponseError(err error, body []b } func (service *CollectEventService) recordSuccessWithDuration(metricName string, duration time.Duration) { - service.metric.MetricIncrease(metricName) + recordSuccessWithDuration(service.metric, metricName, duration) +} + +func recordSuccessWithDuration(metric *base.MetricClient, metricName string, duration time.Duration) { + metric.MetricIncrease(metricName) if duration > time.Duration(0) { durationMetricName := fmt.Sprintf("%s.duration", metricName) - service.metric.MetricTimeDuration(durationMetricName, duration) + metric.MetricTimeDuration(durationMetricName, duration) } } @@ -803,10 +577,6 @@ func writeSuccessResponse(writer http.ResponseWriter, count int) error { return err } -func SaveEvent(ctx context.Context, db *base.DBCluster, event base.HashTagEvent, saveTime time.Time) error { - return upsertHashTagKeysRecordByEvent(ctx, db, event, saveTime) -} - type EventFile struct { name string directory string diff --git a/service/service_save_event.go b/service/service_save_event.go new file mode 100644 index 0000000..a043785 --- /dev/null +++ b/service/service_save_event.go @@ -0,0 +1,314 @@ +package service + +import ( + "bufio" + "bytepower_room/base" + "bytepower_room/base/log" + "context" + "errors" + "fmt" + "os" + "path" + "sync" + "time" + + "go.uber.org/ratelimit" +) + +const saveEventServiceName = "save_event" + +type SaveEventService struct { + config *base.RoomSaveEventConfig + logger *log.Logger + metric *base.MetricClient + db *base.DBCluster + wg sync.WaitGroup + + ctx context.Context + ctxCancelFn context.CancelFunc +} + +func NewSaveEventService(config *base.RoomSaveEventConfig, logger *log.Logger, metric *base.MetricClient, db *base.DBCluster) (*SaveEventService, error) { + if config == nil { + return nil, errors.New("config should not be nil") + } + if logger == nil { + return nil, errors.New("logger should not be nil") + } + if metric == nil { + return nil, errors.New("metric should not be nil") + } + if db == nil { + return nil, errors.New("db should not be nil") + } + ctx, cancel := context.WithCancel(context.Background()) + service := &SaveEventService{ + config: config, + logger: logger, + metric: metric, + db: db, + wg: sync.WaitGroup{}, + ctx: ctx, + ctxCancelFn: cancel, + } + return service, nil +} + +func (service *SaveEventService) Run() { + service.logger.Info( + fmt.Sprintf("start %s", saveEventServiceName), + log.String("time", time.Now().String()), + ) + service.wg.Add(1) + go service.saveEventsToDB() +} + +func (service *SaveEventService) Config() *base.RoomSaveEventConfig { + return service.config +} + +func (service *SaveEventService) Stop() { + service.ctxCancelFn() + service.wg.Wait() + service.logger.Info( + fmt.Sprintf("stop %s", saveEventServiceName), + log.String("time", time.Now().String()), + ) +} + +func (service *SaveEventService) saveEventsToDB() { + defer func() { + service.wg.Done() + }() + + directory := service.config.SaveDB.FileDirectory + interval := 5 * time.Second + for { + select { + case <-service.ctx.Done(): + service.logger.Info("context is done, stop work") + return + default: + files, err := listEventFilesInDirectory(directory) + if err != nil { + recordError(service.logger, service.metric, saveEventServiceName, err, map[string]string{"dir": directory}) + time.Sleep(interval) + continue + } + for _, file := range files { + quit := service.saveEventsFromFileToDB(file, time.Now()) + if quit { + service.logger.Info("quit signal received, stop") + return + } + } + } + time.Sleep(interval) + } +} + +func (service *SaveEventService) saveEventsFromFileToDB(file os.DirEntry, processStartTime time.Time) bool { + directory := service.config.SaveDB.FileDirectory + needProcess, err := isEventFileNeededToProcess(file, service.config.SaveDB.FileAge, processStartTime) + if err != nil { + recordError( + service.logger, + service.metric, + "check_need_process", + err, map[string]string{"name": file.Name()}, + ) + return false + } + if !needProcess { + return false + } + + name := file.Name() + service.logger.Info( + "start to save events from file to database", + log.String("name", name), + log.String("start_time", processStartTime.String()), + ) + fullName := path.Join(directory, name) + count, quit, errs := service._saveEventsFromFileToDB(fullName) + if len(errs) != 0 { + recordError( + service.logger, + service.metric, + "error_count", + fmt.Errorf("%d errors", len(errs)), + map[string]string{ + "name": name, + "count": fmt.Sprint(count), + }, + ) + } else { + service.logger.Info( + "end to save events from file to database", + log.String("name", name), + log.Int("count", count), + log.String("duration", time.Since(processStartTime).String()), + ) + service.metric.MetricCount("save_count", count) + recordSuccessWithDuration(service.metric, "save_count", time.Since(processStartTime)) + } + if quit { + return quit + } + // rename file if has errors + if len(errs) != 0 { + backupName := path.Join(directory, fmt.Sprintf("%s.bak", name)) + if err := os.Rename(fullName, backupName); err != nil { + recordError( + service.logger, + service.metric, + "backup_file", + err, + map[string]string{"name": fullName, "backup": backupName}, + ) + } else { + service.logger.Info( + "backup file success", + log.String("name", fullName), + log.String("backup", backupName), + ) + } + return quit + } + + // remove file if no errors + if err := os.Remove(fullName); err != nil { + recordError( + service.logger, + service.metric, + "remove_file", + err, + map[string]string{"name": fullName}, + ) + } else { + service.logger.Info( + "remove file success", + log.String("name", fullName), + ) + } + return quit +} + +func isEventFileNeededToProcess(file os.DirEntry, fileAge time.Duration, t time.Time) (bool, error) { + info, err := file.Info() + if err != nil { + return false, err + } + return info.ModTime().Add(fileAge).Before(t), nil +} + +func (service *SaveEventService) _saveEventsFromFileToDB(name string) (int, bool, []error) { + var errors []error + var successCount int + var quit bool + file, err := os.Open(name) + if err != nil { + errors = append(errors, err) + recordError( + service.logger, service.metric, + "open_file", err, map[string]string{"name": name}) + return successCount, quit, errors + } + defer func() { + if err := file.Close(); err != nil { + recordError( + service.logger, + service.metric, + "close_file", + err, + map[string]string{"name": name}, + ) + } + }() + scanner := bufio.NewScanner(file) + ratelimitBucket := ratelimit.New(service.config.SaveDB.RateLimitPerSecond) +loop: + for scanner.Scan() { + var event base.HashTagEvent + err := json.Unmarshal(scanner.Bytes(), &event) + if err != nil { + errors = append(errors, err) + recordError( + service.logger, + service.metric, + "unmarshal_event", + err, + map[string]string{ + "name": name, + "event": scanner.Text(), + }, + ) + continue + } + select { + case <-service.ctx.Done(): + quit = true + break loop + default: + ratelimitBucket.Take() + if err := service.saveEvent(event); err != nil { + errors = append(errors, err) + recordError( + service.logger, + service.metric, + "save_event_to_db", + err, + map[string]string{ + "name": name, + "event": scanner.Text(), + }) + continue + } + successCount += 1 + } + } + if err := scanner.Err(); err != nil { + recordError( + service.logger, + service.metric, + "scan", + err, + map[string]string{"name": name}) + errors = append(errors, err) + } + return successCount, quit, errors +} + +func (service *SaveEventService) saveEvent(event base.HashTagEvent) error { + var err error + if err = event.Check(); err != nil { + return err + } + config := service.config.SaveDB + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(config.TimeoutMS)*time.Millisecond) + defer cancel() + retryInterval := time.Duration(config.RetryIntervalMS) * time.Millisecond + for i := 0; i < config.RetryTimes; i++ { + err = upsertHashTagKeysRecordByEvent(ctx, service.db, event, time.Now()) + if err != nil { + if isRetryErrorForUpdateInTx(err) { + service.logger.Warn( + "save_event_to_db_retry", + log.Error(err), + log.String("event", event.String()), + log.Int("retry_times", i), + ) + service.metric.MetricIncrease("save_event_to_db_retry") + time.Sleep(retryInterval) + continue + } + return err + } + break + } + return err +} + +func SaveEvent(ctx context.Context, db *base.DBCluster, event base.HashTagEvent, saveTime time.Time) error { + return upsertHashTagKeysRecordByEvent(ctx, db, event, saveTime) +} diff --git a/service/service_server.go b/service/service_server.go index c79cff5..3612404 100644 --- a/service/service_server.go +++ b/service/service_server.go @@ -8,6 +8,7 @@ import ( "context" "errors" "fmt" + stdLog "log" "net/http" "os" "strings" @@ -47,6 +48,7 @@ type RoomService struct { pprofAddress string pprofServer *http.Server pid int + stopCh chan bool } func NewRoomService(config *base.RoomServerConfig, dep base.Dependency, host string, port int) (*RoomService, error) { @@ -68,13 +70,17 @@ func NewRoomService(config *base.RoomServerConfig, dep base.Dependency, host str dep: dep, address: fmt.Sprintf("%s:%d", host, port), pprofAddress: fmt.Sprintf("%s:%d", host, port+10000), - pid: os.Getpid()} + pid: os.Getpid(), + stopCh: make(chan bool), + } return roomService, nil } func (service *RoomService) Run() { service.logWithAddressAndPid(log.LevelInfo, "server.start") service.server = redcon.NewServer(service.address, service.connServeHandler, service.connAcceptHandler, service.connCloseHandler) + logger := stdLog.New(os.Stdout, "room redcon ", stdLog.LstdFlags) + service.server.SetLogger(logger) service.server.AcceptError = service.connAcceptErrorHandler listener, err := greuse.Listen("tcp", service.address) if err != nil { @@ -88,6 +94,8 @@ func (service *RoomService) Run() { } }() + go service.monitorConnections() + // start pprof server if service.config.EnablePProf { service.logWithAddressAndPid(log.LevelInfo, "server.pprof_start") @@ -106,22 +114,62 @@ func (service *RoomService) Run() { } } -func (service *RoomService) Stop() { - if err := service.server.Close(); err != nil { +func (service *RoomService) monitorConnections() { + metric := service.dep.Metric + ticker := time.NewTicker(service.config.MonitorConnectionInterval) + defer func() { + ticker.Stop() + }() +loop: + for { + select { + case <-ticker.C: + connectionCount := service.server.OpenConnectionCount() + transactionCount := transactionManager.transactionCount() + service.logWithAddressAndPid( + log.LevelInfo, "connection.info", + log.Int("connection_count", connectionCount), + log.Int64("total_connection_count", atomic.LoadInt64(&connectionTotal)), + log.Int("transaction_count", transactionCount), + ) + metric.MetricGauge("connection.total", connectionCount) + metric.MetricGauge("transaction.total", transactionCount) + case <-service.stopCh: + break loop + } + } +} + +func (service *RoomService) Stop(waitDuration time.Duration) { + closeResults, err := service.server.Close(waitDuration) + if err != nil { service.logWithAddressAndPid(log.LevelError, "error.server.close", log.Error(err)) + service.dep.Metric.MetricIncrease("error.close.server") + } else if len(closeResults.Errs) != 0 { + service.logWithAddressAndPid( + log.LevelError, + "error.server.connection.close.count", + log.Int("error_count", len(closeResults.Errs)), + log.Int("success_count", closeResults.Count), + ) + service.dep.Metric.MetricCount("error.close.connection", len(closeResults.Errs)) + for _, err := range closeResults.Errs { + service.logWithAddressAndPid(log.LevelError, "error.server.connection.close", log.Error(err)) + } + } else { + service.logWithAddressAndPid(log.LevelInfo, "server.connection.close.count", log.Int("count", closeResults.Count)) } if service.pprofServer != nil { if err := service.pprofServer.Close(); err != nil { service.logWithAddressAndPid(log.LevelError, "error.server.pprof_close", log.Error(err)) } } + close(service.stopCh) } func (service *RoomService) connAcceptHandler(conn redcon.Conn) bool { service.dep.Metric.MetricIncrease("connection.accept") connectionCount := atomic.AddInt64(&connectionTotal, 1) - service.dep.Metric.MetricGauge("connection.total", connectionCount) - service.dep.Metric.MetricGauge("transaction.total", transactionManager.transactionCount()) service.logWithAddressAndPid( log.LevelDebug, "connection.accept", log.String("local_addr", conn.NetConn().LocalAddr().String()), @@ -202,6 +250,13 @@ func (service *RoomService) connServeHandler(conn redcon.Conn, cmds []redcon.Com } service.sendEvents(allCommands, serveStartTime) service.recordCommands(allCommands, results, serveStartTime) + if service.server.IsServerClosing() && conn.InTx() { + service.logWithAddressAndPid( + log.LevelInfo, + "conn in tx, cannot close", + log.Any("conn", conn), + ) + } } func (service *RoomService) preProcessCommand(cmd redcon.Command, serveStartTime time.Time) (commands.Commander, error) { @@ -295,7 +350,7 @@ func getTransactionIfNeeded(dep base.Dependency, conn redcon.Conn, command comma transaction := transactionManager.getTransaction(conn) if transaction == nil { if isTransactionNeeded(command) { - transaction = commands.NewTransaction(dep) + transaction = commands.NewTransaction(dep, conn) transactionManager.addTransaction(conn, transaction) metric.MetricIncrease("transaction.new") logger.Debug( @@ -387,8 +442,6 @@ func (service *RoomService) connCloseHandler(conn redcon.Conn, err error) { log.Error(err), ) } - metric.MetricGauge("connection.total", connectionCount) - metric.MetricGauge("transaction.total", transactionCount) } func (service *RoomService) logWithAddressAndPid(level log.Level, subject string, logPairs ...log.LogPair) { diff --git a/service/transaction.go b/service/transaction.go index 13880e8..26ee50b 100644 --- a/service/transaction.go +++ b/service/transaction.go @@ -22,8 +22,10 @@ func (manager *TransactionManager) addTransaction(conn redcon.Conn, tx *commands oldTx := manager.connTransMap[conn] delete(manager.connTransMap, conn) manager.connTransMap[conn] = tx + conn.SetTxStatus(true) manager.mutex.Unlock() if oldTx != nil { + //TODO: this will never happen oldTx.Close(commands.TransactionCloseReasonReset) } } @@ -38,6 +40,7 @@ func (manager *TransactionManager) removeTransaction(conn redcon.Conn, reason co manager.mutex.Lock() tx := manager.connTransMap[conn] delete(manager.connTransMap, conn) + conn.SetTxStatus(false) manager.mutex.Unlock() if tx != nil { tx.Close(reason) diff --git a/utility/utility.go b/utility/utility.go index 17dcae7..13b1756 100644 --- a/utility/utility.go +++ b/utility/utility.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "math" + "math/rand" "os" "reflect" "runtime/debug" @@ -719,3 +720,22 @@ func MergeStringSlicesToStringSet(slices ...[]string) *StringSet { } return set } + +func RetryBackoff(retry uint, minBackoff, maxBackoff time.Duration) time.Duration { + if minBackoff == 0 { + return 0 + } + + d := minBackoff << retry + if d < minBackoff { + return maxBackoff + } + + d = minBackoff + time.Duration(rand.Int63n(int64(d))) + + if d > maxBackoff || d < minBackoff { + d = maxBackoff + } + + return d +} diff --git a/utility/utility_test.go b/utility/utility_test.go index 18819b9..ecdd750 100644 --- a/utility/utility_test.go +++ b/utility/utility_test.go @@ -1,7 +1,9 @@ package utility import ( + "fmt" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -167,3 +169,25 @@ func TestMergeStringSliceAndRemoveDuplicateItems(t *testing.T) { assert.Equal(t, len(uniqueItems), len(result)) assert.ElementsMatch(t, result, uniqueItems) } + +func TestRetryBackOff(t *testing.T) { + min := 2 * time.Second + max := 20 * time.Second + testCases := []struct { + retry int + expected func(t time.Duration) bool + }{ + {-1, func(t time.Duration) bool { return t == max }}, + {0, func(t time.Duration) bool { return min <= t && t < min*2 }}, + {1, func(t time.Duration) bool { return min <= t && t < min*3 }}, + {2, func(t time.Duration) bool { return min <= t && t < min*5 }}, + {3, func(t time.Duration) bool { return min <= t && t < min*9 }}, + {4, func(t time.Duration) bool { return min <= t && t <= max }}, + {5, func(t time.Duration) bool { return min <= t && t <= max }}, + } + for _, testCase := range testCases { + r := RetryBackoff(uint(testCase.retry), min, max) + fmt.Printf("RetryBackOff retry=%d, min=%s, max=%s, result=%s\n", testCase.retry, min, max, r) + assert.True(t, testCase.expected(r)) + } +}