diff --git a/drpcmanager/active_streams_test.go b/drpcmanager/active_streams_test.go index f463b18..d2e5942 100644 --- a/drpcmanager/active_streams_test.go +++ b/drpcmanager/active_streams_test.go @@ -22,7 +22,7 @@ func testMuxWriter(t *testing.T) *drpcwire.MuxWriter { } func testStream(t *testing.T, id uint64) *drpcstream.Stream { - return drpcstream.New(context.Background(), id, testMuxWriter(t)) + return drpcstream.New(context.Background(), id, testMuxWriter(t), drpcstream.NewBufferPool()) } func TestActiveStreams_AddAndGet(t *testing.T) { diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index 6c77e17..c319ff6 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -61,7 +61,8 @@ type Manager struct { wg sync.WaitGroup // tracks active manageStream goroutines // streams tracks active streams. - streams *activeStreams + streams *activeStreams + recvPool *drpcstream.BufferPool pdone drpcsignal.Chan // signals when NewServerStream has registered the new stream invokes chan invokeInfo // completed invoke info from manageReader to NewServerStream @@ -130,6 +131,7 @@ func NewWithOptions(tr drpc.Transport, kind ManagerKind, opts Options) *Manager m.pendingStreams = make(map[uint64]*pendingStream) m.streams = newActiveStreams() + m.recvPool = drpcstream.NewBufferPool() // set the internal stream options drpcopts.SetStreamTransport(&m.opts.Stream.Internal, m.tr) @@ -268,7 +270,7 @@ func (m *Manager) newStream(ctx context.Context, sid uint64, kind drpc.StreamKin drpcopts.SetStreamStats(&opts.Internal, cb(rpc)) } - stream := drpcstream.NewWithOptions(ctx, sid, m.wr, opts) + stream := drpcstream.NewWithOptions(ctx, sid, m.wr, m.recvPool, opts) if err := m.streams.Add(sid, stream); err != nil { return nil, err diff --git a/drpcstream/buffer_pool.go b/drpcstream/buffer_pool.go new file mode 100644 index 0000000..0785e51 --- /dev/null +++ b/drpcstream/buffer_pool.go @@ -0,0 +1,42 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcstream + +import "sync" + +// BufferPool wraps sync.Pool to provide reusable byte slices for the +// stream receive path. Buffers obtained via Get should be returned via +// Put when no longer needed. Forgetting to Put is safe (GC reclaims) +// but reduces reuse. +type BufferPool struct { + pool sync.Pool +} + +// NewBufferPool returns a new buffer pool. +func NewBufferPool() *BufferPool { + return &BufferPool{ + pool: sync.Pool{ + New: func() interface{} { + b := make([]byte, 0, 4096) + return &b + }, + }, + } +} + +// Get returns a zero-length byte slice from the pool, retaining its +// backing array for reuse. +func (bp *BufferPool) Get() *[]byte { + p := bp.pool.Get().(*[]byte) + *p = (*p)[:0] + return p +} + +// Put returns a buffer to the pool. Nil is safe to pass. +func (bp *BufferPool) Put(b *[]byte) { + if b == nil { + return + } + bp.pool.Put(b) +} diff --git a/drpcstream/ring_buffer.go b/drpcstream/ring_buffer.go index 5cb620a..3056617 100644 --- a/drpcstream/ring_buffer.go +++ b/drpcstream/ring_buffer.go @@ -15,11 +15,12 @@ const defaultRingBufferCapacity = 256 // ringBuffer is a bounded single-producer / single-consumer FIFO queue for // assembled packet data. It sits between manageReader (producer, calls -// Enqueue) and the application goroutine (consumer, calls Dequeue/Done). +// Enqueue) and the application goroutine (consumer, calls Dequeue). // -// Slots are pre-allocated and reused: each slot's backing array grows via -// append to fit incoming data, then stays at its high-water mark, avoiding -// per-message allocation in steady state. +// Buffers are obtained from a shared BufferPool. Enqueue copies data into a +// pooled buffer; Dequeue returns ownership of that buffer to the caller and +// advances the tail immediately. The caller is responsible for returning the +// buffer to the pool via BufferPool.Put. // // After Close, Dequeue drains any queued messages before returning the close // error. This ensures graceful shutdown (KindClose/KindCloseSend) delivers @@ -28,23 +29,24 @@ type ringBuffer struct { mu sync.Mutex cond sync.Cond - buf [][]byte // ring of byte slices - head int // next write position (producer) - tail int // next read position (consumer) - count int // number of occupied slots + pool *BufferPool // shared pool; nil means allocate fresh each time + buf []*[]byte // ring of pooled buffer pointers + head int // next write position (producer) + tail int // next read position (consumer) + count int // number of occupied slots - held bool // true between Dequeue and Done - err error // terminal error, set by Close + err error // terminal error, set by Close } -func (rb *ringBuffer) init() { +func (rb *ringBuffer) init(pool *BufferPool) { rb.cond.L = &rb.mu - rb.buf = make([][]byte, defaultRingBufferCapacity) + rb.pool = pool + rb.buf = make([]*[]byte, defaultRingBufferCapacity) } -// Enqueue copies data into the next write slot. If the buffer is full, it -// blocks until a slot is freed or the buffer is closed. If the buffer is -// closed, Enqueue returns silently without enqueuing. +// Enqueue copies data into a pooled buffer and places it in the next write +// slot. If the buffer is full, it blocks until a slot is freed or the buffer +// is closed. If the buffer is closed, Enqueue returns silently. func (rb *ringBuffer) Enqueue(data []byte) { rb.mu.Lock() defer rb.mu.Unlock() @@ -56,16 +58,19 @@ func (rb *ringBuffer) Enqueue(data []byte) { return } - rb.buf[rb.head] = append(rb.buf[rb.head][:0], data...) + b := rb.pool.Get() + *b = append(*b, data...) + + rb.buf[rb.head] = b rb.head = (rb.head + 1) % len(rb.buf) rb.count++ rb.cond.Broadcast() } -// Dequeue returns the data from the next read slot. If the buffer is empty, -// it blocks until data is available or the buffer is closed. The returned -// slice is valid until Done is called. -func (rb *ringBuffer) Dequeue() ([]byte, error) { +// Dequeue returns the next buffered message. The returned *[]byte is owned +// by the caller; the tail is advanced immediately. If the ring buffer has a +// pool, the caller should return the buffer via BufferPool.Put when done. +func (rb *ringBuffer) Dequeue() (*[]byte, error) { rb.mu.Lock() defer rb.mu.Unlock() @@ -76,37 +81,21 @@ func (rb *ringBuffer) Dequeue() ([]byte, error) { return nil, rb.err } - rb.held = true - return rb.buf[rb.tail], nil -} - -// Done advances the read pointer, making the slot available for reuse. -// It must be called exactly once after each successful Dequeue. -// -// TODO(shubham): remove this method once a shared buffer pool is introduced. -// With a pool, Dequeue will advance the tail immediately and the caller will -// return the buffer to the pool directly. -func (rb *ringBuffer) Done() { - rb.mu.Lock() - defer rb.mu.Unlock() - + b := rb.buf[rb.tail] + rb.buf[rb.tail] = nil rb.tail = (rb.tail + 1) % len(rb.buf) rb.count-- - rb.held = false rb.cond.Broadcast() + + return b, nil } // Close marks the buffer as closed with the given error. All blocked Enqueue -// and Dequeue calls are woken and will return. Close waits for any in-progress -// Dequeue/Done pair to complete before setting the error. Subsequent calls are -// no-ops. +// and Dequeue calls are woken and will return. Subsequent calls are no-ops. func (rb *ringBuffer) Close(err error) { rb.mu.Lock() defer rb.mu.Unlock() - for rb.held { - rb.cond.Wait() - } if rb.err != nil { return } diff --git a/drpcstream/ring_buffer_test.go b/drpcstream/ring_buffer_test.go index 8be9c58..d62d633 100644 --- a/drpcstream/ring_buffer_test.go +++ b/drpcstream/ring_buffer_test.go @@ -13,19 +13,18 @@ import ( func TestRingBuffer_EnqueueDequeue(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(NewBufferPool()) rb.Enqueue([]byte("hello")) data, err := rb.Dequeue() assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("hello")) - rb.Done() + assert.DeepEqual(t, *data, []byte("hello")) } func TestRingBuffer_FIFO(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(NewBufferPool()) rb.Enqueue([]byte("first")) rb.Enqueue([]byte("second")) @@ -34,31 +33,30 @@ func TestRingBuffer_FIFO(t *testing.T) { for _, want := range []string{"first", "second", "third"} { data, err := rb.Dequeue() assert.NoError(t, err) - assert.DeepEqual(t, data, []byte(want)) - rb.Done() + assert.DeepEqual(t, *data, []byte(want)) } } func TestRingBuffer_DequeueBlocksUntilEnqueue(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(NewBufferPool()) got := make(chan []byte, 1) go func() { data, err := rb.Dequeue() assert.NoError(t, err) - got <- data + got <- *data }() rb.Enqueue([]byte("delayed")) assert.DeepEqual(t, <-got, []byte("delayed")) - rb.Done() } func TestRingBuffer_EnqueueBlocksWhenFull(t *testing.T) { var rb ringBuffer rb.cond.L = &rb.mu - rb.buf = make([][]byte, 2) // capacity 2 + rb.pool = NewBufferPool() + rb.buf = make([]*[]byte, 2) // capacity 2 rb.Enqueue([]byte("a")) rb.Enqueue([]byte("b")) @@ -73,8 +71,7 @@ func TestRingBuffer_EnqueueBlocksWhenFull(t *testing.T) { // Drain one slot. data, err := rb.Dequeue() assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("a")) - rb.Done() + assert.DeepEqual(t, *data, []byte("a")) // Now the blocked Enqueue should complete. <-done @@ -82,18 +79,16 @@ func TestRingBuffer_EnqueueBlocksWhenFull(t *testing.T) { // Verify remaining items. data, err = rb.Dequeue() assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("b")) - rb.Done() + assert.DeepEqual(t, *data, []byte("b")) data, err = rb.Dequeue() assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("c")) - rb.Done() + assert.DeepEqual(t, *data, []byte("c")) } func TestRingBuffer_CloseUnblocksDequeue(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(NewBufferPool()) errch := make(chan error, 1) go func() { @@ -108,7 +103,8 @@ func TestRingBuffer_CloseUnblocksDequeue(t *testing.T) { func TestRingBuffer_CloseUnblocksEnqueue(t *testing.T) { var rb ringBuffer rb.cond.L = &rb.mu - rb.buf = make([][]byte, 1) // capacity 1 + rb.pool = NewBufferPool() + rb.buf = make([]*[]byte, 1) // capacity 1 rb.Enqueue([]byte("fill")) @@ -124,7 +120,7 @@ func TestRingBuffer_CloseUnblocksEnqueue(t *testing.T) { func TestRingBuffer_CloseDrainsQueued(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(NewBufferPool()) rb.Enqueue([]byte("queued")) rb.Close(io.EOF) @@ -132,8 +128,7 @@ func TestRingBuffer_CloseDrainsQueued(t *testing.T) { // Dequeue returns the queued data first. data, err := rb.Dequeue() assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("queued")) - rb.Done() + assert.DeepEqual(t, *data, []byte("queued")) // Next Dequeue returns the close error. data, err = rb.Dequeue() @@ -143,7 +138,7 @@ func TestRingBuffer_CloseDrainsQueued(t *testing.T) { func TestRingBuffer_CloseIdempotent(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(NewBufferPool()) rb.Close(io.EOF) rb.Close(io.ErrUnexpectedEOF) // should not overwrite @@ -154,7 +149,7 @@ func TestRingBuffer_CloseIdempotent(t *testing.T) { func TestRingBuffer_EnqueueAfterClose(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(NewBufferPool()) rb.Close(io.EOF) rb.Enqueue([]byte("dropped")) // should not panic or block @@ -163,44 +158,21 @@ func TestRingBuffer_EnqueueAfterClose(t *testing.T) { func TestRingBuffer_SlotReuse(t *testing.T) { var rb ringBuffer rb.cond.L = &rb.mu - rb.buf = make([][]byte, 2) + rb.pool = NewBufferPool() + rb.buf = make([]*[]byte, 2) // Fill and drain a few rounds to exercise slot reuse. for round := 0; round < 5; round++ { rb.Enqueue([]byte("data")) data, err := rb.Dequeue() assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("data")) - rb.Done() + assert.DeepEqual(t, *data, []byte("data")) } } -func TestRingBuffer_CloseWaitsForHeld(t *testing.T) { - var rb ringBuffer - rb.init() - - rb.Enqueue([]byte("msg")) - - // Dequeue the data but don't call Done yet. - data, err := rb.Dequeue() - assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("msg")) - - closed := make(chan struct{}) - go func() { - rb.Close(io.EOF) - close(closed) - }() - - // Close should be blocked because held is true. - // Call Done to release it. - rb.Done() - <-closed -} - func TestRingBuffer_ConcurrentProducerConsumer(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(NewBufferPool()) const n = 1000 var wg sync.WaitGroup @@ -218,11 +190,25 @@ func TestRingBuffer_ConcurrentProducerConsumer(t *testing.T) { for i := 0; i < n; i++ { data, err := rb.Dequeue() assert.NoError(t, err) - assert.Equal(t, data[0], byte(i)) - rb.Done() + assert.Equal(t, (*data)[0], byte(i)) } }() wg.Wait() rb.Close(io.EOF) } + +func TestRingBuffer_WithPool(t *testing.T) { + pool := NewBufferPool() + var rb ringBuffer + rb.init(pool) + + rb.Enqueue([]byte("pooled")) + + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, *data, []byte("pooled")) + pool.Put(data) + + rb.Close(io.EOF) +} diff --git a/drpcstream/stream.go b/drpcstream/stream.go index 1fa0460..d8545cf 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -53,6 +53,7 @@ type Stream struct { id drpcwire.ID wr *drpcwire.MuxWriter + pool *BufferPool recvQueue ringBuffer wbuf []byte @@ -78,15 +79,15 @@ var _ drpc.Stream = (*Stream)(nil) // New returns a new stream bound to the context with the given stream id and // will use the writer to write messages on. It is important use monotonically // increasing stream ids within a single transport. -func New(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter) *Stream { - return NewWithOptions(ctx, sid, wr, Options{}) +func New(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, pool *BufferPool) *Stream { + return NewWithOptions(ctx, sid, wr, pool, Options{}) } // NewWithOptions returns a new stream bound to the context with the given // stream id and will use the writer to write messages on. It is important use // monotonically increasing stream ids within a single transport. The options // are used to control details of how the Stream operates. -func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, opts Options) *Stream { +func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, pool *BufferPool, opts Options) *Stream { var task *trace.Task if trace.IsEnabled() { kind, rpc := drpcopts.GetStreamKind(&opts.Internal), drpcopts.GetStreamRPC(&opts.Internal) @@ -108,12 +109,12 @@ func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, opt pa: pa, - id: drpcwire.ID{Stream: sid}, - wr: wr, + id: drpcwire.ID{Stream: sid}, + wr: wr, + pool: pool, } - // initialize the packet buffer - s.recvQueue.init() + s.recvQueue.init(pool) return s } @@ -414,12 +415,12 @@ func (s *Stream) RawRecv() (data []byte, err error) { s.read.Lock() defer s.read.Unlock() - data, err = s.recvQueue.Dequeue() + b, err := s.recvQueue.Dequeue() if err != nil { return nil, err } - data = append([]byte(nil), data...) - s.recvQueue.Done() + data = append([]byte(nil), *b...) + s.pool.Put(b) return data, nil } @@ -456,12 +457,12 @@ func (s *Stream) MsgRecv(msg drpc.Message, enc drpc.Encoding) (err error) { s.read.Lock() defer s.read.Unlock() - data, err := s.recvQueue.Dequeue() + b, err := s.recvQueue.Dequeue() if err != nil { return err } - err = enc.Unmarshal(data, msg) - s.recvQueue.Done() + err = enc.Unmarshal(*b, msg) + s.pool.Put(b) return err } diff --git a/drpcstream/stream_test.go b/drpcstream/stream_test.go index 9cee8dd..f54b509 100644 --- a/drpcstream/stream_test.go +++ b/drpcstream/stream_test.go @@ -117,7 +117,7 @@ func TestStream_StateTransitions(t *testing.T) { } for _, test := range cases { - st := New(ctx, 1, mw) + st := New(ctx, 1, mw, NewBufferPool()) assert.NoError(t, test.Op(st)) checkErrs(t, test.Send, st.RawWrite(drpcwire.KindMessage, nil)) @@ -169,7 +169,7 @@ func TestStream_Unblocks(t *testing.T) { } for _, test := range cases { - st := New(ctx, 1, mw) + st := New(ctx, 1, mw, NewBufferPool()) ctx.Run(func(ctx context.Context) { _, _ = st.RawRecv() }) assert.NoError(t, test.Op(st)) @@ -180,7 +180,7 @@ func TestStream_Unblocks(t *testing.T) { func TestStream_ContextCancel(t *testing.T) { ctx := context.Background() mw := testMuxWriter(t) - st := New(ctx, 0, mw) + st := New(ctx, 0, mw, NewBufferPool()) child, cancel := context.WithCancel(st.Context()) defer cancel() @@ -195,7 +195,7 @@ func TestStream_ConcurrentCloseCancel(t *testing.T) { defer ctx.Close() mw := testMuxWriter(t) - st := New(ctx, 0, mw) + st := New(ctx, 0, mw, NewBufferPool()) // Close and Cancel concurrently should not panic or deadlock. errch := make(chan error, 1) @@ -219,7 +219,7 @@ func TestStream_PacketBufferReuse(t *testing.T) { mw := testMuxWriter(t) data := make([]byte, 20) mid := uint64(1) - st := New(ctx, 1, mw) + st := New(ctx, 1, mw, NewBufferPool()) ctx.Run(func(ctx context.Context) { for !st.IsTerminated() { @@ -265,7 +265,7 @@ func TestStream_PacketBufferReuse(t *testing.T) { func TestHandleFrame_FirstFrameOnFreshStream(t *testing.T) { mw := testMuxWriter(t) for _, messageID := range []uint64{1, 2} { - st := New(context.Background(), 1, mw) + st := New(context.Background(), 1, mw, NewBufferPool()) // Close the ring buffer so KindMessage Enqueue doesn't block. st.recvQueue.Close(io.EOF) err := st.HandleFrame(drpcwire.Frame{ @@ -278,7 +278,7 @@ func TestHandleFrame_FirstFrameOnFreshStream(t *testing.T) { // Invoke and InvokeMetadata frames are rejected on an already-created stream. func TestHandleFrame_InvokeOnExistingStream(t *testing.T) { mw := testMuxWriter(t) - st := New(context.Background(), 1, mw) + st := New(context.Background(), 1, mw, NewBufferPool()) err := handleFrame(st, drpcwire.KindInvoke, 1) assert.Error(t, err) @@ -288,7 +288,7 @@ func TestHandleFrame_InvokeOnExistingStream(t *testing.T) { func TestHandleFrame_InvokeMetadataOnExistingStream(t *testing.T) { mw := testMuxWriter(t) - st := New(context.Background(), 1, mw) + st := New(context.Background(), 1, mw, NewBufferPool()) err := handleFrame(st, drpcwire.KindInvokeMetadata, 1) assert.Error(t, err) @@ -299,7 +299,7 @@ func TestHandleFrame_InvokeMetadataOnExistingStream(t *testing.T) { // Frames arriving after the stream is terminated are silently ignored. func TestHandleFrame_AfterTerminated(t *testing.T) { mw := testMuxWriter(t) - st := New(context.Background(), 1, mw) + st := New(context.Background(), 1, mw, NewBufferPool()) // Terminate the stream via cancel. st.Cancel(context.Canceled) @@ -317,7 +317,7 @@ func TestHandleFrame_MessageDeliveredViaRecv(t *testing.T) { defer ctx.Close() mw := testMuxWriter(t) - st := New(ctx, 1, mw) + st := New(ctx, 1, mw, NewBufferPool()) // Launch receiver before sending to avoid Put blocking. recv := make(chan []byte, 1)