diff --git a/client.go b/client.go index 69324fc..060160e 100644 --- a/client.go +++ b/client.go @@ -19,6 +19,7 @@ type Client struct { timeout time.Duration expiration time.Duration idle int + now func() time.Time lock sync.Mutex addrs []string @@ -83,6 +84,24 @@ func SetDefaultTTL(expiration time.Duration) ClientOption { } } +// ClockFunc is a function that returns the current time. +// +// Normally this should just be the time.Now function. +type ClockFunc func() time.Time + +// SetClock sets the ClockFunc used for getting the current time. +// +// If unset the default is to use the time.Now function. +// +// Note this should typically only be set in testing. +func SetClock(f ClockFunc) ClientOption { + return func(c *Client) { + c.lock.Lock() + defer c.lock.Unlock() + c.now = f + } +} + const ( defaultDialTimeout = 5 * time.Second defaultExpiration = 1 * time.Hour @@ -100,6 +119,7 @@ func New(instances []string, opts ...ClientOption) *Client { c.timeout = defaultDialTimeout c.expiration = defaultExpiration c.idle = defaultIdleCount + c.now = time.Now for _, opt := range opts { opt(c) @@ -129,17 +149,24 @@ func (c *Client) Close() error { return c.pools.Close() } -func seconds(expiration time.Duration) (int, error) { - if expiration == 0 { +// seconds returns the number of seconds until expiration, unless the +// expiration is more than 30 days (2_592_000 seconds), in which case the +// absolute timestamp is used and expected by the memcached instance +func (c *Client) seconds(expiration time.Duration) (int, error) { + switch { + case expiration == 0: return 0, nil - } - - if expiration < 1*time.Second { + case expiration < 1*time.Second: return 0, ErrExpiration + case expiration > 2_592_000*time.Second: + unix := c.now() + later := unix.Add(expiration) + s := int(later.Unix()) + return s, nil + default: + s := int(expiration.Seconds()) + return s, nil } - - s := int(expiration.Seconds()) - return s, nil } func (c *Client) do(key string, f func(*iopool.Buffer) error) error { diff --git a/client_test.go b/client_test.go index c0b1d70..ab90c9f 100644 --- a/client_test.go +++ b/client_test.go @@ -29,22 +29,45 @@ func Test_SetDefaultTTL(t *testing.T) { func Test_seconds(t *testing.T) { t.Parallel() + c := &Client{ + now: func() time.Time { + // January 23rd, 2026, 10:24:00 AM + return time.Date(2026, 1, 23, 10, 24, 0, 0, time.UTC) + }, + } + t.Run("zero", func(t *testing.T) { - s, err := seconds(0) + s, err := c.seconds(0) must.NoError(t, err) must.Zero(t, s) }) t.Run("millis", func(t *testing.T) { - _, err := seconds(250 * time.Millisecond) + _, err := c.seconds(250 * time.Millisecond) must.ErrorIs(t, err, ErrExpiration) }) t.Run("seconds", func(t *testing.T) { - s, err := seconds(4 * time.Second) + s, err := c.seconds(4 * time.Second) must.NoError(t, err) must.Eq(t, 4, s) }) + + t.Run("month", func(t *testing.T) { + ttl := 30 * 24 * time.Hour + fix := ttl - (1 * time.Second) + s, err := c.seconds(fix) + must.NoError(t, err) + must.Eq(t, 2591999, s) + }) + + t.Run("longer", func(t *testing.T) { + ttl := 30 * 24 * time.Hour + fix := ttl + (1 * time.Second) + s, err := c.seconds(fix) + must.NoError(t, err) + must.Eq(t, 1771755841, s) // February 22, 2026 10:24:01 AM + }) } func Test_check(t *testing.T) { diff --git a/verbs.go b/verbs.go index 6ee6480..ad42af6 100644 --- a/verbs.go +++ b/verbs.go @@ -82,7 +82,7 @@ func Set[T any](c *Client, key string, item T, opts ...Option) error { return encerr } - expiration, experr := seconds(options.expiration) + expiration, experr := c.seconds(options.expiration) if experr != nil { return experr } @@ -156,7 +156,7 @@ func Add[T any](c *Client, key string, item T, opts ...Option) error { return encerr } - expiration, experr := seconds(options.expiration) + expiration, experr := c.seconds(options.expiration) if experr != nil { return experr }