From ef3cc68f556e6bc01cb2c39e31eb37b41db4e4a9 Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Tue, 31 Mar 2026 16:41:59 +0530 Subject: [PATCH 01/15] drpcmanager: remove InactivityTimeout option We are not using it and not planning to use it anytime soon, so till then it's just a burden to maintain. --- drpcmanager/manager.go | 20 -------------------- drpcmanager/manager_test.go | 11 ----------- 2 files changed, 31 deletions(-) diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index e2e0993..286f157 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -13,7 +13,6 @@ import ( "sync" "sync/atomic" "syscall" - "time" "github.com/zeebo/errs" grpcmetadata "google.golang.org/grpc/metadata" @@ -48,13 +47,6 @@ type Options struct { // being flushed when the cancel happens. SoftCancel bool - // InactivityTimeout is the amount of time the manager will wait when - // creating a NewServerStream. It only includes the time it is reading - // packets from the remote client. In other words, it only includes the time - // that the client could delay before invoking an RPC. If zero or negative, - // no timeout is used. - InactivityTimeout time.Duration - // Internal contains options that are for internal use only. Internal drpcopts.Manager @@ -468,19 +460,7 @@ func (m *Manager) NewServerStream(ctx context.Context) (stream *drpcstream.Strea } }() - var timeoutCh <-chan time.Time - - // set up the timeout on the context if necessary. - if timeout := m.opts.InactivityTimeout; timeout > 0 { - timer := time.NewTimer(timeout) - defer timer.Stop() - timeoutCh = timer.C - } - select { - case <-timeoutCh: - return nil, "", context.DeadlineExceeded - case <-ctx.Done(): return nil, "", ctx.Err() diff --git a/drpcmanager/manager_test.go b/drpcmanager/manager_test.go index 0b7beea..95cc375 100644 --- a/drpcmanager/manager_test.go +++ b/drpcmanager/manager_test.go @@ -30,17 +30,6 @@ func closed(ch <-chan struct{}) bool { } } -func TestTimeout(t *testing.T) { - tr := make(blockingTransport) - man := NewWithOptions(tr, Options{ - InactivityTimeout: time.Millisecond, - }) - defer func() { _ = man.Close() }() - - _, _, err := man.NewServerStream(context.Background()) - assert.That(t, errors.Is(err, context.DeadlineExceeded)) -} - func TestDrpcMetadata(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() From 5b7b70f932dc1b6d8a5b02e76915f88a85e4b80d Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Wed, 8 Apr 2026 14:45:40 +0530 Subject: [PATCH 02/15] drpccache: remove the package --- drpccache/README.md | 74 ------------------------ drpccache/cache.go | 92 ------------------------------ drpccache/cache_test.go | 71 ----------------------- drpccache/doc.go | 5 -- drpcserver/server.go | 6 -- internal/integration/cache_test.go | 63 -------------------- 6 files changed, 311 deletions(-) delete mode 100644 drpccache/README.md delete mode 100644 drpccache/cache.go delete mode 100644 drpccache/cache_test.go delete mode 100644 drpccache/doc.go delete mode 100644 internal/integration/cache_test.go diff --git a/drpccache/README.md b/drpccache/README.md deleted file mode 100644 index 5739abc..0000000 --- a/drpccache/README.md +++ /dev/null @@ -1,74 +0,0 @@ -# package drpccache - -`import "storj.io/drpc/drpccache"` - -Package drpccache implements per stream cache for drpc. - -## Usage - -#### func WithContext - -```go -func WithContext(parent context.Context, cache *Cache) context.Context -``` -WithContext returns a context with the value cache associated with the context. - -#### type Cache - -```go -type Cache struct { -} -``` - -Cache is a per stream cache. - -#### func FromContext - -```go -func FromContext(ctx context.Context) *Cache -``` -FromContext returns a cache from a context. - -Example usage: - - cache := drpccache.FromContext(stream.Context()) - if cache != nil { - value := cache.LoadOrCreate("initialized", func() (interface{}) { - return 42 - }) - } - -#### func New - -```go -func New() *Cache -``` -New returns a new cache. - -#### func (*Cache) Clear - -```go -func (cache *Cache) Clear() -``` -Clear clears the cache. - -#### func (*Cache) Load - -```go -func (cache *Cache) Load(key interface{}) interface{} -``` -Load returns the value with the given key. - -#### func (*Cache) LoadOrCreate - -```go -func (cache *Cache) LoadOrCreate(key interface{}, fn func() interface{}) interface{} -``` -LoadOrCreate returns the value with the given key. - -#### func (*Cache) Store - -```go -func (cache *Cache) Store(key, value interface{}) -``` -Store sets the value at a key. diff --git a/drpccache/cache.go b/drpccache/cache.go deleted file mode 100644 index 107b777..0000000 --- a/drpccache/cache.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (C) 2020 Storj Labs, Inc. -// See LICENSE for copying information. - -package drpccache - -import ( - "context" - "sync" -) - -type cacheKey struct{} - -// Cache is a per stream cache. -type Cache struct { - mu sync.Mutex - values map[interface{}]interface{} -} - -// New returns a new cache. -func New() *Cache { return &Cache{} } - -// FromContext returns a cache from a context. -// -// Example usage: -// -// cache := drpccache.FromContext(stream.Context()) -// if cache != nil { -// value := cache.LoadOrCreate("initialized", func() (interface{}) { -// return 42 -// }) -// } -func FromContext(ctx context.Context) *Cache { - cache, _ := ctx.Value(cacheKey{}).(*Cache) - return cache -} - -// WithContext returns a context with the value cache associated with the context. -func WithContext(parent context.Context, cache *Cache) context.Context { - return context.WithValue(parent, cacheKey{}, cache) -} - -// init ensures that the values map exist. -func (cache *Cache) init() { - if cache.values == nil { - cache.values = map[interface{}]interface{}{} - } -} - -// Clear clears the cache. -func (cache *Cache) Clear() { - cache.mu.Lock() - defer cache.mu.Unlock() - - cache.values = nil -} - -// Store sets the value at a key. -func (cache *Cache) Store(key, value interface{}) { - cache.mu.Lock() - defer cache.mu.Unlock() - - cache.init() - cache.values[key] = value -} - -// Load returns the value with the given key. -func (cache *Cache) Load(key interface{}) interface{} { - cache.mu.Lock() - defer cache.mu.Unlock() - - if cache.values == nil { - return nil - } - - return cache.values[key] -} - -// LoadOrCreate returns the value with the given key. -func (cache *Cache) LoadOrCreate(key interface{}, fn func() interface{}) interface{} { - cache.mu.Lock() - defer cache.mu.Unlock() - - cache.init() - - value, ok := cache.values[key] - if !ok { - value = fn() - cache.values[key] = value - } - - return value -} diff --git a/drpccache/cache_test.go b/drpccache/cache_test.go deleted file mode 100644 index bd98dbf..0000000 --- a/drpccache/cache_test.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (C) 2021 Storj Labs, Inc. -// See LICENSE for copying information. - -package drpccache - -import ( - "context" - "testing" - - "github.com/zeebo/assert" -) - -func TestWithContext(t *testing.T) { - ctx := context.Background() - assert.Nil(t, FromContext(ctx)) - - cache := New() - ctx = WithContext(ctx, cache) - assert.Equal(t, cache, FromContext(ctx)) -} - -func TestLoad(t *testing.T) { - cache := New() - - assert.Nil(t, cache.Load("key1")) - assert.Nil(t, cache.Load("key2")) - - cache.Store("key1", "val1") - - assert.Equal(t, cache.Load("key1"), "val1") - assert.Nil(t, cache.Load("key2")) - - cache.Store("key2", "val2") - - assert.Equal(t, cache.Load("key1"), "val1") - assert.Equal(t, cache.Load("key2"), "val2") -} - -func TestClear(t *testing.T) { - cache := New() - - cache.Store("key1", "val1") - cache.Store("key2", "val2") - - assert.Equal(t, cache.Load("key1"), "val1") - assert.Equal(t, cache.Load("key2"), "val2") - - cache.Clear() - - assert.Nil(t, cache.Load("key1")) - assert.Nil(t, cache.Load("key2")) -} - -func TestLoadOrCreate(t *testing.T) { - f := func(val interface{}) func() interface{} { - return func() interface{} { return val } - } - - cache := New() - - assert.Nil(t, cache.Load("key1")) - assert.Nil(t, cache.Load("key2")) - - assert.Equal(t, cache.LoadOrCreate("key1", f("key1")), "key1") - assert.Equal(t, cache.LoadOrCreate("key1", f("key2")), "key1") - - assert.Equal(t, cache.Load("key1"), "key1") - assert.Nil(t, cache.Load("key2")) - - cache.LoadOrCreate("key1", func() interface{} { panic("called") }) -} diff --git a/drpccache/doc.go b/drpccache/doc.go deleted file mode 100644 index 6926e5d..0000000 --- a/drpccache/doc.go +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (C) 2019 Storj Labs, Inc. -// See LICENSE for copying information. - -// Package drpccache implements per stream cache for drpc. -package drpccache diff --git a/drpcserver/server.go b/drpcserver/server.go index 13c8186..c11c1c6 100644 --- a/drpcserver/server.go +++ b/drpcserver/server.go @@ -12,7 +12,6 @@ import ( "github.com/zeebo/errs" "storj.io/drpc" - "storj.io/drpc/drpccache" "storj.io/drpc/drpcctx" "storj.io/drpc/drpcmanager" "storj.io/drpc/drpcstats" @@ -143,11 +142,6 @@ func (s *Server) ServeOne(ctx context.Context, tr drpc.Transport) (err error) { man := drpcmanager.NewWithOptions(tr, s.opts.Manager) defer func() { err = errs.Combine(err, man.Close()) }() - cache := drpccache.New() - defer cache.Clear() - - ctx = drpccache.WithContext(ctx, cache) - for { stream, rpc, err := man.NewServerStream(ctx) if err != nil { diff --git a/internal/integration/cache_test.go b/internal/integration/cache_test.go deleted file mode 100644 index d4e1c95..0000000 --- a/internal/integration/cache_test.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (C) 2020 Storj Labs, Inc. -// See LICENSE for copying information. - -package integration - -import ( - "context" - "testing" - - "github.com/zeebo/assert" - "github.com/zeebo/errs" - - "storj.io/drpc/drpccache" - "storj.io/drpc/drpctest" -) - -func TestCache(t *testing.T) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - // create a server that signals then waits for the context to die - cli, close := createConnection(t, impl{ - Method1Fn: func(ctx context.Context, _ *In) (*Out, error) { - cache := drpccache.FromContext(ctx) - if cache == nil { - return nil, errs.New("no cache associated with context") - } - cache.LoadOrCreate("value", func() interface{} { return 42 }) - return &Out{Out: 123}, nil - }, - Method2Fn: func(stream DRPCService_Method2Stream) error { - cache := drpccache.FromContext(stream.Context()) - if cache == nil { - return errs.New("no cache associated with context") - } - value, _ := cache.Load("value").(int) - return stream.SendAndClose(&Out{Out: int64(value)}) - }, - }) - defer close() - - { // value not yet cached - stream, err := cli.Method2(ctx) - assert.NoError(t, err) - out, err := stream.CloseAndRecv() - assert.NoError(t, err) - assert.True(t, Equal(out, &Out{Out: 0})) - } - - { // store value in the cache - out, err := cli.Method1(ctx, in(1)) - assert.NoError(t, err) - assert.True(t, Equal(out, &Out{Out: 123})) - } - - { // expected value in the cache - stream, err := cli.Method2(ctx) - assert.NoError(t, err) - out, err := stream.CloseAndRecv() - assert.NoError(t, err) - assert.True(t, Equal(out, &Out{Out: 42})) - } -} From bda48aace03c836188b3c525ac17132a2b1d76c2 Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Mon, 6 Apr 2026 08:55:16 +0530 Subject: [PATCH 03/15] drpc: enable stream multiplexing Enable multiple concurrent streams over a single transport. This is the foundational change that replaces the single-stream-at-a-time model with true multiplexing, allowing clients and servers to run multiple RPCs concurrently on a shared connection. Background: Previously, the Manager enforced single-stream semantics: a semaphore (sem) allowed only one active stream, and each new stream had to wait for the previous one to finish (waitForPreviousStream). Stream IDs were required to be monotonically increasing (checkStreamMonotonicity), and a single PacketAssembler was shared across all invoke sequences. This was simple and correct for one-at-a-time RPCs but incompatible with multiplexing. Structural changes: Manager: - Remove the semaphore (sem) and acquireSemaphore/waitForPreviousStream. Multiple streams can now be created concurrently without blocking on each other. - Remove checkStreamMonotonicity. With multiplexing, frames from different streams arrive interleaved; monotonicity is not meaningful. - Remove lastFrameID/lastFrameKind tracking fields (only used by the monotonicity check). - Replace the single shared PacketAssembler with a per-stream invokesAssembler map (map[uint64]*invokeAssembler). Each stream's invoke/metadata frame sequence is assembled independently. - Remove SoftCancel option (see error semantics below). - Remove GetLatest from streamRegistry; manageReader now dispatches frames by looking up the stream ID in the registry directly. Server: - ServeOne now spawns a goroutine per incoming RPC via sync.WaitGroup, rather than handling RPCs sequentially. Errors from individual RPC handlers are logged (via opts.Log) rather than terminating the connection. Stream: - NewWithOptions no longer calls wr.Reset() on the shared Writer. With multiplexing, multiple streams share the same Writer; resetting it would discard buffered frames from other streams. - SendCancel no longer returns a (busy, error) tuple. It blocks on the stream's write lock instead of returning busy=true when another write is in progress. This guarantees the cancel frame is sent (or fails with an IO error), at the cost of waiting for any in-progress write to finish. A future writer queue will eliminate this blocking. Error and cancellation semantics: The central design principle is that the manageReader goroutine is the single authority on transport health. It is the only goroutine that reads from the transport, and if the transport fails, it will detect the failure and terminate the connection. Write-side errors propagate to the caller but do not directly terminate the connection; the reader will independently detect the broken transport (since an IO write failure implies the transport is broken, and the next read will also fail). This matches gRPC's approach: when loopyWriter encounters an IO error, it does not close the connection. Instead, it relies on the reader to detect the failure and clean up. Error classification: Connection-scoped (terminates all streams): - Transport read error: manageReader fails to read a frame. - Frame assembly error: corrupted wire data that cannot be parsed. - Protocol error: e.g., receiving an invoke on an existing stream, or an unknown non-control packet kind. - Manager.Close(): explicit shutdown by the application. Stream-scoped (only affects that stream): - Application error: the RPC handler returns an error, which is sent via SendError (KindError) and terminates only that stream. - Remote close: receiving KindClose or KindCloseSend terminates or half-closes only that stream. - Remote cancel: receiving KindCancel terminates only that stream. - Remote error: receiving KindError terminates only that stream. - Write error (MsgSend, SendError, CloseSend, SendCancel): the error propagates to the caller. The stream is terminated locally. The manageReader goroutine will detect the transport failure on its next read and terminate the connection. Context cancellation: When a stream's context is cancelled, manageStream: 1. Attempts to send a KindCancel frame (SendCancel). This blocks until any in-progress write on that stream completes, then sends the cancel. If the send fails (IO error), the error is logged. The reader will catch the transport failure. 2. Cancels the stream locally (stream.Cancel), which terminates the stream and causes any blocked Send/Recv to return the context error. 3. Waits for the stream to finish (stream.Finished). The SoftCancel option is removed. Previously, SoftCancel=false would terminate the entire connection when a stream's context was cancelled (calling m.terminate if the stream wasn't already finished). With multiplexing, cancelling one stream must never kill the connection. SoftCancel=true behavior (send cancel, then cancel locally) is now the only behavior, simplified to always block for the write lock rather than returning "busy" and falling back to a hard cancel. Manager termination: When the manager terminates (from any connection-scoped error), it closes the transport and the stream registry. Each active stream's manageStream goroutine detects termination via m.sigs.term, cancels its stream, and waits for it to finish. Manager.Close() then waits for all stream goroutines (m.wg.Wait) and the reader goroutine before returning. Known limitations: - The shared drpcwire.Writer is protected by a mutex. All streams serialize their writes through this single writer. - SendCancel blocks on the stream's write lock. If a stream has a large in-progress write, the cancel is delayed. - packetBuffer is single-slot (Put blocks until Get+Done). A slow consumer stream blocks manageReader, stalling frame delivery to all streams. This needs to be addressed with per-stream buffering or async delivery. - Conn.Invoke() holds a mutex for the entire unary RPC duration, serializing concurrent unary RPCs. Streaming RPCs (NewStream) are not affected. --- drpcclient/dialoptions.go | 1 - drpcmanager/active_streams.go | 14 -- drpcmanager/manager.go | 193 +++++----------------------- drpcserver/server.go | 19 ++- drpcstream/stream.go | 14 +- internal/integration/common_test.go | 7 +- internal/integration/simple_test.go | 125 ++++++++++++++++++ 7 files changed, 178 insertions(+), 195 deletions(-) diff --git a/drpcclient/dialoptions.go b/drpcclient/dialoptions.go index 473f018..91b406d 100644 --- a/drpcclient/dialoptions.go +++ b/drpcclient/dialoptions.go @@ -129,7 +129,6 @@ func DialContext(ctx context.Context, address string, opts ...DialOption) (conn Stream: drpcstream.Options{ MaximumBufferSize: 0, // unlimited }, - SoftCancel: false, }, }), nil } diff --git a/drpcmanager/active_streams.go b/drpcmanager/active_streams.go index f422375..963f012 100644 --- a/drpcmanager/active_streams.go +++ b/drpcmanager/active_streams.go @@ -63,20 +63,6 @@ func (r *activeStreams) Get(id uint64) (*drpcstream.Stream, bool) { return s, ok } -// GetLatest returns the stream with the highest ID, or nil if empty. -func (r *activeStreams) GetLatest() *drpcstream.Stream { - r.mu.RLock() - defer r.mu.RUnlock() - - var latest *drpcstream.Stream - for _, s := range r.streams { - if latest == nil || latest.ID() < s.ID() { - latest = s - } - } - return latest -} - // Close marks the collection as closed, preventing future Add calls. // It does not cancel any streams. func (r *activeStreams) Close() { diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index 286f157..9421958 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -40,13 +40,6 @@ type Options struct { // Stream are passed to any streams the manager creates. Stream drpcstream.Options - // SoftCancel controls if a context cancel will cause the transport to be - // closed or, if true, a soft cancel message will be attempted if possible. - // A soft cancel can reduce the amount of closed and dialed connections at - // the potential cost of higher latencies if there is latent data still - // being flushed when the cancel happens. - SoftCancel bool - // Internal contains options that are for internal use only. Internal drpcopts.Manager @@ -66,9 +59,6 @@ type Manager struct { rd *drpcwire.Reader opts Options - lastFrameID drpcwire.ID - lastFrameKind drpcwire.Kind - // next client stream ID, incremented atomically lastStreamID atomic.Uint64 @@ -79,14 +69,12 @@ type Manager struct { // races with new stream's Add). streams *activeStreams - sem drpcsignal.Chan // held by the active stream - - pdone drpcsignal.Chan // signals when NewServerStream has added the new stream + pdone drpcsignal.Chan // signals when NewServerStream has registered the new stream invokes chan invokeInfo // completed invoke info from manageReader to NewServerStream - // Below fields are owned by the manageReader goroutine, used in handleInvokeFrame. - metadata map[string]string // accumulated invoke metadata - pa drpcwire.PacketAssembler // assembles invoke/metadata frames into packets + // invokesAssembler is owned by the manageReader goroutine, used in + // handleInvokeFrame. + invokesAssembler map[uint64]*invokeAssembler sigs struct { term drpcsignal.Signal // set when the manager should start terminating @@ -95,6 +83,11 @@ type Manager struct { } } +type invokeAssembler struct { + metadata map[string]string // accumulated invoke metadata + pa drpcwire.PacketAssembler // assembles invoke/metadata frames into packets +} + // invokeInfo carries the assembled invoke data from manageReader to // NewServerStream. It is reused across invocations; call Reset between uses. type invokeInfo struct { @@ -120,14 +113,12 @@ func NewWithOptions(tr drpc.Transport, opts Options) *Manager { invokes: make(chan invokeInfo), } - // this semaphore controls the number of concurrent streams. it MUST be 1. - m.sem.Make(1) - // a buffer of size 1 allows NewServerStream to signal it is done creating a // new server stream without having to coordinate with manageReader. m.pdone.Make(1) - m.pa = drpcwire.NewPacketAssembler() + m.invokesAssembler = make(map[uint64]*invokeAssembler) + m.streams = newActiveStreams() // set the internal stream options @@ -147,65 +138,6 @@ func (m *Manager) log(what string, cb func() string) { } } -// -// helpers -// - -// acquireSemaphore attempts to acquire the semaphore protecting streams. If the -// context is canceled or the manager is terminated, it returns an error. -func (m *Manager) acquireSemaphore(ctx context.Context) error { - if err, ok := m.sigs.term.Get(); ok { - return err - } else if err := ctx.Err(); err != nil { - return err - } - - select { - case <-ctx.Done(): - return ctx.Err() - - case <-m.sigs.term.Signal(): - return m.sigs.term.Err() - - case m.sem.Get() <- struct{}{}: - if err := m.waitForPreviousStream(ctx); err != nil { - m.sem.Recv() - return err - } - return nil - } -} - -// waitForPreviousStream will, if there was a previous stream, ensure it is -// Closed and then wait until it is in the Finished state, where it will no -// longer make any reads or writes on the transport. It exits early if the -// context is canceled or the manager is terminated. -func (m *Manager) waitForPreviousStream(ctx context.Context) (err error) { - prev := m.streams.GetLatest() - if prev == nil { - return nil - } - - // if the stream is not finished yet, we need to wait for it to be - // finished before letting the next stream to start. - if prev.IsFinished() { - return nil - } - - m.log("WAIT", prev.String) - - select { - case <-ctx.Done(): - return ctx.Err() - - case <-m.sigs.term.Signal(): - return m.sigs.term.Err() - - case <-prev.Finished(): - return nil - } -} - // terminate puts the Manager into a terminal state and closes any resources // that need to be closed to signal the state change. func (m *Manager) terminate(err error) { @@ -239,54 +171,39 @@ func (m *Manager) manageReader() { m.log("READ", incomingFrame.String) - if ok := m.checkStreamMonotonicity(incomingFrame); !ok { - m.terminate(managerClosed.Wrap(drpc.ProtocolError.New("id monotonicity violation"))) - return - } + stream, ok := m.streams.Get(incomingFrame.ID.Stream) - switch curr := m.streams.GetLatest(); { - // If the frame is for the current stream, deliver it. - case curr != nil && incomingFrame.ID.Stream == curr.ID(): - if err := curr.HandleFrame(incomingFrame); err != nil { + switch { + // if the packet is for an active stream, deliver it. + case ok && stream != nil: + if err := stream.HandleFrame(incomingFrame); err != nil { m.terminate(managerClosed.Wrap(err)) return } - // If a frame arrives for an old stream, just ignore it. - case curr != nil && incomingFrame.ID.Stream < curr.ID(): - - // If an invoke sequence is being sent for a new stream, close any - // old unterminated stream and forward it to be handled. case incomingFrame.Kind == drpcwire.KindInvoke || incomingFrame.Kind == drpcwire.KindInvokeMetadata: - if curr != nil && !curr.IsTerminated() { - curr.Cancel(context.Canceled) - } if err := m.handleInvokeFrame(incomingFrame); err != nil { m.terminate(managerClosed.Wrap(err)) return } + // silently drop packet for an unknown stream default: m.log("DROP", incomingFrame.String) } } } -func (m *Manager) checkStreamMonotonicity(incomingFrame drpcwire.Frame) bool { - ok := incomingFrame.ID.Stream >= m.lastFrameID.Stream - m.lastFrameKind = incomingFrame.Kind - m.lastFrameID = incomingFrame.ID - if incomingFrame.Done { - m.lastFrameID.Message += 1 - } - return ok -} - // handleInvokeFrame assembles invoke/metadata frames into complete packets and // forwards the finished invoke info to NewServerStream via m.newServerStreamInfo. // Metadata packets are accumulated; the invoke packet triggers the send. func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error { - pkt, packetReady, err := m.pa.AppendFrame(fr) + ia, ok := m.invokesAssembler[fr.ID.Stream] + if !ok { + ia = &invokeAssembler{pa: drpcwire.NewPacketAssembler()} + m.invokesAssembler[fr.ID.Stream] = ia + } + pkt, packetReady, err := ia.pa.AppendFrame(fr) if err != nil { return err } @@ -300,20 +217,19 @@ func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error { if err != nil { return err } - m.metadata = meta + ia.metadata = meta return nil } // Invoke packet completes the sequence. Send to NewServerStream. select { - case m.invokes <- invokeInfo{sid: pkt.ID.Stream, data: pkt.Data, metadata: m.metadata}: + case m.invokes <- invokeInfo{sid: pkt.ID.Stream, data: pkt.Data, metadata: ia.metadata}: // Wait for NewServerStream to finish stream creation before reading the // next frame. This guarantees curr is set for subsequent non-invoke // packets. m.pdone.Recv() - - m.pa.Reset() - m.metadata = nil + // TODO: reuse invoke assembler + delete(m.invokesAssembler, fr.ID.Stream) case <-m.sigs.term.Signal(): } return nil @@ -362,47 +278,18 @@ func (m *Manager) manageStream(ctx context.Context, stream *drpcstream.Stream) { } stream.Cancel(err) <-stream.Finished() - m.sem.Recv() case <-stream.Finished(): - m.sem.Recv() case <-ctx.Done(): m.log("CANCEL", stream.String) - if m.opts.SoftCancel { - // allow a new stream to begin. - m.sem.Recv() - - // attempt to send the soft cancel. if it fails or if the stream is - // busy sending something else, then we have to hard cancel. - if busy, err := stream.SendCancel(ctx.Err()); err != nil { - m.terminate(err) - } else if busy { - m.log("BUSY", stream.String) - m.terminate(ctx.Err()) - } - stream.Cancel(ctx.Err()) - - // wait for the stream to signal that it is finished. - <-stream.Finished() - } else { - // If the stream isn't already finished, we have to terminate the - // transport to do an active cancel. If it is already finished, - // there is no need. - if !stream.Cancel(ctx.Err()) { - m.log("UNFIN", stream.String) - m.terminate(ctx.Err()) - } else { - m.log("CLEAN", stream.String) - } - - // wait for the stream to signal that it is finished. - <-stream.Finished() - - // allow a new stream to begin. - m.sem.Recv() + if err := stream.SendCancel(ctx.Err()); err != nil { + // SendCancel can fail if it's an IO error which reader will catch. + m.log("SendCancel", func() string { return fmt.Sprintf("%s: %s", stream.String(), err) }) } + stream.Cancel(ctx.Err()) + <-stream.Finished() } } @@ -421,9 +308,6 @@ func (m *Manager) Closed() <-chan struct{} { // the return result is only valid until the next call to NewClientStream or // NewServerStream. func (m *Manager) Unblocked() <-chan struct{} { - if prev := m.streams.GetLatest(); prev != nil { - return prev.Context().Done() - } return closedCh } @@ -440,10 +324,6 @@ func (m *Manager) Close() error { // NewClientStream starts a stream on the managed transport for use by a client. func (m *Manager) NewClientStream(ctx context.Context, rpc string) (stream *drpcstream.Stream, err error) { - if err := m.acquireSemaphore(ctx); err != nil { - return nil, err - } - return m.newStream(ctx, m.lastStreamID.Add(1), drpc.StreamKindClient, rpc) } @@ -451,15 +331,6 @@ func (m *Manager) NewClientStream(ctx context.Context, rpc string) (stream *drpc // It does this by waiting for the client to issue an invoke message and // returning the details. func (m *Manager) NewServerStream(ctx context.Context) (stream *drpcstream.Stream, rpc string, err error) { - if err := m.acquireSemaphore(ctx); err != nil { - return nil, "", err - } - defer func() { - if err != nil { - m.sem.Recv() - } - }() - select { case <-ctx.Done(): return nil, "", ctx.Err() diff --git a/drpcserver/server.go b/drpcserver/server.go index c11c1c6..73a2bfb 100644 --- a/drpcserver/server.go +++ b/drpcserver/server.go @@ -140,16 +140,27 @@ func (s *Server) ServeOne(ctx context.Context, tr drpc.Transport) (err error) { } man := drpcmanager.NewWithOptions(tr, s.opts.Manager) - defer func() { err = errs.Combine(err, man.Close()) }() + var wg sync.WaitGroup + defer func() { + wg.Wait() + err = errs.Combine(err, man.Close()) + }() for { stream, rpc, err := man.NewServerStream(ctx) if err != nil { return errs.Wrap(err) } - if err := s.handleRPC(stream, rpc); err != nil { - return errs.Wrap(err) - } + // TODO: add worker pool + wg.Add(1) + go func() { + defer wg.Done() + if err := s.handleRPC(stream, rpc); err != nil { + if s.opts.Log != nil { + s.opts.Log(err) + } + } + }() } } diff --git a/drpcstream/stream.go b/drpcstream/stream.go index 5f099b6..e9448d2 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -104,7 +104,8 @@ func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts O pa: pa, id: drpcwire.ID{Stream: sid}, - wr: wr.Reset(), + // TODO: think more deeply on the consequences here. + wr: wr, } // initialize the packet buffer @@ -567,18 +568,13 @@ func (s *Stream) SendError(serr error) (err error) { // context.Canceled and sends a cancel error to the remote side for a soft // cancel. It is a no-op if the stream is already terminated. It returns true // for busy if writes are already blocked and a hard cancel is required. -func (s *Stream) SendCancel(err error) (busy bool, _ error) { +func (s *Stream) SendCancel(err error) error { s.log("CALL", func() string { return "SendCancel()" }) s.mu.Lock() - if !s.write.Unlocked() { // if writes are happening, then we have to do a hard cancel. - s.mu.Unlock() - return true, nil - } - if s.sigs.term.IsSet() { s.mu.Unlock() - return false, nil + return nil } defer s.checkFinished() @@ -589,7 +585,7 @@ func (s *Stream) SendCancel(err error) (busy bool, _ error) { s.terminate(err) s.mu.Unlock() - return false, s.checkCancelError(s.sendPacketLocked(drpcwire.KindCancel, true, nil)) + return s.checkCancelError(s.sendPacketLocked(drpcwire.KindCancel, true, nil)) } // Close terminates the stream and sends that the stream has been closed to the diff --git a/internal/integration/common_test.go b/internal/integration/common_test.go index 5acb61a..1fd555e 100644 --- a/internal/integration/common_test.go +++ b/internal/integration/common_test.go @@ -16,7 +16,6 @@ import ( "google.golang.org/grpc/status" "storj.io/drpc/drpcconn" - "storj.io/drpc/drpcmanager" "storj.io/drpc/drpcmetadata" "storj.io/drpc/drpcmux" "storj.io/drpc/drpcserver" @@ -45,11 +44,7 @@ func createRawConnection(t testing.TB, server DRPCServiceServer, ctx *drpctest.T assert.NoError(t, DRPCRegisterService(mux, server)) srv := drpcserver.New(mux) ctx.Run(func(ctx context.Context) { _ = srv.ServeOne(ctx, c1) }) - return drpcconn.NewWithOptions(c2, drpcconn.Options{ - Manager: drpcmanager.Options{ - SoftCancel: true, - }, - }) + return drpcconn.NewWithOptions(c2, drpcconn.Options{}) } func createConnection(t testing.TB, server DRPCServiceServer) (DRPCServiceClient, func()) { diff --git a/internal/integration/simple_test.go b/internal/integration/simple_test.go index 26a2e8c..b0c2c24 100644 --- a/internal/integration/simple_test.go +++ b/internal/integration/simple_test.go @@ -85,6 +85,131 @@ func TestSimple(t *testing.T) { } } +func TestMultiplexedStreams(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + // Echo server: sends back each received message immediately. + echoServer := impl{ + Method1Fn: standardImpl.Method1Fn, + Method2Fn: standardImpl.Method2Fn, + Method3Fn: standardImpl.Method3Fn, + Method4Fn: func(stream DRPCService_Method4Stream) error { + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + } + + cli, close := createConnection(t, echoServer) + defer close() + + // Open two bidi streams on the same connection. + s1, err := cli.Method4(ctx) + assert.NoError(t, err) + + s2, err := cli.Method4(ctx) + assert.NoError(t, err) + + // Send on both streams interleaved. + assert.NoError(t, s1.Send(&In{In: 1})) + assert.NoError(t, s2.Send(&In{In: 2})) + + // Receive from both: each stream gets its own response. + out1, err := s1.Recv() + assert.NoError(t, err) + assert.Equal(t, out1.Out, int64(1)) + + out2, err := s2.Recv() + assert.NoError(t, err) + assert.Equal(t, out2.Out, int64(2)) + + // Close both streams. + assert.NoError(t, s1.CloseSend()) + assert.NoError(t, s2.CloseSend()) + + _, err = s1.Recv() + assert.That(t, errors.Is(err, io.EOF)) + + _, err = s2.Recv() + assert.That(t, errors.Is(err, io.EOF)) +} + +func TestConcurrentStreams(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + echoServer := impl{ + Method1Fn: standardImpl.Method1Fn, + Method2Fn: standardImpl.Method2Fn, + Method3Fn: standardImpl.Method3Fn, + Method4Fn: func(stream DRPCService_Method4Stream) error { + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + } + + cli, close := createConnection(t, echoServer) + defer close() + + const numStreams = 10 + const numMessages = 20 + + errs := make(chan error, numStreams) + for i := 0; i < numStreams; i++ { + i := i + ctx.Run(func(ctx context.Context) { + select { + case <-ctx.Done(): + case errs <- func() error { + stream, err := cli.Method4(ctx) + if err != nil { + return fmt.Errorf("stream %d: open: %w", i, err) + } + for j := 0; j < numMessages; j++ { + val := int64(i*1000 + j) + if err := stream.Send(&In{In: val}); err != nil { + return fmt.Errorf("stream %d: send %d: %w", i, j, err) + } + out, err := stream.Recv() + if err != nil { + return fmt.Errorf("stream %d: recv %d: %w", i, j, err) + } + if out.Out != val { + return fmt.Errorf("stream %d: msg %d: got %d, want %d", i, j, out.Out, val) + } + } + if err := stream.CloseSend(); err != nil { + return fmt.Errorf("stream %d: close send: %w", i, err) + } + _, err = stream.Recv() + if !errors.Is(err, io.EOF) { + return fmt.Errorf("stream %d: final recv: got %v, want EOF", i, err) + } + return nil + }(): + } + }) + } + + for i := 0; i < numStreams; i++ { + assert.NoError(t, <-errs) + } +} + func TestConcurrent(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() From a0e475d7a59d7d84e0f940ac709884b10e754dfc Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Mon, 6 Apr 2026 08:53:46 +0530 Subject: [PATCH 04/15] drpcstream: make invoke sequence atomic under write lock Add Stream.WriteInvoke that writes InvokeMetadata and Invoke frames under a single write lock acquisition. This prevents SendCancel from slipping in between the two frames when a context is cancelled during stream setup. Without this, the following race is possible: 1. Client creates stream, starts manageStream goroutine. 2. doNewStream sends InvokeMetadata, releases write lock. 3. Context cancels. manageStream calls SendCancel, acquires write lock, sends KindCancel, terminates the stream. 4. doNewStream tries to send Invoke, sees stream terminated, returns error. The server receives InvokeMetadata then Cancel, but never the Invoke. It has no registered stream to cancel, so the Cancel is dropped and the partial invokeAssembler entry leaks until the connection closes. With WriteInvoke, SendCancel blocks until both frames are written. The server always sees a complete invoke before any cancel. --- drpcconn/conn.go | 21 ++------------------- drpcstream/stream.go | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/drpcconn/conn.go b/drpcconn/conn.go index 636f346..cbde52e 100644 --- a/drpcconn/conn.go +++ b/drpcconn/conn.go @@ -135,12 +135,7 @@ func (c *Conn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, ou } func (c *Conn) doInvoke(stream *drpcstream.Stream, enc drpc.Encoding, rpc string, data []byte, metadata []byte, out drpc.Message) (err error) { - if len(metadata) > 0 { - if err := stream.RawWrite(drpcwire.KindInvokeMetadata, metadata); err != nil { - return err - } - } - if err := stream.RawWrite(drpcwire.KindInvoke, []byte(rpc)); err != nil { + if err := stream.WriteInvoke(rpc, metadata); err != nil { return err } if err := stream.RawWrite(drpcwire.KindMessage, data); err != nil { @@ -171,25 +166,13 @@ func (c *Conn) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (_ return nil, err } - if err := c.doNewStream(stream, rpc, metadata); err != nil { + if err := stream.WriteInvoke(rpc, metadata); err != nil { return nil, errs.Combine(err, stream.Close()) } return stream, nil } -func (c *Conn) doNewStream(stream *drpcstream.Stream, rpc string, metadata []byte) error { - if len(metadata) > 0 { - if err := stream.RawWrite(drpcwire.KindInvokeMetadata, metadata); err != nil { - return err - } - } - if err := stream.RawWrite(drpcwire.KindInvoke, []byte(rpc)); err != nil { - return err - } - return nil -} - // encodeMetadata retrieves and encodes metadata from the provided // (outgoing/client) context. func (c *Conn) encodeMetadata(ctx context.Context) (metadata []byte, err error) { diff --git a/drpcstream/stream.go b/drpcstream/stream.go index e9448d2..4dd50ee 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -371,6 +371,22 @@ func (s *Stream) terminate(err error) { s.checkFinished() } +// WriteInvoke writes the invoke metadata (if any) and invoke frame +// atomically under the write lock. This prevents SendCancel from +// interrupting the invoke sequence. +func (s *Stream) WriteInvoke(rpc string, metadata []byte) error { + defer s.checkFinished() + s.write.Lock() + defer s.write.Unlock() + + if len(metadata) > 0 { + if err := s.rawWriteLocked(drpcwire.KindInvokeMetadata, metadata); err != nil { + return err + } + } + return s.rawWriteLocked(drpcwire.KindInvoke, []byte(rpc)) +} + // // raw read/write // From ef11bd14d7c0d2beef1d67b089f0933ce5e0e11b Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Tue, 7 Apr 2026 21:43:46 +0530 Subject: [PATCH 05/15] drpcmanager: move streams cancellation on termination to activeStreams.Close Previously, when a manager terminated, each stream's manageStream goroutine independently detected m.sigs.term.Signal and cancelled its own stream. This was asynchronous. Instead, cancel all streams directly in activeStreams.Close(), called from Manager.terminate. This makes termination synchronous and immediate, and removes the term.Signal case from manageStream. --- drpc.go | 8 -------- drpcconn/conn.go | 2 +- drpcmanager/active_streams.go | 22 +++++++++------------ drpcmanager/manager.go | 36 ++++++++++++++++++++--------------- drpcserver/server.go | 2 +- 5 files changed, 32 insertions(+), 38 deletions(-) diff --git a/drpc.go b/drpc.go index f6f9208..ed03703 100644 --- a/drpc.go +++ b/drpc.go @@ -76,14 +76,6 @@ type Stream interface { // received on it. Context() context.Context - // Kind returns the type of the stream ("unknown", "cli", or "srv"). Client - // and server streams must be treated differently for error handling and - // logging purposes. - // - // Client streams return Unavailable errors when the remote closes the - // connection, while server streams return Canceled errors. - Kind() StreamKind - // MsgSend sends the Message to the remote. MsgSend(msg Message, enc Encoding) error diff --git a/drpcconn/conn.go b/drpcconn/conn.go index cbde52e..eadaa52 100644 --- a/drpcconn/conn.go +++ b/drpcconn/conn.go @@ -56,7 +56,7 @@ func NewWithOptions(tr drpc.Transport, opts Options) *Conn { c.stats = make(map[string]*drpcstats.Stats) } - c.man = drpcmanager.NewWithOptions(tr, opts.Manager) + c.man = drpcmanager.NewWithOptions(tr, drpcmanager.Client, opts.Manager) return c } diff --git a/drpcmanager/active_streams.go b/drpcmanager/active_streams.go index 963f012..86f5548 100644 --- a/drpcmanager/active_streams.go +++ b/drpcmanager/active_streams.go @@ -59,27 +59,23 @@ func (r *activeStreams) Get(id uint64) (*drpcstream.Stream, bool) { r.mu.RLock() defer r.mu.RUnlock() + if r.closed { + return nil, false + } s, ok := r.streams[id] return s, ok } -// Close marks the collection as closed, preventing future Add calls. -// It does not cancel any streams. -func (r *activeStreams) Close() { +// Close cancels all active streams with the given error, clears the +// collection, and marks it as closed to prevent future Add calls. +func (r *activeStreams) Close(err error) { r.mu.Lock() defer r.mu.Unlock() r.closed = true -} - -// ForEach calls fn for each active stream. The collection is read-locked -// during iteration. -func (r *activeStreams) ForEach(fn func(*drpcstream.Stream)) { - r.mu.RLock() - defer r.mu.RUnlock() - - for _, s := range r.streams { - fn(s) + for id, s := range r.streams { + s.Cancel(err) + delete(r.streams, id) } } diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index 9421958..d804a15 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -81,8 +81,18 @@ type Manager struct { read drpcsignal.Signal // set after the goroutine reading from the transport is done tport drpcsignal.Signal // set after the transport has been closed } + + kind ManagerKind } +type ManagerKind uint8 + +const ( + _ ManagerKind = iota + Client + Server +) + type invokeAssembler struct { metadata map[string]string // accumulated invoke metadata pa drpcwire.PacketAssembler // assembles invoke/metadata frames into packets @@ -97,13 +107,13 @@ type invokeInfo struct { } // New returns a new Manager for the transport. -func New(tr drpc.Transport) *Manager { - return NewWithOptions(tr, Options{}) +func New(tr drpc.Transport, kind ManagerKind) *Manager { + return NewWithOptions(tr, kind, Options{}) } // NewWithOptions returns a new manager for the transport. It uses the provided // options to manage details of how it uses it. -func NewWithOptions(tr drpc.Transport, opts Options) *Manager { +func NewWithOptions(tr drpc.Transport, kind ManagerKind, opts Options) *Manager { m := &Manager{ tr: tr, wr: drpcwire.NewWriter(tr, opts.WriterBufferSize), @@ -111,6 +121,7 @@ func NewWithOptions(tr drpc.Transport, opts Options) *Manager { opts: opts, invokes: make(chan invokeInfo), + kind: kind, } // a buffer of size 1 allows NewServerStream to signal it is done creating a @@ -144,7 +155,13 @@ func (m *Manager) terminate(err error) { if m.sigs.term.Set(err) { m.log("TERM", func() string { return fmt.Sprint(err) }) m.sigs.tport.Set(m.tr.Close()) - m.streams.Close() + if errors.Is(err, io.EOF) { + err = context.Canceled + if m.kind == Client { + err = drpc.ClosedError.New("connection closed") + } + } + m.streams.Close(err) } } @@ -268,17 +285,6 @@ func (m *Manager) manageStream(ctx context.Context, stream *drpcstream.Stream) { defer m.wg.Done() defer m.streams.Remove(stream.ID()) select { - case <-m.sigs.term.Signal(): - err := m.sigs.term.Err() - if errors.Is(err, io.EOF) { - err = context.Canceled - if stream.Kind() == drpc.StreamKindClient { - err = drpc.ClosedError.New("connection closed") - } - } - stream.Cancel(err) - <-stream.Finished() - case <-stream.Finished(): case <-ctx.Done(): diff --git a/drpcserver/server.go b/drpcserver/server.go index 73a2bfb..1849e80 100644 --- a/drpcserver/server.go +++ b/drpcserver/server.go @@ -139,7 +139,7 @@ func (s *Server) ServeOne(ctx context.Context, tr drpc.Transport) (err error) { } } - man := drpcmanager.NewWithOptions(tr, s.opts.Manager) + man := drpcmanager.NewWithOptions(tr, drpcmanager.Server, s.opts.Manager) var wg sync.WaitGroup defer func() { wg.Wait() From 959897e8e5787b5bcf04c18b952d2708f18201bf Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Wed, 8 Apr 2026 12:55:54 +0530 Subject: [PATCH 06/15] drpcstream: document two-phase shutdown and stream synchronization --- drpcstream/stream.go | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/drpcstream/stream.go b/drpcstream/stream.go index 4dd50ee..8eef1da 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -47,6 +47,11 @@ type Stream struct { opts Options task *trace.Task + // write and read serialize operations within a stream. The data path + // (MsgSend/MsgRecv) and the control path (SendCancel/Close/SendError) + // genuinely race because cancellation arrives from manageStream while the + // application may be mid-send. These are inspectMutex (not sync.Mutex) so + // that checkFinished can test whether ops are in flight without blocking. write inspectMutex read inspectMutex flush sync.Once @@ -62,6 +67,13 @@ type Stream struct { sigs struct { send drpcsignal.Signal // set when done sending messages recv drpcsignal.Signal // set when done receiving messages + // Stream shutdown is two-phase: term then fin. When termination arrives + // (remote error, local cancel, close), there may be an in-flight write + // on the transport that is past the term check and inside WriteFrame. + // term tells new operations to bail out; fin signals that all in-flight + // operations have actually completed. Consumers (manageStream) wait on + // fin before cleaning up, guaranteeing no goroutine is touching the + // stream afterward. term drpcsignal.Signal // set when the stream is terminating and no new ops should begin fin drpcsignal.Signal // set when the stream is finished and all ops are complete cancel drpcsignal.Signal // set when externally canceled @@ -300,9 +312,11 @@ func (s *Stream) handlePacket(pkt drpcwire.Packet) (err error) { // helpers // -// checkFinished checks to see if the stream is terminated, and if so, sets the -// finished flag. This must be called after every read or write is complete, as -// well as when the stream becomes terminated. +// checkFinished bridges the two-phase shutdown. It is called in two places: +// inside terminate() for when no I/O is in flight (fin fires immediately), +// and deferred after every read/write unlock for when an operation was in +// flight at termination time (fin fires once the last operation completes). +// Whichever call site runs last sees term set and both locks free, and sets fin. func (s *Stream) checkFinished() { if s.sigs.term.IsSet() && s.write.Unlocked() && s.read.Unlocked() { if s.sigs.fin.Set(nil) { From 4ffd1de2ba57d88e6e3a60ff172fbb6b78304d7e Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Wed, 8 Apr 2026 14:21:44 +0530 Subject: [PATCH 07/15] drpcconn: remove stats and shared write buffer With multiplexing, multiple Invoke calls run concurrently. The shared c.wbuf buffer and its protecting mutex are replaced with per-call allocation. Stats collection is removed (unused by cockroach). --- drpcconn/conn.go | 65 ++++------------------------- drpcmanager/manager.go | 6 +-- internal/integration/simple_test.go | 49 ---------------------- 3 files changed, 11 insertions(+), 109 deletions(-) diff --git a/drpcconn/conn.go b/drpcconn/conn.go index eadaa52..b8bbf49 100644 --- a/drpcconn/conn.go +++ b/drpcconn/conn.go @@ -5,7 +5,6 @@ package drpcconn import ( "context" - "sync" "github.com/zeebo/errs" grpcmetadata "google.golang.org/grpc/metadata" @@ -13,30 +12,20 @@ import ( "storj.io/drpc/drpcenc" "storj.io/drpc/drpcmanager" "storj.io/drpc/drpcmetadata" - "storj.io/drpc/drpcstats" "storj.io/drpc/drpcstream" "storj.io/drpc/drpcwire" - "storj.io/drpc/internal/drpcopts" ) // Options controls configuration settings for a conn. type Options struct { // Manager controls the options we pass to the manager of this conn. Manager drpcmanager.Options - - // CollectStats controls whether the server should collect stats on the - // rpcs it creates. - CollectStats bool } // Conn is a drpc client connection. type Conn struct { - tr drpc.Transport - man *drpcmanager.Manager - mu sync.Mutex - wbuf []byte - - stats map[string]*drpcstats.Stats + tr drpc.Transport + man *drpcmanager.Manager } var _ drpc.Conn = (*Conn)(nil) @@ -51,41 +40,11 @@ func NewWithOptions(tr drpc.Transport, opts Options) *Conn { tr: tr, } - if opts.CollectStats { - drpcopts.SetManagerStatsCB(&opts.Manager.Internal, c.getStats) - c.stats = make(map[string]*drpcstats.Stats) - } - c.man = drpcmanager.NewWithOptions(tr, drpcmanager.Client, opts.Manager) return c } -// Stats returns the collected stats grouped by rpc. -func (c *Conn) Stats() map[string]drpcstats.Stats { - c.mu.Lock() - defer c.mu.Unlock() - - stats := make(map[string]drpcstats.Stats, len(c.stats)) - for k, v := range c.stats { - stats[k] = v.AtomicClone() - } - return stats -} - -// getStats returns the drpcopts.Stats struct for the given rpc. -func (c *Conn) getStats(rpc string) *drpcstats.Stats { - c.mu.Lock() - defer c.mu.Unlock() - - stats := c.stats[rpc] - if stats == nil { - stats = new(drpcstats.Stats) - c.stats[rpc] = stats - } - return stats -} - // Transport returns the transport the conn is using. func (c *Conn) Transport() drpc.Transport { return c.tr } @@ -93,15 +52,15 @@ func (c *Conn) Transport() drpc.Transport { return c.tr } func (c *Conn) Closed() <-chan struct{} { return c.man.Closed() } // Unblocked returns a channel that is closed once the connection is no longer -// blocked by a previously canceled Invoke or NewStream call. It should not -// be called concurrently with Invoke or NewStream. +// blocked. With multiplexing, multiple streams run concurrently and this +// channel is always closed immediately. func (c *Conn) Unblocked() <-chan struct{} { return c.man.Unblocked() } // Close closes the connection. func (c *Conn) Close() (err error) { return c.man.Close() } // Invoke issues the rpc on the transport serializing in, waits for a response, and -// deserializes it into out. Only one Invoke or Stream may be open at a time. +// deserializes it into out. func (c *Conn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) (err error) { defer func() { err = drpc.ToRPCErr(err) }() @@ -117,18 +76,13 @@ func (c *Conn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, ou } defer func() { err = errs.Combine(err, stream.Close()) }() - // we have to protect c.wbuf here even though the manager only allows one - // stream at a time because the stream may async close allowing another - // concurrent call to Invoke to proceed. - c.mu.Lock() - defer c.mu.Unlock() - - c.wbuf, err = drpcenc.MarshalAppend(in, enc, c.wbuf[:0]) + // TODO: use buffer pool to reduce allocations + data, err := drpcenc.MarshalAppend(in, enc, nil) if err != nil { return err } - if err := c.doInvoke(stream, enc, rpc, c.wbuf, metadata, out); err != nil { + if err := c.doInvoke(stream, enc, rpc, data, metadata, out); err != nil { return err } return nil @@ -150,8 +104,7 @@ func (c *Conn) doInvoke(stream *drpcstream.Stream, enc drpc.Encoding, rpc string return nil } -// NewStream begins a streaming rpc on the connection. Only one Invoke or Stream may -// be open at a time. +// NewStream begins a streaming rpc on the connection. func (c *Conn) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (_ drpc.Stream, err error) { defer func() { err = drpc.ToRPCErr(err) }() diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index d804a15..76f1078 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -309,10 +309,8 @@ func (m *Manager) Closed() <-chan struct{} { } // Unblocked returns a channel that is closed when the manager is no longer -// blocked from creating a new stream due to a previous stream's soft cancel. It -// should not be called concurrently with NewClientStream or NewServerStream and -// the return result is only valid until the next call to NewClientStream or -// NewServerStream. +// blocked. With multiplexing, multiple streams run concurrently and this +// channel is always closed immediately. func (m *Manager) Unblocked() <-chan struct{} { return closedCh } diff --git a/internal/integration/simple_test.go b/internal/integration/simple_test.go index b0c2c24..8dbf47a 100644 --- a/internal/integration/simple_test.go +++ b/internal/integration/simple_test.go @@ -289,52 +289,3 @@ func TestServerStats(t *testing.T) { "/service.Service/Method3": {Read: 2, Written: 6}, }) } - -func TestClientStats(t *testing.T) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - c1, c2 := net.Pipe() - mux := drpcmux.New() - _ = DRPCRegisterService(mux, standardImpl) - - srv := drpcserver.New(mux) - ctx.Run(func(ctx context.Context) { _ = srv.ServeOne(ctx, c1) }) - - conn := drpcconn.NewWithOptions(c2, drpcconn.Options{ - CollectStats: true, - }) - defer func() { _ = conn.Close() }() - cli := NewDRPCServiceClient(conn) - - assert.Equal(t, srv.Stats(), map[string]drpcstats.Stats{}) - - _, err := cli.Method1(ctx, in(5)) - assert.Error(t, err) - - assert.Equal(t, conn.Stats(), map[string]drpcstats.Stats{ - "/service.Service/Method1": {Read: 9, Written: 26}, - }) - - _, err = cli.Method1(ctx, in(1)) - assert.NoError(t, err) - - assert.Equal(t, conn.Stats(), map[string]drpcstats.Stats{ - "/service.Service/Method1": {Read: 9 + 2, Written: 26 + 26}, - }) - - stream, err := cli.Method3(ctx, in(3)) - assert.NoError(t, err) - for i := 0; i < 3; i++ { - _, err := stream.Recv() - assert.NoError(t, err) - } - _, err = stream.Recv() - assert.That(t, errors.Is(err, io.EOF)) - assert.NoError(t, stream.Close()) - - assert.Equal(t, conn.Stats(), map[string]drpcstats.Stats{ - "/service.Service/Method1": {Read: 9 + 2, Written: 26 + 26}, - "/service.Service/Method3": {Read: 6, Written: 26}, - }) -} From 2a097d5d1f2ad34929affc8148178f965f68dd63 Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Wed, 8 Apr 2026 15:21:18 +0530 Subject: [PATCH 08/15] drpc: update tests for multiplexing API changes Fix compilation errors and remove tests for removed features: - Update New/NewWithOptions calls to include ManagerKind argument - Update activeStreams.Close calls to include error argument - Remove tests for SoftCancel, cross-stream monotonicity, Unblocked semantics, SendCancel busy return, and ForEach (all removed) - Adapt OldStreamFramesIgnored to use stream.Finished instead of Unblocked --- drpcmanager/active_streams_test.go | 33 +---- drpcmanager/manager_test.go | 222 +++-------------------------- drpcmanager/random_test.go | 2 +- drpcstream/stream_test.go | 24 ---- 4 files changed, 26 insertions(+), 255 deletions(-) diff --git a/drpcmanager/active_streams_test.go b/drpcmanager/active_streams_test.go index cdf2ec9..c4159b6 100644 --- a/drpcmanager/active_streams_test.go +++ b/drpcmanager/active_streams_test.go @@ -5,6 +5,7 @@ package drpcmanager import ( "context" + "errors" "testing" "github.com/zeebo/assert" @@ -73,7 +74,7 @@ func TestActiveStreams_DuplicateAdd(t *testing.T) { func TestActiveStreams_AddAfterClose(t *testing.T) { streams := newActiveStreams() - streams.Close() + streams.Close(errors.New("closed")) err := streams.Add(1, testStream(1)) assert.Error(t, err) @@ -84,7 +85,7 @@ func TestActiveStreams_RemoveAfterClose(t *testing.T) { s := testStream(1) assert.NoError(t, streams.Add(1, s)) - streams.Close() + streams.Close(errors.New("closed")) // must not panic streams.Remove(1) @@ -104,31 +105,3 @@ func TestActiveStreams_Len(t *testing.T) { assert.Equal(t, streams.Len(), 1) } -func TestActiveStreams_ForEach(t *testing.T) { - streams := newActiveStreams() - s1 := testStream(1) - s2 := testStream(2) - s3 := testStream(3) - - assert.NoError(t, streams.Add(1, s1)) - assert.NoError(t, streams.Add(2, s2)) - assert.NoError(t, streams.Add(3, s3)) - - seen := make(map[uint64]*drpcstream.Stream) - streams.ForEach(func(s *drpcstream.Stream) { - seen[s.ID()] = s - }) - - assert.Equal(t, len(seen), 3) - assert.Equal(t, seen[1], s1) - assert.Equal(t, seen[2], s2) - assert.Equal(t, seen[3], s3) -} - -func TestActiveStreams_ForEach_Empty(t *testing.T) { - streams := newActiveStreams() - - count := 0 - streams.ForEach(func(_ *drpcstream.Stream) { count++ }) - assert.Equal(t, count, 0) -} diff --git a/drpcmanager/manager_test.go b/drpcmanager/manager_test.go index 95cc375..a98626d 100644 --- a/drpcmanager/manager_test.go +++ b/drpcmanager/manager_test.go @@ -8,7 +8,6 @@ import ( "errors" "io" "net" - "sync" "testing" "time" @@ -21,15 +20,6 @@ import ( "storj.io/drpc/drpcwire" ) -func closed(ch <-chan struct{}) bool { - select { - case <-ch: - return true - default: - return false - } -} - func TestDrpcMetadata(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() @@ -38,10 +28,10 @@ func TestDrpcMetadata(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - cman := New(cconn) + cman := New(cconn, Client) defer func() { _ = cman.Close() }() - sman := NewWithOptions(sconn, Options{ + sman := NewWithOptions(sconn, Server, Options{ GRPCMetadataCompatMode: false, }) defer func() { _ = sman.Close() }() @@ -59,10 +49,7 @@ func TestDrpcMetadata(t *testing.T) { assert.NoError(t, stream.RawWrite(drpcwire.KindInvoke, []byte("invoke"))) assert.NoError(t, stream.RawWrite(drpcwire.KindMessage, []byte("message"))) assert.NoError(t, stream.RawFlush()) - assert.That(t, !closed(cman.Unblocked())) - assert.NoError(t, stream.Close()) - assert.That(t, closed(cman.Unblocked())) }) ctx.Run(func(ctx context.Context) { @@ -98,10 +85,10 @@ func TestDrpcMetadataWithGRPCMetadataCompatMode(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - cman := New(cconn) + cman := New(cconn, Client) defer func() { _ = cman.Close() }() - sman := NewWithOptions(sconn, Options{ + sman := NewWithOptions(sconn, Server, Options{ GRPCMetadataCompatMode: true, }) defer func() { _ = sman.Close() }() @@ -119,10 +106,7 @@ func TestDrpcMetadataWithGRPCMetadataCompatMode(t *testing.T) { assert.NoError(t, stream.RawWrite(drpcwire.KindInvoke, []byte("invoke"))) assert.NoError(t, stream.RawWrite(drpcwire.KindMessage, []byte("message"))) assert.NoError(t, stream.RawFlush()) - assert.That(t, !closed(cman.Unblocked())) - assert.NoError(t, stream.Close()) - assert.That(t, closed(cman.Unblocked())) }) ctx.Run(func(ctx context.Context) { @@ -196,7 +180,7 @@ func TestManageReader_GlobalMonotonicity_SameStream(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - man := New(sconn) + man := New(sconn, Server) defer func() { _ = man.Close() }() // Consume the invoke and drain messages so HandleFrame doesn't block. @@ -219,33 +203,6 @@ func TestManageReader_GlobalMonotonicity_SameStream(t *testing.T) { waitForClosed(t, man) } -// Cross-stream monotonicity: after seeing stream 2, a frame for stream 1 -// with a higher message ID is still rejected because {1,x} < {2,y}. -func TestManageReader_GlobalMonotonicity_CrossStream(t *testing.T) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - cconn, sconn := net.Pipe() - defer func() { _ = cconn.Close() }() - defer func() { _ = sconn.Close() }() - - man := New(sconn) - defer func() { _ = man.Close() }() - - // Consume both invokes so manageReader can proceed. - ctx.Run(func(ctx context.Context) { - _, _, _ = man.NewServerStream(ctx) - _, _, _ = man.NewServerStream(ctx) - }) - - writeFrames(t, cconn, - createFrame(drpcwire.KindInvoke, 1, 1, "rpc1", true), - createFrame(drpcwire.KindInvoke, 2, 1, "rpc2", true), - createFrame(drpcwire.KindMessage, 1, 4, "bad", true), - ) - - waitForClosed(t, man) -} // Invoke replay: after [s1,m1,invoke,done=true], lastFrameID is bumped to // {1,2}. A replayed [s1,m1,invoke] is caught by the monotonicity check. @@ -257,7 +214,7 @@ func TestManageReader_InvokeReplayBlocked(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - man := New(sconn) + man := New(sconn, Server) defer func() { _ = man.Close() }() ctx.Run(func(ctx context.Context) { @@ -282,7 +239,7 @@ func TestManageReader_ContinuationFramesAccepted(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - man := New(sconn) + man := New(sconn, Server) defer func() { _ = man.Close() }() recv := make(chan []byte, 1) @@ -303,8 +260,8 @@ func TestManageReader_ContinuationFramesAccepted(t *testing.T) { assert.DeepEqual(t, <-recv, []byte("hello")) } -// Old-stream frames are silently ignored on the client side when the local -// stream ID has advanced past the incoming frame's stream ID. +// Old-stream frames are silently ignored when the stream has been cancelled +// and removed from the registry. func TestManageReader_OldStreamFramesIgnored(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() @@ -313,11 +270,10 @@ func TestManageReader_OldStreamFramesIgnored(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - cman := NewWithOptions(cconn, Options{SoftCancel: true}) + cman := New(cconn, Client) defer func() { _ = cman.Close() }() - // Drain all client writes so nothing blocks, and write server - // responses once we've seen enough data. + // Drain all client writes so nothing blocks. ctx.Run(func(ctx context.Context) { buf := make([]byte, 4096) for { @@ -328,13 +284,13 @@ func TestManageReader_OldStreamFramesIgnored(t *testing.T) { } }) - // Create stream 1 on the client, then cancel it so the client - // advances to stream 2. + // Create stream 1 on the client, then cancel it so it's removed + // from the registry. subctx, cancel := context.WithCancel(ctx) - _, err := cman.NewClientStream(subctx, "rpc1") + stream1, err := cman.NewClientStream(subctx, "rpc1") assert.NoError(t, err) cancel() - <-cman.Unblocked() + <-stream1.Finished() stream2, err := cman.NewClientStream(ctx, "rpc2") assert.NoError(t, err) @@ -363,7 +319,7 @@ func TestManageReader_ValidInvokeSequence(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - man := New(sconn) + man := New(sconn, Server) defer func() { _ = man.Close() }() recv := make(chan []byte, 1) @@ -395,7 +351,7 @@ func TestManageReader_MultiFrameDelivery(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - man := New(sconn) + man := New(sconn, Server) defer func() { _ = man.Close() }() recv := make(chan []byte, 1) @@ -428,7 +384,7 @@ func TestManageReader_HigherMsgDiscardsInProgress(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - man := New(sconn) + man := New(sconn, Server) defer func() { _ = man.Close() }() recv := make(chan []byte, 1) @@ -459,7 +415,7 @@ func TestManageReader_KindChangeWithinPacket(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - man := New(sconn) + man := New(sconn, Server) defer func() { _ = man.Close() }() ctx.Run(func(ctx context.Context) { @@ -492,7 +448,7 @@ func TestManageReader_MultiFrameWithSkippedMessageID(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - man := New(sconn) + man := New(sconn, Server) defer func() { _ = man.Close() }() recv := make(chan []byte, 1) @@ -523,7 +479,7 @@ func TestManageReader_InvokeOnExistingStream(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - man := New(sconn) + man := New(sconn, Server) defer func() { _ = man.Close() }() ctx.Run(func(ctx context.Context) { @@ -552,7 +508,7 @@ func TestManageReader_WaitsForStreamCreation(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - man := New(sconn) + man := New(sconn, Server) defer func() { _ = man.Close() }() // Write invoke + message immediately. The message arrives before @@ -578,137 +534,3 @@ func TestManageReader_WaitsForStreamCreation(t *testing.T) { assert.DeepEqual(t, <-recv, []byte("data")) } -type blockingTransport chan struct{} - -func (b blockingTransport) Read(p []byte) (n int, err error) { <-b; return 0, io.EOF } -func (b blockingTransport) Write(p []byte) (n int, err error) { <-b; return 0, io.EOF } -func (b blockingTransport) Close() error { close(b); return nil } - -func TestUnblocked_NoCancel(t *testing.T) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - cconn, sconn := net.Pipe() - defer func() { _ = cconn.Close() }() - defer func() { _ = sconn.Close() }() - - cman := New(cconn) - defer func() { _ = cman.Close() }() - - sman := New(sconn) - defer func() { _ = sman.Close() }() - - ctx.Run(func(ctx context.Context) { - stream, err := cman.NewClientStream(ctx, "rpc") - assert.NoError(t, err) - defer func() { _ = stream.Close() }() - - assert.NoError(t, stream.RawWrite(drpcwire.KindInvoke, []byte("invoke"))) - assert.NoError(t, stream.RawWrite(drpcwire.KindMessage, []byte("message"))) - assert.NoError(t, stream.RawFlush()) - assert.That(t, !closed(cman.Unblocked())) - - assert.NoError(t, stream.Close()) - assert.That(t, closed(cman.Unblocked())) - }) - - ctx.Run(func(ctx context.Context) { - stream, _, err := sman.NewServerStream(ctx) - assert.NoError(t, err) - defer func() { _ = stream.Close() }() - - _, err = stream.RawRecv() - assert.NoError(t, err) - - _, err = stream.RawRecv() - assert.That(t, errors.Is(err, io.EOF)) - }) - - ctx.Wait() -} - -func TestUnblocked_SoftCancel(t *testing.T) { - run := func(t *testing.T, softCancel bool) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - tr := newBlockedTransport() - man := NewWithOptions(tr, Options{SoftCancel: softCancel}) - defer func() { _ = man.Close() }() - defer tr.setReadOpen(true) - defer tr.setWriteOpen(true) - - for i := 0; i < 10; i++ { - func() { - subctx, cancel := context.WithCancel(ctx) - defer cancel() - - stream, err := man.NewClientStream(subctx, "rpc") - if softCancel { - assert.NoError(t, err) - } else if i > 0 { - assert.Error(t, err) - return - } - defer func() { _ = stream.Close() }() - - assert.That(t, !closed(man.Unblocked())) - cancel() - - // temporary unblock writing to allow the stream to finish soft cancel - tr.setWriteOpen(true) - <-man.Unblocked() - tr.setWriteOpen(false) - }() - } - } - - t.Run("Enabled", func(t *testing.T) { run(t, true) }) - t.Run("Disabled", func(t *testing.T) { run(t, false) }) -} - -type blockedTransport struct { - mu *sync.Mutex - co *sync.Cond - ro bool - wo bool -} - -func newBlockedTransport() *blockedTransport { - mu := new(sync.Mutex) - co := sync.NewCond(mu) - return &blockedTransport{ - mu: mu, - co: co, - } -} - -func (b *blockedTransport) setWriteOpen(open bool) { - b.mu.Lock() - defer b.mu.Unlock() - - b.wo = open - b.co.Broadcast() -} - -func (b *blockedTransport) setReadOpen(open bool) { - b.mu.Lock() - defer b.mu.Unlock() - - b.ro = open - b.co.Broadcast() -} - -func (b *blockedTransport) wait(p int, rw *bool) (int, error) { - b.mu.Lock() - defer b.mu.Unlock() - - for !*rw { - b.co.Wait() - } - return p, nil -} - -func (b *blockedTransport) Read(p []byte) (n int, err error) { return b.wait(len(p), &b.ro) } -func (b *blockedTransport) Write(p []byte) (n int, err error) { return b.wait(len(p), &b.wo) } -func (b *blockedTransport) Close() error { return nil } diff --git a/drpcmanager/random_test.go b/drpcmanager/random_test.go index db41bc2..7875944 100644 --- a/drpcmanager/random_test.go +++ b/drpcmanager/random_test.go @@ -197,7 +197,7 @@ func runRandomized(t *testing.T, prog []byte, r runner) { defer func() { _ = ps.Close() }() wr := drpcwire.NewWriter(pc, 0) - man := New(ps) + man := New(ps, Server) defer func() { _ = man.Close() }() errch := make(chan error, 1) diff --git a/drpcstream/stream_test.go b/drpcstream/stream_test.go index 62e4944..fe12ad9 100644 --- a/drpcstream/stream_test.go +++ b/drpcstream/stream_test.go @@ -298,30 +298,6 @@ func TestStream_PacketBufferReuse(t *testing.T) { } } -func TestStream_SendCancelBusyDuringBlockedClose(t *testing.T) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - pr, pw := io.Pipe() - defer func() { _ = pr.Close() }() - defer func() { _ = pw.Close() }() - - st := New(ctx, 0, drpcwire.NewWriter(pw, 0)) - - // launch a goroutine to close the stream - ctx.Run(func(ctx context.Context) { _ = st.Close() }) - - // read just 1 byte from the pipe to ensure that the Close has started - _, err := pr.Read(make([]byte, 1)) - assert.NoError(t, err) - assert.That(t, st.IsTerminated()) - - // even though the stream is terminated, soft cancel should report that - // the stream is still busy because the close is being sent. - busy, err := st.SendCancel(context.Canceled) - assert.NoError(t, err) - assert.That(t, busy) -} // // HandleFrame tests From b695ff56f657a9ac3c13e6763367ae4f57011b36 Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Wed, 8 Apr 2026 20:27:07 +0530 Subject: [PATCH 09/15] drpcstream: replace single-slot packetBuffer with bounded packetQueue MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The old packetBuffer was single-slot: Put blocked until the consumer called Get+Done. This meant manageReader was stuck delivering to one stream and couldn't serve others — a deadlock under multiplexing. Replace it with a ring-buffer packetQueue (capacity 256) that copies data on Put and returns immediately. Get drains queued messages before returning the close error, ensuring graceful shutdown delivers all buffered data. --- drpcmanager/manager.go | 4 +- drpcstream/packet_queue.go | 115 ++++++++++++++++ drpcstream/packet_queue_test.go | 228 ++++++++++++++++++++++++++++++++ drpcstream/pktbuf.go | 85 ------------ drpcstream/stream.go | 9 +- 5 files changed, 348 insertions(+), 93 deletions(-) create mode 100644 drpcstream/packet_queue.go create mode 100644 drpcstream/packet_queue_test.go delete mode 100644 drpcstream/pktbuf.go diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index 76f1078..c12841c 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -64,9 +64,7 @@ type Manager struct { wg sync.WaitGroup // tracks active manageStream goroutines - // streams tracks active streams. Currently holds at most one active stream; - // a second may briefly coexist during stream handoff (old stream's Remove - // races with new stream's Add). + // streams tracks active streams. streams *activeStreams pdone drpcsignal.Chan // signals when NewServerStream has registered the new stream diff --git a/drpcstream/packet_queue.go b/drpcstream/packet_queue.go new file mode 100644 index 0000000..bc26dec --- /dev/null +++ b/drpcstream/packet_queue.go @@ -0,0 +1,115 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcstream + +import "sync" + +// defaultPacketQueueCapacity is the number of messages the packet queue can +// hold before the producer blocks. This decouples the transport reader +// (manageReader) from the consumer (RPC handler), preventing a slow handler +// from blocking frame delivery to other streams. +// +// TODO: benchmark whether power-of-2 masking improves performance over modulo. +const defaultPacketQueueCapacity = 256 + +// packetQueue is a bounded single-producer / single-consumer queue for +// assembled packet data. It sits between manageReader (producer, calls Put) +// and the application goroutine (consumer, calls Get/Done). +// +// It is implemented as a ring buffer with mutex + cond synchronization. +// Slots are pre-allocated and reused: each slot's backing array grows via +// append to fit incoming data, then stays at its high-water mark, avoiding +// per-message allocation in steady state. +// +// After Close, Get drains any queued messages before returning the close +// error. This ensures graceful shutdown (KindClose/KindCloseSend) delivers +// all buffered data to the consumer. +type packetQueue struct { + mu sync.Mutex + cond sync.Cond + + buf [][]byte // ring buffer of byte slices + head int // next write position (producer) + tail int // next read position (consumer) + count int // number of occupied slots + + held bool // true between Get and Done + err error // terminal error, set by Close +} + +func (pq *packetQueue) init() { + pq.cond.L = &pq.mu + pq.buf = make([][]byte, defaultPacketQueueCapacity) +} + +// Put copies data into the next write slot. If the queue is full, it blocks +// until a slot is freed or the queue is closed. If the queue is closed, Put +// returns silently without enqueuing. +func (pq *packetQueue) Put(data []byte) { + pq.mu.Lock() + defer pq.mu.Unlock() + + for pq.count == len(pq.buf) && pq.err == nil { + pq.cond.Wait() + } + if pq.err != nil { + return + } + + pq.buf[pq.head] = append(pq.buf[pq.head][:0], data...) + pq.head = (pq.head + 1) % len(pq.buf) + pq.count++ + pq.cond.Broadcast() +} + +// Get returns the data from the next read slot. If the queue is empty, it +// blocks until data is available or the queue is closed. The returned slice +// is valid until Done is called. +func (pq *packetQueue) Get() ([]byte, error) { + pq.mu.Lock() + defer pq.mu.Unlock() + + for pq.count == 0 && pq.err == nil { + pq.cond.Wait() + } + if pq.count == 0 { + // Queue is empty and closed — return the close error. + return nil, pq.err + } + + // Return data even if closed, draining pending items first. + pq.held = true + return pq.buf[pq.tail], nil +} + +// Done advances the read pointer, making the slot available for reuse. +// It must be called exactly once after each successful Get. +func (pq *packetQueue) Done() { + pq.mu.Lock() + defer pq.mu.Unlock() + + pq.tail = (pq.tail + 1) % len(pq.buf) + pq.count-- + pq.held = false + pq.cond.Broadcast() +} + +// Close marks the queue as closed with the given error. All blocked Put and +// Get calls are woken and will return. Close waits for any in-progress +// Get/Done pair to complete before setting the error. Subsequent calls are +// no-ops. +func (pq *packetQueue) Close(err error) { + pq.mu.Lock() + defer pq.mu.Unlock() + + for pq.held { + pq.cond.Wait() + } + if pq.err != nil { + return + } + + pq.err = err + pq.cond.Broadcast() +} diff --git a/drpcstream/packet_queue_test.go b/drpcstream/packet_queue_test.go new file mode 100644 index 0000000..7af0a7c --- /dev/null +++ b/drpcstream/packet_queue_test.go @@ -0,0 +1,228 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcstream + +import ( + "io" + "sync" + "testing" + + "github.com/zeebo/assert" +) + +func TestPacketQueue_PutGet(t *testing.T) { + var pq packetQueue + pq.init() + + pq.Put([]byte("hello")) + + data, err := pq.Get() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("hello")) + pq.Done() +} + +func TestPacketQueue_FIFO(t *testing.T) { + var pq packetQueue + pq.init() + + pq.Put([]byte("first")) + pq.Put([]byte("second")) + pq.Put([]byte("third")) + + for _, want := range []string{"first", "second", "third"} { + data, err := pq.Get() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte(want)) + pq.Done() + } +} + +func TestPacketQueue_GetBlocksUntilPut(t *testing.T) { + var pq packetQueue + pq.init() + + got := make(chan []byte, 1) + go func() { + data, err := pq.Get() + assert.NoError(t, err) + got <- data + }() + + pq.Put([]byte("delayed")) + assert.DeepEqual(t, <-got, []byte("delayed")) + pq.Done() +} + +func TestPacketQueue_PutBlocksWhenFull(t *testing.T) { + var pq packetQueue + pq.cond.L = &pq.mu + pq.buf = make([][]byte, 2) // capacity 2 + + pq.Put([]byte("a")) + pq.Put([]byte("b")) + + // Third put should block until we drain one. + done := make(chan struct{}) + go func() { + pq.Put([]byte("c")) + close(done) + }() + + // Drain one slot. + data, err := pq.Get() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("a")) + pq.Done() + + // Now the blocked Put should complete. + <-done + + // Verify remaining items. + data, err = pq.Get() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("b")) + pq.Done() + + data, err = pq.Get() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("c")) + pq.Done() +} + +func TestPacketQueue_CloseUnblocksGet(t *testing.T) { + var pq packetQueue + pq.init() + + errch := make(chan error, 1) + go func() { + _, err := pq.Get() + errch <- err + }() + + pq.Close(io.EOF) + assert.Equal(t, <-errch, io.EOF) +} + +func TestPacketQueue_CloseUnblocksPut(t *testing.T) { + var pq packetQueue + pq.cond.L = &pq.mu + pq.buf = make([][]byte, 1) // capacity 1 + + pq.Put([]byte("fill")) + + done := make(chan struct{}) + go func() { + pq.Put([]byte("blocked")) + close(done) + }() + + pq.Close(io.EOF) + <-done +} + +func TestPacketQueue_CloseDrainsQueued(t *testing.T) { + var pq packetQueue + pq.init() + + pq.Put([]byte("queued")) + pq.Close(io.EOF) + + // Get returns the queued data first. + data, err := pq.Get() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("queued")) + pq.Done() + + // Next Get returns the close error. + data, err = pq.Get() + assert.Nil(t, data) + assert.Equal(t, err, io.EOF) +} + +func TestPacketQueue_CloseIdempotent(t *testing.T) { + var pq packetQueue + pq.init() + + pq.Close(io.EOF) + pq.Close(io.ErrUnexpectedEOF) // should not overwrite + + _, err := pq.Get() + assert.Equal(t, err, io.EOF) // original error preserved +} + +func TestPacketQueue_PutAfterClose(t *testing.T) { + var pq packetQueue + pq.init() + + pq.Close(io.EOF) + pq.Put([]byte("dropped")) // should not panic or block +} + +func TestPacketQueue_SlotReuse(t *testing.T) { + var pq packetQueue + pq.cond.L = &pq.mu + pq.buf = make([][]byte, 2) + + // Fill and drain a few rounds to exercise slot reuse. + for round := 0; round < 5; round++ { + pq.Put([]byte("data")) + data, err := pq.Get() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("data")) + pq.Done() + } +} + +func TestPacketQueue_CloseWaitsForHeld(t *testing.T) { + var pq packetQueue + pq.init() + + pq.Put([]byte("msg")) + + // Get the data but don't call Done yet. + data, err := pq.Get() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("msg")) + + closed := make(chan struct{}) + go func() { + pq.Close(io.EOF) + close(closed) + }() + + // Close should be blocked because held is true. + // Call Done to release it. + pq.Done() + <-closed +} + +func TestPacketQueue_ConcurrentProducerConsumer(t *testing.T) { + var pq packetQueue + pq.init() + + const n = 1000 + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + for i := 0; i < n; i++ { + pq.Put([]byte{byte(i)}) + } + }() + + go func() { + defer wg.Done() + for i := 0; i < n; i++ { + data, err := pq.Get() + assert.NoError(t, err) + assert.Equal(t, data[0], byte(i)) + pq.Done() + } + }() + + wg.Wait() + pq.Close(io.EOF) +} diff --git a/drpcstream/pktbuf.go b/drpcstream/pktbuf.go deleted file mode 100644 index db68864..0000000 --- a/drpcstream/pktbuf.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (C) 2019 Storj Labs, Inc. -// See LICENSE for copying information. - -package drpcstream - -import ( - "sync" -) - -type packetBuffer struct { - mu sync.Mutex - cond sync.Cond - err error - data []byte - set bool - held bool -} - -func (pb *packetBuffer) init() { - pb.cond.L = &pb.mu -} - -func (pb *packetBuffer) Close(err error) { - pb.mu.Lock() - defer pb.mu.Unlock() - - for pb.held { - pb.cond.Wait() - } - - if pb.err == nil { - pb.data = nil - pb.set = false - pb.err = err - pb.cond.Broadcast() - } -} - -func (pb *packetBuffer) Put(data []byte) { - pb.mu.Lock() - defer pb.mu.Unlock() - - for pb.set && pb.err == nil { - pb.cond.Wait() - } - if pb.err != nil { - return - } - - pb.data = data - pb.set = true - pb.held = false - pb.cond.Broadcast() - - for pb.set || pb.held { - pb.cond.Wait() - } -} - -func (pb *packetBuffer) Get() ([]byte, error) { - pb.mu.Lock() - defer pb.mu.Unlock() - - for !pb.set && pb.err == nil { - pb.cond.Wait() - } - if pb.err != nil { - return nil, pb.err - } - - pb.held = true - pb.cond.Broadcast() - - return pb.data, nil -} - -func (pb *packetBuffer) Done() { - pb.mu.Lock() - defer pb.mu.Unlock() - - pb.data = nil - pb.set = false - pb.held = false - pb.cond.Broadcast() -} diff --git a/drpcstream/stream.go b/drpcstream/stream.go index 8eef1da..75687e8 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -60,7 +60,7 @@ type Stream struct { id drpcwire.ID wr *drpcwire.Writer - pbuf packetBuffer + pbuf packetQueue wbuf []byte mu sync.Mutex // protects state transitions @@ -594,10 +594,9 @@ func (s *Stream) SendError(serr error) (err error) { return s.checkCancelError(s.sendPacketLocked(drpcwire.KindError, false, drpcwire.MarshalError(serr))) } -// SendCancel transitions the stream into the canceled state with -// context.Canceled and sends a cancel error to the remote side for a soft -// cancel. It is a no-op if the stream is already terminated. It returns true -// for busy if writes are already blocked and a hard cancel is required. +// SendCancel terminates the stream and sends a cancel to the remote side. It +// blocks until any in-progress write completes. It is a no-op if the stream is +// already terminated. func (s *Stream) SendCancel(err error) error { s.log("CALL", func() string { return "SendCancel()" }) From b4a50fac4c6f5c57077dbdd38963cc7f233712c0 Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Sun, 12 Apr 2026 13:39:41 +0530 Subject: [PATCH 10/15] drpcwire: introduce MuxWriter for stream multiplexing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace per-stream drpcwire.Writer with a shared MuxWriter that uses a dedicated drain goroutine. This decouples stream writes from transport I/O, enabling true concurrent stream multiplexing. Core Architecture: MuxWriter (drpcwire/mux_writer.go): - Single instance per Manager, shared across all streams - Dedicated goroutine continuously drains buffered frames to transport - Non-blocking WriteFrame: appends to buffer under lock, signals goroutine - Double-buffer swap (buf/spare) minimizes time spent under lock Stream changes (drpcstream/stream.go): - wr field: *drpcwire.Writer -> *drpcwire.MuxWriter - Removed: ManualFlush option, RawFlush, rawFlushLocked, checkRecvFlush - sendPacketLocked/rawWriteLocked: WriteFrame only, no Flush Manager integration (drpcmanager/manager.go): - Creates MuxWriter with onError: m.terminate - terminate(): wr.Stop() THEN tr.Close() — Stop makes WriteFrame reject immediately; transport close unblocks any in-flight Write in the drain goroutine - Close(): <-wr.Done() to wait for drain goroutine exit Key Design Decisions: 1. No explicit flush needed. The drain goroutine continuously pulls from the buffer. Natural batching occurs because appends accumulate while the goroutine is mid-Write. The old cork pattern (delay flush until first recv) is unnecessary — appending is a memcpy under lock, and the goroutine controls when transport I/O happens. 2. sync.Cond over channels. Signal coalescing: multiple WriteFrame calls while run() is in Write produce a single wakeup. No allocation overhead. Stop uses closed bool + Broadcast. Consistent with packetQueue. 3. Two-phase shutdown (Stop/Done split to avoid deadlock). Stop() is non-blocking: sets closed, Broadcast, returns immediately. Done() returns a channel that closes when run() exits. This split is critical for the onError path: run() -> Write fails -> sets closed -> onError -> terminate -> Stop (finds closed=true, noop) -> run() returns. If Stop blocked until run() exited, this path would self-join. 4. run() owns its lifecycle on write failure. When Write fails, run() sets closed=true itself before calling onError. The subsequent onError -> terminate -> Stop path finds closed already set. No coordination needed; the flag is idempotent. 5. No per-stream FrameWriter wrapper. Initially considered a per-stream FrameWriter wrapping *MuxWriter, but the only value was a closed check before append. That check lives in MuxWriter.WriteFrame directly. Streams hold *MuxWriter and call WriteFrame. What this unlocks: - Concurrent multiplexing: streams no longer serialize on writes - Simplified stream: all flush/cork complexity removed - Natural batching from continuous drain - Direct error propagation: transport write failures fire manager termination via onError callback Breaking changes: - drpcstream.Options.ManualFlush removed - Stream.RawFlush(), SetManualFlush() removed - Stream constructor: *drpcwire.Writer -> *drpcwire.MuxWriter Test coverage: 8 concurrency tests for MuxWriter covering concurrent WriteFrame, write errors, onError->Stop deadlock path, blocked Write unblocked by Close, concurrent Stop, abort semantics (Stop discards buffered data), and write-during-active-drain. A data race in the initial implementation (reading buf capacity without lock) was caught by these tests and fixed. --- drpcconn/conn_test.go | 26 +- drpcmanager/active_streams_test.go | 28 ++- drpcmanager/manager.go | 20 +- drpcmanager/manager_test.go | 2 - drpcmanager/random_test.go | 23 +- drpcstream/stream.go | 113 +-------- drpcstream/stream_test.go | 181 +++----------- drpcwire/mux_writer.go | 101 ++++++++ drpcwire/mux_writer_test.go | 332 ++++++++++++++++++++++++++ drpcwire/writer.go | 107 --------- drpcwire/writer_test.go | 32 --- internal/grpccompat/benchmark_test.go | 7 +- internal/grpccompat/common_test.go | 1 - internal/integration/cancel_test.go | 10 - 14 files changed, 536 insertions(+), 447 deletions(-) create mode 100644 drpcwire/mux_writer.go create mode 100644 drpcwire/mux_writer_test.go delete mode 100644 drpcwire/writer.go delete mode 100644 drpcwire/writer_test.go diff --git a/drpcconn/conn_test.go b/drpcconn/conn_test.go index e7402b6..f16c8a3 100644 --- a/drpcconn/conn_test.go +++ b/drpcconn/conn_test.go @@ -40,31 +40,32 @@ func TestConn_InvokeFlushesSendClose(t *testing.T) { invokeDone := make(chan struct{}) ctx.Run(func(ctx context.Context) { - wr := drpcwire.NewWriter(ps, 64) + wr := drpcwire.NewMuxWriter(ps, nil) + defer wr.StopWait() rd := drpcwire.NewReader(ps) _, _ = rd.ReadFrame() // Invoke _, _ = rd.ReadFrame() // Message pkt, _ := rd.ReadFrame() // CloseSend - _ = wr.WritePacket(drpcwire.Packet{ + _ = wr.WriteFrame(drpcwire.Frame{ Data: []byte("qux"), ID: drpcwire.ID{Stream: pkt.ID.Stream, Message: 1}, Kind: drpcwire.KindMessage, + Done: true, }) - _ = wr.Flush() _, _ = rd.ReadFrame() // Close <-invokeDone // wait for invoke to return // ensure that any later packets are dropped by writing one // before closing the transport. - for i := 0; i < 5; i++ { - _ = wr.WritePacket(drpcwire.Packet{ + for range 5 { + _ = wr.WriteFrame(drpcwire.Frame{ ID: drpcwire.ID{Stream: pkt.ID.Stream, Message: 2}, Kind: drpcwire.KindCloseSend, + Done: true, }) - _ = wr.Flush() } _ = ps.Close() @@ -78,7 +79,7 @@ func TestConn_InvokeFlushesSendClose(t *testing.T) { invokeDone <- struct{}{} // signal invoke has returned - // we should eventually notice the transport is closed + // we should eventually notice the transport is closed due to ps.Close() select { case <-conn.Closed(): case <-time.After(1 * time.Second): @@ -95,7 +96,8 @@ func TestConn_InvokeSendsGrpcAndDrpcMetadata(t *testing.T) { defer func() { assert.NoError(t, ps.Close()) }() ctx.Run(func(ctx context.Context) { - wr := drpcwire.NewWriter(ps, 64) + wr := drpcwire.NewMuxWriter(ps, nil) + defer wr.StopWait() rd := drpcwire.NewReader(ps) md, err := rd.ReadFrame() // Metadata @@ -114,12 +116,12 @@ func TestConn_InvokeSendsGrpcAndDrpcMetadata(t *testing.T) { _, _ = rd.ReadFrame() // Message pkt, _ := rd.ReadFrame() // CloseSend - _ = wr.WritePacket(drpcwire.Packet{ + _ = wr.WriteFrame(drpcwire.Frame{ Data: []byte("qux"), ID: drpcwire.ID{Stream: pkt.ID.Stream, Message: 1}, Kind: drpcwire.KindMessage, + Done: true, }) - _ = wr.Flush() _, _ = rd.ReadFrame() // Close }) @@ -181,6 +183,10 @@ func TestConn_NewStreamSendsGrpcAndDrpcMetadata(t *testing.T) { s, err := conn.NewStream(ctx, "/com.example.Foo/Bar", testEncoding{}) assert.NoError(t, err) _ = s.CloseSend() + + // Wait for the server goroutine to read all frames before defers + // close the pipe. With MuxWriter, writes are asynchronous. + ctx.Wait() } func TestConn_encodeMetadata(t *testing.T) { diff --git a/drpcmanager/active_streams_test.go b/drpcmanager/active_streams_test.go index c4159b6..842acb7 100644 --- a/drpcmanager/active_streams_test.go +++ b/drpcmanager/active_streams_test.go @@ -6,6 +6,7 @@ package drpcmanager import ( "context" "errors" + "io" "testing" "github.com/zeebo/assert" @@ -14,13 +15,19 @@ import ( "storj.io/drpc/drpcwire" ) -func testStream(id uint64) *drpcstream.Stream { - return drpcstream.New(context.Background(), id, &drpcwire.Writer{}) +func testMuxWriter(t *testing.T) *drpcwire.MuxWriter { + mw := drpcwire.NewMuxWriter(io.Discard, func(error) {}) + t.Cleanup(func() { mw.Stop(); <-mw.Done() }) + return mw +} + +func testStream(t *testing.T, id uint64) *drpcstream.Stream { + return drpcstream.New(context.Background(), id, testMuxWriter(t)) } func TestActiveStreams_AddAndGet(t *testing.T) { streams := newActiveStreams() - s := testStream(1) + s := testStream(t, 1) assert.NoError(t, streams.Add(1, s)) @@ -39,7 +46,7 @@ func TestActiveStreams_GetMissing(t *testing.T) { func TestActiveStreams_Remove(t *testing.T) { streams := newActiveStreams() - s := testStream(1) + s := testStream(t, 1) assert.NoError(t, streams.Add(1, s)) assert.Equal(t, streams.Len(), 1) @@ -60,8 +67,8 @@ func TestActiveStreams_RemoveIdempotent(t *testing.T) { func TestActiveStreams_DuplicateAdd(t *testing.T) { streams := newActiveStreams() - s1 := testStream(1) - s2 := testStream(1) + s1 := testStream(t, 1) + s2 := testStream(t, 1) assert.NoError(t, streams.Add(1, s1)) assert.Error(t, streams.Add(1, s2)) @@ -76,13 +83,13 @@ func TestActiveStreams_AddAfterClose(t *testing.T) { streams := newActiveStreams() streams.Close(errors.New("closed")) - err := streams.Add(1, testStream(1)) + err := streams.Add(1, testStream(t, 1)) assert.Error(t, err) } func TestActiveStreams_RemoveAfterClose(t *testing.T) { streams := newActiveStreams() - s := testStream(1) + s := testStream(t, 1) assert.NoError(t, streams.Add(1, s)) streams.Close(errors.New("closed")) @@ -95,13 +102,12 @@ func TestActiveStreams_Len(t *testing.T) { streams := newActiveStreams() assert.Equal(t, streams.Len(), 0) - assert.NoError(t, streams.Add(1, testStream(1))) + assert.NoError(t, streams.Add(1, testStream(t, 1))) assert.Equal(t, streams.Len(), 1) - assert.NoError(t, streams.Add(2, testStream(2))) + assert.NoError(t, streams.Add(2, testStream(t, 2))) assert.Equal(t, streams.Len(), 2) streams.Remove(1) assert.Equal(t, streams.Len(), 1) } - diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index c12841c..75df2af 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -30,10 +30,6 @@ var managerClosed = errs.Class("manager closed") // Options controls configuration settings for a manager. type Options struct { - // WriterBufferSize controls the size of the buffer that we will fill before - // flushing. Normal writes to streams typically issue a flush explicitly. - WriterBufferSize int - // Reader are passed to any readers the manager creates. Reader drpcwire.ReaderOptions @@ -55,7 +51,7 @@ type Options struct { // to the appropriate stream. type Manager struct { tr drpc.Transport - wr *drpcwire.Writer + wr *drpcwire.MuxWriter rd *drpcwire.Reader opts Options @@ -114,7 +110,6 @@ func New(tr drpc.Transport, kind ManagerKind) *Manager { func NewWithOptions(tr drpc.Transport, kind ManagerKind, opts Options) *Manager { m := &Manager{ tr: tr, - wr: drpcwire.NewWriter(tr, opts.WriterBufferSize), rd: drpcwire.NewReaderWithOptions(tr, opts.Reader), opts: opts, @@ -122,6 +117,8 @@ func NewWithOptions(tr drpc.Transport, kind ManagerKind, opts Options) *Manager kind: kind, } + m.wr = drpcwire.NewMuxWriter(tr, m.terminate) + // a buffer of size 1 allows NewServerStream to signal it is done creating a // new server stream without having to coordinate with manageReader. m.pdone.Make(1) @@ -148,10 +145,14 @@ func (m *Manager) log(what string, cb func() string) { } // terminate puts the Manager into a terminal state and closes any resources -// that need to be closed to signal the state change. +// that need to be closed to signal the state change. The mux writer is stopped +// before closing the transport so that WriteFrame immediately rejects new +// writes; the subsequent transport close unblocks any in-flight Write in the +// drain goroutine. func (m *Manager) terminate(err error) { if m.sigs.term.Set(err) { m.log("TERM", func() string { return fmt.Sprint(err) }) + m.wr.Stop() m.sigs.tport.Set(m.tr.Close()) if errors.Is(err, io.EOF) { err = context.Canceled @@ -317,8 +318,9 @@ func (m *Manager) Unblocked() <-chan struct{} { func (m *Manager) Close() error { m.terminate(managerClosed.New("Close called")) - m.wg.Wait() // wait for all stream goroutines - m.sigs.read.Wait() + <-m.wr.Done() // wait for writer goroutine to exit + m.wg.Wait() // wait for all stream goroutines + m.sigs.read.Wait() // wait for reader goroutine to exit m.sigs.tport.Wait() return m.sigs.tport.Err() diff --git a/drpcmanager/manager_test.go b/drpcmanager/manager_test.go index a98626d..b7b1f79 100644 --- a/drpcmanager/manager_test.go +++ b/drpcmanager/manager_test.go @@ -48,7 +48,6 @@ func TestDrpcMetadata(t *testing.T) { assert.NoError(t, stream.RawWrite(drpcwire.KindInvokeMetadata, buf)) assert.NoError(t, stream.RawWrite(drpcwire.KindInvoke, []byte("invoke"))) assert.NoError(t, stream.RawWrite(drpcwire.KindMessage, []byte("message"))) - assert.NoError(t, stream.RawFlush()) assert.NoError(t, stream.Close()) }) @@ -105,7 +104,6 @@ func TestDrpcMetadataWithGRPCMetadataCompatMode(t *testing.T) { assert.NoError(t, stream.RawWrite(drpcwire.KindInvokeMetadata, buf)) assert.NoError(t, stream.RawWrite(drpcwire.KindInvoke, []byte("invoke"))) assert.NoError(t, stream.RawWrite(drpcwire.KindMessage, []byte("message"))) - assert.NoError(t, stream.RawFlush()) assert.NoError(t, stream.Close()) }) diff --git a/drpcmanager/random_test.go b/drpcmanager/random_test.go index 7875944..d67a026 100644 --- a/drpcmanager/random_test.go +++ b/drpcmanager/random_test.go @@ -45,14 +45,15 @@ func (rc *randClient) newSteam(ctx context.Context, man *Manager) (*drpcstream.S return stream, err } -func (rc *randClient) execute(t *testing.T, wr *drpcwire.Writer, op byte) { +func (rc *randClient) execute(t *testing.T, wr *drpcwire.MuxWriter, op byte) { cmd, arg, done := parseOp(op) if !rc.active { - assert.NoError(t, wr.WritePacket(drpcwire.Packet{ + assert.NoError(t, wr.WriteFrame(drpcwire.Frame{ Data: make([]byte, arg), ID: rc.id.incMessage(), Kind: drpcwire.KindInvoke, + Done: true, })) rc.active = true } @@ -60,9 +61,10 @@ func (rc *randClient) execute(t *testing.T, wr *drpcwire.Writer, op byte) { switch cmd { case 0: // new invoke if rc.active { - assert.NoError(t, wr.WritePacket(drpcwire.Packet{ + assert.NoError(t, wr.WriteFrame(drpcwire.Frame{ ID: rc.id.incMessage(), Kind: drpcwire.KindClose, + Done: true, })) } @@ -99,10 +101,11 @@ func (rc *randClient) execute(t *testing.T, wr *drpcwire.Writer, op byte) { })) case 2: // cause the remote side to close - assert.NoError(t, wr.WritePacket(drpcwire.Packet{ + assert.NoError(t, wr.WriteFrame(drpcwire.Frame{ Data: []byte("remote-close"), ID: rc.id.incMessage(), Kind: drpcwire.KindMessage, + Done: true, })) case 3, 4, 5, 6, 7: // send normal message @@ -130,7 +133,7 @@ func (rs *randServer) newSteam(ctx context.Context, man *Manager) (*drpcstream.S return man.NewClientStream(ctx, "rpc") } -func (rs *randServer) execute(t *testing.T, wr *drpcwire.Writer, op byte) { +func (rs *randServer) execute(t *testing.T, wr *drpcwire.MuxWriter, op byte) { cmd, arg, done := parseOp(op) switch cmd { @@ -160,10 +163,11 @@ func (rs *randServer) execute(t *testing.T, wr *drpcwire.Writer, op byte) { })) case 2: // cause the remote side to close - assert.NoError(t, wr.WritePacket(drpcwire.Packet{ + assert.NoError(t, wr.WriteFrame(drpcwire.Frame{ Data: []byte("remote-close"), ID: rs.id.incMessage(), Kind: drpcwire.KindMessage, + Done: true, })) case 3, 4, 5, 6, 7: // send random message @@ -185,7 +189,7 @@ func (rs *randServer) execute(t *testing.T, wr *drpcwire.Writer, op byte) { type runner interface { newSteam(ctx context.Context, man *Manager) (*drpcstream.Stream, error) - execute(t *testing.T, wr *drpcwire.Writer, op byte) + execute(t *testing.T, wr *drpcwire.MuxWriter, op byte) } func runRandomized(t *testing.T, prog []byte, r runner) { @@ -196,7 +200,9 @@ func runRandomized(t *testing.T, prog []byte, r runner) { defer func() { _ = pc.Close() }() defer func() { _ = ps.Close() }() - wr := drpcwire.NewWriter(pc, 0) + wr := drpcwire.NewMuxWriter(pc, func(error) {}) + defer func() { wr.Stop(); <-wr.Done() }() + man := New(ps, Server) defer func() { _ = man.Close() }() @@ -223,7 +229,6 @@ func runRandomized(t *testing.T, prog []byte, r runner) { for _, op := range prog { r.execute(t, wr, op) - assert.NoError(t, wr.Flush()) } assert.NoError(t, man.Close()) diff --git a/drpcstream/stream.go b/drpcstream/stream.go index 75687e8..343b482 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -26,12 +26,6 @@ type Options struct { // SplitSize controls the default size we split data packets into frames. SplitSize int - // ManualFlush controls if the stream will automatically flush after every - // message send. Note that flushing is not part of the drpc.Stream - // interface, so if you use this you must be ready to type assert and call - // RawFlush dynamically. - ManualFlush bool - // MaximumBufferSize causes the Stream to drop any internal buffers that are // larger than this amount to control maximum memory usage at the expense of // more allocations. 0 is unlimited. @@ -54,19 +48,18 @@ type Stream struct { // that checkFinished can test whether ops are in flight without blocking. write inspectMutex read inspectMutex - flush sync.Once pa drpcwire.PacketAssembler id drpcwire.ID - wr *drpcwire.Writer + wr *drpcwire.MuxWriter pbuf packetQueue wbuf []byte mu sync.Mutex // protects state transitions sigs struct { - send drpcsignal.Signal // set when done sending messages - recv drpcsignal.Signal // set when done receiving messages + send drpcsignal.Signal // set when done sending messages + recv drpcsignal.Signal // set when done receiving messages // Stream shutdown is two-phase: term then fin. When termination arrives // (remote error, local cancel, close), there may be an in-flight write // on the transport that is past the term check and inside WriteFrame. @@ -85,7 +78,7 @@ var _ drpc.Stream = (*Stream)(nil) // New returns a new stream bound to the context with the given stream id and // will use the writer to write messages on. It is important use monotonically // increasing stream ids within a single transport. -func New(ctx context.Context, sid uint64, wr *drpcwire.Writer) *Stream { +func New(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter) *Stream { return NewWithOptions(ctx, sid, wr, Options{}) } @@ -93,7 +86,7 @@ func New(ctx context.Context, sid uint64, wr *drpcwire.Writer) *Stream { // stream id and will use the writer to write messages on. It is important use // monotonically increasing stream ids within a single transport. The options // are used to control details of how the Stream operates. -func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts Options) *Stream { +func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, opts Options) *Stream { var task *trace.Task if trace.IsEnabled() { kind, rpc := drpcopts.GetStreamKind(&opts.Internal), drpcopts.GetStreamRPC(&opts.Internal) @@ -116,7 +109,6 @@ func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts O pa: pa, id: drpcwire.ID{Stream: sid}, - // TODO: think more deeply on the consequences here. wr: wr, } @@ -204,30 +196,6 @@ func (s *Stream) Finished() <-chan struct{} { return s.sigs.fin.Signal() } // issue any writes or reads. func (s *Stream) IsFinished() bool { return s.sigs.fin.IsSet() } -// SetManualFlush sets the ManualFlush option. It cannot be called concurrently -// with any sends or receives on the stream. Example use case: -// -// flusher := stream.(interface{ -// GetStream() drpc.Stream -// }).GetStream().(interface{ -// SetManualFlush(bool) -// }) -// -// flusher.SetManualFlush(true) -// err = stream.Send(&pb.Message{Request: "hello, "}) -// flusher.SetManualFlush(false) -// if err != nil { -// return err -// } -// -// // the next send will send both messages in the same write -// // to the underlying connection. -// err = stream.Send(&pb.Message{Request: "world!"}) -// if err != nil { -// return err -// } -func (s *Stream) SetManualFlush(mf bool) { s.opts.ManualFlush = mf } - // // frame handler // @@ -346,9 +314,9 @@ func (s *Stream) newFrameLocked(kind drpcwire.Kind) drpcwire.Frame { return drpcwire.Frame{ID: s.id, Kind: kind} } -// sendPacketLocked sends the packet in a single write and flushes. It does not -// check for any conditions to stop it from writing and is meant for internal -// stream use to do things like signal errors or closes to the remote side. +// sendPacketLocked sends the packet in a single write. It does not check for +// any conditions to stop it from writing and is meant for internal stream use +// to do things like signal errors or closes to the remote side. func (s *Stream) sendPacketLocked(kind drpcwire.Kind, control bool, data []byte) (err error) { fr := s.newFrameLocked(kind) fr.Data = data @@ -361,9 +329,6 @@ func (s *Stream) sendPacketLocked(kind drpcwire.Kind, control bool, data []byte) if err := s.wr.WriteFrame(fr); err != nil { return errs.Wrap(err) } - if err := s.wr.Flush(); err != nil { - return errs.Wrap(err) - } return nil } @@ -443,58 +408,8 @@ func (s *Stream) rawWriteLocked(kind drpcwire.Kind, data []byte) (err error) { } } -// RawFlush flushes any buffers of data. -func (s *Stream) RawFlush() (err error) { - defer s.checkFinished() - s.write.Lock() - defer s.write.Unlock() - - return s.rawFlushLocked() -} - -// rawFlushLocked checks for any conditions that should cause a flush to not -// happen and then issues the flush. It assumes the caller is holding the -// appropriate locks. -func (s *Stream) rawFlushLocked() (err error) { - if s.wr.Empty() { - return nil - } - - switch { - case s.sigs.cancel.IsSet(): - return s.sigs.cancel.Err() - case s.sigs.send.IsSet(): - return s.sigs.send.Err() - case s.sigs.term.IsSet(): - return s.sigs.term.Err() - } - - s.log("FLUSH", func() string { return "" }) - - return s.checkCancelError(errs.Wrap(s.wr.Flush())) -} - -func (s *Stream) checkRecvFlush() (err error) { - s.flush.Do(func() { err = s.RawFlush() }) - if err != nil { - return err - } - - if s.opts.ManualFlush && !s.wr.Empty() { - if err := s.RawFlush(); err != nil { - return err - } - } - - return nil -} - // RawRecv returns the raw bytes received for a message. func (s *Stream) RawRecv() (data []byte, err error) { - if err := s.checkRecvFlush(); err != nil { - return nil, err - } - defer s.checkFinished() s.read.Lock() defer s.read.Unlock() @@ -513,12 +428,9 @@ func (s *Stream) RawRecv() (data []byte, err error) { // msg read/write // -// MsgSend marshals the message with the encoding, writes it, and flushes. +// MsgSend marshals the message with the encoding and writes it. func (s *Stream) MsgSend(msg drpc.Message, enc drpc.Encoding) (err error) { defer func() { err = drpc.ToRPCErr(err) }() - - s.flush.Do(func() {}) - defer s.checkFinished() s.write.Lock() defer s.write.Unlock() @@ -533,9 +445,6 @@ func (s *Stream) MsgSend(msg drpc.Message, enc drpc.Encoding) (err error) { if err := s.rawWriteLocked(drpcwire.KindMessage, wbuf); err != nil { return err } - if !s.opts.ManualFlush { - return s.rawFlushLocked() - } return nil } @@ -543,10 +452,6 @@ func (s *Stream) MsgSend(msg drpc.Message, enc drpc.Encoding) (err error) { func (s *Stream) MsgRecv(msg drpc.Message, enc drpc.Encoding) (err error) { defer func() { err = drpc.ToRPCErr(err) }() - if err := s.checkRecvFlush(); err != nil { - return err - } - defer s.checkFinished() s.read.Lock() defer s.read.Unlock() diff --git a/drpcstream/stream_test.go b/drpcstream/stream_test.go index fe12ad9..9406936 100644 --- a/drpcstream/stream_test.go +++ b/drpcstream/stream_test.go @@ -4,7 +4,6 @@ package drpcstream import ( - "bytes" "context" "errors" "io" @@ -19,6 +18,15 @@ import ( "storj.io/drpc/drpcwire" ) +// testMuxWriter creates a MuxWriter that writes to io.Discard with a no-op +// error handler. The writer goroutine is stopped when the test finishes. +func testMuxWriter(t *testing.T) *drpcwire.MuxWriter { + t.Helper() + mw := drpcwire.NewMuxWriter(io.Discard, func(error) {}) + t.Cleanup(mw.StopWait) + return mw +} + // handleFrame is a helper that sends a single-frame packet to the stream. // It constructs a frame with the given kind, matching the stream's ID, // using the provided message ID, done=true. @@ -34,6 +42,7 @@ func TestStream_StateTransitions(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() + mw := testMuxWriter(t) any := errors.New("any sentinel error") checkErrs := func(t *testing.T, exp interface{}, got error) { @@ -108,7 +117,7 @@ func TestStream_StateTransitions(t *testing.T) { } for _, test := range cases { - st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) + st := New(ctx, 1, mw) assert.NoError(t, test.Op(st)) checkErrs(t, test.Send, st.RawWrite(drpcwire.KindMessage, nil)) @@ -125,6 +134,8 @@ func TestStream_Unblocks(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() + mw := testMuxWriter(t) + cases := []struct { Op func(st *Stream) error }{ @@ -158,7 +169,7 @@ func TestStream_Unblocks(t *testing.T) { } for _, test := range cases { - st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) + st := New(ctx, 1, mw) ctx.Run(func(ctx context.Context) { _, _ = st.RawRecv() }) assert.NoError(t, test.Op(st)) @@ -168,7 +179,8 @@ func TestStream_Unblocks(t *testing.T) { func TestStream_ContextCancel(t *testing.T) { ctx := context.Background() - st := New(ctx, 0, drpcwire.NewWriter(io.Discard, 0)) + mw := testMuxWriter(t) + st := New(ctx, 0, mw) child, cancel := context.WithCancel(st.Context()) defer cancel() @@ -182,84 +194,32 @@ func TestStream_ConcurrentCloseCancel(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() - pr, pw := io.Pipe() - defer func() { _ = pr.Close() }() - defer func() { _ = pw.Close() }() + mw := testMuxWriter(t) + st := New(ctx, 0, mw) - st := New(ctx, 0, drpcwire.NewWriter(pw, 0)) - - // start the Close call + // Close and Cancel concurrently should not panic or deadlock. errch := make(chan error, 1) go func() { errch <- st.Close() }() - // wait for the close to begin writing - _, err := pr.Read(make([]byte, 1)) - assert.NoError(t, err) - - // cancel the context and close the transport st.Cancel(context.Canceled) - assert.NoError(t, pw.Close()) - - // we should always receive the canceled error - assert.That(t, errors.Is(<-errch, context.Canceled)) -} - -func TestStream_CorkUntilFirstRead(t *testing.T) { - run := func() { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - var buf bytes.Buffer - st := New(ctx, 0, drpcwire.NewWriter(&buf, 50)) - - // concurrently read and write at the same time. - // we should always see the write happen. - - errch := make(chan error, 3) - ctx.Run(func(ctx context.Context) { - errch <- st.MsgSend([]byte("write"), byteEncoding{}) - }) - ctx.Run(func(ctx context.Context) { - _, err := st.RawRecv() - errch <- err - }) - ctx.Run(func(ctx context.Context) { - errch <- st.HandleFrame(drpcwire.Frame{ - Data: []byte("read"), - ID: drpcwire.ID{Message: 1}, - Kind: drpcwire.KindMessage, - Done: true, - }) - }) - assert.NoError(t, <-errch) - assert.NoError(t, <-errch) - assert.NoError(t, <-errch) - - assert.Equal(t, buf.String(), "\x05\x00\x01\x05write") - } - for i := 0; i < 100; i++ { - run() + // Close returns nil or context.Canceled depending on timing. + err := <-errch + if err != nil { + assert.That(t, errors.Is(err, context.Canceled)) } } -type byteEncoding struct{} - -func (byteEncoding) Marshal(msg drpc.Message) ([]byte, error) { return msg.([]byte), nil } -func (byteEncoding) Unmarshal(buf []byte, msg drpc.Message) error { - *msg.(*[]byte) = append(*msg.(*[]byte), buf...) - return nil -} - func TestStream_PacketBufferReuse(t *testing.T) { run := func() { ctx := drpctest.NewTracker(t) defer ctx.Close() defer ctx.Wait() + mw := testMuxWriter(t) data := make([]byte, 20) mid := uint64(1) - st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) + st := New(ctx, 1, mw) ctx.Run(func(ctx context.Context) { for !st.IsTerminated() { @@ -298,18 +258,14 @@ func TestStream_PacketBufferReuse(t *testing.T) { } } - // // HandleFrame tests // func TestHandleFrame_FirstFrameOnFreshStream(t *testing.T) { - // On the client side, the first message received will have ID 1. But on the - // server side, invoke is consumed by the manager. The first frame reaching - // the stream could have msg > 1 (e.g., msg=2). nextMessageID=1, so 2 > 1 - // makes this a valid frame. + mw := testMuxWriter(t) for _, messageID := range []uint64{1, 2} { - st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + st := New(context.Background(), 1, mw) // Close the packet buffer so KindMessage Put doesn't block. st.pbuf.Close(io.EOF) err := st.HandleFrame(drpcwire.Frame{ @@ -321,7 +277,8 @@ func TestHandleFrame_FirstFrameOnFreshStream(t *testing.T) { // Invoke and InvokeMetadata frames are rejected on an already-created stream. func TestHandleFrame_InvokeOnExistingStream(t *testing.T) { - st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + mw := testMuxWriter(t) + st := New(context.Background(), 1, mw) err := handleFrame(st, drpcwire.KindInvoke, 1) assert.Error(t, err) @@ -330,7 +287,8 @@ func TestHandleFrame_InvokeOnExistingStream(t *testing.T) { } func TestHandleFrame_InvokeMetadataOnExistingStream(t *testing.T) { - st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + mw := testMuxWriter(t) + st := New(context.Background(), 1, mw) err := handleFrame(st, drpcwire.KindInvokeMetadata, 1) assert.Error(t, err) @@ -340,7 +298,8 @@ func TestHandleFrame_InvokeMetadataOnExistingStream(t *testing.T) { // Frames arriving after the stream is terminated are silently ignored. func TestHandleFrame_AfterTerminated(t *testing.T) { - st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + mw := testMuxWriter(t) + st := New(context.Background(), 1, mw) // Terminate the stream via cancel. st.Cancel(context.Canceled) @@ -357,7 +316,8 @@ func TestHandleFrame_MessageDeliveredViaRecv(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() - st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) + mw := testMuxWriter(t) + st := New(ctx, 1, mw) // Launch receiver before sending to avoid Put blocking. recv := make(chan []byte, 1) @@ -376,74 +336,3 @@ func TestHandleFrame_MessageDeliveredViaRecv(t *testing.T) { assert.DeepEqual(t, <-recv, []byte("payload")) } - -// -// Write-side tests -// - -func TestRawWrite_NonMessageSingleFrame(t *testing.T) { - // Non-KindMessage kinds must produce a single frame (n=0 in - // rawWriteLocked means default 64KB, effectively no split for - // small payloads). Verify they produce exactly one frame with Done=true. - kinds := []drpcwire.Kind{ - drpcwire.KindInvoke, - drpcwire.KindError, - drpcwire.KindCancel, - drpcwire.KindClose, - drpcwire.KindCloseSend, - drpcwire.KindInvokeMetadata, - } - - for _, kind := range kinds { - var buf bytes.Buffer - st := New(context.Background(), 1, drpcwire.NewWriter(&buf, 0)) - - assert.NoError(t, st.RawWrite(kind, []byte("data"))) - assert.NoError(t, st.RawFlush()) - var err error - - // Parse all frames from the buffer — should be exactly one. - data := buf.Bytes() - var frames []drpcwire.Frame - for len(data) > 0 { - var fr drpcwire.Frame - var ok bool - data, fr, ok, err = drpcwire.ParseFrame(data) - assert.NoError(t, err) - assert.That(t, ok) - frames = append(frames, fr) - } - assert.Equal(t, len(frames), 1) - assert.That(t, frames[0].Done) - assert.Equal(t, frames[0].Kind, kind) - } -} - -func TestRawWrite_MessageRespectsSplitSize(t *testing.T) { - var buf bytes.Buffer - st := NewWithOptions(context.Background(), 1, - drpcwire.NewWriter(&buf, 0), - Options{SplitSize: 5}, - ) - - // "helloworld" is 10 bytes, split at 5 → 2 frames. - assert.NoError(t, st.RawWrite(drpcwire.KindMessage, []byte("helloworld"))) - assert.NoError(t, st.RawFlush()) - var err error - - data := buf.Bytes() - var frames []drpcwire.Frame - for len(data) > 0 { - var fr drpcwire.Frame - var ok bool - data, fr, ok, err = drpcwire.ParseFrame(data) - assert.NoError(t, err) - assert.That(t, ok) - frames = append(frames, fr) - } - assert.Equal(t, len(frames), 2) - assert.That(t, !frames[0].Done) - assert.That(t, frames[1].Done) - assert.DeepEqual(t, frames[0].Data, []byte("hello")) - assert.DeepEqual(t, frames[1].Data, []byte("world")) -} diff --git a/drpcwire/mux_writer.go b/drpcwire/mux_writer.go new file mode 100644 index 0000000..5582023 --- /dev/null +++ b/drpcwire/mux_writer.go @@ -0,0 +1,101 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcwire + +import ( + "io" + "sync" + + "storj.io/drpc" +) + +type MuxWriter struct { + w io.Writer + buf []byte + mu sync.Mutex + cond *sync.Cond + closed bool + onError func(error) + done chan struct{} +} + +var defaultBufferCapacity = 4096 + +func NewMuxWriter(w io.Writer, onError func(error)) *MuxWriter { + mw := &MuxWriter{ + w: w, + buf: make([]byte, 0, defaultBufferCapacity), + onError: onError, + done: make(chan struct{}), + } + mw.cond = sync.NewCond(&mw.mu) + go mw.run() + return mw +} + +func (mw *MuxWriter) run() { + defer close(mw.done) + spare := make([]byte, 0, defaultBufferCapacity) + for { + mw.mu.Lock() + for len(mw.buf) == 0 && !mw.closed { + mw.cond.Wait() + } + + if mw.closed { + mw.mu.Unlock() + return + } + + temp := mw.buf + mw.buf = spare + spare = temp + mw.mu.Unlock() + if _, err := mw.w.Write(spare); err != nil { + mw.mu.Lock() + if mw.closed { + mw.mu.Unlock() + return + } + mw.closed = true + mw.mu.Unlock() + if mw.onError != nil { + mw.onError(err) + } + return + } + + spare = spare[:0] + } +} + +func (mw *MuxWriter) WriteFrame(fr Frame) (err error) { + mw.mu.Lock() + defer mw.mu.Unlock() + if mw.closed { + return drpc.ClosedError.New("mux writer closed") + } + mw.buf = AppendFrame(mw.buf, fr) + mw.cond.Signal() + return nil +} + +func (mw *MuxWriter) Stop() { + mw.mu.Lock() + if mw.closed { + } else { + mw.closed = true + mw.cond.Broadcast() + } + mw.mu.Unlock() +} + +func (mw *MuxWriter) Done() <-chan struct{} { + return mw.done +} + +func (mw *MuxWriter) StopWait() { + mw.Stop() + <-mw.Done() +} diff --git a/drpcwire/mux_writer_test.go b/drpcwire/mux_writer_test.go new file mode 100644 index 0000000..093a255 --- /dev/null +++ b/drpcwire/mux_writer_test.go @@ -0,0 +1,332 @@ +// Copyright (C) 2021 Storj Labs, Inc. +// See LICENSE for copying information. + +package drpcwire + +import ( + "bytes" + "errors" + "io" + "sync" + "testing" + "time" + + "github.com/zeebo/assert" + + "storj.io/drpc" +) + +// blockingWriter blocks in Write until unblock is closed, then returns err. +type blockingWriter struct { + unblock chan struct{} + err error // error to return once unblocked + wrote chan []byte // sends a copy of data on each Write entry +} + +func newBlockingWriter() *blockingWriter { + return &blockingWriter{ + unblock: make(chan struct{}), + wrote: make(chan []byte, 10), + } +} + +func (w *blockingWriter) Write(p []byte) (int, error) { + cp := make([]byte, len(p)) + copy(cp, p) + w.wrote <- cp + <-w.unblock + if w.err != nil { + return 0, w.err + } + return len(p), nil +} + +// failWriter returns err on the nth call to Write (1-indexed). Calls before +// that succeed normally. +type failWriter struct { + n int + count int + err error + buf bytes.Buffer +} + +func newFailWriter(n int, err error) *failWriter { + return &failWriter{n: n, err: err} +} + +func (w *failWriter) Write(p []byte) (int, error) { + w.count++ + if w.count >= w.n { + return 0, w.err + } + return w.buf.Write(p) +} + +func TestMuxWriter(t *testing.T) { + var exp []byte + pr, pw := io.Pipe() + mw := NewMuxWriter(pw, func(error) {}) + + for range 1000 { + fr := RandFrame() + exp = AppendFrame(exp, fr) + assert.NoError(t, mw.WriteFrame(fr)) + } + + // Read exactly len(exp) bytes: this blocks until MuxWriter has drained + // all frames through the pipe. + got := make([]byte, len(exp)) + _, err := io.ReadFull(pr, got) + assert.NoError(t, err) + + // Now stop the writer and close the pipe. + mw.Stop() + <-mw.Done() + pw.Close() + pr.Close() + + assert.That(t, bytes.Equal(exp, got)) +} + +func TestMuxWriter_WriteFrameAfterStop(t *testing.T) { + mw := NewMuxWriter(io.Discard, func(error) {}) + mw.Stop() + <-mw.Done() + + err := mw.WriteFrame(RandFrame()) + assert.Error(t, err) + assert.That(t, drpc.ClosedError.Has(err)) +} + +func TestMuxWriter_ConcurrentWriteFrame(t *testing.T) { + pr, pw := io.Pipe() + mw := NewMuxWriter(pw, func(error) {}) + + const numWriters = 10 + const framesPerWriter = 100 + + // Pre-generate frames and compute total expected bytes so we can use + // io.ReadFull to block until everything has drained (Stop has abort + // semantics, so we can't rely on it to drain). + allFrames := make([][]Frame, numWriters) + var expSize int + for i := range numWriters { + allFrames[i] = make([]Frame, framesPerWriter) + for j := range framesPerWriter { + fr := Frame{ + Data: []byte{byte(j)}, + ID: ID{Stream: uint64(i + 1), Message: uint64(j + 1)}, + Kind: KindMessage, + Done: true, + } + allFrames[i][j] = fr + expSize += len(AppendFrame(nil, fr)) + } + } + + var wg sync.WaitGroup + for i := range numWriters { + wg.Add(1) + go func() { + defer wg.Done() + for j := range framesPerWriter { + assert.NoError(t, mw.WriteFrame(allFrames[i][j])) + } + }() + } + + wg.Wait() + + // Block until all bytes have been drained through the pipe. + got := make([]byte, expSize) + _, err := io.ReadFull(pr, got) + assert.NoError(t, err) + mw.Stop() + <-mw.Done() + pw.Close() + pr.Close() + + // Parse received bytes and count frames. + count := 0 + for len(got) > 0 { + rem, _, ok, err := ParseFrame(got) + assert.NoError(t, err) + assert.That(t, ok) + got = rem + count++ + } + assert.Equal(t, count, numWriters*framesPerWriter) +} + +func TestMuxWriter_WriteErrorCallsOnError(t *testing.T) { + writeErr := errors.New("disk full") + fw := newFailWriter(1, writeErr) + + gotErr := make(chan error, 1) + mw := NewMuxWriter(fw, func(err error) { gotErr <- err }) + + assert.NoError(t, mw.WriteFrame(RandFrame())) + + select { + case err := <-gotErr: + assert.Equal(t, err, writeErr) + case <-time.After(5 * time.Second): + t.Fatal("onError not called") + } + + select { + case <-mw.Done(): + case <-time.After(5 * time.Second): + t.Fatal("Done did not return") + } +} + +// Tests the critical deadlock path from the design doc: +// run() → Write fails → sets closed → onError → Stop() → noop → run() returns. +func TestMuxWriter_OnErrorCallingStopDoesNotDeadlock(t *testing.T) { + writeErr := errors.New("broken pipe") + fw := newFailWriter(1, writeErr) + + var mw *MuxWriter + mw = NewMuxWriter(fw, func(err error) { + // Simulate manager.terminate calling Stop. + mw.Stop() + }) + + assert.NoError(t, mw.WriteFrame(RandFrame())) + + select { + case <-mw.Done(): + case <-time.After(5 * time.Second): + t.Fatal("deadlock: Done did not return") + } +} + +// Tests the manager's two-phase shutdown: close transport to unblock a blocked +// Write, then Stop signals the goroutine to exit. +func TestMuxWriter_BlockedWriteUnblockedByClose(t *testing.T) { + bw := newBlockingWriter() + mw := NewMuxWriter(bw, func(error) {}) + + assert.NoError(t, mw.WriteFrame(RandFrame())) + + // Wait for run() to enter Write. + select { + case <-bw.wrote: + case <-time.After(5 * time.Second): + t.Fatal("run() did not enter Write") + } + + // Simulate terminate: Stop, then unblock the writer (like tr.Close()). + mw.Stop() + bw.err = errors.New("closed") + close(bw.unblock) + + select { + case <-mw.Done(): + case <-time.After(5 * time.Second): + t.Fatal("deadlock: Done did not return") + } +} + +func TestMuxWriter_ConcurrentStop(t *testing.T) { + mw := NewMuxWriter(io.Discard, func(error) {}) + + // Write a frame so the goroutine has work. + assert.NoError(t, mw.WriteFrame(RandFrame())) + + const n = 20 + var wg sync.WaitGroup + wg.Add(n) + for range n { + go func() { + defer wg.Done() + mw.Stop() + }() + } + wg.Wait() + + select { + case <-mw.Done(): + case <-time.After(5 * time.Second): + t.Fatal("Done did not return") + } +} + +// Stop has abort semantics: buffered data is discarded, not drained. +func TestMuxWriter_StopDiscardsBufferedData(t *testing.T) { + bw := newBlockingWriter() + mw := NewMuxWriter(bw, func(error) {}) + + // Write several frames while the writer is blocked on the first Write. + for range 10 { + assert.NoError(t, mw.WriteFrame(RandFrame())) + } + + // Wait for run() to enter Write with the first batch. + select { + case <-bw.wrote: + case <-time.After(5 * time.Second): + t.Fatal("run() did not enter Write") + } + + // More frames accumulate in buf while Write is blocked. + for range 10 { + assert.NoError(t, mw.WriteFrame(RandFrame())) + } + + // Stop without letting the blocked Write complete. + mw.Stop() + bw.err = errors.New("closed") + close(bw.unblock) + + select { + case <-mw.Done(): + case <-time.After(5 * time.Second): + t.Fatal("Done did not return") + } + + // Only the first batch was written; the rest were discarded by Stop. + assert.Equal(t, len(bw.wrote), 0) // no more writes after the first +} + +func TestMuxWriter_WriteFrameDuringActiveDrain(t *testing.T) { + // gatedWriter lets us control when each Write completes. + type gate struct{ ch chan struct{} } + gates := make(chan gate, 10) + + gw := writerFunc(func(p []byte) (int, error) { + g := gate{ch: make(chan struct{})} + gates <- g + <-g.ch + return len(p), nil + }) + + mw := NewMuxWriter(gw, func(error) {}) + + // Batch 1: write a frame, wait for run() to pick it up and block in Write. + fr1 := Frame{Data: []byte("batch1"), ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Done: true} + assert.NoError(t, mw.WriteFrame(fr1)) + + g1 := <-gates // run() is now blocked in Write for batch 1 + + // Batch 2: write another frame while batch 1 is still draining. + fr2 := Frame{Data: []byte("batch2"), ID: ID{Stream: 1, Message: 2}, Kind: KindMessage, Done: true} + assert.NoError(t, mw.WriteFrame(fr2)) + + // Complete batch 1 write. + close(g1.ch) + + // run() loops, picks up batch 2, enters Write again. + g2 := <-gates + close(g2.ch) + + // Both batches were written. Stop and verify. + mw.Stop() + <-mw.Done() +} + +// writerFunc adapts a function to io.Writer. +type writerFunc func([]byte) (int, error) + +func (f writerFunc) Write(p []byte) (int, error) { return f(p) } diff --git a/drpcwire/writer.go b/drpcwire/writer.go deleted file mode 100644 index fe909cf..0000000 --- a/drpcwire/writer.go +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright (C) 2019 Storj Labs, Inc. -// See LICENSE for copying information. - -package drpcwire - -import ( - "fmt" - "io" - "sync" - "sync/atomic" - - "storj.io/drpc/drpcdebug" -) - -// -// Writer -// - -// Writer is a helper to buffer and write packets and frames to an io.Writer. -type Writer struct { - empty uint32 - w io.Writer - size int - mu sync.Mutex - buf []byte -} - -// NewWriter returns a Writer that will attempt to buffer size data before -// sending it to the io.Writer. -func NewWriter(w io.Writer, size int) *Writer { - if size == 0 { - size = 4 * 1024 - } - - return &Writer{ - w: w, - size: size, - buf: make([]byte, 0, size), - } -} - -func (b *Writer) log(what string, cb func() string) { - if drpcdebug.Enabled { - drpcdebug.Log(func() (_, _, _ string) { return fmt.Sprintf("", b), what, cb() }) - } -} - -// WritePacket writes the packet as a single frame, ignoring any size -// constraints. -func (b *Writer) WritePacket(pkt Packet) (err error) { - return b.WriteFrame(Frame{ - Data: pkt.Data, - ID: pkt.ID, - Kind: pkt.Kind, - Control: pkt.Control, - Done: true, - }) -} - -// Empty returns true if there are no bytes buffered in the writer. -func (b *Writer) Empty() bool { - return atomic.LoadUint32(&b.empty) == 0 -} - -// Reset clears any pending data in the buffer. -func (b *Writer) Reset() *Writer { - b.mu.Lock() - defer b.mu.Unlock() - - b.buf = b.buf[:0] - atomic.StoreUint32(&b.empty, 0) - return b -} - -// WriteFrame appends the frame into the buffer, and if the buffer is larger -// than the configured size, flushes it. -func (b *Writer) WriteFrame(fr Frame) (err error) { - b.mu.Lock() - defer b.mu.Unlock() - - if len(b.buf) == 0 { - atomic.StoreUint32(&b.empty, 1) - } - b.buf = AppendFrame(b.buf, fr) - if len(b.buf) >= b.size { - b.log("FLUSH", func() string { return fmt.Sprintf("buffer: %d > %d", len(b.buf), b.size) }) - _, err = b.w.Write(b.buf) - b.buf = b.buf[:0] - atomic.StoreUint32(&b.empty, 0) - } - return err -} - -// Flush forces a flush of any buffered data to the io.Writer. It is a no-op if -// there is no data in the buffer. -func (b *Writer) Flush() (err error) { - b.mu.Lock() - defer b.mu.Unlock() - - if len(b.buf) > 0 { - _, err = b.w.Write(b.buf) - b.log("FLUSH", func() string { return fmt.Sprintf("explicit: %d", len(b.buf)) }) - b.buf = b.buf[:0] - atomic.StoreUint32(&b.empty, 0) - } - return err -} diff --git a/drpcwire/writer_test.go b/drpcwire/writer_test.go deleted file mode 100644 index 840a616..0000000 --- a/drpcwire/writer_test.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (C) 2021 Storj Labs, Inc. -// See LICENSE for copying information. - -package drpcwire - -import ( - "bytes" - "testing" - - "github.com/zeebo/assert" -) - -func TestWriter(t *testing.T) { - run := func(size int) func(t *testing.T) { - return func(t *testing.T) { - var exp []byte - var got bytes.Buffer - - wr := NewWriter(&got, size) - for i := 0; i < 1000; i++ { - fr := RandFrame() - exp = AppendFrame(exp, fr) - assert.NoError(t, wr.WriteFrame(fr)) - } - assert.NoError(t, wr.Flush()) - assert.That(t, bytes.Equal(exp, got.Bytes())) - } - } - - t.Run("Size 0B", run(0)) - t.Run("Size 1MB", run(1024*1024)) -} diff --git a/internal/grpccompat/benchmark_test.go b/internal/grpccompat/benchmark_test.go index 3bd0af7..c59f59a 100644 --- a/internal/grpccompat/benchmark_test.go +++ b/internal/grpccompat/benchmark_test.go @@ -13,7 +13,6 @@ import ( "google.golang.org/protobuf/proto" "storj.io/drpc/drpcmanager" - "storj.io/drpc/drpcstream" ) var benchmarkImpl = &serviceImpl{ @@ -65,11 +64,7 @@ var benchmarkImpl = &serviceImpl{ } func benchmarkBoth(b *testing.B, fn func(b *testing.B, in *In, client Client)) { - options := drpcmanager.Options{ - Stream: drpcstream.Options{ - ManualFlush: true, - }, - } + options := drpcmanager.Options{} for _, size := range []struct { Name string diff --git a/internal/grpccompat/common_test.go b/internal/grpccompat/common_test.go index 8cc5c77..65bb2c9 100644 --- a/internal/grpccompat/common_test.go +++ b/internal/grpccompat/common_test.go @@ -16,7 +16,6 @@ import ( "testing" "time" - "github.com/zeebo/assert" "github.com/zeebo/errs" "google.golang.org/grpc" "google.golang.org/grpc/codes" diff --git a/internal/integration/cancel_test.go b/internal/integration/cancel_test.go index 03a3860..dec012d 100644 --- a/internal/integration/cancel_test.go +++ b/internal/integration/cancel_test.go @@ -15,7 +15,6 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "storj.io/drpc" "storj.io/drpc/drpcconn" "storj.io/drpc/drpcpool" "storj.io/drpc/drpctest" @@ -150,15 +149,6 @@ func TestCancellationPropagation_Stream(t *testing.T) { clientctx.Run(func(ctx context.Context) { stream, _ := cli.Method4(ctx) - // this is a weird case where the rpc does not send or receive or even - // close the stream, and neither does the other side, and so we have to - // explicitly flush the invoke. - type ( - getStreamer interface{ GetStream() drpc.Stream } - rawFlusher interface{ RawFlush() error } - ) - _ = stream.(getStreamer).GetStream().(rawFlusher).RawFlush() - called <- struct{}{} select { case <-stream.Context().Done(): From 36292f16e127a044f027d4a458ac61cf255f3ff4 Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Sun, 12 Apr 2026 15:21:35 +0530 Subject: [PATCH 11/15] drpcmanager: fix context cancellation error for unary RPCs Fix a race condition where a unary RPC with a cancelled context could return io.EOF instead of codes.Canceled. Two changes, mirroring how gRPC handles this: 1. Early ctx.Err() check in NewClientStream before creating the stream. 2. Deferred stream.CheckCancelError in doInvoke to convert io.EOF to the cancel error if the stream was cancelled mid-operation. The problem: With multiplexing, each stream gets a manageStream goroutine that watches ctx.Done() and calls SendCancel + Cancel when the context is cancelled. This races with doInvoke, which writes invoke and message frames through the same stream. The race has three outcomes depending on who acquires the stream's write lock first: 1. doInvoke wins the lock, completes all writes, and the RPC succeeds even though it should have been cancelled. 2. SendCancel wins, sets send=io.EOF before doInvoke runs. rawWriteLocked sees send.IsSet() and returns io.EOF. Invoke's ToRPCErr passes io.EOF through unchanged, so the caller gets the wrong error code. 3. doInvoke finishes writes, then MsgRecv sees the cancellation and returns codes.Canceled. This is the correct outcome but only happens by luck of timing. Why this didn't happen before multiplexing: The old single-stream manager used a non-blocking SendCancel that returned (busy=true) when the write lock was held by an in-progress write. With SoftCancel=false (the default), the fallback path was: manageStream calls stream.Cancel(ctx.Err()). The stream is not finished because doInvoke holds the write lock, so the manager calls m.terminate(), which closes the entire transport. The in-flight Writer.Write() fails with an IO error, and checkCancelError sees cancel.IsSet() and returns context.Canceled. The correct error surfaced, but through connection termination. This was fine in single-stream mode where one stream is one connection. With multiplexing, we cannot terminate the entire connection for one stream's cancellation. The new SendCancel blocks on the write lock to guarantee the cancel frame is sent, and that introduced this race. How gRPC handles this (verified against grpc-go source): gRPC uses two mechanisms. First, newAttemptLocked (stream.go:408) checks cs.ctx.Err() before creating the transport stream. This catches the already-cancelled case without allocating resources. Second, for unary RPCs, csAttempt.sendMsg (stream.go:1092) swallows write errors and returns nil when !cs.desc.ClientStreams. The real error always surfaces from RecvMsg, which detects context cancellation via recvBufferReader.readClient (transport.go:239) and returns status.Error(codes.Canceled, ...). This means gRPC never returns io.EOF from a unary RPC because it never short-circuits on a send error. For streaming RPCs, gRPC returns io.EOF from Send() after cancel (the stream is done for writing) and codes.Canceled from Recv() (the actual reason). Our grpccompat tests confirm this by comparing gRPC and DRPC error results for identical cancel scenarios. Our fix: Rather than restructuring doInvoke to swallow send errors like gRPC, we use the stream's existing CheckCancelError mechanism. NewClientStream checks ctx.Err() before creating the stream. This mirrors gRPC's newAttemptLocked check and avoids wasting a stream ID, spawning a goroutine, and allocating stream resources. doInvoke defers stream.CheckCancelError on its return value. If any operation in doInvoke fails because SendCancel won the write lock race (returning io.EOF via the send signal), CheckCancelError replaces it with the cancel signal's error (context.Canceled). This is the same function the stream already uses internally for transport write failures. CheckCancelError is exported (was checkCancelError) so that doInvoke in the drpcconn package can call it. On TOCTOU: The NewClientStream check is technically TOCTOU: the context could be cancelled immediately after the check passes. This is acceptable because Go's context cancellation model is cooperative, not preemptive. The context package provides Done() "for use in select statements," and operations check at natural boundaries rather than continuously. The standard library follows this pattern: http.Client.Do checks between redirect hops, database/sql checks before query execution, and gRPC checks in newAttemptLocked before creating the transport stream. If the context is cancelled mid-operation, manageStream handles cleanup and the deferred CheckCancelError corrects the error code. --- drpcconn/conn.go | 1 + drpcmanager/manager.go | 3 +++ drpcstream/stream.go | 14 +++++++------- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/drpcconn/conn.go b/drpcconn/conn.go index b8bbf49..00b5303 100644 --- a/drpcconn/conn.go +++ b/drpcconn/conn.go @@ -89,6 +89,7 @@ func (c *Conn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, ou } func (c *Conn) doInvoke(stream *drpcstream.Stream, enc drpc.Encoding, rpc string, data []byte, metadata []byte, out drpc.Message) (err error) { + defer func() { err = stream.CheckCancelError(err) }() if err := stream.WriteInvoke(rpc, metadata); err != nil { return err } diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index 75df2af..ed716e0 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -328,6 +328,9 @@ func (m *Manager) Close() error { // NewClientStream starts a stream on the managed transport for use by a client. func (m *Manager) NewClientStream(ctx context.Context, rpc string) (stream *drpcstream.Stream, err error) { + if err := ctx.Err(); err != nil { + return nil, err + } return m.newStream(ctx, m.lastStreamID.Add(1), drpc.StreamKindClient, rpc) } diff --git a/drpcstream/stream.go b/drpcstream/stream.go index 343b482..4f4a214 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -297,10 +297,10 @@ func (s *Stream) checkFinished() { } } -// checkCancelError will replace the error with one from the cancel signal if it +// CheckCancelError will replace the error with one from the cancel signal if it // is set. This is to prevent errors from reads/writes to a transport after it // has been asynchronously closed due to context cancelation. -func (s *Stream) checkCancelError(err error) error { +func (s *Stream) CheckCancelError(err error) error { if s.sigs.cancel.IsSet() { return s.sigs.cancel.Err() } @@ -401,7 +401,7 @@ func (s *Stream) rawWriteLocked(kind drpcwire.Kind, data []byte) (err error) { s.log("SEND", fr.String) if err := s.wr.WriteFrame(fr); err != nil { - return s.checkCancelError(errs.Wrap(err)) + return s.CheckCancelError(errs.Wrap(err)) } else if fr.Done { return nil } @@ -496,7 +496,7 @@ func (s *Stream) SendError(serr error) (err error) { s.terminate(termError) s.mu.Unlock() - return s.checkCancelError(s.sendPacketLocked(drpcwire.KindError, false, drpcwire.MarshalError(serr))) + return s.CheckCancelError(s.sendPacketLocked(drpcwire.KindError, false, drpcwire.MarshalError(serr))) } // SendCancel terminates the stream and sends a cancel to the remote side. It @@ -519,7 +519,7 @@ func (s *Stream) SendCancel(err error) error { s.terminate(err) s.mu.Unlock() - return s.checkCancelError(s.sendPacketLocked(drpcwire.KindCancel, true, nil)) + return s.CheckCancelError(s.sendPacketLocked(drpcwire.KindCancel, true, nil)) } // Close terminates the stream and sends that the stream has been closed to the @@ -540,7 +540,7 @@ func (s *Stream) Close() (err error) { s.terminate(termClosed) s.mu.Unlock() - return s.checkCancelError(s.sendPacketLocked(drpcwire.KindClose, false, nil)) + return s.CheckCancelError(s.sendPacketLocked(drpcwire.KindClose, false, nil)) } // CloseSend informs the remote that no more messages will be sent. If the remote has @@ -563,7 +563,7 @@ func (s *Stream) CloseSend() (err error) { s.terminateIfBothClosed() s.mu.Unlock() - return s.checkCancelError(s.sendPacketLocked(drpcwire.KindCloseSend, false, nil)) + return s.CheckCancelError(s.sendPacketLocked(drpcwire.KindCloseSend, false, nil)) } // Cancel transitions the stream into a state where all writes to the transport will return From 1e27cbc110377eb54ad2baa86ea3f839e8d07e90 Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Wed, 15 Apr 2026 15:02:32 +0530 Subject: [PATCH 12/15] fixup! enable stream multiplexing --- drpcconn/conn_test.go | 4 ++-- drpcmanager/active_streams.go | 10 ++++++---- drpcmanager/active_streams_test.go | 2 +- drpcmanager/manager.go | 6 +++--- drpcmanager/random_test.go | 2 +- drpcstream/stream_test.go | 2 +- drpcwire/mux_writer.go | 19 +++++++------------ drpcwire/mux_writer_test.go | 20 +++++++++----------- 8 files changed, 30 insertions(+), 35 deletions(-) diff --git a/drpcconn/conn_test.go b/drpcconn/conn_test.go index f16c8a3..aadda12 100644 --- a/drpcconn/conn_test.go +++ b/drpcconn/conn_test.go @@ -41,7 +41,7 @@ func TestConn_InvokeFlushesSendClose(t *testing.T) { ctx.Run(func(ctx context.Context) { wr := drpcwire.NewMuxWriter(ps, nil) - defer wr.StopWait() + defer func() { wr.Stop(nil); <-wr.Done() }() rd := drpcwire.NewReader(ps) _, _ = rd.ReadFrame() // Invoke @@ -97,7 +97,7 @@ func TestConn_InvokeSendsGrpcAndDrpcMetadata(t *testing.T) { ctx.Run(func(ctx context.Context) { wr := drpcwire.NewMuxWriter(ps, nil) - defer wr.StopWait() + defer func() { wr.Stop(nil); <-wr.Done() }() rd := drpcwire.NewReader(ps) md, err := rd.ReadFrame() // Metadata diff --git a/drpcmanager/active_streams.go b/drpcmanager/active_streams.go index 86f5548..4ee7dc1 100644 --- a/drpcmanager/active_streams.go +++ b/drpcmanager/active_streams.go @@ -12,9 +12,10 @@ import ( // activeStreams is a thread-safe map of stream IDs to stream objects. // It is used by the Manager to track active streams for lifecycle management. type activeStreams struct { - mu sync.RWMutex - streams map[uint64]*drpcstream.Stream - closed bool + mu sync.RWMutex + streams map[uint64]*drpcstream.Stream + closed bool + closeErr error } func newActiveStreams() *activeStreams { @@ -34,7 +35,7 @@ func (r *activeStreams) Add(id uint64, stream *drpcstream.Stream) error { defer r.mu.Unlock() if r.closed { - return managerClosed.New("add to closed collection") + return r.closeErr } if _, ok := r.streams[id]; ok { return managerClosed.New("duplicate stream id") @@ -73,6 +74,7 @@ func (r *activeStreams) Close(err error) { defer r.mu.Unlock() r.closed = true + r.closeErr = err for id, s := range r.streams { s.Cancel(err) delete(r.streams, id) diff --git a/drpcmanager/active_streams_test.go b/drpcmanager/active_streams_test.go index 842acb7..f463b18 100644 --- a/drpcmanager/active_streams_test.go +++ b/drpcmanager/active_streams_test.go @@ -17,7 +17,7 @@ import ( func testMuxWriter(t *testing.T) *drpcwire.MuxWriter { mw := drpcwire.NewMuxWriter(io.Discard, func(error) {}) - t.Cleanup(func() { mw.Stop(); <-mw.Done() }) + t.Cleanup(func() { mw.Stop(nil); <-mw.Done() }) return mw } diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index ed716e0..2f13d88 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -152,14 +152,14 @@ func (m *Manager) log(what string, cb func() string) { func (m *Manager) terminate(err error) { if m.sigs.term.Set(err) { m.log("TERM", func() string { return fmt.Sprint(err) }) - m.wr.Stop() - m.sigs.tport.Set(m.tr.Close()) if errors.Is(err, io.EOF) { err = context.Canceled if m.kind == Client { err = drpc.ClosedError.New("connection closed") } } + m.wr.Stop(err) + m.sigs.tport.Set(m.tr.Close()) m.streams.Close(err) } } @@ -191,7 +191,7 @@ func (m *Manager) manageReader() { switch { // if the packet is for an active stream, deliver it. - case ok && stream != nil: + case ok: if err := stream.HandleFrame(incomingFrame); err != nil { m.terminate(managerClosed.Wrap(err)) return diff --git a/drpcmanager/random_test.go b/drpcmanager/random_test.go index d67a026..af63bf4 100644 --- a/drpcmanager/random_test.go +++ b/drpcmanager/random_test.go @@ -201,7 +201,7 @@ func runRandomized(t *testing.T, prog []byte, r runner) { defer func() { _ = ps.Close() }() wr := drpcwire.NewMuxWriter(pc, func(error) {}) - defer func() { wr.Stop(); <-wr.Done() }() + defer func() { wr.Stop(nil); <-wr.Done() }() man := New(ps, Server) defer func() { _ = man.Close() }() diff --git a/drpcstream/stream_test.go b/drpcstream/stream_test.go index 9406936..cb087e3 100644 --- a/drpcstream/stream_test.go +++ b/drpcstream/stream_test.go @@ -23,7 +23,7 @@ import ( func testMuxWriter(t *testing.T) *drpcwire.MuxWriter { t.Helper() mw := drpcwire.NewMuxWriter(io.Discard, func(error) {}) - t.Cleanup(mw.StopWait) + t.Cleanup(func() { mw.Stop(nil); <-mw.Done() }) return mw } diff --git a/drpcwire/mux_writer.go b/drpcwire/mux_writer.go index 5582023..c12bc56 100644 --- a/drpcwire/mux_writer.go +++ b/drpcwire/mux_writer.go @@ -6,8 +6,6 @@ package drpcwire import ( "io" "sync" - - "storj.io/drpc" ) type MuxWriter struct { @@ -15,7 +13,8 @@ type MuxWriter struct { buf []byte mu sync.Mutex cond *sync.Cond - closed bool + closed bool + closeErr error onError func(error) done chan struct{} } @@ -59,6 +58,7 @@ func (mw *MuxWriter) run() { return } mw.closed = true + mw.closeErr = err mw.mu.Unlock() if mw.onError != nil { mw.onError(err) @@ -74,18 +74,18 @@ func (mw *MuxWriter) WriteFrame(fr Frame) (err error) { mw.mu.Lock() defer mw.mu.Unlock() if mw.closed { - return drpc.ClosedError.New("mux writer closed") + return mw.closeErr } mw.buf = AppendFrame(mw.buf, fr) mw.cond.Signal() return nil } -func (mw *MuxWriter) Stop() { +func (mw *MuxWriter) Stop(err error) { mw.mu.Lock() - if mw.closed { - } else { + if !mw.closed { mw.closed = true + mw.closeErr = err mw.cond.Broadcast() } mw.mu.Unlock() @@ -94,8 +94,3 @@ func (mw *MuxWriter) Stop() { func (mw *MuxWriter) Done() <-chan struct{} { return mw.done } - -func (mw *MuxWriter) StopWait() { - mw.Stop() - <-mw.Done() -} diff --git a/drpcwire/mux_writer_test.go b/drpcwire/mux_writer_test.go index 093a255..61a60fe 100644 --- a/drpcwire/mux_writer_test.go +++ b/drpcwire/mux_writer_test.go @@ -12,8 +12,6 @@ import ( "time" "github.com/zeebo/assert" - - "storj.io/drpc" ) // blockingWriter blocks in Write until unblock is closed, then returns err. @@ -80,7 +78,7 @@ func TestMuxWriter(t *testing.T) { assert.NoError(t, err) // Now stop the writer and close the pipe. - mw.Stop() + mw.Stop(errors.New("stopped")) <-mw.Done() pw.Close() pr.Close() @@ -90,12 +88,12 @@ func TestMuxWriter(t *testing.T) { func TestMuxWriter_WriteFrameAfterStop(t *testing.T) { mw := NewMuxWriter(io.Discard, func(error) {}) - mw.Stop() + mw.Stop(errors.New("stopped")) <-mw.Done() err := mw.WriteFrame(RandFrame()) assert.Error(t, err) - assert.That(t, drpc.ClosedError.Has(err)) + assert.Equal(t, err.Error(), "stopped") } func TestMuxWriter_ConcurrentWriteFrame(t *testing.T) { @@ -141,7 +139,7 @@ func TestMuxWriter_ConcurrentWriteFrame(t *testing.T) { got := make([]byte, expSize) _, err := io.ReadFull(pr, got) assert.NoError(t, err) - mw.Stop() + mw.Stop(errors.New("stopped")) <-mw.Done() pw.Close() pr.Close() @@ -190,7 +188,7 @@ func TestMuxWriter_OnErrorCallingStopDoesNotDeadlock(t *testing.T) { var mw *MuxWriter mw = NewMuxWriter(fw, func(err error) { // Simulate manager.terminate calling Stop. - mw.Stop() + mw.Stop(errors.New("stopped")) }) assert.NoError(t, mw.WriteFrame(RandFrame())) @@ -218,7 +216,7 @@ func TestMuxWriter_BlockedWriteUnblockedByClose(t *testing.T) { } // Simulate terminate: Stop, then unblock the writer (like tr.Close()). - mw.Stop() + mw.Stop(errors.New("stopped")) bw.err = errors.New("closed") close(bw.unblock) @@ -241,7 +239,7 @@ func TestMuxWriter_ConcurrentStop(t *testing.T) { for range n { go func() { defer wg.Done() - mw.Stop() + mw.Stop(errors.New("stopped")) }() } wg.Wait() @@ -276,7 +274,7 @@ func TestMuxWriter_StopDiscardsBufferedData(t *testing.T) { } // Stop without letting the blocked Write complete. - mw.Stop() + mw.Stop(errors.New("stopped")) bw.err = errors.New("closed") close(bw.unblock) @@ -322,7 +320,7 @@ func TestMuxWriter_WriteFrameDuringActiveDrain(t *testing.T) { close(g2.ch) // Both batches were written. Stop and verify. - mw.Stop() + mw.Stop(errors.New("stopped")) <-mw.Done() } From da196ef0ac8ff6e2e71361d7e15dc705586df36b Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Fri, 17 Apr 2026 13:51:20 +0530 Subject: [PATCH 13/15] rename invokesAssembler to pendingStreams --- drpcmanager/manager.go | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index 2f13d88..6c77e17 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -66,9 +66,10 @@ type Manager struct { pdone drpcsignal.Chan // signals when NewServerStream has registered the new stream invokes chan invokeInfo // completed invoke info from manageReader to NewServerStream - // invokesAssembler is owned by the manageReader goroutine, used in - // handleInvokeFrame. - invokesAssembler map[uint64]*invokeAssembler + // pendingStreams is owned by the manageReader goroutine, used in + // handleInvokeFrame. It tracks streams that are being assembled from + // invoke/metadata frames but haven't been fully created yet. + pendingStreams map[uint64]*pendingStream sigs struct { term drpcsignal.Signal // set when the manager should start terminating @@ -87,7 +88,10 @@ const ( Server ) -type invokeAssembler struct { +// pendingStream accumulates invoke and metadata frames for a stream that is +// being set up but hasn't been fully created yet. Once the invoke packet +// arrives, the pending stream is forwarded to NewServerStream. +type pendingStream struct { metadata map[string]string // accumulated invoke metadata pa drpcwire.PacketAssembler // assembles invoke/metadata frames into packets } @@ -123,7 +127,7 @@ func NewWithOptions(tr drpc.Transport, kind ManagerKind, opts Options) *Manager // new server stream without having to coordinate with manageReader. m.pdone.Make(1) - m.invokesAssembler = make(map[uint64]*invokeAssembler) + m.pendingStreams = make(map[uint64]*pendingStream) m.streams = newActiveStreams() @@ -211,15 +215,15 @@ func (m *Manager) manageReader() { } // handleInvokeFrame assembles invoke/metadata frames into complete packets and -// forwards the finished invoke info to NewServerStream via m.newServerStreamInfo. -// Metadata packets are accumulated; the invoke packet triggers the send. +// forwards the finished invoke info to NewServerStream. Metadata packets are +// accumulated; the invoke packet triggers the send. func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error { - ia, ok := m.invokesAssembler[fr.ID.Stream] + ps, ok := m.pendingStreams[fr.ID.Stream] if !ok { - ia = &invokeAssembler{pa: drpcwire.NewPacketAssembler()} - m.invokesAssembler[fr.ID.Stream] = ia + ps = &pendingStream{pa: drpcwire.NewPacketAssembler()} + m.pendingStreams[fr.ID.Stream] = ps } - pkt, packetReady, err := ia.pa.AppendFrame(fr) + pkt, packetReady, err := ps.pa.AppendFrame(fr) if err != nil { return err } @@ -233,19 +237,19 @@ func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error { if err != nil { return err } - ia.metadata = meta + ps.metadata = meta return nil } // Invoke packet completes the sequence. Send to NewServerStream. select { - case m.invokes <- invokeInfo{sid: pkt.ID.Stream, data: pkt.Data, metadata: ia.metadata}: + case m.invokes <- invokeInfo{sid: pkt.ID.Stream, data: pkt.Data, metadata: ps.metadata}: // Wait for NewServerStream to finish stream creation before reading the // next frame. This guarantees curr is set for subsequent non-invoke // packets. m.pdone.Recv() - // TODO: reuse invoke assembler - delete(m.invokesAssembler, fr.ID.Stream) + // TODO: reuse pending stream + delete(m.pendingStreams, fr.ID.Stream) case <-m.sigs.term.Signal(): } return nil @@ -310,6 +314,9 @@ func (m *Manager) Closed() <-chan struct{} { // Unblocked returns a channel that is closed when the manager is no longer // blocked. With multiplexing, multiple streams run concurrently and this // channel is always closed immediately. +// +// TODO(shubham): audit whether this is still useful in a multiplexing world. +// The only meaningful caller is Pool.Take. func (m *Manager) Unblocked() <-chan struct{} { return closedCh } From ef4bf7888517b5fd18d487e0ff69732528c964f3 Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Fri, 17 Apr 2026 14:39:41 +0530 Subject: [PATCH 14/15] rename packetQueue to ringBuffer --- drpcstream/packet_queue.go | 115 ---------------- drpcstream/packet_queue_test.go | 228 -------------------------------- drpcstream/ring_buffer.go | 116 ++++++++++++++++ drpcstream/ring_buffer_test.go | 228 ++++++++++++++++++++++++++++++++ drpcstream/stream.go | 20 +-- drpcstream/stream_test.go | 4 +- 6 files changed, 356 insertions(+), 355 deletions(-) delete mode 100644 drpcstream/packet_queue.go delete mode 100644 drpcstream/packet_queue_test.go create mode 100644 drpcstream/ring_buffer.go create mode 100644 drpcstream/ring_buffer_test.go diff --git a/drpcstream/packet_queue.go b/drpcstream/packet_queue.go deleted file mode 100644 index bc26dec..0000000 --- a/drpcstream/packet_queue.go +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright (C) 2026 Cockroach Labs. -// See LICENSE for copying information. - -package drpcstream - -import "sync" - -// defaultPacketQueueCapacity is the number of messages the packet queue can -// hold before the producer blocks. This decouples the transport reader -// (manageReader) from the consumer (RPC handler), preventing a slow handler -// from blocking frame delivery to other streams. -// -// TODO: benchmark whether power-of-2 masking improves performance over modulo. -const defaultPacketQueueCapacity = 256 - -// packetQueue is a bounded single-producer / single-consumer queue for -// assembled packet data. It sits between manageReader (producer, calls Put) -// and the application goroutine (consumer, calls Get/Done). -// -// It is implemented as a ring buffer with mutex + cond synchronization. -// Slots are pre-allocated and reused: each slot's backing array grows via -// append to fit incoming data, then stays at its high-water mark, avoiding -// per-message allocation in steady state. -// -// After Close, Get drains any queued messages before returning the close -// error. This ensures graceful shutdown (KindClose/KindCloseSend) delivers -// all buffered data to the consumer. -type packetQueue struct { - mu sync.Mutex - cond sync.Cond - - buf [][]byte // ring buffer of byte slices - head int // next write position (producer) - tail int // next read position (consumer) - count int // number of occupied slots - - held bool // true between Get and Done - err error // terminal error, set by Close -} - -func (pq *packetQueue) init() { - pq.cond.L = &pq.mu - pq.buf = make([][]byte, defaultPacketQueueCapacity) -} - -// Put copies data into the next write slot. If the queue is full, it blocks -// until a slot is freed or the queue is closed. If the queue is closed, Put -// returns silently without enqueuing. -func (pq *packetQueue) Put(data []byte) { - pq.mu.Lock() - defer pq.mu.Unlock() - - for pq.count == len(pq.buf) && pq.err == nil { - pq.cond.Wait() - } - if pq.err != nil { - return - } - - pq.buf[pq.head] = append(pq.buf[pq.head][:0], data...) - pq.head = (pq.head + 1) % len(pq.buf) - pq.count++ - pq.cond.Broadcast() -} - -// Get returns the data from the next read slot. If the queue is empty, it -// blocks until data is available or the queue is closed. The returned slice -// is valid until Done is called. -func (pq *packetQueue) Get() ([]byte, error) { - pq.mu.Lock() - defer pq.mu.Unlock() - - for pq.count == 0 && pq.err == nil { - pq.cond.Wait() - } - if pq.count == 0 { - // Queue is empty and closed — return the close error. - return nil, pq.err - } - - // Return data even if closed, draining pending items first. - pq.held = true - return pq.buf[pq.tail], nil -} - -// Done advances the read pointer, making the slot available for reuse. -// It must be called exactly once after each successful Get. -func (pq *packetQueue) Done() { - pq.mu.Lock() - defer pq.mu.Unlock() - - pq.tail = (pq.tail + 1) % len(pq.buf) - pq.count-- - pq.held = false - pq.cond.Broadcast() -} - -// Close marks the queue as closed with the given error. All blocked Put and -// Get calls are woken and will return. Close waits for any in-progress -// Get/Done pair to complete before setting the error. Subsequent calls are -// no-ops. -func (pq *packetQueue) Close(err error) { - pq.mu.Lock() - defer pq.mu.Unlock() - - for pq.held { - pq.cond.Wait() - } - if pq.err != nil { - return - } - - pq.err = err - pq.cond.Broadcast() -} diff --git a/drpcstream/packet_queue_test.go b/drpcstream/packet_queue_test.go deleted file mode 100644 index 7af0a7c..0000000 --- a/drpcstream/packet_queue_test.go +++ /dev/null @@ -1,228 +0,0 @@ -// Copyright (C) 2026 Cockroach Labs. -// See LICENSE for copying information. - -package drpcstream - -import ( - "io" - "sync" - "testing" - - "github.com/zeebo/assert" -) - -func TestPacketQueue_PutGet(t *testing.T) { - var pq packetQueue - pq.init() - - pq.Put([]byte("hello")) - - data, err := pq.Get() - assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("hello")) - pq.Done() -} - -func TestPacketQueue_FIFO(t *testing.T) { - var pq packetQueue - pq.init() - - pq.Put([]byte("first")) - pq.Put([]byte("second")) - pq.Put([]byte("third")) - - for _, want := range []string{"first", "second", "third"} { - data, err := pq.Get() - assert.NoError(t, err) - assert.DeepEqual(t, data, []byte(want)) - pq.Done() - } -} - -func TestPacketQueue_GetBlocksUntilPut(t *testing.T) { - var pq packetQueue - pq.init() - - got := make(chan []byte, 1) - go func() { - data, err := pq.Get() - assert.NoError(t, err) - got <- data - }() - - pq.Put([]byte("delayed")) - assert.DeepEqual(t, <-got, []byte("delayed")) - pq.Done() -} - -func TestPacketQueue_PutBlocksWhenFull(t *testing.T) { - var pq packetQueue - pq.cond.L = &pq.mu - pq.buf = make([][]byte, 2) // capacity 2 - - pq.Put([]byte("a")) - pq.Put([]byte("b")) - - // Third put should block until we drain one. - done := make(chan struct{}) - go func() { - pq.Put([]byte("c")) - close(done) - }() - - // Drain one slot. - data, err := pq.Get() - assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("a")) - pq.Done() - - // Now the blocked Put should complete. - <-done - - // Verify remaining items. - data, err = pq.Get() - assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("b")) - pq.Done() - - data, err = pq.Get() - assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("c")) - pq.Done() -} - -func TestPacketQueue_CloseUnblocksGet(t *testing.T) { - var pq packetQueue - pq.init() - - errch := make(chan error, 1) - go func() { - _, err := pq.Get() - errch <- err - }() - - pq.Close(io.EOF) - assert.Equal(t, <-errch, io.EOF) -} - -func TestPacketQueue_CloseUnblocksPut(t *testing.T) { - var pq packetQueue - pq.cond.L = &pq.mu - pq.buf = make([][]byte, 1) // capacity 1 - - pq.Put([]byte("fill")) - - done := make(chan struct{}) - go func() { - pq.Put([]byte("blocked")) - close(done) - }() - - pq.Close(io.EOF) - <-done -} - -func TestPacketQueue_CloseDrainsQueued(t *testing.T) { - var pq packetQueue - pq.init() - - pq.Put([]byte("queued")) - pq.Close(io.EOF) - - // Get returns the queued data first. - data, err := pq.Get() - assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("queued")) - pq.Done() - - // Next Get returns the close error. - data, err = pq.Get() - assert.Nil(t, data) - assert.Equal(t, err, io.EOF) -} - -func TestPacketQueue_CloseIdempotent(t *testing.T) { - var pq packetQueue - pq.init() - - pq.Close(io.EOF) - pq.Close(io.ErrUnexpectedEOF) // should not overwrite - - _, err := pq.Get() - assert.Equal(t, err, io.EOF) // original error preserved -} - -func TestPacketQueue_PutAfterClose(t *testing.T) { - var pq packetQueue - pq.init() - - pq.Close(io.EOF) - pq.Put([]byte("dropped")) // should not panic or block -} - -func TestPacketQueue_SlotReuse(t *testing.T) { - var pq packetQueue - pq.cond.L = &pq.mu - pq.buf = make([][]byte, 2) - - // Fill and drain a few rounds to exercise slot reuse. - for round := 0; round < 5; round++ { - pq.Put([]byte("data")) - data, err := pq.Get() - assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("data")) - pq.Done() - } -} - -func TestPacketQueue_CloseWaitsForHeld(t *testing.T) { - var pq packetQueue - pq.init() - - pq.Put([]byte("msg")) - - // Get the data but don't call Done yet. - data, err := pq.Get() - assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("msg")) - - closed := make(chan struct{}) - go func() { - pq.Close(io.EOF) - close(closed) - }() - - // Close should be blocked because held is true. - // Call Done to release it. - pq.Done() - <-closed -} - -func TestPacketQueue_ConcurrentProducerConsumer(t *testing.T) { - var pq packetQueue - pq.init() - - const n = 1000 - var wg sync.WaitGroup - wg.Add(2) - - go func() { - defer wg.Done() - for i := 0; i < n; i++ { - pq.Put([]byte{byte(i)}) - } - }() - - go func() { - defer wg.Done() - for i := 0; i < n; i++ { - data, err := pq.Get() - assert.NoError(t, err) - assert.Equal(t, data[0], byte(i)) - pq.Done() - } - }() - - wg.Wait() - pq.Close(io.EOF) -} diff --git a/drpcstream/ring_buffer.go b/drpcstream/ring_buffer.go new file mode 100644 index 0000000..5cb620a --- /dev/null +++ b/drpcstream/ring_buffer.go @@ -0,0 +1,116 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcstream + +import "sync" + +// defaultRingBufferCapacity is the number of messages the ring buffer can +// hold before the producer blocks. This decouples the transport reader +// (manageReader) from the consumer (RPC handler), preventing a slow handler +// from blocking frame delivery to other streams. +// +// TODO: benchmark whether power-of-2 masking improves performance over modulo. +const defaultRingBufferCapacity = 256 + +// ringBuffer is a bounded single-producer / single-consumer FIFO queue for +// assembled packet data. It sits between manageReader (producer, calls +// Enqueue) and the application goroutine (consumer, calls Dequeue/Done). +// +// Slots are pre-allocated and reused: each slot's backing array grows via +// append to fit incoming data, then stays at its high-water mark, avoiding +// per-message allocation in steady state. +// +// After Close, Dequeue drains any queued messages before returning the close +// error. This ensures graceful shutdown (KindClose/KindCloseSend) delivers +// all buffered data to the consumer. +type ringBuffer struct { + mu sync.Mutex + cond sync.Cond + + buf [][]byte // ring of byte slices + head int // next write position (producer) + tail int // next read position (consumer) + count int // number of occupied slots + + held bool // true between Dequeue and Done + err error // terminal error, set by Close +} + +func (rb *ringBuffer) init() { + rb.cond.L = &rb.mu + rb.buf = make([][]byte, defaultRingBufferCapacity) +} + +// Enqueue copies data into the next write slot. If the buffer is full, it +// blocks until a slot is freed or the buffer is closed. If the buffer is +// closed, Enqueue returns silently without enqueuing. +func (rb *ringBuffer) Enqueue(data []byte) { + rb.mu.Lock() + defer rb.mu.Unlock() + + for rb.count == len(rb.buf) && rb.err == nil { + rb.cond.Wait() + } + if rb.err != nil { + return + } + + rb.buf[rb.head] = append(rb.buf[rb.head][:0], data...) + rb.head = (rb.head + 1) % len(rb.buf) + rb.count++ + rb.cond.Broadcast() +} + +// Dequeue returns the data from the next read slot. If the buffer is empty, +// it blocks until data is available or the buffer is closed. The returned +// slice is valid until Done is called. +func (rb *ringBuffer) Dequeue() ([]byte, error) { + rb.mu.Lock() + defer rb.mu.Unlock() + + for rb.count == 0 && rb.err == nil { + rb.cond.Wait() + } + if rb.count == 0 && rb.err != nil { + return nil, rb.err + } + + rb.held = true + return rb.buf[rb.tail], nil +} + +// Done advances the read pointer, making the slot available for reuse. +// It must be called exactly once after each successful Dequeue. +// +// TODO(shubham): remove this method once a shared buffer pool is introduced. +// With a pool, Dequeue will advance the tail immediately and the caller will +// return the buffer to the pool directly. +func (rb *ringBuffer) Done() { + rb.mu.Lock() + defer rb.mu.Unlock() + + rb.tail = (rb.tail + 1) % len(rb.buf) + rb.count-- + rb.held = false + rb.cond.Broadcast() +} + +// Close marks the buffer as closed with the given error. All blocked Enqueue +// and Dequeue calls are woken and will return. Close waits for any in-progress +// Dequeue/Done pair to complete before setting the error. Subsequent calls are +// no-ops. +func (rb *ringBuffer) Close(err error) { + rb.mu.Lock() + defer rb.mu.Unlock() + + for rb.held { + rb.cond.Wait() + } + if rb.err != nil { + return + } + + rb.err = err + rb.cond.Broadcast() +} diff --git a/drpcstream/ring_buffer_test.go b/drpcstream/ring_buffer_test.go new file mode 100644 index 0000000..8be9c58 --- /dev/null +++ b/drpcstream/ring_buffer_test.go @@ -0,0 +1,228 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcstream + +import ( + "io" + "sync" + "testing" + + "github.com/zeebo/assert" +) + +func TestRingBuffer_EnqueueDequeue(t *testing.T) { + var rb ringBuffer + rb.init() + + rb.Enqueue([]byte("hello")) + + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("hello")) + rb.Done() +} + +func TestRingBuffer_FIFO(t *testing.T) { + var rb ringBuffer + rb.init() + + rb.Enqueue([]byte("first")) + rb.Enqueue([]byte("second")) + rb.Enqueue([]byte("third")) + + for _, want := range []string{"first", "second", "third"} { + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte(want)) + rb.Done() + } +} + +func TestRingBuffer_DequeueBlocksUntilEnqueue(t *testing.T) { + var rb ringBuffer + rb.init() + + got := make(chan []byte, 1) + go func() { + data, err := rb.Dequeue() + assert.NoError(t, err) + got <- data + }() + + rb.Enqueue([]byte("delayed")) + assert.DeepEqual(t, <-got, []byte("delayed")) + rb.Done() +} + +func TestRingBuffer_EnqueueBlocksWhenFull(t *testing.T) { + var rb ringBuffer + rb.cond.L = &rb.mu + rb.buf = make([][]byte, 2) // capacity 2 + + rb.Enqueue([]byte("a")) + rb.Enqueue([]byte("b")) + + // Third enqueue should block until we drain one. + done := make(chan struct{}) + go func() { + rb.Enqueue([]byte("c")) + close(done) + }() + + // Drain one slot. + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("a")) + rb.Done() + + // Now the blocked Enqueue should complete. + <-done + + // Verify remaining items. + data, err = rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("b")) + rb.Done() + + data, err = rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("c")) + rb.Done() +} + +func TestRingBuffer_CloseUnblocksDequeue(t *testing.T) { + var rb ringBuffer + rb.init() + + errch := make(chan error, 1) + go func() { + _, err := rb.Dequeue() + errch <- err + }() + + rb.Close(io.EOF) + assert.Equal(t, <-errch, io.EOF) +} + +func TestRingBuffer_CloseUnblocksEnqueue(t *testing.T) { + var rb ringBuffer + rb.cond.L = &rb.mu + rb.buf = make([][]byte, 1) // capacity 1 + + rb.Enqueue([]byte("fill")) + + done := make(chan struct{}) + go func() { + rb.Enqueue([]byte("blocked")) + close(done) + }() + + rb.Close(io.EOF) + <-done +} + +func TestRingBuffer_CloseDrainsQueued(t *testing.T) { + var rb ringBuffer + rb.init() + + rb.Enqueue([]byte("queued")) + rb.Close(io.EOF) + + // Dequeue returns the queued data first. + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("queued")) + rb.Done() + + // Next Dequeue returns the close error. + data, err = rb.Dequeue() + assert.Nil(t, data) + assert.Equal(t, err, io.EOF) +} + +func TestRingBuffer_CloseIdempotent(t *testing.T) { + var rb ringBuffer + rb.init() + + rb.Close(io.EOF) + rb.Close(io.ErrUnexpectedEOF) // should not overwrite + + _, err := rb.Dequeue() + assert.Equal(t, err, io.EOF) // original error preserved +} + +func TestRingBuffer_EnqueueAfterClose(t *testing.T) { + var rb ringBuffer + rb.init() + + rb.Close(io.EOF) + rb.Enqueue([]byte("dropped")) // should not panic or block +} + +func TestRingBuffer_SlotReuse(t *testing.T) { + var rb ringBuffer + rb.cond.L = &rb.mu + rb.buf = make([][]byte, 2) + + // Fill and drain a few rounds to exercise slot reuse. + for round := 0; round < 5; round++ { + rb.Enqueue([]byte("data")) + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("data")) + rb.Done() + } +} + +func TestRingBuffer_CloseWaitsForHeld(t *testing.T) { + var rb ringBuffer + rb.init() + + rb.Enqueue([]byte("msg")) + + // Dequeue the data but don't call Done yet. + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("msg")) + + closed := make(chan struct{}) + go func() { + rb.Close(io.EOF) + close(closed) + }() + + // Close should be blocked because held is true. + // Call Done to release it. + rb.Done() + <-closed +} + +func TestRingBuffer_ConcurrentProducerConsumer(t *testing.T) { + var rb ringBuffer + rb.init() + + const n = 1000 + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + for i := 0; i < n; i++ { + rb.Enqueue([]byte{byte(i)}) + } + }() + + go func() { + defer wg.Done() + for i := 0; i < n; i++ { + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.Equal(t, data[0], byte(i)) + rb.Done() + } + }() + + wg.Wait() + rb.Close(io.EOF) +} diff --git a/drpcstream/stream.go b/drpcstream/stream.go index 4f4a214..421a126 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -53,7 +53,7 @@ type Stream struct { id drpcwire.ID wr *drpcwire.MuxWriter - pbuf packetQueue + recvQueue ringBuffer wbuf []byte mu sync.Mutex // protects state transitions @@ -113,7 +113,7 @@ func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, opt } // initialize the packet buffer - s.pbuf.init() + s.recvQueue.init() return s } @@ -226,7 +226,7 @@ func (s *Stream) handlePacket(pkt drpcwire.Packet) (err error) { s.log("HANDLE", pkt.String) if pkt.Kind == drpcwire.KindMessage { - s.pbuf.Put(pkt.Data) + s.recvQueue.Enqueue(pkt.Data) return nil } @@ -254,13 +254,13 @@ func (s *Stream) handlePacket(pkt drpcwire.Packet) (err error) { case drpcwire.KindClose: s.sigs.recv.Set(io.EOF) - s.pbuf.Close(io.EOF) + s.recvQueue.Close(io.EOF) s.terminate(drpc.ClosedError.New("remote closed the stream")) return nil case drpcwire.KindCloseSend: s.sigs.recv.Set(io.EOF) - s.pbuf.Close(io.EOF) + s.recvQueue.Close(io.EOF) s.terminateIfBothClosed() return nil @@ -346,7 +346,7 @@ func (s *Stream) terminate(err error) { s.sigs.send.Set(err) s.sigs.recv.Set(err) s.sigs.term.Set(err) - s.pbuf.Close(err) + s.recvQueue.Close(err) s.checkFinished() } @@ -414,12 +414,12 @@ func (s *Stream) RawRecv() (data []byte, err error) { s.read.Lock() defer s.read.Unlock() - data, err = s.pbuf.Get() + data, err = s.recvQueue.Dequeue() if err != nil { return nil, err } data = append([]byte(nil), data...) - s.pbuf.Done() + s.recvQueue.Done() return data, nil } @@ -456,12 +456,12 @@ func (s *Stream) MsgRecv(msg drpc.Message, enc drpc.Encoding) (err error) { s.read.Lock() defer s.read.Unlock() - data, err := s.pbuf.Get() + data, err := s.recvQueue.Dequeue() if err != nil { return err } err = enc.Unmarshal(data, msg) - s.pbuf.Done() + s.recvQueue.Done() return err } diff --git a/drpcstream/stream_test.go b/drpcstream/stream_test.go index cb087e3..9cee8dd 100644 --- a/drpcstream/stream_test.go +++ b/drpcstream/stream_test.go @@ -266,8 +266,8 @@ func TestHandleFrame_FirstFrameOnFreshStream(t *testing.T) { mw := testMuxWriter(t) for _, messageID := range []uint64{1, 2} { st := New(context.Background(), 1, mw) - // Close the packet buffer so KindMessage Put doesn't block. - st.pbuf.Close(io.EOF) + // Close the ring buffer so KindMessage Enqueue doesn't block. + st.recvQueue.Close(io.EOF) err := st.HandleFrame(drpcwire.Frame{ ID: drpcwire.ID{Stream: 1, Message: messageID}, Kind: drpcwire.KindMessage, Done: true, }) From e746a7e19451f21808a679bf9a7f74f7f3b7d150 Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Mon, 13 Apr 2026 17:59:19 +0530 Subject: [PATCH 15/15] integration: add stress and multiplexing tests Add stress tests targeting concurrency bugs in the multiplexing stack (MuxWriter, packetQueue, manager, activeStreams) and correctness tests for stream isolation under cancel, error, and connection close. Stress tests (stress_test.go): - SustainedConcurrentStreams: 50 bidi streams x 100 echo messages, validates no cross-stream data corruption. - RapidOpenCloseCycles: 500 sequential create/use/destroy cycles on one connection. - CancelStorm: 30 streams with ~50% randomly cancelled mid-traffic, verifying cancel isolation. - ShutdownDuringActivity: conn.Close() with 20 active streams, deadlock detection. - MixedRPCTypes: 30 goroutines x 10 rounds of randomly chosen RPC types (unary, client-streaming, server-streaming, bidi) on one connection. - ConcurrentCancelCloseTransportClose: triple-race shutdown (cancel + conn.Close + transport.Close) with active streams. - ConcurrentUnary: 100 goroutines x 50 unary RPCs (sustained throughput). - BurstUnary: 1000 goroutines x 1 unary RPC (burst contention). Multiplexing tests (multiplex_test.go): - CancelIsolation, ErrorIsolation, ConnCloseWithActiveStreams, TransportCloseTerminatesAllStreams. Transport selection is randomized per test (pipe or TCP loopback), with a -transport flag override for deterministic reproduction. All tests use goleak for goroutine leak detection. Moves TestConcurrentStreams and TestConcurrent from simple_test.go into the stress suite as SustainedConcurrentStreams and BurstUnary. --- drpcwire/mux_writer_test.go | 6 +- internal/integration/alias.go | 1 + internal/integration/common_test.go | 48 ++ internal/integration/go.mod | 1 + internal/integration/go.sum | 10 +- internal/integration/multiplex_test.go | 263 ++++++++++ internal/integration/simple_test.go | 101 ---- internal/integration/stress_test.go | 679 +++++++++++++++++++++++++ 8 files changed, 1001 insertions(+), 108 deletions(-) create mode 100644 internal/integration/multiplex_test.go create mode 100644 internal/integration/stress_test.go diff --git a/drpcwire/mux_writer_test.go b/drpcwire/mux_writer_test.go index 61a60fe..481cdca 100644 --- a/drpcwire/mux_writer_test.go +++ b/drpcwire/mux_writer_test.go @@ -1,4 +1,4 @@ -// Copyright (C) 2021 Storj Labs, Inc. +// Copyright (C) 2026 Cockroach Labs. // See LICENSE for copying information. package drpcwire @@ -17,8 +17,8 @@ import ( // blockingWriter blocks in Write until unblock is closed, then returns err. type blockingWriter struct { unblock chan struct{} - err error // error to return once unblocked - wrote chan []byte // sends a copy of data on each Write entry + err error // error to return once unblocked + wrote chan []byte // sends a copy of data on each Write entry } func newBlockingWriter() *blockingWriter { diff --git a/internal/integration/alias.go b/internal/integration/alias.go index 8e0d803..ab7c56d 100644 --- a/internal/integration/alias.go +++ b/internal/integration/alias.go @@ -25,6 +25,7 @@ type ( DRPCService_Method2Stream = service.DRPCService_Method2Stream DRPCService_Method3Stream = service.DRPCService_Method3Stream DRPCService_Method4Stream = service.DRPCService_Method4Stream + DRPCService_Method4Client = service.DRPCService_Method4Client ) var ( diff --git a/internal/integration/common_test.go b/internal/integration/common_test.go index 1fd555e..f3d1a80 100644 --- a/internal/integration/common_test.go +++ b/internal/integration/common_test.go @@ -5,7 +5,9 @@ package integration import ( "context" + "flag" "io" + "math/rand" "net" "strconv" "sync" @@ -38,7 +40,30 @@ func data(n int64) []byte { func in(n int64) *In { return &In{In: n} } func out(n int64) *Out { return &Out{Out: n} } +var transport = flag.String("transport", "", "force transport for tests: pipe, tcp, or empty for random") + func createRawConnection(t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker) *drpcconn.Conn { + switch *transport { + case "pipe": + t.Log("transport: pipe") + return createPipeConnection(t, server, ctx) + case "tcp": + t.Log("transport: tcp") + return createTCPConnection(t, server, ctx) + case "": + if rand.Intn(2) == 0 { + t.Log("transport: pipe") + return createPipeConnection(t, server, ctx) + } + t.Log("transport: tcp") + return createTCPConnection(t, server, ctx) + default: + t.Fatalf("unknown -transport value: %q", *transport) + return nil + } +} + +func createPipeConnection(t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker) *drpcconn.Conn { c1, c2 := net.Pipe() mux := drpcmux.New() assert.NoError(t, DRPCRegisterService(mux, server)) @@ -56,6 +81,29 @@ func createConnection(t testing.TB, server DRPCServiceServer) (DRPCServiceClient } } +func createTCPConnection(t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker) *drpcconn.Conn { + ln, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err) + t.Cleanup(func() { _ = ln.Close() }) + + mux := drpcmux.New() + assert.NoError(t, DRPCRegisterService(mux, server)) + srv := drpcserver.New(mux) + + ctx.Run(func(ctx context.Context) { + conn, err := ln.Accept() + if err != nil { + return + } + _ = srv.ServeOne(ctx, conn) + }) + + conn, err := net.Dial("tcp", ln.Addr().String()) + assert.NoError(t, err) + + return drpcconn.NewWithOptions(conn, drpcconn.Options{}) +} + // // server impl // diff --git a/internal/integration/go.mod b/internal/integration/go.mod index df01970..6438de7 100644 --- a/internal/integration/go.mod +++ b/internal/integration/go.mod @@ -6,6 +6,7 @@ require ( github.com/gogo/protobuf v1.3.2 github.com/zeebo/assert v1.3.1 github.com/zeebo/errs v1.4.0 + go.uber.org/goleak v1.3.0 golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa google.golang.org/grpc v1.57.2 google.golang.org/protobuf v1.33.0 diff --git a/internal/integration/go.sum b/internal/integration/go.sum index d06e526..970a054 100644 --- a/internal/integration/go.sum +++ b/internal/integration/go.sum @@ -20,6 +20,8 @@ github.com/zeebo/assert v1.3.1 h1:vukIABvugfNMZMQO1ABsyQDJDTVQbn+LWSMy1ol1h6A= github.com/zeebo/assert v1.3.1/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= github.com/zeebo/errs v1.4.0 h1:XNdoD/RRMKP7HD0UhJnIzUy74ISdGGxURlYG8HSWSfM= github.com/zeebo/errs v1.4.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -31,8 +33,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= -golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= +golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -43,8 +45,8 @@ golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= diff --git a/internal/integration/multiplex_test.go b/internal/integration/multiplex_test.go new file mode 100644 index 0000000..ce69e45 --- /dev/null +++ b/internal/integration/multiplex_test.go @@ -0,0 +1,263 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package integration + +import ( + "context" + "errors" + "io" + "testing" + "time" + + "github.com/zeebo/assert" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "storj.io/drpc/drpctest" +) + +// TestMultiplex_CancelIsolation verifies that canceling one stream's context +// does not affect other concurrent streams on the same connection. +func TestMultiplex_CancelIsolation(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + started := make(chan struct{}, 3) + cli, close := createConnection(t, impl{ + Method4Fn: func(stream DRPCService_Method4Stream) error { + started <- struct{}{} + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + }) + defer close() + + // Open 3 bidi streams with independent contexts. + ctx1, cancel1 := context.WithCancel(ctx) + defer cancel1() + s1, err := cli.Method4(ctx1) + assert.NoError(t, err) + + ctx2, cancel2 := context.WithCancel(ctx) + defer cancel2() + s2, err := cli.Method4(ctx2) + assert.NoError(t, err) + + ctx3, cancel3 := context.WithCancel(ctx) + defer cancel3() + s3, err := cli.Method4(ctx3) + assert.NoError(t, err) + + // Wait for all server handlers to start. + <-started + <-started + <-started + + // Cancel stream 2. + cancel2() + + // Verify stream 2 is dead. This blocks until the cancel propagates. + _, err = s2.Recv() + assert.Error(t, err) + st, ok := status.FromError(err) + assert.That(t, ok) + assert.Equal(t, st.Code(), codes.Canceled) + + // Streams 1 and 3 should still work. + assert.NoError(t, s1.Send(in(10))) + out, err := s1.Recv() + assert.NoError(t, err) + assert.Equal(t, out.Out, int64(10)) + + assert.NoError(t, s3.Send(in(30))) + out, err = s3.Recv() + assert.NoError(t, err) + assert.Equal(t, out.Out, int64(30)) + + // Clean up remaining streams. + assert.NoError(t, s1.CloseSend()) + assert.NoError(t, s3.CloseSend()) +} + +// TestMultiplex_ErrorIsolation verifies that a server handler returning an +// error on one stream does not affect other concurrent streams. +func TestMultiplex_ErrorIsolation(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + // Handler reads first message: In == -1 triggers an error, otherwise echo. + cli, close := createConnection(t, impl{ + Method4Fn: func(stream DRPCService_Method4Stream) error { + msg, err := stream.Recv() + if err != nil { + return nil + } + if msg.In == -1 { + return status.Error(codes.InvalidArgument, "bad input") + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + }) + defer close() + + s1, err := cli.Method4(ctx) + assert.NoError(t, err) + + s2, err := cli.Method4(ctx) + assert.NoError(t, err) + + // Trigger error on stream 1. + assert.NoError(t, s1.Send(in(-1))) + + // Send normal message on stream 2. + assert.NoError(t, s2.Send(in(42))) + + // Stream 1 should receive the server error. + _, err = s1.Recv() + assert.Error(t, err) + st, ok := status.FromError(err) + assert.That(t, ok) + assert.Equal(t, st.Code(), codes.InvalidArgument) + + // Stream 2 should be unaffected. + out, err := s2.Recv() + assert.NoError(t, err) + assert.Equal(t, out.Out, int64(42)) + + // Stream 2 keeps working after stream 1 is dead. + assert.NoError(t, s2.Send(in(100))) + out, err = s2.Recv() + assert.NoError(t, err) + assert.Equal(t, out.Out, int64(100)) + + // Clean up. + assert.NoError(t, s2.CloseSend()) + _, err = s2.Recv() + assert.That(t, errors.Is(err, io.EOF)) +} + +// TestMultiplex_ConnCloseWithActiveStreams verifies that closing a connection +// with multiple active streams terminates all of them and does not deadlock. +func TestMultiplex_ConnCloseWithActiveStreams(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + started := make(chan struct{}, 3) + conn := createRawConnection(t, impl{ + Method4Fn: func(stream DRPCService_Method4Stream) error { + started <- struct{}{} + <-stream.Context().Done() + return stream.Context().Err() + }, + }, ctx) + cli := NewDRPCServiceClient(conn) + + // Open 3 streams whose handlers block until canceled. + const N = 3 + streams := make([]DRPCService_Method4Client, N) + for i := 0; i < N; i++ { + s, err := cli.Method4(ctx) + assert.NoError(t, err) + streams[i] = s + } + + // Wait for all handlers to be running. + for i := 0; i < N; i++ { + <-started + } + + // conn.Close triggers manager.Close which must not deadlock. + done := make(chan error, 1) + go func() { done <- conn.Close() }() + + timer := time.NewTimer(5 * time.Second) + defer timer.Stop() + select { + case <-done: + case <-timer.C: + t.Fatal("conn.Close() deadlocked with active streams") + } + + // All streams should be terminated. + for i, s := range streams { + _, err := s.Recv() + assert.Error(t, err) + t.Logf("stream %d: %v", i, err) + } +} + +// TestMultiplex_TransportCloseTerminatesAllStreams verifies that an external +// transport failure terminates all active streams on the connection. +func TestMultiplex_TransportCloseTerminatesAllStreams(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + started := make(chan struct{}, 3) + conn := createRawConnection(t, impl{ + Method4Fn: func(stream DRPCService_Method4Stream) error { + started <- struct{}{} + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + }, ctx) + cli := NewDRPCServiceClient(conn) + + // Open 3 streams. + const N = 3 + streams := make([]DRPCService_Method4Client, N) + for i := 0; i < N; i++ { + s, err := cli.Method4(ctx) + assert.NoError(t, err) + streams[i] = s + } + + // Wait for all handlers. + for i := 0; i < N; i++ { + <-started + } + + // Verify streams work before the failure. + assert.NoError(t, streams[0].Send(in(1))) + out, err := streams[0].Recv() + assert.NoError(t, err) + assert.Equal(t, out.Out, int64(1)) + + // Simulate transport failure. + assert.NoError(t, conn.Transport().Close()) + + // All streams should receive errors. + for i, s := range streams { + _, err := s.Recv() + assert.Error(t, err) + t.Logf("stream %d: %v", i, err) + } + + // Connection should close cleanly after transport failure. + _ = conn.Close() +} diff --git a/internal/integration/simple_test.go b/internal/integration/simple_test.go index 8dbf47a..3347397 100644 --- a/internal/integration/simple_test.go +++ b/internal/integration/simple_test.go @@ -6,7 +6,6 @@ package integration import ( "context" "errors" - "fmt" "io" "net" "testing" @@ -141,106 +140,6 @@ func TestMultiplexedStreams(t *testing.T) { assert.That(t, errors.Is(err, io.EOF)) } -func TestConcurrentStreams(t *testing.T) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - echoServer := impl{ - Method1Fn: standardImpl.Method1Fn, - Method2Fn: standardImpl.Method2Fn, - Method3Fn: standardImpl.Method3Fn, - Method4Fn: func(stream DRPCService_Method4Stream) error { - for { - msg, err := stream.Recv() - if err != nil { - return nil - } - if err := stream.Send(&Out{Out: msg.In}); err != nil { - return err - } - } - }, - } - - cli, close := createConnection(t, echoServer) - defer close() - - const numStreams = 10 - const numMessages = 20 - - errs := make(chan error, numStreams) - for i := 0; i < numStreams; i++ { - i := i - ctx.Run(func(ctx context.Context) { - select { - case <-ctx.Done(): - case errs <- func() error { - stream, err := cli.Method4(ctx) - if err != nil { - return fmt.Errorf("stream %d: open: %w", i, err) - } - for j := 0; j < numMessages; j++ { - val := int64(i*1000 + j) - if err := stream.Send(&In{In: val}); err != nil { - return fmt.Errorf("stream %d: send %d: %w", i, j, err) - } - out, err := stream.Recv() - if err != nil { - return fmt.Errorf("stream %d: recv %d: %w", i, j, err) - } - if out.Out != val { - return fmt.Errorf("stream %d: msg %d: got %d, want %d", i, j, out.Out, val) - } - } - if err := stream.CloseSend(); err != nil { - return fmt.Errorf("stream %d: close send: %w", i, err) - } - _, err = stream.Recv() - if !errors.Is(err, io.EOF) { - return fmt.Errorf("stream %d: final recv: got %v, want EOF", i, err) - } - return nil - }(): - } - }) - } - - for i := 0; i < numStreams; i++ { - assert.NoError(t, <-errs) - } -} - -func TestConcurrent(t *testing.T) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - cli, close := createConnection(t, standardImpl) - defer close() - - const N = 1000 - errs := make(chan error) - for i := 0; i < N; i++ { - ctx.Run(func(ctx context.Context) { - select { - case <-ctx.Done(): - case errs <- func() error { - out, err := cli.Method1(ctx, &In{In: 1}) - if err != nil { - return err - } else if out.Out != 1 { - return fmt.Errorf("wrong result %d", out.Out) - } else { - return nil - } - }(): - } - }) - } - for i := 0; i < N; i++ { - assert.NoError(t, <-errs) - } -} - func TestServerStats(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() diff --git a/internal/integration/stress_test.go b/internal/integration/stress_test.go new file mode 100644 index 0000000..7e67969 --- /dev/null +++ b/internal/integration/stress_test.go @@ -0,0 +1,679 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package integration + +import ( + "context" + "errors" + "fmt" + "io" + "math/rand" + "sync" + "testing" + "time" + + "github.com/zeebo/assert" + "go.uber.org/goleak" + + "storj.io/drpc/drpctest" +) + +// TestStress_SustainedConcurrentStreams opens 50 bidi streams on one +// connection, each exchanging 100 echo messages concurrently. This saturates +// the manageReader dispatch path (one reader fanning out to 50 packetQueues) +// and the MuxWriter batching path (50 goroutines calling WriteFrame). Each +// message encodes the stream's identity so we can detect cross-stream data +// corruption, which would indicate a routing bug in the manager or a buffer +// reuse bug in the packetQueue. +func TestStress_SustainedConcurrentStreams(t *testing.T) { + defer goleak.VerifyNone(t) + + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + conn := createRawConnection(t, impl{ + Method4Fn: func(stream DRPCService_Method4Stream) error { + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + }, ctx) + defer func() { _ = conn.Close() }() + cli := NewDRPCServiceClient(conn) + + const N = 50 // concurrent streams + const M = 100 // messages per stream + + errs := make(chan error, N) + for i := 0; i < N; i++ { + i := i + ctx.Run(func(ctx context.Context) { + select { + case <-ctx.Done(): + case errs <- func() error { + stream, err := cli.Method4(ctx) + if err != nil { + return fmt.Errorf("stream %d: open: %w", i, err) + } + for j := 0; j < M; j++ { + val := int64(i*1000 + j) // encode stream identity + if err := stream.Send(&In{In: val}); err != nil { + return fmt.Errorf("stream %d: send %d: %w", i, j, err) + } + out, err := stream.Recv() + if err != nil { + return fmt.Errorf("stream %d: recv %d: %w", i, j, err) + } + if out.Out != val { + return fmt.Errorf("stream %d: msg %d: got %d, want %d (cross-contamination?)", i, j, out.Out, val) + } + } + if err := stream.CloseSend(); err != nil { + return fmt.Errorf("stream %d: close send: %w", i, err) + } + _, err = stream.Recv() + if !errors.Is(err, io.EOF) { + return fmt.Errorf("stream %d: final recv: got %v, want EOF", i, err) + } + return nil + }(): + } + }) + } + + for i := 0; i < N; i++ { + assert.NoError(t, <-errs) + } +} + +// TestStress_RapidOpenCloseCycles opens and closes a bidi stream 500 times +// sequentially on one connection. Each cycle creates a stream, exchanges one +// message, then tears it down. This tests that stream ID allocation, the +// activeStreams map cleanup, and the invoke handshake (pdone channel) work +// correctly across many rapid create/destroy cycles without leaking resources. +func TestStress_RapidOpenCloseCycles(t *testing.T) { + defer goleak.VerifyNone(t) + + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + conn := createRawConnection(t, impl{ + Method4Fn: func(stream DRPCService_Method4Stream) error { + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + }, ctx) + defer func() { _ = conn.Close() }() + cli := NewDRPCServiceClient(conn) + + const K = 500 + + for i := 0; i < K; i++ { + stream, err := cli.Method4(ctx) + assert.NoError(t, err) + + val := int64(i) + assert.NoError(t, stream.Send(&In{In: val})) + out, err := stream.Recv() + assert.NoError(t, err) + assert.Equal(t, out.Out, val) + + assert.NoError(t, stream.CloseSend()) + _, err = stream.Recv() + assert.That(t, errors.Is(err, io.EOF)) + } +} + +// TestStress_CancelStorm opens 30 streams that all exchange messages +// concurrently, then cancels ~50% of them (chosen randomly) with jitter so +// cancellations land at different points in the send/recv cycle. This races +// context cancellation against in-flight Send and Recv calls, exercising the +// stream's cancel propagation and cleanup. Surviving (non-cancelled) streams +// must complete without error, verifying cancel isolation. +func TestStress_CancelStorm(t *testing.T) { + defer goleak.VerifyNone(t) + + seed := time.Now().UnixNano() + t.Logf("random seed: %d", seed) + + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + conn := createRawConnection(t, impl{ + Method4Fn: func(stream DRPCService_Method4Stream) error { + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + }, ctx) + defer func() { _ = conn.Close() }() + cli := NewDRPCServiceClient(conn) + + const N = 30 + const messagesPerStream = 20 + + type entry struct { + stream DRPCService_Method4Client + cancel context.CancelFunc + } + + entries := make([]entry, N) + for i := 0; i < N; i++ { + sctx, cancel := context.WithCancel(ctx) + stream, err := cli.Method4(sctx) + assert.NoError(t, err) + entries[i] = entry{stream: stream, cancel: cancel} + } + + // Decide which streams to cancel (~50%). + rng := rand.New(rand.NewSource(seed)) + cancelled := make([]bool, N) + for i := range entries { + cancelled[i] = rng.Intn(2) == 0 + } + + // All streams exchange messages concurrently. Cancelled streams + // will hit errors once the cancel goroutine fires their context; + // surviving streams must complete without error. + var wg sync.WaitGroup + wg.Add(N) + for i := range entries { + i := i + ctx.Run(func(_ context.Context) { + defer wg.Done() + defer entries[i].cancel() + for j := 0; j < messagesPerStream; j++ { + val := int64(i*1000 + j) + if err := entries[i].stream.Send(&In{In: val}); err != nil { + if !cancelled[i] { + t.Errorf("surviving stream %d: send %d: %v", i, j, err) + } + return + } + out, err := entries[i].stream.Recv() + if err != nil { + if !cancelled[i] { + t.Errorf("surviving stream %d: recv %d: %v", i, j, err) + } + return + } + if out.Out != val { + t.Errorf("stream %d: msg %d: got %d, want %d", i, j, out.Out, val) + } + } + if !cancelled[i] { + _ = entries[i].stream.CloseSend() + } + }) + } + + // Fire cancels concurrently with traffic, with jitter. + ctx.Run(func(_ context.Context) { + time.Sleep(time.Millisecond) // let streams start sending + for i := range entries { + if cancelled[i] { + time.Sleep(time.Duration(rng.Intn(500)) * time.Microsecond) + entries[i].cancel() + } + } + }) + + wg.Wait() +} + +// TestStress_ShutdownDuringActivity opens 20 streams all actively sending and +// receiving, then calls conn.Close(). This exercises the manager's terminate() +// path with traffic in flight: the MuxWriter must stop, manageReader must +// unblock, and all stream goroutines must exit. The test fails if conn.Close() +// doesn't return within 5 seconds (deadlock). +func TestStress_ShutdownDuringActivity(t *testing.T) { + defer goleak.VerifyNone(t) + + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + started := make(chan struct{}, 20) + conn := createRawConnection(t, impl{ + Method4Fn: func(stream DRPCService_Method4Stream) error { + started <- struct{}{} + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + }, ctx) + cli := NewDRPCServiceClient(conn) + + const N = 20 + + streams := make([]DRPCService_Method4Client, N) + for i := 0; i < N; i++ { + s, err := cli.Method4(ctx) + assert.NoError(t, err) + streams[i] = s + } + + // Wait for all server handlers to start. + for i := 0; i < N; i++ { + <-started + } + + // Start sending on all streams continuously. + for i, s := range streams { + i, s := i, s + ctx.Run(func(_ context.Context) { + for { + if err := s.Send(&In{In: int64(i)}); err != nil { + return + } + if _, err := s.Recv(); err != nil { + return + } + } + }) + } + + // Let traffic flow briefly. + time.Sleep(10 * time.Millisecond) + + // Close the connection. Must return within timeout. + done := make(chan error, 1) + go func() { done <- conn.Close() }() + + timer := time.NewTimer(5 * time.Second) + defer timer.Stop() + select { + case <-done: + case <-timer.C: + t.Fatal("conn.Close() deadlocked with active streams") + } +} + +// TestStress_MixedRPCTypes runs 30 goroutines each executing 10 rounds of a +// randomly chosen RPC type (unary, client-streaming, server-streaming, bidi) +// concurrently on one connection. Different RPC types have different frame +// sequences and stream lifecycles: unary is short-lived (invoke, response, +// close), client-streaming sends multiple frames before expecting a response, +// server-streaming reads until EOF, and bidi is long-lived bidirectional. +// Mixing them exercises the manager's ability to correctly interleave +// heterogeneous stream types on a shared transport. +func TestStress_MixedRPCTypes(t *testing.T) { + defer goleak.VerifyNone(t) + + seed := time.Now().UnixNano() + t.Logf("random seed: %d", seed) + + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + conn := createRawConnection(t, impl{ + // Method1: unary echo. + Method1Fn: func(ctx context.Context, in *In) (*Out, error) { + return &Out{Out: in.In}, nil + }, + // Method2: client-streaming — sum inputs. + Method2Fn: func(stream DRPCService_Method2Stream) error { + var total int64 + for { + in, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return err + } + total += in.In + } + return stream.SendAndClose(&Out{Out: total}) + }, + // Method3: server-streaming — send N copies. + Method3Fn: func(in *In, stream DRPCService_Method3Stream) error { + for i := 0; i < int(in.In); i++ { + if err := stream.Send(&Out{Out: in.In}); err != nil { + return err + } + } + return nil + }, + // Method4: bidi echo. + Method4Fn: func(stream DRPCService_Method4Stream) error { + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + }, ctx) + defer func() { _ = conn.Close() }() + cli := NewDRPCServiceClient(conn) + + const N = 30 + const rounds = 10 + + errs := make(chan error, N) + for i := 0; i < N; i++ { + i := i + ctx.Run(func(ctx context.Context) { + select { + case <-ctx.Done(): + case errs <- func() error { + rng := rand.New(rand.NewSource(seed + int64(i))) + for r := 0; r < rounds; r++ { + switch rng.Intn(4) { + case 0: // Unary + out, err := cli.Method1(ctx, &In{In: int64(i*100 + r)}) + if err != nil { + return fmt.Errorf("goroutine %d round %d: unary: %w", i, r, err) + } + if out.Out != int64(i*100+r) { + return fmt.Errorf("goroutine %d round %d: unary: got %d want %d", i, r, out.Out, i*100+r) + } + + case 1: // Client-streaming + stream, err := cli.Method2(ctx) + if err != nil { + return fmt.Errorf("goroutine %d round %d: client-stream open: %w", i, r, err) + } + var want int64 + for k := 0; k < 5; k++ { + v := int64(k + 1) + want += v + if err := stream.Send(&In{In: v}); err != nil { + return fmt.Errorf("goroutine %d round %d: client-stream send: %w", i, r, err) + } + } + out, err := stream.CloseAndRecv() + if err != nil { + return fmt.Errorf("goroutine %d round %d: client-stream close: %w", i, r, err) + } + if out.Out != want { + return fmt.Errorf("goroutine %d round %d: client-stream: got %d want %d", i, r, out.Out, want) + } + + case 2: // Server-streaming + count := int64(3) + stream, err := cli.Method3(ctx, &In{In: count}) + if err != nil { + return fmt.Errorf("goroutine %d round %d: server-stream open: %w", i, r, err) + } + var got int + for { + out, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return fmt.Errorf("goroutine %d round %d: server-stream recv: %w", i, r, err) + } + if out.Out != count { + return fmt.Errorf("goroutine %d round %d: server-stream: got %d want %d", i, r, out.Out, count) + } + got++ + } + if got != int(count) { + return fmt.Errorf("goroutine %d round %d: server-stream: got %d msgs want %d", i, r, got, count) + } + + case 3: // Bidi + stream, err := cli.Method4(ctx) + if err != nil { + return fmt.Errorf("goroutine %d round %d: bidi open: %w", i, r, err) + } + for k := 0; k < 5; k++ { + val := int64(i*10000 + r*100 + k) + if err := stream.Send(&In{In: val}); err != nil { + return fmt.Errorf("goroutine %d round %d: bidi send: %w", i, r, err) + } + out, err := stream.Recv() + if err != nil { + return fmt.Errorf("goroutine %d round %d: bidi recv: %w", i, r, err) + } + if out.Out != val { + return fmt.Errorf("goroutine %d round %d: bidi: got %d want %d", i, r, out.Out, val) + } + } + if err := stream.CloseSend(); err != nil { + return fmt.Errorf("goroutine %d round %d: bidi close: %w", i, r, err) + } + _, err = stream.Recv() + if !errors.Is(err, io.EOF) { + return fmt.Errorf("goroutine %d round %d: bidi final: got %v want EOF", i, r, err) + } + } + } + return nil + }(): + } + }) + } + + for i := 0; i < N; i++ { + assert.NoError(t, <-errs) + } +} + +// TestStress_ConcurrentCancelCloseTransportClose fires context cancellation, +// conn.Close(), and transport.Close() nearly simultaneously while 10 streams +// are actively exchanging messages. This is the worst-case shutdown scenario: +// three independent shutdown paths (cancel propagation, manager.Close, +// transport EOF) all racing to call terminate(). The manager's terminate() is +// idempotent via sigs.term first-wins, so only one should win; the rest must +// be no-ops. Deadlock within 5 seconds is a failure. +func TestStress_ConcurrentCancelCloseTransportClose(t *testing.T) { + defer goleak.VerifyNone(t) + + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + started := make(chan struct{}, 10) + conn := createRawConnection(t, impl{ + Method4Fn: func(stream DRPCService_Method4Stream) error { + started <- struct{}{} + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + }, ctx) + cli := NewDRPCServiceClient(conn) + + const N = 10 + + cancels := make([]context.CancelFunc, N) + streams := make([]DRPCService_Method4Client, N) + for i := 0; i < N; i++ { + sctx, cancel := context.WithCancel(ctx) + s, err := cli.Method4(sctx) + assert.NoError(t, err) + cancels[i] = cancel + streams[i] = s + } + + for i := 0; i < N; i++ { + <-started + } + + // Start sending on all streams. + for i, s := range streams { + i, s := i, s + ctx.Run(func(_ context.Context) { + for { + if err := s.Send(&In{In: int64(i)}); err != nil { + return + } + if _, err := s.Recv(); err != nil { + return + } + } + }) + } + + time.Sleep(10 * time.Millisecond) + + // Fire all three shutdown mechanisms nearly simultaneously. + var wg sync.WaitGroup + wg.Add(3) + + go func() { + defer wg.Done() + for _, c := range cancels { + c() + } + }() + + go func() { + defer wg.Done() + _ = conn.Close() + }() + + go func() { + defer wg.Done() + _ = conn.Transport().Close() + }() + + // Must complete within timeout — no deadlock. + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + timer := time.NewTimer(5 * time.Second) + defer timer.Stop() + select { + case <-done: + case <-timer.C: + t.Fatal("triple-race shutdown deadlocked") + } +} + +// TestStress_ConcurrentUnary runs 100 goroutines each making 50 unary RPCs +// on one connection. Each unary call creates a stream, does the invoke +// handshake, sends request, receives response, and tears down, so this +// produces 5000 rapid stream lifecycles. Complements TestStress_BurstUnary +// below, which tests instantaneous burst contention rather than sustained +// throughput. +func TestStress_ConcurrentUnary(t *testing.T) { + defer goleak.VerifyNone(t) + + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + conn := createRawConnection(t, impl{ + Method1Fn: func(ctx context.Context, in *In) (*Out, error) { + return &Out{Out: in.In}, nil + }, + }, ctx) + defer func() { _ = conn.Close() }() + cli := NewDRPCServiceClient(conn) + + const N = 100 + const M = 50 + + errs := make(chan error, N) + for i := 0; i < N; i++ { + i := i + ctx.Run(func(ctx context.Context) { + select { + case <-ctx.Done(): + case errs <- func() error { + for j := 0; j < M; j++ { + val := int64(i*1000 + j) + out, err := cli.Method1(ctx, &In{In: val}) + if err != nil { + return fmt.Errorf("goroutine %d call %d: %w", i, j, err) + } + if out.Out != val { + return fmt.Errorf("goroutine %d call %d: got %d want %d", i, j, out.Out, val) + } + } + return nil + }(): + } + }) + } + + for i := 0; i < N; i++ { + assert.NoError(t, <-errs) + } +} + +// TestStress_BurstUnary fires 1000 goroutines each making a single unary RPC +// simultaneously. All 1000 invokes hit the pdone channel at once, testing +// burst contention on the invoke handshake and the MuxWriter's ability to +// batch a thundering herd of WriteFrame calls. +func TestStress_BurstUnary(t *testing.T) { + defer goleak.VerifyNone(t) + + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + conn := createRawConnection(t, impl{ + Method1Fn: func(ctx context.Context, in *In) (*Out, error) { + return &Out{Out: in.In}, nil + }, + }, ctx) + defer func() { _ = conn.Close() }() + cli := NewDRPCServiceClient(conn) + + const N = 1000 + + errs := make(chan error, N) + for i := 0; i < N; i++ { + i := i + ctx.Run(func(ctx context.Context) { + select { + case <-ctx.Done(): + case errs <- func() error { + val := int64(i) + out, err := cli.Method1(ctx, &In{In: val}) + if err != nil { + return fmt.Errorf("goroutine %d: %w", i, err) + } + if out.Out != val { + return fmt.Errorf("goroutine %d: got %d want %d", i, out.Out, val) + } + return nil + }(): + } + }) + } + + for i := 0; i < N; i++ { + assert.NoError(t, <-errs) + } +}