diff --git a/drpc.go b/drpc.go index f6f92084..ed037037 100644 --- a/drpc.go +++ b/drpc.go @@ -76,14 +76,6 @@ type Stream interface { // received on it. Context() context.Context - // Kind returns the type of the stream ("unknown", "cli", or "srv"). Client - // and server streams must be treated differently for error handling and - // logging purposes. - // - // Client streams return Unavailable errors when the remote closes the - // connection, while server streams return Canceled errors. - Kind() StreamKind - // MsgSend sends the Message to the remote. MsgSend(msg Message, enc Encoding) error diff --git a/drpccache/README.md b/drpccache/README.md deleted file mode 100644 index 5739abcc..00000000 --- a/drpccache/README.md +++ /dev/null @@ -1,74 +0,0 @@ -# package drpccache - -`import "storj.io/drpc/drpccache"` - -Package drpccache implements per stream cache for drpc. - -## Usage - -#### func WithContext - -```go -func WithContext(parent context.Context, cache *Cache) context.Context -``` -WithContext returns a context with the value cache associated with the context. - -#### type Cache - -```go -type Cache struct { -} -``` - -Cache is a per stream cache. - -#### func FromContext - -```go -func FromContext(ctx context.Context) *Cache -``` -FromContext returns a cache from a context. - -Example usage: - - cache := drpccache.FromContext(stream.Context()) - if cache != nil { - value := cache.LoadOrCreate("initialized", func() (interface{}) { - return 42 - }) - } - -#### func New - -```go -func New() *Cache -``` -New returns a new cache. - -#### func (*Cache) Clear - -```go -func (cache *Cache) Clear() -``` -Clear clears the cache. - -#### func (*Cache) Load - -```go -func (cache *Cache) Load(key interface{}) interface{} -``` -Load returns the value with the given key. - -#### func (*Cache) LoadOrCreate - -```go -func (cache *Cache) LoadOrCreate(key interface{}, fn func() interface{}) interface{} -``` -LoadOrCreate returns the value with the given key. - -#### func (*Cache) Store - -```go -func (cache *Cache) Store(key, value interface{}) -``` -Store sets the value at a key. diff --git a/drpccache/cache.go b/drpccache/cache.go deleted file mode 100644 index 107b7775..00000000 --- a/drpccache/cache.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (C) 2020 Storj Labs, Inc. -// See LICENSE for copying information. - -package drpccache - -import ( - "context" - "sync" -) - -type cacheKey struct{} - -// Cache is a per stream cache. -type Cache struct { - mu sync.Mutex - values map[interface{}]interface{} -} - -// New returns a new cache. -func New() *Cache { return &Cache{} } - -// FromContext returns a cache from a context. -// -// Example usage: -// -// cache := drpccache.FromContext(stream.Context()) -// if cache != nil { -// value := cache.LoadOrCreate("initialized", func() (interface{}) { -// return 42 -// }) -// } -func FromContext(ctx context.Context) *Cache { - cache, _ := ctx.Value(cacheKey{}).(*Cache) - return cache -} - -// WithContext returns a context with the value cache associated with the context. -func WithContext(parent context.Context, cache *Cache) context.Context { - return context.WithValue(parent, cacheKey{}, cache) -} - -// init ensures that the values map exist. -func (cache *Cache) init() { - if cache.values == nil { - cache.values = map[interface{}]interface{}{} - } -} - -// Clear clears the cache. -func (cache *Cache) Clear() { - cache.mu.Lock() - defer cache.mu.Unlock() - - cache.values = nil -} - -// Store sets the value at a key. -func (cache *Cache) Store(key, value interface{}) { - cache.mu.Lock() - defer cache.mu.Unlock() - - cache.init() - cache.values[key] = value -} - -// Load returns the value with the given key. -func (cache *Cache) Load(key interface{}) interface{} { - cache.mu.Lock() - defer cache.mu.Unlock() - - if cache.values == nil { - return nil - } - - return cache.values[key] -} - -// LoadOrCreate returns the value with the given key. -func (cache *Cache) LoadOrCreate(key interface{}, fn func() interface{}) interface{} { - cache.mu.Lock() - defer cache.mu.Unlock() - - cache.init() - - value, ok := cache.values[key] - if !ok { - value = fn() - cache.values[key] = value - } - - return value -} diff --git a/drpccache/cache_test.go b/drpccache/cache_test.go deleted file mode 100644 index bd98dbf9..00000000 --- a/drpccache/cache_test.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (C) 2021 Storj Labs, Inc. -// See LICENSE for copying information. - -package drpccache - -import ( - "context" - "testing" - - "github.com/zeebo/assert" -) - -func TestWithContext(t *testing.T) { - ctx := context.Background() - assert.Nil(t, FromContext(ctx)) - - cache := New() - ctx = WithContext(ctx, cache) - assert.Equal(t, cache, FromContext(ctx)) -} - -func TestLoad(t *testing.T) { - cache := New() - - assert.Nil(t, cache.Load("key1")) - assert.Nil(t, cache.Load("key2")) - - cache.Store("key1", "val1") - - assert.Equal(t, cache.Load("key1"), "val1") - assert.Nil(t, cache.Load("key2")) - - cache.Store("key2", "val2") - - assert.Equal(t, cache.Load("key1"), "val1") - assert.Equal(t, cache.Load("key2"), "val2") -} - -func TestClear(t *testing.T) { - cache := New() - - cache.Store("key1", "val1") - cache.Store("key2", "val2") - - assert.Equal(t, cache.Load("key1"), "val1") - assert.Equal(t, cache.Load("key2"), "val2") - - cache.Clear() - - assert.Nil(t, cache.Load("key1")) - assert.Nil(t, cache.Load("key2")) -} - -func TestLoadOrCreate(t *testing.T) { - f := func(val interface{}) func() interface{} { - return func() interface{} { return val } - } - - cache := New() - - assert.Nil(t, cache.Load("key1")) - assert.Nil(t, cache.Load("key2")) - - assert.Equal(t, cache.LoadOrCreate("key1", f("key1")), "key1") - assert.Equal(t, cache.LoadOrCreate("key1", f("key2")), "key1") - - assert.Equal(t, cache.Load("key1"), "key1") - assert.Nil(t, cache.Load("key2")) - - cache.LoadOrCreate("key1", func() interface{} { panic("called") }) -} diff --git a/drpccache/doc.go b/drpccache/doc.go deleted file mode 100644 index 6926e5d4..00000000 --- a/drpccache/doc.go +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (C) 2019 Storj Labs, Inc. -// See LICENSE for copying information. - -// Package drpccache implements per stream cache for drpc. -package drpccache diff --git a/drpcclient/dialoptions.go b/drpcclient/dialoptions.go index 473f018f..91b406d5 100644 --- a/drpcclient/dialoptions.go +++ b/drpcclient/dialoptions.go @@ -129,7 +129,6 @@ func DialContext(ctx context.Context, address string, opts ...DialOption) (conn Stream: drpcstream.Options{ MaximumBufferSize: 0, // unlimited }, - SoftCancel: false, }, }), nil } diff --git a/drpcconn/conn.go b/drpcconn/conn.go index 636f3468..00b53038 100644 --- a/drpcconn/conn.go +++ b/drpcconn/conn.go @@ -5,7 +5,6 @@ package drpcconn import ( "context" - "sync" "github.com/zeebo/errs" grpcmetadata "google.golang.org/grpc/metadata" @@ -13,30 +12,20 @@ import ( "storj.io/drpc/drpcenc" "storj.io/drpc/drpcmanager" "storj.io/drpc/drpcmetadata" - "storj.io/drpc/drpcstats" "storj.io/drpc/drpcstream" "storj.io/drpc/drpcwire" - "storj.io/drpc/internal/drpcopts" ) // Options controls configuration settings for a conn. type Options struct { // Manager controls the options we pass to the manager of this conn. Manager drpcmanager.Options - - // CollectStats controls whether the server should collect stats on the - // rpcs it creates. - CollectStats bool } // Conn is a drpc client connection. type Conn struct { - tr drpc.Transport - man *drpcmanager.Manager - mu sync.Mutex - wbuf []byte - - stats map[string]*drpcstats.Stats + tr drpc.Transport + man *drpcmanager.Manager } var _ drpc.Conn = (*Conn)(nil) @@ -51,41 +40,11 @@ func NewWithOptions(tr drpc.Transport, opts Options) *Conn { tr: tr, } - if opts.CollectStats { - drpcopts.SetManagerStatsCB(&opts.Manager.Internal, c.getStats) - c.stats = make(map[string]*drpcstats.Stats) - } - - c.man = drpcmanager.NewWithOptions(tr, opts.Manager) + c.man = drpcmanager.NewWithOptions(tr, drpcmanager.Client, opts.Manager) return c } -// Stats returns the collected stats grouped by rpc. -func (c *Conn) Stats() map[string]drpcstats.Stats { - c.mu.Lock() - defer c.mu.Unlock() - - stats := make(map[string]drpcstats.Stats, len(c.stats)) - for k, v := range c.stats { - stats[k] = v.AtomicClone() - } - return stats -} - -// getStats returns the drpcopts.Stats struct for the given rpc. -func (c *Conn) getStats(rpc string) *drpcstats.Stats { - c.mu.Lock() - defer c.mu.Unlock() - - stats := c.stats[rpc] - if stats == nil { - stats = new(drpcstats.Stats) - c.stats[rpc] = stats - } - return stats -} - // Transport returns the transport the conn is using. func (c *Conn) Transport() drpc.Transport { return c.tr } @@ -93,15 +52,15 @@ func (c *Conn) Transport() drpc.Transport { return c.tr } func (c *Conn) Closed() <-chan struct{} { return c.man.Closed() } // Unblocked returns a channel that is closed once the connection is no longer -// blocked by a previously canceled Invoke or NewStream call. It should not -// be called concurrently with Invoke or NewStream. +// blocked. With multiplexing, multiple streams run concurrently and this +// channel is always closed immediately. func (c *Conn) Unblocked() <-chan struct{} { return c.man.Unblocked() } // Close closes the connection. func (c *Conn) Close() (err error) { return c.man.Close() } // Invoke issues the rpc on the transport serializing in, waits for a response, and -// deserializes it into out. Only one Invoke or Stream may be open at a time. +// deserializes it into out. func (c *Conn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) (err error) { defer func() { err = drpc.ToRPCErr(err) }() @@ -117,30 +76,21 @@ func (c *Conn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, ou } defer func() { err = errs.Combine(err, stream.Close()) }() - // we have to protect c.wbuf here even though the manager only allows one - // stream at a time because the stream may async close allowing another - // concurrent call to Invoke to proceed. - c.mu.Lock() - defer c.mu.Unlock() - - c.wbuf, err = drpcenc.MarshalAppend(in, enc, c.wbuf[:0]) + // TODO: use buffer pool to reduce allocations + data, err := drpcenc.MarshalAppend(in, enc, nil) if err != nil { return err } - if err := c.doInvoke(stream, enc, rpc, c.wbuf, metadata, out); err != nil { + if err := c.doInvoke(stream, enc, rpc, data, metadata, out); err != nil { return err } return nil } func (c *Conn) doInvoke(stream *drpcstream.Stream, enc drpc.Encoding, rpc string, data []byte, metadata []byte, out drpc.Message) (err error) { - if len(metadata) > 0 { - if err := stream.RawWrite(drpcwire.KindInvokeMetadata, metadata); err != nil { - return err - } - } - if err := stream.RawWrite(drpcwire.KindInvoke, []byte(rpc)); err != nil { + defer func() { err = stream.CheckCancelError(err) }() + if err := stream.WriteInvoke(rpc, metadata); err != nil { return err } if err := stream.RawWrite(drpcwire.KindMessage, data); err != nil { @@ -155,8 +105,7 @@ func (c *Conn) doInvoke(stream *drpcstream.Stream, enc drpc.Encoding, rpc string return nil } -// NewStream begins a streaming rpc on the connection. Only one Invoke or Stream may -// be open at a time. +// NewStream begins a streaming rpc on the connection. func (c *Conn) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (_ drpc.Stream, err error) { defer func() { err = drpc.ToRPCErr(err) }() @@ -171,25 +120,13 @@ func (c *Conn) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (_ return nil, err } - if err := c.doNewStream(stream, rpc, metadata); err != nil { + if err := stream.WriteInvoke(rpc, metadata); err != nil { return nil, errs.Combine(err, stream.Close()) } return stream, nil } -func (c *Conn) doNewStream(stream *drpcstream.Stream, rpc string, metadata []byte) error { - if len(metadata) > 0 { - if err := stream.RawWrite(drpcwire.KindInvokeMetadata, metadata); err != nil { - return err - } - } - if err := stream.RawWrite(drpcwire.KindInvoke, []byte(rpc)); err != nil { - return err - } - return nil -} - // encodeMetadata retrieves and encodes metadata from the provided // (outgoing/client) context. func (c *Conn) encodeMetadata(ctx context.Context) (metadata []byte, err error) { diff --git a/drpcconn/conn_test.go b/drpcconn/conn_test.go index e7402b6a..aadda12b 100644 --- a/drpcconn/conn_test.go +++ b/drpcconn/conn_test.go @@ -40,31 +40,32 @@ func TestConn_InvokeFlushesSendClose(t *testing.T) { invokeDone := make(chan struct{}) ctx.Run(func(ctx context.Context) { - wr := drpcwire.NewWriter(ps, 64) + wr := drpcwire.NewMuxWriter(ps, nil) + defer func() { wr.Stop(nil); <-wr.Done() }() rd := drpcwire.NewReader(ps) _, _ = rd.ReadFrame() // Invoke _, _ = rd.ReadFrame() // Message pkt, _ := rd.ReadFrame() // CloseSend - _ = wr.WritePacket(drpcwire.Packet{ + _ = wr.WriteFrame(drpcwire.Frame{ Data: []byte("qux"), ID: drpcwire.ID{Stream: pkt.ID.Stream, Message: 1}, Kind: drpcwire.KindMessage, + Done: true, }) - _ = wr.Flush() _, _ = rd.ReadFrame() // Close <-invokeDone // wait for invoke to return // ensure that any later packets are dropped by writing one // before closing the transport. - for i := 0; i < 5; i++ { - _ = wr.WritePacket(drpcwire.Packet{ + for range 5 { + _ = wr.WriteFrame(drpcwire.Frame{ ID: drpcwire.ID{Stream: pkt.ID.Stream, Message: 2}, Kind: drpcwire.KindCloseSend, + Done: true, }) - _ = wr.Flush() } _ = ps.Close() @@ -78,7 +79,7 @@ func TestConn_InvokeFlushesSendClose(t *testing.T) { invokeDone <- struct{}{} // signal invoke has returned - // we should eventually notice the transport is closed + // we should eventually notice the transport is closed due to ps.Close() select { case <-conn.Closed(): case <-time.After(1 * time.Second): @@ -95,7 +96,8 @@ func TestConn_InvokeSendsGrpcAndDrpcMetadata(t *testing.T) { defer func() { assert.NoError(t, ps.Close()) }() ctx.Run(func(ctx context.Context) { - wr := drpcwire.NewWriter(ps, 64) + wr := drpcwire.NewMuxWriter(ps, nil) + defer func() { wr.Stop(nil); <-wr.Done() }() rd := drpcwire.NewReader(ps) md, err := rd.ReadFrame() // Metadata @@ -114,12 +116,12 @@ func TestConn_InvokeSendsGrpcAndDrpcMetadata(t *testing.T) { _, _ = rd.ReadFrame() // Message pkt, _ := rd.ReadFrame() // CloseSend - _ = wr.WritePacket(drpcwire.Packet{ + _ = wr.WriteFrame(drpcwire.Frame{ Data: []byte("qux"), ID: drpcwire.ID{Stream: pkt.ID.Stream, Message: 1}, Kind: drpcwire.KindMessage, + Done: true, }) - _ = wr.Flush() _, _ = rd.ReadFrame() // Close }) @@ -181,6 +183,10 @@ func TestConn_NewStreamSendsGrpcAndDrpcMetadata(t *testing.T) { s, err := conn.NewStream(ctx, "/com.example.Foo/Bar", testEncoding{}) assert.NoError(t, err) _ = s.CloseSend() + + // Wait for the server goroutine to read all frames before defers + // close the pipe. With MuxWriter, writes are asynchronous. + ctx.Wait() } func TestConn_encodeMetadata(t *testing.T) { diff --git a/drpcmanager/active_streams.go b/drpcmanager/active_streams.go index f4223755..4ee7dc17 100644 --- a/drpcmanager/active_streams.go +++ b/drpcmanager/active_streams.go @@ -12,9 +12,10 @@ import ( // activeStreams is a thread-safe map of stream IDs to stream objects. // It is used by the Manager to track active streams for lifecycle management. type activeStreams struct { - mu sync.RWMutex - streams map[uint64]*drpcstream.Stream - closed bool + mu sync.RWMutex + streams map[uint64]*drpcstream.Stream + closed bool + closeErr error } func newActiveStreams() *activeStreams { @@ -34,7 +35,7 @@ func (r *activeStreams) Add(id uint64, stream *drpcstream.Stream) error { defer r.mu.Unlock() if r.closed { - return managerClosed.New("add to closed collection") + return r.closeErr } if _, ok := r.streams[id]; ok { return managerClosed.New("duplicate stream id") @@ -59,41 +60,24 @@ func (r *activeStreams) Get(id uint64) (*drpcstream.Stream, bool) { r.mu.RLock() defer r.mu.RUnlock() + if r.closed { + return nil, false + } s, ok := r.streams[id] return s, ok } -// GetLatest returns the stream with the highest ID, or nil if empty. -func (r *activeStreams) GetLatest() *drpcstream.Stream { - r.mu.RLock() - defer r.mu.RUnlock() - - var latest *drpcstream.Stream - for _, s := range r.streams { - if latest == nil || latest.ID() < s.ID() { - latest = s - } - } - return latest -} - -// Close marks the collection as closed, preventing future Add calls. -// It does not cancel any streams. -func (r *activeStreams) Close() { +// Close cancels all active streams with the given error, clears the +// collection, and marks it as closed to prevent future Add calls. +func (r *activeStreams) Close(err error) { r.mu.Lock() defer r.mu.Unlock() r.closed = true -} - -// ForEach calls fn for each active stream. The collection is read-locked -// during iteration. -func (r *activeStreams) ForEach(fn func(*drpcstream.Stream)) { - r.mu.RLock() - defer r.mu.RUnlock() - - for _, s := range r.streams { - fn(s) + r.closeErr = err + for id, s := range r.streams { + s.Cancel(err) + delete(r.streams, id) } } diff --git a/drpcmanager/active_streams_test.go b/drpcmanager/active_streams_test.go index cdf2ec99..f463b188 100644 --- a/drpcmanager/active_streams_test.go +++ b/drpcmanager/active_streams_test.go @@ -5,6 +5,8 @@ package drpcmanager import ( "context" + "errors" + "io" "testing" "github.com/zeebo/assert" @@ -13,13 +15,19 @@ import ( "storj.io/drpc/drpcwire" ) -func testStream(id uint64) *drpcstream.Stream { - return drpcstream.New(context.Background(), id, &drpcwire.Writer{}) +func testMuxWriter(t *testing.T) *drpcwire.MuxWriter { + mw := drpcwire.NewMuxWriter(io.Discard, func(error) {}) + t.Cleanup(func() { mw.Stop(nil); <-mw.Done() }) + return mw +} + +func testStream(t *testing.T, id uint64) *drpcstream.Stream { + return drpcstream.New(context.Background(), id, testMuxWriter(t)) } func TestActiveStreams_AddAndGet(t *testing.T) { streams := newActiveStreams() - s := testStream(1) + s := testStream(t, 1) assert.NoError(t, streams.Add(1, s)) @@ -38,7 +46,7 @@ func TestActiveStreams_GetMissing(t *testing.T) { func TestActiveStreams_Remove(t *testing.T) { streams := newActiveStreams() - s := testStream(1) + s := testStream(t, 1) assert.NoError(t, streams.Add(1, s)) assert.Equal(t, streams.Len(), 1) @@ -59,8 +67,8 @@ func TestActiveStreams_RemoveIdempotent(t *testing.T) { func TestActiveStreams_DuplicateAdd(t *testing.T) { streams := newActiveStreams() - s1 := testStream(1) - s2 := testStream(1) + s1 := testStream(t, 1) + s2 := testStream(t, 1) assert.NoError(t, streams.Add(1, s1)) assert.Error(t, streams.Add(1, s2)) @@ -73,18 +81,18 @@ func TestActiveStreams_DuplicateAdd(t *testing.T) { func TestActiveStreams_AddAfterClose(t *testing.T) { streams := newActiveStreams() - streams.Close() + streams.Close(errors.New("closed")) - err := streams.Add(1, testStream(1)) + err := streams.Add(1, testStream(t, 1)) assert.Error(t, err) } func TestActiveStreams_RemoveAfterClose(t *testing.T) { streams := newActiveStreams() - s := testStream(1) + s := testStream(t, 1) assert.NoError(t, streams.Add(1, s)) - streams.Close() + streams.Close(errors.New("closed")) // must not panic streams.Remove(1) @@ -94,41 +102,12 @@ func TestActiveStreams_Len(t *testing.T) { streams := newActiveStreams() assert.Equal(t, streams.Len(), 0) - assert.NoError(t, streams.Add(1, testStream(1))) + assert.NoError(t, streams.Add(1, testStream(t, 1))) assert.Equal(t, streams.Len(), 1) - assert.NoError(t, streams.Add(2, testStream(2))) + assert.NoError(t, streams.Add(2, testStream(t, 2))) assert.Equal(t, streams.Len(), 2) streams.Remove(1) assert.Equal(t, streams.Len(), 1) } - -func TestActiveStreams_ForEach(t *testing.T) { - streams := newActiveStreams() - s1 := testStream(1) - s2 := testStream(2) - s3 := testStream(3) - - assert.NoError(t, streams.Add(1, s1)) - assert.NoError(t, streams.Add(2, s2)) - assert.NoError(t, streams.Add(3, s3)) - - seen := make(map[uint64]*drpcstream.Stream) - streams.ForEach(func(s *drpcstream.Stream) { - seen[s.ID()] = s - }) - - assert.Equal(t, len(seen), 3) - assert.Equal(t, seen[1], s1) - assert.Equal(t, seen[2], s2) - assert.Equal(t, seen[3], s3) -} - -func TestActiveStreams_ForEach_Empty(t *testing.T) { - streams := newActiveStreams() - - count := 0 - streams.ForEach(func(_ *drpcstream.Stream) { count++ }) - assert.Equal(t, count, 0) -} diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index e2e0993b..6c77e174 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -13,7 +13,6 @@ import ( "sync" "sync/atomic" "syscall" - "time" "github.com/zeebo/errs" grpcmetadata "google.golang.org/grpc/metadata" @@ -31,30 +30,12 @@ var managerClosed = errs.Class("manager closed") // Options controls configuration settings for a manager. type Options struct { - // WriterBufferSize controls the size of the buffer that we will fill before - // flushing. Normal writes to streams typically issue a flush explicitly. - WriterBufferSize int - // Reader are passed to any readers the manager creates. Reader drpcwire.ReaderOptions // Stream are passed to any streams the manager creates. Stream drpcstream.Options - // SoftCancel controls if a context cancel will cause the transport to be - // closed or, if true, a soft cancel message will be attempted if possible. - // A soft cancel can reduce the amount of closed and dialed connections at - // the potential cost of higher latencies if there is latent data still - // being flushed when the cancel happens. - SoftCancel bool - - // InactivityTimeout is the amount of time the manager will wait when - // creating a NewServerStream. It only includes the time it is reading - // packets from the remote client. In other words, it only includes the time - // that the client could delay before invoking an RPC. If zero or negative, - // no timeout is used. - InactivityTimeout time.Duration - // Internal contains options that are for internal use only. Internal drpcopts.Manager @@ -70,37 +51,49 @@ type Options struct { // to the appropriate stream. type Manager struct { tr drpc.Transport - wr *drpcwire.Writer + wr *drpcwire.MuxWriter rd *drpcwire.Reader opts Options - lastFrameID drpcwire.ID - lastFrameKind drpcwire.Kind - // next client stream ID, incremented atomically lastStreamID atomic.Uint64 wg sync.WaitGroup // tracks active manageStream goroutines - // streams tracks active streams. Currently holds at most one active stream; - // a second may briefly coexist during stream handoff (old stream's Remove - // races with new stream's Add). + // streams tracks active streams. streams *activeStreams - sem drpcsignal.Chan // held by the active stream - - pdone drpcsignal.Chan // signals when NewServerStream has added the new stream + pdone drpcsignal.Chan // signals when NewServerStream has registered the new stream invokes chan invokeInfo // completed invoke info from manageReader to NewServerStream - // Below fields are owned by the manageReader goroutine, used in handleInvokeFrame. - metadata map[string]string // accumulated invoke metadata - pa drpcwire.PacketAssembler // assembles invoke/metadata frames into packets + // pendingStreams is owned by the manageReader goroutine, used in + // handleInvokeFrame. It tracks streams that are being assembled from + // invoke/metadata frames but haven't been fully created yet. + pendingStreams map[uint64]*pendingStream sigs struct { term drpcsignal.Signal // set when the manager should start terminating read drpcsignal.Signal // set after the goroutine reading from the transport is done tport drpcsignal.Signal // set after the transport has been closed } + + kind ManagerKind +} + +type ManagerKind uint8 + +const ( + _ ManagerKind = iota + Client + Server +) + +// pendingStream accumulates invoke and metadata frames for a stream that is +// being set up but hasn't been fully created yet. Once the invoke packet +// arrives, the pending stream is forwarded to NewServerStream. +type pendingStream struct { + metadata map[string]string // accumulated invoke metadata + pa drpcwire.PacketAssembler // assembles invoke/metadata frames into packets } // invokeInfo carries the assembled invoke data from manageReader to @@ -112,30 +105,30 @@ type invokeInfo struct { } // New returns a new Manager for the transport. -func New(tr drpc.Transport) *Manager { - return NewWithOptions(tr, Options{}) +func New(tr drpc.Transport, kind ManagerKind) *Manager { + return NewWithOptions(tr, kind, Options{}) } // NewWithOptions returns a new manager for the transport. It uses the provided // options to manage details of how it uses it. -func NewWithOptions(tr drpc.Transport, opts Options) *Manager { +func NewWithOptions(tr drpc.Transport, kind ManagerKind, opts Options) *Manager { m := &Manager{ tr: tr, - wr: drpcwire.NewWriter(tr, opts.WriterBufferSize), rd: drpcwire.NewReaderWithOptions(tr, opts.Reader), opts: opts, invokes: make(chan invokeInfo), + kind: kind, } - // this semaphore controls the number of concurrent streams. it MUST be 1. - m.sem.Make(1) + m.wr = drpcwire.NewMuxWriter(tr, m.terminate) // a buffer of size 1 allows NewServerStream to signal it is done creating a // new server stream without having to coordinate with manageReader. m.pdone.Make(1) - m.pa = drpcwire.NewPacketAssembler() + m.pendingStreams = make(map[uint64]*pendingStream) + m.streams = newActiveStreams() // set the internal stream options @@ -155,72 +148,23 @@ func (m *Manager) log(what string, cb func() string) { } } -// -// helpers -// - -// acquireSemaphore attempts to acquire the semaphore protecting streams. If the -// context is canceled or the manager is terminated, it returns an error. -func (m *Manager) acquireSemaphore(ctx context.Context) error { - if err, ok := m.sigs.term.Get(); ok { - return err - } else if err := ctx.Err(); err != nil { - return err - } - - select { - case <-ctx.Done(): - return ctx.Err() - - case <-m.sigs.term.Signal(): - return m.sigs.term.Err() - - case m.sem.Get() <- struct{}{}: - if err := m.waitForPreviousStream(ctx); err != nil { - m.sem.Recv() - return err - } - return nil - } -} - -// waitForPreviousStream will, if there was a previous stream, ensure it is -// Closed and then wait until it is in the Finished state, where it will no -// longer make any reads or writes on the transport. It exits early if the -// context is canceled or the manager is terminated. -func (m *Manager) waitForPreviousStream(ctx context.Context) (err error) { - prev := m.streams.GetLatest() - if prev == nil { - return nil - } - - // if the stream is not finished yet, we need to wait for it to be - // finished before letting the next stream to start. - if prev.IsFinished() { - return nil - } - - m.log("WAIT", prev.String) - - select { - case <-ctx.Done(): - return ctx.Err() - - case <-m.sigs.term.Signal(): - return m.sigs.term.Err() - - case <-prev.Finished(): - return nil - } -} - // terminate puts the Manager into a terminal state and closes any resources -// that need to be closed to signal the state change. +// that need to be closed to signal the state change. The mux writer is stopped +// before closing the transport so that WriteFrame immediately rejects new +// writes; the subsequent transport close unblocks any in-flight Write in the +// drain goroutine. func (m *Manager) terminate(err error) { if m.sigs.term.Set(err) { m.log("TERM", func() string { return fmt.Sprint(err) }) + if errors.Is(err, io.EOF) { + err = context.Canceled + if m.kind == Client { + err = drpc.ClosedError.New("connection closed") + } + } + m.wr.Stop(err) m.sigs.tport.Set(m.tr.Close()) - m.streams.Close() + m.streams.Close(err) } } @@ -247,54 +191,39 @@ func (m *Manager) manageReader() { m.log("READ", incomingFrame.String) - if ok := m.checkStreamMonotonicity(incomingFrame); !ok { - m.terminate(managerClosed.Wrap(drpc.ProtocolError.New("id monotonicity violation"))) - return - } + stream, ok := m.streams.Get(incomingFrame.ID.Stream) - switch curr := m.streams.GetLatest(); { - // If the frame is for the current stream, deliver it. - case curr != nil && incomingFrame.ID.Stream == curr.ID(): - if err := curr.HandleFrame(incomingFrame); err != nil { + switch { + // if the packet is for an active stream, deliver it. + case ok: + if err := stream.HandleFrame(incomingFrame); err != nil { m.terminate(managerClosed.Wrap(err)) return } - // If a frame arrives for an old stream, just ignore it. - case curr != nil && incomingFrame.ID.Stream < curr.ID(): - - // If an invoke sequence is being sent for a new stream, close any - // old unterminated stream and forward it to be handled. case incomingFrame.Kind == drpcwire.KindInvoke || incomingFrame.Kind == drpcwire.KindInvokeMetadata: - if curr != nil && !curr.IsTerminated() { - curr.Cancel(context.Canceled) - } if err := m.handleInvokeFrame(incomingFrame); err != nil { m.terminate(managerClosed.Wrap(err)) return } + // silently drop packet for an unknown stream default: m.log("DROP", incomingFrame.String) } } } -func (m *Manager) checkStreamMonotonicity(incomingFrame drpcwire.Frame) bool { - ok := incomingFrame.ID.Stream >= m.lastFrameID.Stream - m.lastFrameKind = incomingFrame.Kind - m.lastFrameID = incomingFrame.ID - if incomingFrame.Done { - m.lastFrameID.Message += 1 - } - return ok -} - // handleInvokeFrame assembles invoke/metadata frames into complete packets and -// forwards the finished invoke info to NewServerStream via m.newServerStreamInfo. -// Metadata packets are accumulated; the invoke packet triggers the send. +// forwards the finished invoke info to NewServerStream. Metadata packets are +// accumulated; the invoke packet triggers the send. func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error { - pkt, packetReady, err := m.pa.AppendFrame(fr) + ps, ok := m.pendingStreams[fr.ID.Stream] + if !ok { + ps = &pendingStream{pa: drpcwire.NewPacketAssembler()} + m.pendingStreams[fr.ID.Stream] = ps + } + pkt, packetReady, err := ps.pa.AppendFrame(fr) if err != nil { return err } @@ -308,20 +237,19 @@ func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error { if err != nil { return err } - m.metadata = meta + ps.metadata = meta return nil } // Invoke packet completes the sequence. Send to NewServerStream. select { - case m.invokes <- invokeInfo{sid: pkt.ID.Stream, data: pkt.Data, metadata: m.metadata}: + case m.invokes <- invokeInfo{sid: pkt.ID.Stream, data: pkt.Data, metadata: ps.metadata}: // Wait for NewServerStream to finish stream creation before reading the // next frame. This guarantees curr is set for subsequent non-invoke // packets. m.pdone.Recv() - - m.pa.Reset() - m.metadata = nil + // TODO: reuse pending stream + delete(m.pendingStreams, fr.ID.Stream) case <-m.sigs.term.Signal(): } return nil @@ -360,57 +288,17 @@ func (m *Manager) manageStream(ctx context.Context, stream *drpcstream.Stream) { defer m.wg.Done() defer m.streams.Remove(stream.ID()) select { - case <-m.sigs.term.Signal(): - err := m.sigs.term.Err() - if errors.Is(err, io.EOF) { - err = context.Canceled - if stream.Kind() == drpc.StreamKindClient { - err = drpc.ClosedError.New("connection closed") - } - } - stream.Cancel(err) - <-stream.Finished() - m.sem.Recv() - case <-stream.Finished(): - m.sem.Recv() case <-ctx.Done(): m.log("CANCEL", stream.String) - if m.opts.SoftCancel { - // allow a new stream to begin. - m.sem.Recv() - - // attempt to send the soft cancel. if it fails or if the stream is - // busy sending something else, then we have to hard cancel. - if busy, err := stream.SendCancel(ctx.Err()); err != nil { - m.terminate(err) - } else if busy { - m.log("BUSY", stream.String) - m.terminate(ctx.Err()) - } - stream.Cancel(ctx.Err()) - - // wait for the stream to signal that it is finished. - <-stream.Finished() - } else { - // If the stream isn't already finished, we have to terminate the - // transport to do an active cancel. If it is already finished, - // there is no need. - if !stream.Cancel(ctx.Err()) { - m.log("UNFIN", stream.String) - m.terminate(ctx.Err()) - } else { - m.log("CLEAN", stream.String) - } - - // wait for the stream to signal that it is finished. - <-stream.Finished() - - // allow a new stream to begin. - m.sem.Recv() + if err := stream.SendCancel(ctx.Err()); err != nil { + // SendCancel can fail if it's an IO error which reader will catch. + m.log("SendCancel", func() string { return fmt.Sprintf("%s: %s", stream.String(), err) }) } + stream.Cancel(ctx.Err()) + <-stream.Finished() } } @@ -424,14 +312,12 @@ func (m *Manager) Closed() <-chan struct{} { } // Unblocked returns a channel that is closed when the manager is no longer -// blocked from creating a new stream due to a previous stream's soft cancel. It -// should not be called concurrently with NewClientStream or NewServerStream and -// the return result is only valid until the next call to NewClientStream or -// NewServerStream. +// blocked. With multiplexing, multiple streams run concurrently and this +// channel is always closed immediately. +// +// TODO(shubham): audit whether this is still useful in a multiplexing world. +// The only meaningful caller is Pool.Take. func (m *Manager) Unblocked() <-chan struct{} { - if prev := m.streams.GetLatest(); prev != nil { - return prev.Context().Done() - } return closedCh } @@ -439,8 +325,9 @@ func (m *Manager) Unblocked() <-chan struct{} { func (m *Manager) Close() error { m.terminate(managerClosed.New("Close called")) - m.wg.Wait() // wait for all stream goroutines - m.sigs.read.Wait() + <-m.wr.Done() // wait for writer goroutine to exit + m.wg.Wait() // wait for all stream goroutines + m.sigs.read.Wait() // wait for reader goroutine to exit m.sigs.tport.Wait() return m.sigs.tport.Err() @@ -448,10 +335,9 @@ func (m *Manager) Close() error { // NewClientStream starts a stream on the managed transport for use by a client. func (m *Manager) NewClientStream(ctx context.Context, rpc string) (stream *drpcstream.Stream, err error) { - if err := m.acquireSemaphore(ctx); err != nil { + if err := ctx.Err(); err != nil { return nil, err } - return m.newStream(ctx, m.lastStreamID.Add(1), drpc.StreamKindClient, rpc) } @@ -459,28 +345,7 @@ func (m *Manager) NewClientStream(ctx context.Context, rpc string) (stream *drpc // It does this by waiting for the client to issue an invoke message and // returning the details. func (m *Manager) NewServerStream(ctx context.Context) (stream *drpcstream.Stream, rpc string, err error) { - if err := m.acquireSemaphore(ctx); err != nil { - return nil, "", err - } - defer func() { - if err != nil { - m.sem.Recv() - } - }() - - var timeoutCh <-chan time.Time - - // set up the timeout on the context if necessary. - if timeout := m.opts.InactivityTimeout; timeout > 0 { - timer := time.NewTimer(timeout) - defer timer.Stop() - timeoutCh = timer.C - } - select { - case <-timeoutCh: - return nil, "", context.DeadlineExceeded - case <-ctx.Done(): return nil, "", ctx.Err() diff --git a/drpcmanager/manager_test.go b/drpcmanager/manager_test.go index 0b7beeac..b7b1f791 100644 --- a/drpcmanager/manager_test.go +++ b/drpcmanager/manager_test.go @@ -8,7 +8,6 @@ import ( "errors" "io" "net" - "sync" "testing" "time" @@ -21,26 +20,6 @@ import ( "storj.io/drpc/drpcwire" ) -func closed(ch <-chan struct{}) bool { - select { - case <-ch: - return true - default: - return false - } -} - -func TestTimeout(t *testing.T) { - tr := make(blockingTransport) - man := NewWithOptions(tr, Options{ - InactivityTimeout: time.Millisecond, - }) - defer func() { _ = man.Close() }() - - _, _, err := man.NewServerStream(context.Background()) - assert.That(t, errors.Is(err, context.DeadlineExceeded)) -} - func TestDrpcMetadata(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() @@ -49,10 +28,10 @@ func TestDrpcMetadata(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - cman := New(cconn) + cman := New(cconn, Client) defer func() { _ = cman.Close() }() - sman := NewWithOptions(sconn, Options{ + sman := NewWithOptions(sconn, Server, Options{ GRPCMetadataCompatMode: false, }) defer func() { _ = sman.Close() }() @@ -69,11 +48,7 @@ func TestDrpcMetadata(t *testing.T) { assert.NoError(t, stream.RawWrite(drpcwire.KindInvokeMetadata, buf)) assert.NoError(t, stream.RawWrite(drpcwire.KindInvoke, []byte("invoke"))) assert.NoError(t, stream.RawWrite(drpcwire.KindMessage, []byte("message"))) - assert.NoError(t, stream.RawFlush()) - assert.That(t, !closed(cman.Unblocked())) - assert.NoError(t, stream.Close()) - assert.That(t, closed(cman.Unblocked())) }) ctx.Run(func(ctx context.Context) { @@ -109,10 +84,10 @@ func TestDrpcMetadataWithGRPCMetadataCompatMode(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - cman := New(cconn) + cman := New(cconn, Client) defer func() { _ = cman.Close() }() - sman := NewWithOptions(sconn, Options{ + sman := NewWithOptions(sconn, Server, Options{ GRPCMetadataCompatMode: true, }) defer func() { _ = sman.Close() }() @@ -129,11 +104,7 @@ func TestDrpcMetadataWithGRPCMetadataCompatMode(t *testing.T) { assert.NoError(t, stream.RawWrite(drpcwire.KindInvokeMetadata, buf)) assert.NoError(t, stream.RawWrite(drpcwire.KindInvoke, []byte("invoke"))) assert.NoError(t, stream.RawWrite(drpcwire.KindMessage, []byte("message"))) - assert.NoError(t, stream.RawFlush()) - assert.That(t, !closed(cman.Unblocked())) - assert.NoError(t, stream.Close()) - assert.That(t, closed(cman.Unblocked())) }) ctx.Run(func(ctx context.Context) { @@ -207,7 +178,7 @@ func TestManageReader_GlobalMonotonicity_SameStream(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - man := New(sconn) + man := New(sconn, Server) defer func() { _ = man.Close() }() // Consume the invoke and drain messages so HandleFrame doesn't block. @@ -230,33 +201,6 @@ func TestManageReader_GlobalMonotonicity_SameStream(t *testing.T) { waitForClosed(t, man) } -// Cross-stream monotonicity: after seeing stream 2, a frame for stream 1 -// with a higher message ID is still rejected because {1,x} < {2,y}. -func TestManageReader_GlobalMonotonicity_CrossStream(t *testing.T) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - cconn, sconn := net.Pipe() - defer func() { _ = cconn.Close() }() - defer func() { _ = sconn.Close() }() - - man := New(sconn) - defer func() { _ = man.Close() }() - - // Consume both invokes so manageReader can proceed. - ctx.Run(func(ctx context.Context) { - _, _, _ = man.NewServerStream(ctx) - _, _, _ = man.NewServerStream(ctx) - }) - - writeFrames(t, cconn, - createFrame(drpcwire.KindInvoke, 1, 1, "rpc1", true), - createFrame(drpcwire.KindInvoke, 2, 1, "rpc2", true), - createFrame(drpcwire.KindMessage, 1, 4, "bad", true), - ) - - waitForClosed(t, man) -} // Invoke replay: after [s1,m1,invoke,done=true], lastFrameID is bumped to // {1,2}. A replayed [s1,m1,invoke] is caught by the monotonicity check. @@ -268,7 +212,7 @@ func TestManageReader_InvokeReplayBlocked(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - man := New(sconn) + man := New(sconn, Server) defer func() { _ = man.Close() }() ctx.Run(func(ctx context.Context) { @@ -293,7 +237,7 @@ func TestManageReader_ContinuationFramesAccepted(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - man := New(sconn) + man := New(sconn, Server) defer func() { _ = man.Close() }() recv := make(chan []byte, 1) @@ -314,8 +258,8 @@ func TestManageReader_ContinuationFramesAccepted(t *testing.T) { assert.DeepEqual(t, <-recv, []byte("hello")) } -// Old-stream frames are silently ignored on the client side when the local -// stream ID has advanced past the incoming frame's stream ID. +// Old-stream frames are silently ignored when the stream has been cancelled +// and removed from the registry. func TestManageReader_OldStreamFramesIgnored(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() @@ -324,11 +268,10 @@ func TestManageReader_OldStreamFramesIgnored(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - cman := NewWithOptions(cconn, Options{SoftCancel: true}) + cman := New(cconn, Client) defer func() { _ = cman.Close() }() - // Drain all client writes so nothing blocks, and write server - // responses once we've seen enough data. + // Drain all client writes so nothing blocks. ctx.Run(func(ctx context.Context) { buf := make([]byte, 4096) for { @@ -339,13 +282,13 @@ func TestManageReader_OldStreamFramesIgnored(t *testing.T) { } }) - // Create stream 1 on the client, then cancel it so the client - // advances to stream 2. + // Create stream 1 on the client, then cancel it so it's removed + // from the registry. subctx, cancel := context.WithCancel(ctx) - _, err := cman.NewClientStream(subctx, "rpc1") + stream1, err := cman.NewClientStream(subctx, "rpc1") assert.NoError(t, err) cancel() - <-cman.Unblocked() + <-stream1.Finished() stream2, err := cman.NewClientStream(ctx, "rpc2") assert.NoError(t, err) @@ -374,7 +317,7 @@ func TestManageReader_ValidInvokeSequence(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - man := New(sconn) + man := New(sconn, Server) defer func() { _ = man.Close() }() recv := make(chan []byte, 1) @@ -406,7 +349,7 @@ func TestManageReader_MultiFrameDelivery(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - man := New(sconn) + man := New(sconn, Server) defer func() { _ = man.Close() }() recv := make(chan []byte, 1) @@ -439,7 +382,7 @@ func TestManageReader_HigherMsgDiscardsInProgress(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - man := New(sconn) + man := New(sconn, Server) defer func() { _ = man.Close() }() recv := make(chan []byte, 1) @@ -470,7 +413,7 @@ func TestManageReader_KindChangeWithinPacket(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - man := New(sconn) + man := New(sconn, Server) defer func() { _ = man.Close() }() ctx.Run(func(ctx context.Context) { @@ -503,7 +446,7 @@ func TestManageReader_MultiFrameWithSkippedMessageID(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - man := New(sconn) + man := New(sconn, Server) defer func() { _ = man.Close() }() recv := make(chan []byte, 1) @@ -534,7 +477,7 @@ func TestManageReader_InvokeOnExistingStream(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - man := New(sconn) + man := New(sconn, Server) defer func() { _ = man.Close() }() ctx.Run(func(ctx context.Context) { @@ -563,7 +506,7 @@ func TestManageReader_WaitsForStreamCreation(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - man := New(sconn) + man := New(sconn, Server) defer func() { _ = man.Close() }() // Write invoke + message immediately. The message arrives before @@ -589,137 +532,3 @@ func TestManageReader_WaitsForStreamCreation(t *testing.T) { assert.DeepEqual(t, <-recv, []byte("data")) } -type blockingTransport chan struct{} - -func (b blockingTransport) Read(p []byte) (n int, err error) { <-b; return 0, io.EOF } -func (b blockingTransport) Write(p []byte) (n int, err error) { <-b; return 0, io.EOF } -func (b blockingTransport) Close() error { close(b); return nil } - -func TestUnblocked_NoCancel(t *testing.T) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - cconn, sconn := net.Pipe() - defer func() { _ = cconn.Close() }() - defer func() { _ = sconn.Close() }() - - cman := New(cconn) - defer func() { _ = cman.Close() }() - - sman := New(sconn) - defer func() { _ = sman.Close() }() - - ctx.Run(func(ctx context.Context) { - stream, err := cman.NewClientStream(ctx, "rpc") - assert.NoError(t, err) - defer func() { _ = stream.Close() }() - - assert.NoError(t, stream.RawWrite(drpcwire.KindInvoke, []byte("invoke"))) - assert.NoError(t, stream.RawWrite(drpcwire.KindMessage, []byte("message"))) - assert.NoError(t, stream.RawFlush()) - assert.That(t, !closed(cman.Unblocked())) - - assert.NoError(t, stream.Close()) - assert.That(t, closed(cman.Unblocked())) - }) - - ctx.Run(func(ctx context.Context) { - stream, _, err := sman.NewServerStream(ctx) - assert.NoError(t, err) - defer func() { _ = stream.Close() }() - - _, err = stream.RawRecv() - assert.NoError(t, err) - - _, err = stream.RawRecv() - assert.That(t, errors.Is(err, io.EOF)) - }) - - ctx.Wait() -} - -func TestUnblocked_SoftCancel(t *testing.T) { - run := func(t *testing.T, softCancel bool) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - tr := newBlockedTransport() - man := NewWithOptions(tr, Options{SoftCancel: softCancel}) - defer func() { _ = man.Close() }() - defer tr.setReadOpen(true) - defer tr.setWriteOpen(true) - - for i := 0; i < 10; i++ { - func() { - subctx, cancel := context.WithCancel(ctx) - defer cancel() - - stream, err := man.NewClientStream(subctx, "rpc") - if softCancel { - assert.NoError(t, err) - } else if i > 0 { - assert.Error(t, err) - return - } - defer func() { _ = stream.Close() }() - - assert.That(t, !closed(man.Unblocked())) - cancel() - - // temporary unblock writing to allow the stream to finish soft cancel - tr.setWriteOpen(true) - <-man.Unblocked() - tr.setWriteOpen(false) - }() - } - } - - t.Run("Enabled", func(t *testing.T) { run(t, true) }) - t.Run("Disabled", func(t *testing.T) { run(t, false) }) -} - -type blockedTransport struct { - mu *sync.Mutex - co *sync.Cond - ro bool - wo bool -} - -func newBlockedTransport() *blockedTransport { - mu := new(sync.Mutex) - co := sync.NewCond(mu) - return &blockedTransport{ - mu: mu, - co: co, - } -} - -func (b *blockedTransport) setWriteOpen(open bool) { - b.mu.Lock() - defer b.mu.Unlock() - - b.wo = open - b.co.Broadcast() -} - -func (b *blockedTransport) setReadOpen(open bool) { - b.mu.Lock() - defer b.mu.Unlock() - - b.ro = open - b.co.Broadcast() -} - -func (b *blockedTransport) wait(p int, rw *bool) (int, error) { - b.mu.Lock() - defer b.mu.Unlock() - - for !*rw { - b.co.Wait() - } - return p, nil -} - -func (b *blockedTransport) Read(p []byte) (n int, err error) { return b.wait(len(p), &b.ro) } -func (b *blockedTransport) Write(p []byte) (n int, err error) { return b.wait(len(p), &b.wo) } -func (b *blockedTransport) Close() error { return nil } diff --git a/drpcmanager/random_test.go b/drpcmanager/random_test.go index db41bc20..af63bf4d 100644 --- a/drpcmanager/random_test.go +++ b/drpcmanager/random_test.go @@ -45,14 +45,15 @@ func (rc *randClient) newSteam(ctx context.Context, man *Manager) (*drpcstream.S return stream, err } -func (rc *randClient) execute(t *testing.T, wr *drpcwire.Writer, op byte) { +func (rc *randClient) execute(t *testing.T, wr *drpcwire.MuxWriter, op byte) { cmd, arg, done := parseOp(op) if !rc.active { - assert.NoError(t, wr.WritePacket(drpcwire.Packet{ + assert.NoError(t, wr.WriteFrame(drpcwire.Frame{ Data: make([]byte, arg), ID: rc.id.incMessage(), Kind: drpcwire.KindInvoke, + Done: true, })) rc.active = true } @@ -60,9 +61,10 @@ func (rc *randClient) execute(t *testing.T, wr *drpcwire.Writer, op byte) { switch cmd { case 0: // new invoke if rc.active { - assert.NoError(t, wr.WritePacket(drpcwire.Packet{ + assert.NoError(t, wr.WriteFrame(drpcwire.Frame{ ID: rc.id.incMessage(), Kind: drpcwire.KindClose, + Done: true, })) } @@ -99,10 +101,11 @@ func (rc *randClient) execute(t *testing.T, wr *drpcwire.Writer, op byte) { })) case 2: // cause the remote side to close - assert.NoError(t, wr.WritePacket(drpcwire.Packet{ + assert.NoError(t, wr.WriteFrame(drpcwire.Frame{ Data: []byte("remote-close"), ID: rc.id.incMessage(), Kind: drpcwire.KindMessage, + Done: true, })) case 3, 4, 5, 6, 7: // send normal message @@ -130,7 +133,7 @@ func (rs *randServer) newSteam(ctx context.Context, man *Manager) (*drpcstream.S return man.NewClientStream(ctx, "rpc") } -func (rs *randServer) execute(t *testing.T, wr *drpcwire.Writer, op byte) { +func (rs *randServer) execute(t *testing.T, wr *drpcwire.MuxWriter, op byte) { cmd, arg, done := parseOp(op) switch cmd { @@ -160,10 +163,11 @@ func (rs *randServer) execute(t *testing.T, wr *drpcwire.Writer, op byte) { })) case 2: // cause the remote side to close - assert.NoError(t, wr.WritePacket(drpcwire.Packet{ + assert.NoError(t, wr.WriteFrame(drpcwire.Frame{ Data: []byte("remote-close"), ID: rs.id.incMessage(), Kind: drpcwire.KindMessage, + Done: true, })) case 3, 4, 5, 6, 7: // send random message @@ -185,7 +189,7 @@ func (rs *randServer) execute(t *testing.T, wr *drpcwire.Writer, op byte) { type runner interface { newSteam(ctx context.Context, man *Manager) (*drpcstream.Stream, error) - execute(t *testing.T, wr *drpcwire.Writer, op byte) + execute(t *testing.T, wr *drpcwire.MuxWriter, op byte) } func runRandomized(t *testing.T, prog []byte, r runner) { @@ -196,8 +200,10 @@ func runRandomized(t *testing.T, prog []byte, r runner) { defer func() { _ = pc.Close() }() defer func() { _ = ps.Close() }() - wr := drpcwire.NewWriter(pc, 0) - man := New(ps) + wr := drpcwire.NewMuxWriter(pc, func(error) {}) + defer func() { wr.Stop(nil); <-wr.Done() }() + + man := New(ps, Server) defer func() { _ = man.Close() }() errch := make(chan error, 1) @@ -223,7 +229,6 @@ func runRandomized(t *testing.T, prog []byte, r runner) { for _, op := range prog { r.execute(t, wr, op) - assert.NoError(t, wr.Flush()) } assert.NoError(t, man.Close()) diff --git a/drpcserver/server.go b/drpcserver/server.go index 13c81860..1849e805 100644 --- a/drpcserver/server.go +++ b/drpcserver/server.go @@ -12,7 +12,6 @@ import ( "github.com/zeebo/errs" "storj.io/drpc" - "storj.io/drpc/drpccache" "storj.io/drpc/drpcctx" "storj.io/drpc/drpcmanager" "storj.io/drpc/drpcstats" @@ -140,22 +139,28 @@ func (s *Server) ServeOne(ctx context.Context, tr drpc.Transport) (err error) { } } - man := drpcmanager.NewWithOptions(tr, s.opts.Manager) - defer func() { err = errs.Combine(err, man.Close()) }() - - cache := drpccache.New() - defer cache.Clear() - - ctx = drpccache.WithContext(ctx, cache) + man := drpcmanager.NewWithOptions(tr, drpcmanager.Server, s.opts.Manager) + var wg sync.WaitGroup + defer func() { + wg.Wait() + err = errs.Combine(err, man.Close()) + }() for { stream, rpc, err := man.NewServerStream(ctx) if err != nil { return errs.Wrap(err) } - if err := s.handleRPC(stream, rpc); err != nil { - return errs.Wrap(err) - } + // TODO: add worker pool + wg.Add(1) + go func() { + defer wg.Done() + if err := s.handleRPC(stream, rpc); err != nil { + if s.opts.Log != nil { + s.opts.Log(err) + } + } + }() } } diff --git a/drpcstream/pktbuf.go b/drpcstream/pktbuf.go deleted file mode 100644 index db688649..00000000 --- a/drpcstream/pktbuf.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (C) 2019 Storj Labs, Inc. -// See LICENSE for copying information. - -package drpcstream - -import ( - "sync" -) - -type packetBuffer struct { - mu sync.Mutex - cond sync.Cond - err error - data []byte - set bool - held bool -} - -func (pb *packetBuffer) init() { - pb.cond.L = &pb.mu -} - -func (pb *packetBuffer) Close(err error) { - pb.mu.Lock() - defer pb.mu.Unlock() - - for pb.held { - pb.cond.Wait() - } - - if pb.err == nil { - pb.data = nil - pb.set = false - pb.err = err - pb.cond.Broadcast() - } -} - -func (pb *packetBuffer) Put(data []byte) { - pb.mu.Lock() - defer pb.mu.Unlock() - - for pb.set && pb.err == nil { - pb.cond.Wait() - } - if pb.err != nil { - return - } - - pb.data = data - pb.set = true - pb.held = false - pb.cond.Broadcast() - - for pb.set || pb.held { - pb.cond.Wait() - } -} - -func (pb *packetBuffer) Get() ([]byte, error) { - pb.mu.Lock() - defer pb.mu.Unlock() - - for !pb.set && pb.err == nil { - pb.cond.Wait() - } - if pb.err != nil { - return nil, pb.err - } - - pb.held = true - pb.cond.Broadcast() - - return pb.data, nil -} - -func (pb *packetBuffer) Done() { - pb.mu.Lock() - defer pb.mu.Unlock() - - pb.data = nil - pb.set = false - pb.held = false - pb.cond.Broadcast() -} diff --git a/drpcstream/ring_buffer.go b/drpcstream/ring_buffer.go new file mode 100644 index 00000000..5cb620ab --- /dev/null +++ b/drpcstream/ring_buffer.go @@ -0,0 +1,116 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcstream + +import "sync" + +// defaultRingBufferCapacity is the number of messages the ring buffer can +// hold before the producer blocks. This decouples the transport reader +// (manageReader) from the consumer (RPC handler), preventing a slow handler +// from blocking frame delivery to other streams. +// +// TODO: benchmark whether power-of-2 masking improves performance over modulo. +const defaultRingBufferCapacity = 256 + +// ringBuffer is a bounded single-producer / single-consumer FIFO queue for +// assembled packet data. It sits between manageReader (producer, calls +// Enqueue) and the application goroutine (consumer, calls Dequeue/Done). +// +// Slots are pre-allocated and reused: each slot's backing array grows via +// append to fit incoming data, then stays at its high-water mark, avoiding +// per-message allocation in steady state. +// +// After Close, Dequeue drains any queued messages before returning the close +// error. This ensures graceful shutdown (KindClose/KindCloseSend) delivers +// all buffered data to the consumer. +type ringBuffer struct { + mu sync.Mutex + cond sync.Cond + + buf [][]byte // ring of byte slices + head int // next write position (producer) + tail int // next read position (consumer) + count int // number of occupied slots + + held bool // true between Dequeue and Done + err error // terminal error, set by Close +} + +func (rb *ringBuffer) init() { + rb.cond.L = &rb.mu + rb.buf = make([][]byte, defaultRingBufferCapacity) +} + +// Enqueue copies data into the next write slot. If the buffer is full, it +// blocks until a slot is freed or the buffer is closed. If the buffer is +// closed, Enqueue returns silently without enqueuing. +func (rb *ringBuffer) Enqueue(data []byte) { + rb.mu.Lock() + defer rb.mu.Unlock() + + for rb.count == len(rb.buf) && rb.err == nil { + rb.cond.Wait() + } + if rb.err != nil { + return + } + + rb.buf[rb.head] = append(rb.buf[rb.head][:0], data...) + rb.head = (rb.head + 1) % len(rb.buf) + rb.count++ + rb.cond.Broadcast() +} + +// Dequeue returns the data from the next read slot. If the buffer is empty, +// it blocks until data is available or the buffer is closed. The returned +// slice is valid until Done is called. +func (rb *ringBuffer) Dequeue() ([]byte, error) { + rb.mu.Lock() + defer rb.mu.Unlock() + + for rb.count == 0 && rb.err == nil { + rb.cond.Wait() + } + if rb.count == 0 && rb.err != nil { + return nil, rb.err + } + + rb.held = true + return rb.buf[rb.tail], nil +} + +// Done advances the read pointer, making the slot available for reuse. +// It must be called exactly once after each successful Dequeue. +// +// TODO(shubham): remove this method once a shared buffer pool is introduced. +// With a pool, Dequeue will advance the tail immediately and the caller will +// return the buffer to the pool directly. +func (rb *ringBuffer) Done() { + rb.mu.Lock() + defer rb.mu.Unlock() + + rb.tail = (rb.tail + 1) % len(rb.buf) + rb.count-- + rb.held = false + rb.cond.Broadcast() +} + +// Close marks the buffer as closed with the given error. All blocked Enqueue +// and Dequeue calls are woken and will return. Close waits for any in-progress +// Dequeue/Done pair to complete before setting the error. Subsequent calls are +// no-ops. +func (rb *ringBuffer) Close(err error) { + rb.mu.Lock() + defer rb.mu.Unlock() + + for rb.held { + rb.cond.Wait() + } + if rb.err != nil { + return + } + + rb.err = err + rb.cond.Broadcast() +} diff --git a/drpcstream/ring_buffer_test.go b/drpcstream/ring_buffer_test.go new file mode 100644 index 00000000..8be9c587 --- /dev/null +++ b/drpcstream/ring_buffer_test.go @@ -0,0 +1,228 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcstream + +import ( + "io" + "sync" + "testing" + + "github.com/zeebo/assert" +) + +func TestRingBuffer_EnqueueDequeue(t *testing.T) { + var rb ringBuffer + rb.init() + + rb.Enqueue([]byte("hello")) + + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("hello")) + rb.Done() +} + +func TestRingBuffer_FIFO(t *testing.T) { + var rb ringBuffer + rb.init() + + rb.Enqueue([]byte("first")) + rb.Enqueue([]byte("second")) + rb.Enqueue([]byte("third")) + + for _, want := range []string{"first", "second", "third"} { + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte(want)) + rb.Done() + } +} + +func TestRingBuffer_DequeueBlocksUntilEnqueue(t *testing.T) { + var rb ringBuffer + rb.init() + + got := make(chan []byte, 1) + go func() { + data, err := rb.Dequeue() + assert.NoError(t, err) + got <- data + }() + + rb.Enqueue([]byte("delayed")) + assert.DeepEqual(t, <-got, []byte("delayed")) + rb.Done() +} + +func TestRingBuffer_EnqueueBlocksWhenFull(t *testing.T) { + var rb ringBuffer + rb.cond.L = &rb.mu + rb.buf = make([][]byte, 2) // capacity 2 + + rb.Enqueue([]byte("a")) + rb.Enqueue([]byte("b")) + + // Third enqueue should block until we drain one. + done := make(chan struct{}) + go func() { + rb.Enqueue([]byte("c")) + close(done) + }() + + // Drain one slot. + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("a")) + rb.Done() + + // Now the blocked Enqueue should complete. + <-done + + // Verify remaining items. + data, err = rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("b")) + rb.Done() + + data, err = rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("c")) + rb.Done() +} + +func TestRingBuffer_CloseUnblocksDequeue(t *testing.T) { + var rb ringBuffer + rb.init() + + errch := make(chan error, 1) + go func() { + _, err := rb.Dequeue() + errch <- err + }() + + rb.Close(io.EOF) + assert.Equal(t, <-errch, io.EOF) +} + +func TestRingBuffer_CloseUnblocksEnqueue(t *testing.T) { + var rb ringBuffer + rb.cond.L = &rb.mu + rb.buf = make([][]byte, 1) // capacity 1 + + rb.Enqueue([]byte("fill")) + + done := make(chan struct{}) + go func() { + rb.Enqueue([]byte("blocked")) + close(done) + }() + + rb.Close(io.EOF) + <-done +} + +func TestRingBuffer_CloseDrainsQueued(t *testing.T) { + var rb ringBuffer + rb.init() + + rb.Enqueue([]byte("queued")) + rb.Close(io.EOF) + + // Dequeue returns the queued data first. + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("queued")) + rb.Done() + + // Next Dequeue returns the close error. + data, err = rb.Dequeue() + assert.Nil(t, data) + assert.Equal(t, err, io.EOF) +} + +func TestRingBuffer_CloseIdempotent(t *testing.T) { + var rb ringBuffer + rb.init() + + rb.Close(io.EOF) + rb.Close(io.ErrUnexpectedEOF) // should not overwrite + + _, err := rb.Dequeue() + assert.Equal(t, err, io.EOF) // original error preserved +} + +func TestRingBuffer_EnqueueAfterClose(t *testing.T) { + var rb ringBuffer + rb.init() + + rb.Close(io.EOF) + rb.Enqueue([]byte("dropped")) // should not panic or block +} + +func TestRingBuffer_SlotReuse(t *testing.T) { + var rb ringBuffer + rb.cond.L = &rb.mu + rb.buf = make([][]byte, 2) + + // Fill and drain a few rounds to exercise slot reuse. + for round := 0; round < 5; round++ { + rb.Enqueue([]byte("data")) + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("data")) + rb.Done() + } +} + +func TestRingBuffer_CloseWaitsForHeld(t *testing.T) { + var rb ringBuffer + rb.init() + + rb.Enqueue([]byte("msg")) + + // Dequeue the data but don't call Done yet. + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("msg")) + + closed := make(chan struct{}) + go func() { + rb.Close(io.EOF) + close(closed) + }() + + // Close should be blocked because held is true. + // Call Done to release it. + rb.Done() + <-closed +} + +func TestRingBuffer_ConcurrentProducerConsumer(t *testing.T) { + var rb ringBuffer + rb.init() + + const n = 1000 + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + for i := 0; i < n; i++ { + rb.Enqueue([]byte{byte(i)}) + } + }() + + go func() { + defer wg.Done() + for i := 0; i < n; i++ { + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.Equal(t, data[0], byte(i)) + rb.Done() + } + }() + + wg.Wait() + rb.Close(io.EOF) +} diff --git a/drpcstream/stream.go b/drpcstream/stream.go index 5f099b60..421a1269 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -26,12 +26,6 @@ type Options struct { // SplitSize controls the default size we split data packets into frames. SplitSize int - // ManualFlush controls if the stream will automatically flush after every - // message send. Note that flushing is not part of the drpc.Stream - // interface, so if you use this you must be ready to type assert and call - // RawFlush dynamically. - ManualFlush bool - // MaximumBufferSize causes the Stream to drop any internal buffers that are // larger than this amount to control maximum memory usage at the expense of // more allocations. 0 is unlimited. @@ -47,21 +41,32 @@ type Stream struct { opts Options task *trace.Task + // write and read serialize operations within a stream. The data path + // (MsgSend/MsgRecv) and the control path (SendCancel/Close/SendError) + // genuinely race because cancellation arrives from manageStream while the + // application may be mid-send. These are inspectMutex (not sync.Mutex) so + // that checkFinished can test whether ops are in flight without blocking. write inspectMutex read inspectMutex - flush sync.Once pa drpcwire.PacketAssembler id drpcwire.ID - wr *drpcwire.Writer - pbuf packetBuffer + wr *drpcwire.MuxWriter + recvQueue ringBuffer wbuf []byte mu sync.Mutex // protects state transitions sigs struct { - send drpcsignal.Signal // set when done sending messages - recv drpcsignal.Signal // set when done receiving messages + send drpcsignal.Signal // set when done sending messages + recv drpcsignal.Signal // set when done receiving messages + // Stream shutdown is two-phase: term then fin. When termination arrives + // (remote error, local cancel, close), there may be an in-flight write + // on the transport that is past the term check and inside WriteFrame. + // term tells new operations to bail out; fin signals that all in-flight + // operations have actually completed. Consumers (manageStream) wait on + // fin before cleaning up, guaranteeing no goroutine is touching the + // stream afterward. term drpcsignal.Signal // set when the stream is terminating and no new ops should begin fin drpcsignal.Signal // set when the stream is finished and all ops are complete cancel drpcsignal.Signal // set when externally canceled @@ -73,7 +78,7 @@ var _ drpc.Stream = (*Stream)(nil) // New returns a new stream bound to the context with the given stream id and // will use the writer to write messages on. It is important use monotonically // increasing stream ids within a single transport. -func New(ctx context.Context, sid uint64, wr *drpcwire.Writer) *Stream { +func New(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter) *Stream { return NewWithOptions(ctx, sid, wr, Options{}) } @@ -81,7 +86,7 @@ func New(ctx context.Context, sid uint64, wr *drpcwire.Writer) *Stream { // stream id and will use the writer to write messages on. It is important use // monotonically increasing stream ids within a single transport. The options // are used to control details of how the Stream operates. -func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts Options) *Stream { +func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, opts Options) *Stream { var task *trace.Task if trace.IsEnabled() { kind, rpc := drpcopts.GetStreamKind(&opts.Internal), drpcopts.GetStreamRPC(&opts.Internal) @@ -104,11 +109,11 @@ func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts O pa: pa, id: drpcwire.ID{Stream: sid}, - wr: wr.Reset(), + wr: wr, } // initialize the packet buffer - s.pbuf.init() + s.recvQueue.init() return s } @@ -191,30 +196,6 @@ func (s *Stream) Finished() <-chan struct{} { return s.sigs.fin.Signal() } // issue any writes or reads. func (s *Stream) IsFinished() bool { return s.sigs.fin.IsSet() } -// SetManualFlush sets the ManualFlush option. It cannot be called concurrently -// with any sends or receives on the stream. Example use case: -// -// flusher := stream.(interface{ -// GetStream() drpc.Stream -// }).GetStream().(interface{ -// SetManualFlush(bool) -// }) -// -// flusher.SetManualFlush(true) -// err = stream.Send(&pb.Message{Request: "hello, "}) -// flusher.SetManualFlush(false) -// if err != nil { -// return err -// } -// -// // the next send will send both messages in the same write -// // to the underlying connection. -// err = stream.Send(&pb.Message{Request: "world!"}) -// if err != nil { -// return err -// } -func (s *Stream) SetManualFlush(mf bool) { s.opts.ManualFlush = mf } - // // frame handler // @@ -245,7 +226,7 @@ func (s *Stream) handlePacket(pkt drpcwire.Packet) (err error) { s.log("HANDLE", pkt.String) if pkt.Kind == drpcwire.KindMessage { - s.pbuf.Put(pkt.Data) + s.recvQueue.Enqueue(pkt.Data) return nil } @@ -273,13 +254,13 @@ func (s *Stream) handlePacket(pkt drpcwire.Packet) (err error) { case drpcwire.KindClose: s.sigs.recv.Set(io.EOF) - s.pbuf.Close(io.EOF) + s.recvQueue.Close(io.EOF) s.terminate(drpc.ClosedError.New("remote closed the stream")) return nil case drpcwire.KindCloseSend: s.sigs.recv.Set(io.EOF) - s.pbuf.Close(io.EOF) + s.recvQueue.Close(io.EOF) s.terminateIfBothClosed() return nil @@ -299,9 +280,11 @@ func (s *Stream) handlePacket(pkt drpcwire.Packet) (err error) { // helpers // -// checkFinished checks to see if the stream is terminated, and if so, sets the -// finished flag. This must be called after every read or write is complete, as -// well as when the stream becomes terminated. +// checkFinished bridges the two-phase shutdown. It is called in two places: +// inside terminate() for when no I/O is in flight (fin fires immediately), +// and deferred after every read/write unlock for when an operation was in +// flight at termination time (fin fires once the last operation completes). +// Whichever call site runs last sees term set and both locks free, and sets fin. func (s *Stream) checkFinished() { if s.sigs.term.IsSet() && s.write.Unlocked() && s.read.Unlocked() { if s.sigs.fin.Set(nil) { @@ -314,10 +297,10 @@ func (s *Stream) checkFinished() { } } -// checkCancelError will replace the error with one from the cancel signal if it +// CheckCancelError will replace the error with one from the cancel signal if it // is set. This is to prevent errors from reads/writes to a transport after it // has been asynchronously closed due to context cancelation. -func (s *Stream) checkCancelError(err error) error { +func (s *Stream) CheckCancelError(err error) error { if s.sigs.cancel.IsSet() { return s.sigs.cancel.Err() } @@ -331,9 +314,9 @@ func (s *Stream) newFrameLocked(kind drpcwire.Kind) drpcwire.Frame { return drpcwire.Frame{ID: s.id, Kind: kind} } -// sendPacketLocked sends the packet in a single write and flushes. It does not -// check for any conditions to stop it from writing and is meant for internal -// stream use to do things like signal errors or closes to the remote side. +// sendPacketLocked sends the packet in a single write. It does not check for +// any conditions to stop it from writing and is meant for internal stream use +// to do things like signal errors or closes to the remote side. func (s *Stream) sendPacketLocked(kind drpcwire.Kind, control bool, data []byte) (err error) { fr := s.newFrameLocked(kind) fr.Data = data @@ -346,9 +329,6 @@ func (s *Stream) sendPacketLocked(kind drpcwire.Kind, control bool, data []byte) if err := s.wr.WriteFrame(fr); err != nil { return errs.Wrap(err) } - if err := s.wr.Flush(); err != nil { - return errs.Wrap(err) - } return nil } @@ -366,10 +346,26 @@ func (s *Stream) terminate(err error) { s.sigs.send.Set(err) s.sigs.recv.Set(err) s.sigs.term.Set(err) - s.pbuf.Close(err) + s.recvQueue.Close(err) s.checkFinished() } +// WriteInvoke writes the invoke metadata (if any) and invoke frame +// atomically under the write lock. This prevents SendCancel from +// interrupting the invoke sequence. +func (s *Stream) WriteInvoke(rpc string, metadata []byte) error { + defer s.checkFinished() + s.write.Lock() + defer s.write.Unlock() + + if len(metadata) > 0 { + if err := s.rawWriteLocked(drpcwire.KindInvokeMetadata, metadata); err != nil { + return err + } + } + return s.rawWriteLocked(drpcwire.KindInvoke, []byte(rpc)) +} + // // raw read/write // @@ -405,75 +401,25 @@ func (s *Stream) rawWriteLocked(kind drpcwire.Kind, data []byte) (err error) { s.log("SEND", fr.String) if err := s.wr.WriteFrame(fr); err != nil { - return s.checkCancelError(errs.Wrap(err)) + return s.CheckCancelError(errs.Wrap(err)) } else if fr.Done { return nil } } } -// RawFlush flushes any buffers of data. -func (s *Stream) RawFlush() (err error) { - defer s.checkFinished() - s.write.Lock() - defer s.write.Unlock() - - return s.rawFlushLocked() -} - -// rawFlushLocked checks for any conditions that should cause a flush to not -// happen and then issues the flush. It assumes the caller is holding the -// appropriate locks. -func (s *Stream) rawFlushLocked() (err error) { - if s.wr.Empty() { - return nil - } - - switch { - case s.sigs.cancel.IsSet(): - return s.sigs.cancel.Err() - case s.sigs.send.IsSet(): - return s.sigs.send.Err() - case s.sigs.term.IsSet(): - return s.sigs.term.Err() - } - - s.log("FLUSH", func() string { return "" }) - - return s.checkCancelError(errs.Wrap(s.wr.Flush())) -} - -func (s *Stream) checkRecvFlush() (err error) { - s.flush.Do(func() { err = s.RawFlush() }) - if err != nil { - return err - } - - if s.opts.ManualFlush && !s.wr.Empty() { - if err := s.RawFlush(); err != nil { - return err - } - } - - return nil -} - // RawRecv returns the raw bytes received for a message. func (s *Stream) RawRecv() (data []byte, err error) { - if err := s.checkRecvFlush(); err != nil { - return nil, err - } - defer s.checkFinished() s.read.Lock() defer s.read.Unlock() - data, err = s.pbuf.Get() + data, err = s.recvQueue.Dequeue() if err != nil { return nil, err } data = append([]byte(nil), data...) - s.pbuf.Done() + s.recvQueue.Done() return data, nil } @@ -482,12 +428,9 @@ func (s *Stream) RawRecv() (data []byte, err error) { // msg read/write // -// MsgSend marshals the message with the encoding, writes it, and flushes. +// MsgSend marshals the message with the encoding and writes it. func (s *Stream) MsgSend(msg drpc.Message, enc drpc.Encoding) (err error) { defer func() { err = drpc.ToRPCErr(err) }() - - s.flush.Do(func() {}) - defer s.checkFinished() s.write.Lock() defer s.write.Unlock() @@ -502,9 +445,6 @@ func (s *Stream) MsgSend(msg drpc.Message, enc drpc.Encoding) (err error) { if err := s.rawWriteLocked(drpcwire.KindMessage, wbuf); err != nil { return err } - if !s.opts.ManualFlush { - return s.rawFlushLocked() - } return nil } @@ -512,20 +452,16 @@ func (s *Stream) MsgSend(msg drpc.Message, enc drpc.Encoding) (err error) { func (s *Stream) MsgRecv(msg drpc.Message, enc drpc.Encoding) (err error) { defer func() { err = drpc.ToRPCErr(err) }() - if err := s.checkRecvFlush(); err != nil { - return err - } - defer s.checkFinished() s.read.Lock() defer s.read.Unlock() - data, err := s.pbuf.Get() + data, err := s.recvQueue.Dequeue() if err != nil { return err } err = enc.Unmarshal(data, msg) - s.pbuf.Done() + s.recvQueue.Done() return err } @@ -560,25 +496,19 @@ func (s *Stream) SendError(serr error) (err error) { s.terminate(termError) s.mu.Unlock() - return s.checkCancelError(s.sendPacketLocked(drpcwire.KindError, false, drpcwire.MarshalError(serr))) + return s.CheckCancelError(s.sendPacketLocked(drpcwire.KindError, false, drpcwire.MarshalError(serr))) } -// SendCancel transitions the stream into the canceled state with -// context.Canceled and sends a cancel error to the remote side for a soft -// cancel. It is a no-op if the stream is already terminated. It returns true -// for busy if writes are already blocked and a hard cancel is required. -func (s *Stream) SendCancel(err error) (busy bool, _ error) { +// SendCancel terminates the stream and sends a cancel to the remote side. It +// blocks until any in-progress write completes. It is a no-op if the stream is +// already terminated. +func (s *Stream) SendCancel(err error) error { s.log("CALL", func() string { return "SendCancel()" }) s.mu.Lock() - if !s.write.Unlocked() { // if writes are happening, then we have to do a hard cancel. - s.mu.Unlock() - return true, nil - } - if s.sigs.term.IsSet() { s.mu.Unlock() - return false, nil + return nil } defer s.checkFinished() @@ -589,7 +519,7 @@ func (s *Stream) SendCancel(err error) (busy bool, _ error) { s.terminate(err) s.mu.Unlock() - return false, s.checkCancelError(s.sendPacketLocked(drpcwire.KindCancel, true, nil)) + return s.CheckCancelError(s.sendPacketLocked(drpcwire.KindCancel, true, nil)) } // Close terminates the stream and sends that the stream has been closed to the @@ -610,7 +540,7 @@ func (s *Stream) Close() (err error) { s.terminate(termClosed) s.mu.Unlock() - return s.checkCancelError(s.sendPacketLocked(drpcwire.KindClose, false, nil)) + return s.CheckCancelError(s.sendPacketLocked(drpcwire.KindClose, false, nil)) } // CloseSend informs the remote that no more messages will be sent. If the remote has @@ -633,7 +563,7 @@ func (s *Stream) CloseSend() (err error) { s.terminateIfBothClosed() s.mu.Unlock() - return s.checkCancelError(s.sendPacketLocked(drpcwire.KindCloseSend, false, nil)) + return s.CheckCancelError(s.sendPacketLocked(drpcwire.KindCloseSend, false, nil)) } // Cancel transitions the stream into a state where all writes to the transport will return diff --git a/drpcstream/stream_test.go b/drpcstream/stream_test.go index 62e49448..9cee8ddf 100644 --- a/drpcstream/stream_test.go +++ b/drpcstream/stream_test.go @@ -4,7 +4,6 @@ package drpcstream import ( - "bytes" "context" "errors" "io" @@ -19,6 +18,15 @@ import ( "storj.io/drpc/drpcwire" ) +// testMuxWriter creates a MuxWriter that writes to io.Discard with a no-op +// error handler. The writer goroutine is stopped when the test finishes. +func testMuxWriter(t *testing.T) *drpcwire.MuxWriter { + t.Helper() + mw := drpcwire.NewMuxWriter(io.Discard, func(error) {}) + t.Cleanup(func() { mw.Stop(nil); <-mw.Done() }) + return mw +} + // handleFrame is a helper that sends a single-frame packet to the stream. // It constructs a frame with the given kind, matching the stream's ID, // using the provided message ID, done=true. @@ -34,6 +42,7 @@ func TestStream_StateTransitions(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() + mw := testMuxWriter(t) any := errors.New("any sentinel error") checkErrs := func(t *testing.T, exp interface{}, got error) { @@ -108,7 +117,7 @@ func TestStream_StateTransitions(t *testing.T) { } for _, test := range cases { - st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) + st := New(ctx, 1, mw) assert.NoError(t, test.Op(st)) checkErrs(t, test.Send, st.RawWrite(drpcwire.KindMessage, nil)) @@ -125,6 +134,8 @@ func TestStream_Unblocks(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() + mw := testMuxWriter(t) + cases := []struct { Op func(st *Stream) error }{ @@ -158,7 +169,7 @@ func TestStream_Unblocks(t *testing.T) { } for _, test := range cases { - st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) + st := New(ctx, 1, mw) ctx.Run(func(ctx context.Context) { _, _ = st.RawRecv() }) assert.NoError(t, test.Op(st)) @@ -168,7 +179,8 @@ func TestStream_Unblocks(t *testing.T) { func TestStream_ContextCancel(t *testing.T) { ctx := context.Background() - st := New(ctx, 0, drpcwire.NewWriter(io.Discard, 0)) + mw := testMuxWriter(t) + st := New(ctx, 0, mw) child, cancel := context.WithCancel(st.Context()) defer cancel() @@ -182,84 +194,32 @@ func TestStream_ConcurrentCloseCancel(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() - pr, pw := io.Pipe() - defer func() { _ = pr.Close() }() - defer func() { _ = pw.Close() }() - - st := New(ctx, 0, drpcwire.NewWriter(pw, 0)) + mw := testMuxWriter(t) + st := New(ctx, 0, mw) - // start the Close call + // Close and Cancel concurrently should not panic or deadlock. errch := make(chan error, 1) go func() { errch <- st.Close() }() - // wait for the close to begin writing - _, err := pr.Read(make([]byte, 1)) - assert.NoError(t, err) - - // cancel the context and close the transport st.Cancel(context.Canceled) - assert.NoError(t, pw.Close()) - - // we should always receive the canceled error - assert.That(t, errors.Is(<-errch, context.Canceled)) -} - -func TestStream_CorkUntilFirstRead(t *testing.T) { - run := func() { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - var buf bytes.Buffer - st := New(ctx, 0, drpcwire.NewWriter(&buf, 50)) - // concurrently read and write at the same time. - // we should always see the write happen. - - errch := make(chan error, 3) - ctx.Run(func(ctx context.Context) { - errch <- st.MsgSend([]byte("write"), byteEncoding{}) - }) - ctx.Run(func(ctx context.Context) { - _, err := st.RawRecv() - errch <- err - }) - ctx.Run(func(ctx context.Context) { - errch <- st.HandleFrame(drpcwire.Frame{ - Data: []byte("read"), - ID: drpcwire.ID{Message: 1}, - Kind: drpcwire.KindMessage, - Done: true, - }) - }) - - assert.NoError(t, <-errch) - assert.NoError(t, <-errch) - assert.NoError(t, <-errch) - - assert.Equal(t, buf.String(), "\x05\x00\x01\x05write") - } - for i := 0; i < 100; i++ { - run() + // Close returns nil or context.Canceled depending on timing. + err := <-errch + if err != nil { + assert.That(t, errors.Is(err, context.Canceled)) } } -type byteEncoding struct{} - -func (byteEncoding) Marshal(msg drpc.Message) ([]byte, error) { return msg.([]byte), nil } -func (byteEncoding) Unmarshal(buf []byte, msg drpc.Message) error { - *msg.(*[]byte) = append(*msg.(*[]byte), buf...) - return nil -} - func TestStream_PacketBufferReuse(t *testing.T) { run := func() { ctx := drpctest.NewTracker(t) defer ctx.Close() defer ctx.Wait() + mw := testMuxWriter(t) data := make([]byte, 20) mid := uint64(1) - st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) + st := New(ctx, 1, mw) ctx.Run(func(ctx context.Context) { for !st.IsTerminated() { @@ -298,44 +258,16 @@ func TestStream_PacketBufferReuse(t *testing.T) { } } -func TestStream_SendCancelBusyDuringBlockedClose(t *testing.T) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - pr, pw := io.Pipe() - defer func() { _ = pr.Close() }() - defer func() { _ = pw.Close() }() - - st := New(ctx, 0, drpcwire.NewWriter(pw, 0)) - - // launch a goroutine to close the stream - ctx.Run(func(ctx context.Context) { _ = st.Close() }) - - // read just 1 byte from the pipe to ensure that the Close has started - _, err := pr.Read(make([]byte, 1)) - assert.NoError(t, err) - assert.That(t, st.IsTerminated()) - - // even though the stream is terminated, soft cancel should report that - // the stream is still busy because the close is being sent. - busy, err := st.SendCancel(context.Canceled) - assert.NoError(t, err) - assert.That(t, busy) -} - // // HandleFrame tests // func TestHandleFrame_FirstFrameOnFreshStream(t *testing.T) { - // On the client side, the first message received will have ID 1. But on the - // server side, invoke is consumed by the manager. The first frame reaching - // the stream could have msg > 1 (e.g., msg=2). nextMessageID=1, so 2 > 1 - // makes this a valid frame. + mw := testMuxWriter(t) for _, messageID := range []uint64{1, 2} { - st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) - // Close the packet buffer so KindMessage Put doesn't block. - st.pbuf.Close(io.EOF) + st := New(context.Background(), 1, mw) + // Close the ring buffer so KindMessage Enqueue doesn't block. + st.recvQueue.Close(io.EOF) err := st.HandleFrame(drpcwire.Frame{ ID: drpcwire.ID{Stream: 1, Message: messageID}, Kind: drpcwire.KindMessage, Done: true, }) @@ -345,7 +277,8 @@ func TestHandleFrame_FirstFrameOnFreshStream(t *testing.T) { // Invoke and InvokeMetadata frames are rejected on an already-created stream. func TestHandleFrame_InvokeOnExistingStream(t *testing.T) { - st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + mw := testMuxWriter(t) + st := New(context.Background(), 1, mw) err := handleFrame(st, drpcwire.KindInvoke, 1) assert.Error(t, err) @@ -354,7 +287,8 @@ func TestHandleFrame_InvokeOnExistingStream(t *testing.T) { } func TestHandleFrame_InvokeMetadataOnExistingStream(t *testing.T) { - st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + mw := testMuxWriter(t) + st := New(context.Background(), 1, mw) err := handleFrame(st, drpcwire.KindInvokeMetadata, 1) assert.Error(t, err) @@ -364,7 +298,8 @@ func TestHandleFrame_InvokeMetadataOnExistingStream(t *testing.T) { // Frames arriving after the stream is terminated are silently ignored. func TestHandleFrame_AfterTerminated(t *testing.T) { - st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + mw := testMuxWriter(t) + st := New(context.Background(), 1, mw) // Terminate the stream via cancel. st.Cancel(context.Canceled) @@ -381,7 +316,8 @@ func TestHandleFrame_MessageDeliveredViaRecv(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() - st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) + mw := testMuxWriter(t) + st := New(ctx, 1, mw) // Launch receiver before sending to avoid Put blocking. recv := make(chan []byte, 1) @@ -400,74 +336,3 @@ func TestHandleFrame_MessageDeliveredViaRecv(t *testing.T) { assert.DeepEqual(t, <-recv, []byte("payload")) } - -// -// Write-side tests -// - -func TestRawWrite_NonMessageSingleFrame(t *testing.T) { - // Non-KindMessage kinds must produce a single frame (n=0 in - // rawWriteLocked means default 64KB, effectively no split for - // small payloads). Verify they produce exactly one frame with Done=true. - kinds := []drpcwire.Kind{ - drpcwire.KindInvoke, - drpcwire.KindError, - drpcwire.KindCancel, - drpcwire.KindClose, - drpcwire.KindCloseSend, - drpcwire.KindInvokeMetadata, - } - - for _, kind := range kinds { - var buf bytes.Buffer - st := New(context.Background(), 1, drpcwire.NewWriter(&buf, 0)) - - assert.NoError(t, st.RawWrite(kind, []byte("data"))) - assert.NoError(t, st.RawFlush()) - var err error - - // Parse all frames from the buffer — should be exactly one. - data := buf.Bytes() - var frames []drpcwire.Frame - for len(data) > 0 { - var fr drpcwire.Frame - var ok bool - data, fr, ok, err = drpcwire.ParseFrame(data) - assert.NoError(t, err) - assert.That(t, ok) - frames = append(frames, fr) - } - assert.Equal(t, len(frames), 1) - assert.That(t, frames[0].Done) - assert.Equal(t, frames[0].Kind, kind) - } -} - -func TestRawWrite_MessageRespectsSplitSize(t *testing.T) { - var buf bytes.Buffer - st := NewWithOptions(context.Background(), 1, - drpcwire.NewWriter(&buf, 0), - Options{SplitSize: 5}, - ) - - // "helloworld" is 10 bytes, split at 5 → 2 frames. - assert.NoError(t, st.RawWrite(drpcwire.KindMessage, []byte("helloworld"))) - assert.NoError(t, st.RawFlush()) - var err error - - data := buf.Bytes() - var frames []drpcwire.Frame - for len(data) > 0 { - var fr drpcwire.Frame - var ok bool - data, fr, ok, err = drpcwire.ParseFrame(data) - assert.NoError(t, err) - assert.That(t, ok) - frames = append(frames, fr) - } - assert.Equal(t, len(frames), 2) - assert.That(t, !frames[0].Done) - assert.That(t, frames[1].Done) - assert.DeepEqual(t, frames[0].Data, []byte("hello")) - assert.DeepEqual(t, frames[1].Data, []byte("world")) -} diff --git a/drpcwire/mux_writer.go b/drpcwire/mux_writer.go new file mode 100644 index 00000000..c12bc565 --- /dev/null +++ b/drpcwire/mux_writer.go @@ -0,0 +1,96 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcwire + +import ( + "io" + "sync" +) + +type MuxWriter struct { + w io.Writer + buf []byte + mu sync.Mutex + cond *sync.Cond + closed bool + closeErr error + onError func(error) + done chan struct{} +} + +var defaultBufferCapacity = 4096 + +func NewMuxWriter(w io.Writer, onError func(error)) *MuxWriter { + mw := &MuxWriter{ + w: w, + buf: make([]byte, 0, defaultBufferCapacity), + onError: onError, + done: make(chan struct{}), + } + mw.cond = sync.NewCond(&mw.mu) + go mw.run() + return mw +} + +func (mw *MuxWriter) run() { + defer close(mw.done) + spare := make([]byte, 0, defaultBufferCapacity) + for { + mw.mu.Lock() + for len(mw.buf) == 0 && !mw.closed { + mw.cond.Wait() + } + + if mw.closed { + mw.mu.Unlock() + return + } + + temp := mw.buf + mw.buf = spare + spare = temp + mw.mu.Unlock() + if _, err := mw.w.Write(spare); err != nil { + mw.mu.Lock() + if mw.closed { + mw.mu.Unlock() + return + } + mw.closed = true + mw.closeErr = err + mw.mu.Unlock() + if mw.onError != nil { + mw.onError(err) + } + return + } + + spare = spare[:0] + } +} + +func (mw *MuxWriter) WriteFrame(fr Frame) (err error) { + mw.mu.Lock() + defer mw.mu.Unlock() + if mw.closed { + return mw.closeErr + } + mw.buf = AppendFrame(mw.buf, fr) + mw.cond.Signal() + return nil +} + +func (mw *MuxWriter) Stop(err error) { + mw.mu.Lock() + if !mw.closed { + mw.closed = true + mw.closeErr = err + mw.cond.Broadcast() + } + mw.mu.Unlock() +} + +func (mw *MuxWriter) Done() <-chan struct{} { + return mw.done +} diff --git a/drpcwire/mux_writer_test.go b/drpcwire/mux_writer_test.go new file mode 100644 index 00000000..481cdcaf --- /dev/null +++ b/drpcwire/mux_writer_test.go @@ -0,0 +1,330 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcwire + +import ( + "bytes" + "errors" + "io" + "sync" + "testing" + "time" + + "github.com/zeebo/assert" +) + +// blockingWriter blocks in Write until unblock is closed, then returns err. +type blockingWriter struct { + unblock chan struct{} + err error // error to return once unblocked + wrote chan []byte // sends a copy of data on each Write entry +} + +func newBlockingWriter() *blockingWriter { + return &blockingWriter{ + unblock: make(chan struct{}), + wrote: make(chan []byte, 10), + } +} + +func (w *blockingWriter) Write(p []byte) (int, error) { + cp := make([]byte, len(p)) + copy(cp, p) + w.wrote <- cp + <-w.unblock + if w.err != nil { + return 0, w.err + } + return len(p), nil +} + +// failWriter returns err on the nth call to Write (1-indexed). Calls before +// that succeed normally. +type failWriter struct { + n int + count int + err error + buf bytes.Buffer +} + +func newFailWriter(n int, err error) *failWriter { + return &failWriter{n: n, err: err} +} + +func (w *failWriter) Write(p []byte) (int, error) { + w.count++ + if w.count >= w.n { + return 0, w.err + } + return w.buf.Write(p) +} + +func TestMuxWriter(t *testing.T) { + var exp []byte + pr, pw := io.Pipe() + mw := NewMuxWriter(pw, func(error) {}) + + for range 1000 { + fr := RandFrame() + exp = AppendFrame(exp, fr) + assert.NoError(t, mw.WriteFrame(fr)) + } + + // Read exactly len(exp) bytes: this blocks until MuxWriter has drained + // all frames through the pipe. + got := make([]byte, len(exp)) + _, err := io.ReadFull(pr, got) + assert.NoError(t, err) + + // Now stop the writer and close the pipe. + mw.Stop(errors.New("stopped")) + <-mw.Done() + pw.Close() + pr.Close() + + assert.That(t, bytes.Equal(exp, got)) +} + +func TestMuxWriter_WriteFrameAfterStop(t *testing.T) { + mw := NewMuxWriter(io.Discard, func(error) {}) + mw.Stop(errors.New("stopped")) + <-mw.Done() + + err := mw.WriteFrame(RandFrame()) + assert.Error(t, err) + assert.Equal(t, err.Error(), "stopped") +} + +func TestMuxWriter_ConcurrentWriteFrame(t *testing.T) { + pr, pw := io.Pipe() + mw := NewMuxWriter(pw, func(error) {}) + + const numWriters = 10 + const framesPerWriter = 100 + + // Pre-generate frames and compute total expected bytes so we can use + // io.ReadFull to block until everything has drained (Stop has abort + // semantics, so we can't rely on it to drain). + allFrames := make([][]Frame, numWriters) + var expSize int + for i := range numWriters { + allFrames[i] = make([]Frame, framesPerWriter) + for j := range framesPerWriter { + fr := Frame{ + Data: []byte{byte(j)}, + ID: ID{Stream: uint64(i + 1), Message: uint64(j + 1)}, + Kind: KindMessage, + Done: true, + } + allFrames[i][j] = fr + expSize += len(AppendFrame(nil, fr)) + } + } + + var wg sync.WaitGroup + for i := range numWriters { + wg.Add(1) + go func() { + defer wg.Done() + for j := range framesPerWriter { + assert.NoError(t, mw.WriteFrame(allFrames[i][j])) + } + }() + } + + wg.Wait() + + // Block until all bytes have been drained through the pipe. + got := make([]byte, expSize) + _, err := io.ReadFull(pr, got) + assert.NoError(t, err) + mw.Stop(errors.New("stopped")) + <-mw.Done() + pw.Close() + pr.Close() + + // Parse received bytes and count frames. + count := 0 + for len(got) > 0 { + rem, _, ok, err := ParseFrame(got) + assert.NoError(t, err) + assert.That(t, ok) + got = rem + count++ + } + assert.Equal(t, count, numWriters*framesPerWriter) +} + +func TestMuxWriter_WriteErrorCallsOnError(t *testing.T) { + writeErr := errors.New("disk full") + fw := newFailWriter(1, writeErr) + + gotErr := make(chan error, 1) + mw := NewMuxWriter(fw, func(err error) { gotErr <- err }) + + assert.NoError(t, mw.WriteFrame(RandFrame())) + + select { + case err := <-gotErr: + assert.Equal(t, err, writeErr) + case <-time.After(5 * time.Second): + t.Fatal("onError not called") + } + + select { + case <-mw.Done(): + case <-time.After(5 * time.Second): + t.Fatal("Done did not return") + } +} + +// Tests the critical deadlock path from the design doc: +// run() → Write fails → sets closed → onError → Stop() → noop → run() returns. +func TestMuxWriter_OnErrorCallingStopDoesNotDeadlock(t *testing.T) { + writeErr := errors.New("broken pipe") + fw := newFailWriter(1, writeErr) + + var mw *MuxWriter + mw = NewMuxWriter(fw, func(err error) { + // Simulate manager.terminate calling Stop. + mw.Stop(errors.New("stopped")) + }) + + assert.NoError(t, mw.WriteFrame(RandFrame())) + + select { + case <-mw.Done(): + case <-time.After(5 * time.Second): + t.Fatal("deadlock: Done did not return") + } +} + +// Tests the manager's two-phase shutdown: close transport to unblock a blocked +// Write, then Stop signals the goroutine to exit. +func TestMuxWriter_BlockedWriteUnblockedByClose(t *testing.T) { + bw := newBlockingWriter() + mw := NewMuxWriter(bw, func(error) {}) + + assert.NoError(t, mw.WriteFrame(RandFrame())) + + // Wait for run() to enter Write. + select { + case <-bw.wrote: + case <-time.After(5 * time.Second): + t.Fatal("run() did not enter Write") + } + + // Simulate terminate: Stop, then unblock the writer (like tr.Close()). + mw.Stop(errors.New("stopped")) + bw.err = errors.New("closed") + close(bw.unblock) + + select { + case <-mw.Done(): + case <-time.After(5 * time.Second): + t.Fatal("deadlock: Done did not return") + } +} + +func TestMuxWriter_ConcurrentStop(t *testing.T) { + mw := NewMuxWriter(io.Discard, func(error) {}) + + // Write a frame so the goroutine has work. + assert.NoError(t, mw.WriteFrame(RandFrame())) + + const n = 20 + var wg sync.WaitGroup + wg.Add(n) + for range n { + go func() { + defer wg.Done() + mw.Stop(errors.New("stopped")) + }() + } + wg.Wait() + + select { + case <-mw.Done(): + case <-time.After(5 * time.Second): + t.Fatal("Done did not return") + } +} + +// Stop has abort semantics: buffered data is discarded, not drained. +func TestMuxWriter_StopDiscardsBufferedData(t *testing.T) { + bw := newBlockingWriter() + mw := NewMuxWriter(bw, func(error) {}) + + // Write several frames while the writer is blocked on the first Write. + for range 10 { + assert.NoError(t, mw.WriteFrame(RandFrame())) + } + + // Wait for run() to enter Write with the first batch. + select { + case <-bw.wrote: + case <-time.After(5 * time.Second): + t.Fatal("run() did not enter Write") + } + + // More frames accumulate in buf while Write is blocked. + for range 10 { + assert.NoError(t, mw.WriteFrame(RandFrame())) + } + + // Stop without letting the blocked Write complete. + mw.Stop(errors.New("stopped")) + bw.err = errors.New("closed") + close(bw.unblock) + + select { + case <-mw.Done(): + case <-time.After(5 * time.Second): + t.Fatal("Done did not return") + } + + // Only the first batch was written; the rest were discarded by Stop. + assert.Equal(t, len(bw.wrote), 0) // no more writes after the first +} + +func TestMuxWriter_WriteFrameDuringActiveDrain(t *testing.T) { + // gatedWriter lets us control when each Write completes. + type gate struct{ ch chan struct{} } + gates := make(chan gate, 10) + + gw := writerFunc(func(p []byte) (int, error) { + g := gate{ch: make(chan struct{})} + gates <- g + <-g.ch + return len(p), nil + }) + + mw := NewMuxWriter(gw, func(error) {}) + + // Batch 1: write a frame, wait for run() to pick it up and block in Write. + fr1 := Frame{Data: []byte("batch1"), ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Done: true} + assert.NoError(t, mw.WriteFrame(fr1)) + + g1 := <-gates // run() is now blocked in Write for batch 1 + + // Batch 2: write another frame while batch 1 is still draining. + fr2 := Frame{Data: []byte("batch2"), ID: ID{Stream: 1, Message: 2}, Kind: KindMessage, Done: true} + assert.NoError(t, mw.WriteFrame(fr2)) + + // Complete batch 1 write. + close(g1.ch) + + // run() loops, picks up batch 2, enters Write again. + g2 := <-gates + close(g2.ch) + + // Both batches were written. Stop and verify. + mw.Stop(errors.New("stopped")) + <-mw.Done() +} + +// writerFunc adapts a function to io.Writer. +type writerFunc func([]byte) (int, error) + +func (f writerFunc) Write(p []byte) (int, error) { return f(p) } diff --git a/drpcwire/writer.go b/drpcwire/writer.go deleted file mode 100644 index fe909cff..00000000 --- a/drpcwire/writer.go +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright (C) 2019 Storj Labs, Inc. -// See LICENSE for copying information. - -package drpcwire - -import ( - "fmt" - "io" - "sync" - "sync/atomic" - - "storj.io/drpc/drpcdebug" -) - -// -// Writer -// - -// Writer is a helper to buffer and write packets and frames to an io.Writer. -type Writer struct { - empty uint32 - w io.Writer - size int - mu sync.Mutex - buf []byte -} - -// NewWriter returns a Writer that will attempt to buffer size data before -// sending it to the io.Writer. -func NewWriter(w io.Writer, size int) *Writer { - if size == 0 { - size = 4 * 1024 - } - - return &Writer{ - w: w, - size: size, - buf: make([]byte, 0, size), - } -} - -func (b *Writer) log(what string, cb func() string) { - if drpcdebug.Enabled { - drpcdebug.Log(func() (_, _, _ string) { return fmt.Sprintf("", b), what, cb() }) - } -} - -// WritePacket writes the packet as a single frame, ignoring any size -// constraints. -func (b *Writer) WritePacket(pkt Packet) (err error) { - return b.WriteFrame(Frame{ - Data: pkt.Data, - ID: pkt.ID, - Kind: pkt.Kind, - Control: pkt.Control, - Done: true, - }) -} - -// Empty returns true if there are no bytes buffered in the writer. -func (b *Writer) Empty() bool { - return atomic.LoadUint32(&b.empty) == 0 -} - -// Reset clears any pending data in the buffer. -func (b *Writer) Reset() *Writer { - b.mu.Lock() - defer b.mu.Unlock() - - b.buf = b.buf[:0] - atomic.StoreUint32(&b.empty, 0) - return b -} - -// WriteFrame appends the frame into the buffer, and if the buffer is larger -// than the configured size, flushes it. -func (b *Writer) WriteFrame(fr Frame) (err error) { - b.mu.Lock() - defer b.mu.Unlock() - - if len(b.buf) == 0 { - atomic.StoreUint32(&b.empty, 1) - } - b.buf = AppendFrame(b.buf, fr) - if len(b.buf) >= b.size { - b.log("FLUSH", func() string { return fmt.Sprintf("buffer: %d > %d", len(b.buf), b.size) }) - _, err = b.w.Write(b.buf) - b.buf = b.buf[:0] - atomic.StoreUint32(&b.empty, 0) - } - return err -} - -// Flush forces a flush of any buffered data to the io.Writer. It is a no-op if -// there is no data in the buffer. -func (b *Writer) Flush() (err error) { - b.mu.Lock() - defer b.mu.Unlock() - - if len(b.buf) > 0 { - _, err = b.w.Write(b.buf) - b.log("FLUSH", func() string { return fmt.Sprintf("explicit: %d", len(b.buf)) }) - b.buf = b.buf[:0] - atomic.StoreUint32(&b.empty, 0) - } - return err -} diff --git a/drpcwire/writer_test.go b/drpcwire/writer_test.go deleted file mode 100644 index 840a6164..00000000 --- a/drpcwire/writer_test.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (C) 2021 Storj Labs, Inc. -// See LICENSE for copying information. - -package drpcwire - -import ( - "bytes" - "testing" - - "github.com/zeebo/assert" -) - -func TestWriter(t *testing.T) { - run := func(size int) func(t *testing.T) { - return func(t *testing.T) { - var exp []byte - var got bytes.Buffer - - wr := NewWriter(&got, size) - for i := 0; i < 1000; i++ { - fr := RandFrame() - exp = AppendFrame(exp, fr) - assert.NoError(t, wr.WriteFrame(fr)) - } - assert.NoError(t, wr.Flush()) - assert.That(t, bytes.Equal(exp, got.Bytes())) - } - } - - t.Run("Size 0B", run(0)) - t.Run("Size 1MB", run(1024*1024)) -} diff --git a/internal/grpccompat/benchmark_test.go b/internal/grpccompat/benchmark_test.go index 3bd0af71..c59f59a8 100644 --- a/internal/grpccompat/benchmark_test.go +++ b/internal/grpccompat/benchmark_test.go @@ -13,7 +13,6 @@ import ( "google.golang.org/protobuf/proto" "storj.io/drpc/drpcmanager" - "storj.io/drpc/drpcstream" ) var benchmarkImpl = &serviceImpl{ @@ -65,11 +64,7 @@ var benchmarkImpl = &serviceImpl{ } func benchmarkBoth(b *testing.B, fn func(b *testing.B, in *In, client Client)) { - options := drpcmanager.Options{ - Stream: drpcstream.Options{ - ManualFlush: true, - }, - } + options := drpcmanager.Options{} for _, size := range []struct { Name string diff --git a/internal/grpccompat/common_test.go b/internal/grpccompat/common_test.go index 8cc5c777..65bb2c9b 100644 --- a/internal/grpccompat/common_test.go +++ b/internal/grpccompat/common_test.go @@ -16,7 +16,6 @@ import ( "testing" "time" - "github.com/zeebo/assert" "github.com/zeebo/errs" "google.golang.org/grpc" "google.golang.org/grpc/codes" diff --git a/internal/integration/alias.go b/internal/integration/alias.go index 8e0d8030..ab7c56d2 100644 --- a/internal/integration/alias.go +++ b/internal/integration/alias.go @@ -25,6 +25,7 @@ type ( DRPCService_Method2Stream = service.DRPCService_Method2Stream DRPCService_Method3Stream = service.DRPCService_Method3Stream DRPCService_Method4Stream = service.DRPCService_Method4Stream + DRPCService_Method4Client = service.DRPCService_Method4Client ) var ( diff --git a/internal/integration/cache_test.go b/internal/integration/cache_test.go deleted file mode 100644 index d4e1c959..00000000 --- a/internal/integration/cache_test.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (C) 2020 Storj Labs, Inc. -// See LICENSE for copying information. - -package integration - -import ( - "context" - "testing" - - "github.com/zeebo/assert" - "github.com/zeebo/errs" - - "storj.io/drpc/drpccache" - "storj.io/drpc/drpctest" -) - -func TestCache(t *testing.T) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - // create a server that signals then waits for the context to die - cli, close := createConnection(t, impl{ - Method1Fn: func(ctx context.Context, _ *In) (*Out, error) { - cache := drpccache.FromContext(ctx) - if cache == nil { - return nil, errs.New("no cache associated with context") - } - cache.LoadOrCreate("value", func() interface{} { return 42 }) - return &Out{Out: 123}, nil - }, - Method2Fn: func(stream DRPCService_Method2Stream) error { - cache := drpccache.FromContext(stream.Context()) - if cache == nil { - return errs.New("no cache associated with context") - } - value, _ := cache.Load("value").(int) - return stream.SendAndClose(&Out{Out: int64(value)}) - }, - }) - defer close() - - { // value not yet cached - stream, err := cli.Method2(ctx) - assert.NoError(t, err) - out, err := stream.CloseAndRecv() - assert.NoError(t, err) - assert.True(t, Equal(out, &Out{Out: 0})) - } - - { // store value in the cache - out, err := cli.Method1(ctx, in(1)) - assert.NoError(t, err) - assert.True(t, Equal(out, &Out{Out: 123})) - } - - { // expected value in the cache - stream, err := cli.Method2(ctx) - assert.NoError(t, err) - out, err := stream.CloseAndRecv() - assert.NoError(t, err) - assert.True(t, Equal(out, &Out{Out: 42})) - } -} diff --git a/internal/integration/cancel_test.go b/internal/integration/cancel_test.go index 03a38604..dec012da 100644 --- a/internal/integration/cancel_test.go +++ b/internal/integration/cancel_test.go @@ -15,7 +15,6 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "storj.io/drpc" "storj.io/drpc/drpcconn" "storj.io/drpc/drpcpool" "storj.io/drpc/drpctest" @@ -150,15 +149,6 @@ func TestCancellationPropagation_Stream(t *testing.T) { clientctx.Run(func(ctx context.Context) { stream, _ := cli.Method4(ctx) - // this is a weird case where the rpc does not send or receive or even - // close the stream, and neither does the other side, and so we have to - // explicitly flush the invoke. - type ( - getStreamer interface{ GetStream() drpc.Stream } - rawFlusher interface{ RawFlush() error } - ) - _ = stream.(getStreamer).GetStream().(rawFlusher).RawFlush() - called <- struct{}{} select { case <-stream.Context().Done(): diff --git a/internal/integration/common_test.go b/internal/integration/common_test.go index 5acb61a1..f3d1a80b 100644 --- a/internal/integration/common_test.go +++ b/internal/integration/common_test.go @@ -5,7 +5,9 @@ package integration import ( "context" + "flag" "io" + "math/rand" "net" "strconv" "sync" @@ -16,7 +18,6 @@ import ( "google.golang.org/grpc/status" "storj.io/drpc/drpcconn" - "storj.io/drpc/drpcmanager" "storj.io/drpc/drpcmetadata" "storj.io/drpc/drpcmux" "storj.io/drpc/drpcserver" @@ -39,17 +40,36 @@ func data(n int64) []byte { func in(n int64) *In { return &In{In: n} } func out(n int64) *Out { return &Out{Out: n} } +var transport = flag.String("transport", "", "force transport for tests: pipe, tcp, or empty for random") + func createRawConnection(t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker) *drpcconn.Conn { + switch *transport { + case "pipe": + t.Log("transport: pipe") + return createPipeConnection(t, server, ctx) + case "tcp": + t.Log("transport: tcp") + return createTCPConnection(t, server, ctx) + case "": + if rand.Intn(2) == 0 { + t.Log("transport: pipe") + return createPipeConnection(t, server, ctx) + } + t.Log("transport: tcp") + return createTCPConnection(t, server, ctx) + default: + t.Fatalf("unknown -transport value: %q", *transport) + return nil + } +} + +func createPipeConnection(t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker) *drpcconn.Conn { c1, c2 := net.Pipe() mux := drpcmux.New() assert.NoError(t, DRPCRegisterService(mux, server)) srv := drpcserver.New(mux) ctx.Run(func(ctx context.Context) { _ = srv.ServeOne(ctx, c1) }) - return drpcconn.NewWithOptions(c2, drpcconn.Options{ - Manager: drpcmanager.Options{ - SoftCancel: true, - }, - }) + return drpcconn.NewWithOptions(c2, drpcconn.Options{}) } func createConnection(t testing.TB, server DRPCServiceServer) (DRPCServiceClient, func()) { @@ -61,6 +81,29 @@ func createConnection(t testing.TB, server DRPCServiceServer) (DRPCServiceClient } } +func createTCPConnection(t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker) *drpcconn.Conn { + ln, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err) + t.Cleanup(func() { _ = ln.Close() }) + + mux := drpcmux.New() + assert.NoError(t, DRPCRegisterService(mux, server)) + srv := drpcserver.New(mux) + + ctx.Run(func(ctx context.Context) { + conn, err := ln.Accept() + if err != nil { + return + } + _ = srv.ServeOne(ctx, conn) + }) + + conn, err := net.Dial("tcp", ln.Addr().String()) + assert.NoError(t, err) + + return drpcconn.NewWithOptions(conn, drpcconn.Options{}) +} + // // server impl // diff --git a/internal/integration/go.mod b/internal/integration/go.mod index df019703..6438de7c 100644 --- a/internal/integration/go.mod +++ b/internal/integration/go.mod @@ -6,6 +6,7 @@ require ( github.com/gogo/protobuf v1.3.2 github.com/zeebo/assert v1.3.1 github.com/zeebo/errs v1.4.0 + go.uber.org/goleak v1.3.0 golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa google.golang.org/grpc v1.57.2 google.golang.org/protobuf v1.33.0 diff --git a/internal/integration/go.sum b/internal/integration/go.sum index d06e5266..970a0549 100644 --- a/internal/integration/go.sum +++ b/internal/integration/go.sum @@ -20,6 +20,8 @@ github.com/zeebo/assert v1.3.1 h1:vukIABvugfNMZMQO1ABsyQDJDTVQbn+LWSMy1ol1h6A= github.com/zeebo/assert v1.3.1/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= github.com/zeebo/errs v1.4.0 h1:XNdoD/RRMKP7HD0UhJnIzUy74ISdGGxURlYG8HSWSfM= github.com/zeebo/errs v1.4.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -31,8 +33,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= -golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= +golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -43,8 +45,8 @@ golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= diff --git a/internal/integration/multiplex_test.go b/internal/integration/multiplex_test.go new file mode 100644 index 00000000..ce69e45a --- /dev/null +++ b/internal/integration/multiplex_test.go @@ -0,0 +1,263 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package integration + +import ( + "context" + "errors" + "io" + "testing" + "time" + + "github.com/zeebo/assert" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "storj.io/drpc/drpctest" +) + +// TestMultiplex_CancelIsolation verifies that canceling one stream's context +// does not affect other concurrent streams on the same connection. +func TestMultiplex_CancelIsolation(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + started := make(chan struct{}, 3) + cli, close := createConnection(t, impl{ + Method4Fn: func(stream DRPCService_Method4Stream) error { + started <- struct{}{} + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + }) + defer close() + + // Open 3 bidi streams with independent contexts. + ctx1, cancel1 := context.WithCancel(ctx) + defer cancel1() + s1, err := cli.Method4(ctx1) + assert.NoError(t, err) + + ctx2, cancel2 := context.WithCancel(ctx) + defer cancel2() + s2, err := cli.Method4(ctx2) + assert.NoError(t, err) + + ctx3, cancel3 := context.WithCancel(ctx) + defer cancel3() + s3, err := cli.Method4(ctx3) + assert.NoError(t, err) + + // Wait for all server handlers to start. + <-started + <-started + <-started + + // Cancel stream 2. + cancel2() + + // Verify stream 2 is dead. This blocks until the cancel propagates. + _, err = s2.Recv() + assert.Error(t, err) + st, ok := status.FromError(err) + assert.That(t, ok) + assert.Equal(t, st.Code(), codes.Canceled) + + // Streams 1 and 3 should still work. + assert.NoError(t, s1.Send(in(10))) + out, err := s1.Recv() + assert.NoError(t, err) + assert.Equal(t, out.Out, int64(10)) + + assert.NoError(t, s3.Send(in(30))) + out, err = s3.Recv() + assert.NoError(t, err) + assert.Equal(t, out.Out, int64(30)) + + // Clean up remaining streams. + assert.NoError(t, s1.CloseSend()) + assert.NoError(t, s3.CloseSend()) +} + +// TestMultiplex_ErrorIsolation verifies that a server handler returning an +// error on one stream does not affect other concurrent streams. +func TestMultiplex_ErrorIsolation(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + // Handler reads first message: In == -1 triggers an error, otherwise echo. + cli, close := createConnection(t, impl{ + Method4Fn: func(stream DRPCService_Method4Stream) error { + msg, err := stream.Recv() + if err != nil { + return nil + } + if msg.In == -1 { + return status.Error(codes.InvalidArgument, "bad input") + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + }) + defer close() + + s1, err := cli.Method4(ctx) + assert.NoError(t, err) + + s2, err := cli.Method4(ctx) + assert.NoError(t, err) + + // Trigger error on stream 1. + assert.NoError(t, s1.Send(in(-1))) + + // Send normal message on stream 2. + assert.NoError(t, s2.Send(in(42))) + + // Stream 1 should receive the server error. + _, err = s1.Recv() + assert.Error(t, err) + st, ok := status.FromError(err) + assert.That(t, ok) + assert.Equal(t, st.Code(), codes.InvalidArgument) + + // Stream 2 should be unaffected. + out, err := s2.Recv() + assert.NoError(t, err) + assert.Equal(t, out.Out, int64(42)) + + // Stream 2 keeps working after stream 1 is dead. + assert.NoError(t, s2.Send(in(100))) + out, err = s2.Recv() + assert.NoError(t, err) + assert.Equal(t, out.Out, int64(100)) + + // Clean up. + assert.NoError(t, s2.CloseSend()) + _, err = s2.Recv() + assert.That(t, errors.Is(err, io.EOF)) +} + +// TestMultiplex_ConnCloseWithActiveStreams verifies that closing a connection +// with multiple active streams terminates all of them and does not deadlock. +func TestMultiplex_ConnCloseWithActiveStreams(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + started := make(chan struct{}, 3) + conn := createRawConnection(t, impl{ + Method4Fn: func(stream DRPCService_Method4Stream) error { + started <- struct{}{} + <-stream.Context().Done() + return stream.Context().Err() + }, + }, ctx) + cli := NewDRPCServiceClient(conn) + + // Open 3 streams whose handlers block until canceled. + const N = 3 + streams := make([]DRPCService_Method4Client, N) + for i := 0; i < N; i++ { + s, err := cli.Method4(ctx) + assert.NoError(t, err) + streams[i] = s + } + + // Wait for all handlers to be running. + for i := 0; i < N; i++ { + <-started + } + + // conn.Close triggers manager.Close which must not deadlock. + done := make(chan error, 1) + go func() { done <- conn.Close() }() + + timer := time.NewTimer(5 * time.Second) + defer timer.Stop() + select { + case <-done: + case <-timer.C: + t.Fatal("conn.Close() deadlocked with active streams") + } + + // All streams should be terminated. + for i, s := range streams { + _, err := s.Recv() + assert.Error(t, err) + t.Logf("stream %d: %v", i, err) + } +} + +// TestMultiplex_TransportCloseTerminatesAllStreams verifies that an external +// transport failure terminates all active streams on the connection. +func TestMultiplex_TransportCloseTerminatesAllStreams(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + started := make(chan struct{}, 3) + conn := createRawConnection(t, impl{ + Method4Fn: func(stream DRPCService_Method4Stream) error { + started <- struct{}{} + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + }, ctx) + cli := NewDRPCServiceClient(conn) + + // Open 3 streams. + const N = 3 + streams := make([]DRPCService_Method4Client, N) + for i := 0; i < N; i++ { + s, err := cli.Method4(ctx) + assert.NoError(t, err) + streams[i] = s + } + + // Wait for all handlers. + for i := 0; i < N; i++ { + <-started + } + + // Verify streams work before the failure. + assert.NoError(t, streams[0].Send(in(1))) + out, err := streams[0].Recv() + assert.NoError(t, err) + assert.Equal(t, out.Out, int64(1)) + + // Simulate transport failure. + assert.NoError(t, conn.Transport().Close()) + + // All streams should receive errors. + for i, s := range streams { + _, err := s.Recv() + assert.Error(t, err) + t.Logf("stream %d: %v", i, err) + } + + // Connection should close cleanly after transport failure. + _ = conn.Close() +} diff --git a/internal/integration/simple_test.go b/internal/integration/simple_test.go index 26a2e8c9..33473970 100644 --- a/internal/integration/simple_test.go +++ b/internal/integration/simple_test.go @@ -6,7 +6,6 @@ package integration import ( "context" "errors" - "fmt" "io" "net" "testing" @@ -85,35 +84,60 @@ func TestSimple(t *testing.T) { } } -func TestConcurrent(t *testing.T) { +func TestMultiplexedStreams(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() - cli, close := createConnection(t, standardImpl) - defer close() - - const N = 1000 - errs := make(chan error) - for i := 0; i < N; i++ { - ctx.Run(func(ctx context.Context) { - select { - case <-ctx.Done(): - case errs <- func() error { - out, err := cli.Method1(ctx, &In{In: 1}) + // Echo server: sends back each received message immediately. + echoServer := impl{ + Method1Fn: standardImpl.Method1Fn, + Method2Fn: standardImpl.Method2Fn, + Method3Fn: standardImpl.Method3Fn, + Method4Fn: func(stream DRPCService_Method4Stream) error { + for { + msg, err := stream.Recv() if err != nil { - return err - } else if out.Out != 1 { - return fmt.Errorf("wrong result %d", out.Out) - } else { return nil } - }(): + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } } - }) - } - for i := 0; i < N; i++ { - assert.NoError(t, <-errs) + }, } + + cli, close := createConnection(t, echoServer) + defer close() + + // Open two bidi streams on the same connection. + s1, err := cli.Method4(ctx) + assert.NoError(t, err) + + s2, err := cli.Method4(ctx) + assert.NoError(t, err) + + // Send on both streams interleaved. + assert.NoError(t, s1.Send(&In{In: 1})) + assert.NoError(t, s2.Send(&In{In: 2})) + + // Receive from both: each stream gets its own response. + out1, err := s1.Recv() + assert.NoError(t, err) + assert.Equal(t, out1.Out, int64(1)) + + out2, err := s2.Recv() + assert.NoError(t, err) + assert.Equal(t, out2.Out, int64(2)) + + // Close both streams. + assert.NoError(t, s1.CloseSend()) + assert.NoError(t, s2.CloseSend()) + + _, err = s1.Recv() + assert.That(t, errors.Is(err, io.EOF)) + + _, err = s2.Recv() + assert.That(t, errors.Is(err, io.EOF)) } func TestServerStats(t *testing.T) { @@ -164,52 +188,3 @@ func TestServerStats(t *testing.T) { "/service.Service/Method3": {Read: 2, Written: 6}, }) } - -func TestClientStats(t *testing.T) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - c1, c2 := net.Pipe() - mux := drpcmux.New() - _ = DRPCRegisterService(mux, standardImpl) - - srv := drpcserver.New(mux) - ctx.Run(func(ctx context.Context) { _ = srv.ServeOne(ctx, c1) }) - - conn := drpcconn.NewWithOptions(c2, drpcconn.Options{ - CollectStats: true, - }) - defer func() { _ = conn.Close() }() - cli := NewDRPCServiceClient(conn) - - assert.Equal(t, srv.Stats(), map[string]drpcstats.Stats{}) - - _, err := cli.Method1(ctx, in(5)) - assert.Error(t, err) - - assert.Equal(t, conn.Stats(), map[string]drpcstats.Stats{ - "/service.Service/Method1": {Read: 9, Written: 26}, - }) - - _, err = cli.Method1(ctx, in(1)) - assert.NoError(t, err) - - assert.Equal(t, conn.Stats(), map[string]drpcstats.Stats{ - "/service.Service/Method1": {Read: 9 + 2, Written: 26 + 26}, - }) - - stream, err := cli.Method3(ctx, in(3)) - assert.NoError(t, err) - for i := 0; i < 3; i++ { - _, err := stream.Recv() - assert.NoError(t, err) - } - _, err = stream.Recv() - assert.That(t, errors.Is(err, io.EOF)) - assert.NoError(t, stream.Close()) - - assert.Equal(t, conn.Stats(), map[string]drpcstats.Stats{ - "/service.Service/Method1": {Read: 9 + 2, Written: 26 + 26}, - "/service.Service/Method3": {Read: 6, Written: 26}, - }) -} diff --git a/internal/integration/stress_test.go b/internal/integration/stress_test.go new file mode 100644 index 00000000..7e67969a --- /dev/null +++ b/internal/integration/stress_test.go @@ -0,0 +1,679 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package integration + +import ( + "context" + "errors" + "fmt" + "io" + "math/rand" + "sync" + "testing" + "time" + + "github.com/zeebo/assert" + "go.uber.org/goleak" + + "storj.io/drpc/drpctest" +) + +// TestStress_SustainedConcurrentStreams opens 50 bidi streams on one +// connection, each exchanging 100 echo messages concurrently. This saturates +// the manageReader dispatch path (one reader fanning out to 50 packetQueues) +// and the MuxWriter batching path (50 goroutines calling WriteFrame). Each +// message encodes the stream's identity so we can detect cross-stream data +// corruption, which would indicate a routing bug in the manager or a buffer +// reuse bug in the packetQueue. +func TestStress_SustainedConcurrentStreams(t *testing.T) { + defer goleak.VerifyNone(t) + + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + conn := createRawConnection(t, impl{ + Method4Fn: func(stream DRPCService_Method4Stream) error { + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + }, ctx) + defer func() { _ = conn.Close() }() + cli := NewDRPCServiceClient(conn) + + const N = 50 // concurrent streams + const M = 100 // messages per stream + + errs := make(chan error, N) + for i := 0; i < N; i++ { + i := i + ctx.Run(func(ctx context.Context) { + select { + case <-ctx.Done(): + case errs <- func() error { + stream, err := cli.Method4(ctx) + if err != nil { + return fmt.Errorf("stream %d: open: %w", i, err) + } + for j := 0; j < M; j++ { + val := int64(i*1000 + j) // encode stream identity + if err := stream.Send(&In{In: val}); err != nil { + return fmt.Errorf("stream %d: send %d: %w", i, j, err) + } + out, err := stream.Recv() + if err != nil { + return fmt.Errorf("stream %d: recv %d: %w", i, j, err) + } + if out.Out != val { + return fmt.Errorf("stream %d: msg %d: got %d, want %d (cross-contamination?)", i, j, out.Out, val) + } + } + if err := stream.CloseSend(); err != nil { + return fmt.Errorf("stream %d: close send: %w", i, err) + } + _, err = stream.Recv() + if !errors.Is(err, io.EOF) { + return fmt.Errorf("stream %d: final recv: got %v, want EOF", i, err) + } + return nil + }(): + } + }) + } + + for i := 0; i < N; i++ { + assert.NoError(t, <-errs) + } +} + +// TestStress_RapidOpenCloseCycles opens and closes a bidi stream 500 times +// sequentially on one connection. Each cycle creates a stream, exchanges one +// message, then tears it down. This tests that stream ID allocation, the +// activeStreams map cleanup, and the invoke handshake (pdone channel) work +// correctly across many rapid create/destroy cycles without leaking resources. +func TestStress_RapidOpenCloseCycles(t *testing.T) { + defer goleak.VerifyNone(t) + + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + conn := createRawConnection(t, impl{ + Method4Fn: func(stream DRPCService_Method4Stream) error { + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + }, ctx) + defer func() { _ = conn.Close() }() + cli := NewDRPCServiceClient(conn) + + const K = 500 + + for i := 0; i < K; i++ { + stream, err := cli.Method4(ctx) + assert.NoError(t, err) + + val := int64(i) + assert.NoError(t, stream.Send(&In{In: val})) + out, err := stream.Recv() + assert.NoError(t, err) + assert.Equal(t, out.Out, val) + + assert.NoError(t, stream.CloseSend()) + _, err = stream.Recv() + assert.That(t, errors.Is(err, io.EOF)) + } +} + +// TestStress_CancelStorm opens 30 streams that all exchange messages +// concurrently, then cancels ~50% of them (chosen randomly) with jitter so +// cancellations land at different points in the send/recv cycle. This races +// context cancellation against in-flight Send and Recv calls, exercising the +// stream's cancel propagation and cleanup. Surviving (non-cancelled) streams +// must complete without error, verifying cancel isolation. +func TestStress_CancelStorm(t *testing.T) { + defer goleak.VerifyNone(t) + + seed := time.Now().UnixNano() + t.Logf("random seed: %d", seed) + + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + conn := createRawConnection(t, impl{ + Method4Fn: func(stream DRPCService_Method4Stream) error { + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + }, ctx) + defer func() { _ = conn.Close() }() + cli := NewDRPCServiceClient(conn) + + const N = 30 + const messagesPerStream = 20 + + type entry struct { + stream DRPCService_Method4Client + cancel context.CancelFunc + } + + entries := make([]entry, N) + for i := 0; i < N; i++ { + sctx, cancel := context.WithCancel(ctx) + stream, err := cli.Method4(sctx) + assert.NoError(t, err) + entries[i] = entry{stream: stream, cancel: cancel} + } + + // Decide which streams to cancel (~50%). + rng := rand.New(rand.NewSource(seed)) + cancelled := make([]bool, N) + for i := range entries { + cancelled[i] = rng.Intn(2) == 0 + } + + // All streams exchange messages concurrently. Cancelled streams + // will hit errors once the cancel goroutine fires their context; + // surviving streams must complete without error. + var wg sync.WaitGroup + wg.Add(N) + for i := range entries { + i := i + ctx.Run(func(_ context.Context) { + defer wg.Done() + defer entries[i].cancel() + for j := 0; j < messagesPerStream; j++ { + val := int64(i*1000 + j) + if err := entries[i].stream.Send(&In{In: val}); err != nil { + if !cancelled[i] { + t.Errorf("surviving stream %d: send %d: %v", i, j, err) + } + return + } + out, err := entries[i].stream.Recv() + if err != nil { + if !cancelled[i] { + t.Errorf("surviving stream %d: recv %d: %v", i, j, err) + } + return + } + if out.Out != val { + t.Errorf("stream %d: msg %d: got %d, want %d", i, j, out.Out, val) + } + } + if !cancelled[i] { + _ = entries[i].stream.CloseSend() + } + }) + } + + // Fire cancels concurrently with traffic, with jitter. + ctx.Run(func(_ context.Context) { + time.Sleep(time.Millisecond) // let streams start sending + for i := range entries { + if cancelled[i] { + time.Sleep(time.Duration(rng.Intn(500)) * time.Microsecond) + entries[i].cancel() + } + } + }) + + wg.Wait() +} + +// TestStress_ShutdownDuringActivity opens 20 streams all actively sending and +// receiving, then calls conn.Close(). This exercises the manager's terminate() +// path with traffic in flight: the MuxWriter must stop, manageReader must +// unblock, and all stream goroutines must exit. The test fails if conn.Close() +// doesn't return within 5 seconds (deadlock). +func TestStress_ShutdownDuringActivity(t *testing.T) { + defer goleak.VerifyNone(t) + + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + started := make(chan struct{}, 20) + conn := createRawConnection(t, impl{ + Method4Fn: func(stream DRPCService_Method4Stream) error { + started <- struct{}{} + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + }, ctx) + cli := NewDRPCServiceClient(conn) + + const N = 20 + + streams := make([]DRPCService_Method4Client, N) + for i := 0; i < N; i++ { + s, err := cli.Method4(ctx) + assert.NoError(t, err) + streams[i] = s + } + + // Wait for all server handlers to start. + for i := 0; i < N; i++ { + <-started + } + + // Start sending on all streams continuously. + for i, s := range streams { + i, s := i, s + ctx.Run(func(_ context.Context) { + for { + if err := s.Send(&In{In: int64(i)}); err != nil { + return + } + if _, err := s.Recv(); err != nil { + return + } + } + }) + } + + // Let traffic flow briefly. + time.Sleep(10 * time.Millisecond) + + // Close the connection. Must return within timeout. + done := make(chan error, 1) + go func() { done <- conn.Close() }() + + timer := time.NewTimer(5 * time.Second) + defer timer.Stop() + select { + case <-done: + case <-timer.C: + t.Fatal("conn.Close() deadlocked with active streams") + } +} + +// TestStress_MixedRPCTypes runs 30 goroutines each executing 10 rounds of a +// randomly chosen RPC type (unary, client-streaming, server-streaming, bidi) +// concurrently on one connection. Different RPC types have different frame +// sequences and stream lifecycles: unary is short-lived (invoke, response, +// close), client-streaming sends multiple frames before expecting a response, +// server-streaming reads until EOF, and bidi is long-lived bidirectional. +// Mixing them exercises the manager's ability to correctly interleave +// heterogeneous stream types on a shared transport. +func TestStress_MixedRPCTypes(t *testing.T) { + defer goleak.VerifyNone(t) + + seed := time.Now().UnixNano() + t.Logf("random seed: %d", seed) + + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + conn := createRawConnection(t, impl{ + // Method1: unary echo. + Method1Fn: func(ctx context.Context, in *In) (*Out, error) { + return &Out{Out: in.In}, nil + }, + // Method2: client-streaming — sum inputs. + Method2Fn: func(stream DRPCService_Method2Stream) error { + var total int64 + for { + in, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return err + } + total += in.In + } + return stream.SendAndClose(&Out{Out: total}) + }, + // Method3: server-streaming — send N copies. + Method3Fn: func(in *In, stream DRPCService_Method3Stream) error { + for i := 0; i < int(in.In); i++ { + if err := stream.Send(&Out{Out: in.In}); err != nil { + return err + } + } + return nil + }, + // Method4: bidi echo. + Method4Fn: func(stream DRPCService_Method4Stream) error { + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + }, ctx) + defer func() { _ = conn.Close() }() + cli := NewDRPCServiceClient(conn) + + const N = 30 + const rounds = 10 + + errs := make(chan error, N) + for i := 0; i < N; i++ { + i := i + ctx.Run(func(ctx context.Context) { + select { + case <-ctx.Done(): + case errs <- func() error { + rng := rand.New(rand.NewSource(seed + int64(i))) + for r := 0; r < rounds; r++ { + switch rng.Intn(4) { + case 0: // Unary + out, err := cli.Method1(ctx, &In{In: int64(i*100 + r)}) + if err != nil { + return fmt.Errorf("goroutine %d round %d: unary: %w", i, r, err) + } + if out.Out != int64(i*100+r) { + return fmt.Errorf("goroutine %d round %d: unary: got %d want %d", i, r, out.Out, i*100+r) + } + + case 1: // Client-streaming + stream, err := cli.Method2(ctx) + if err != nil { + return fmt.Errorf("goroutine %d round %d: client-stream open: %w", i, r, err) + } + var want int64 + for k := 0; k < 5; k++ { + v := int64(k + 1) + want += v + if err := stream.Send(&In{In: v}); err != nil { + return fmt.Errorf("goroutine %d round %d: client-stream send: %w", i, r, err) + } + } + out, err := stream.CloseAndRecv() + if err != nil { + return fmt.Errorf("goroutine %d round %d: client-stream close: %w", i, r, err) + } + if out.Out != want { + return fmt.Errorf("goroutine %d round %d: client-stream: got %d want %d", i, r, out.Out, want) + } + + case 2: // Server-streaming + count := int64(3) + stream, err := cli.Method3(ctx, &In{In: count}) + if err != nil { + return fmt.Errorf("goroutine %d round %d: server-stream open: %w", i, r, err) + } + var got int + for { + out, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return fmt.Errorf("goroutine %d round %d: server-stream recv: %w", i, r, err) + } + if out.Out != count { + return fmt.Errorf("goroutine %d round %d: server-stream: got %d want %d", i, r, out.Out, count) + } + got++ + } + if got != int(count) { + return fmt.Errorf("goroutine %d round %d: server-stream: got %d msgs want %d", i, r, got, count) + } + + case 3: // Bidi + stream, err := cli.Method4(ctx) + if err != nil { + return fmt.Errorf("goroutine %d round %d: bidi open: %w", i, r, err) + } + for k := 0; k < 5; k++ { + val := int64(i*10000 + r*100 + k) + if err := stream.Send(&In{In: val}); err != nil { + return fmt.Errorf("goroutine %d round %d: bidi send: %w", i, r, err) + } + out, err := stream.Recv() + if err != nil { + return fmt.Errorf("goroutine %d round %d: bidi recv: %w", i, r, err) + } + if out.Out != val { + return fmt.Errorf("goroutine %d round %d: bidi: got %d want %d", i, r, out.Out, val) + } + } + if err := stream.CloseSend(); err != nil { + return fmt.Errorf("goroutine %d round %d: bidi close: %w", i, r, err) + } + _, err = stream.Recv() + if !errors.Is(err, io.EOF) { + return fmt.Errorf("goroutine %d round %d: bidi final: got %v want EOF", i, r, err) + } + } + } + return nil + }(): + } + }) + } + + for i := 0; i < N; i++ { + assert.NoError(t, <-errs) + } +} + +// TestStress_ConcurrentCancelCloseTransportClose fires context cancellation, +// conn.Close(), and transport.Close() nearly simultaneously while 10 streams +// are actively exchanging messages. This is the worst-case shutdown scenario: +// three independent shutdown paths (cancel propagation, manager.Close, +// transport EOF) all racing to call terminate(). The manager's terminate() is +// idempotent via sigs.term first-wins, so only one should win; the rest must +// be no-ops. Deadlock within 5 seconds is a failure. +func TestStress_ConcurrentCancelCloseTransportClose(t *testing.T) { + defer goleak.VerifyNone(t) + + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + started := make(chan struct{}, 10) + conn := createRawConnection(t, impl{ + Method4Fn: func(stream DRPCService_Method4Stream) error { + started <- struct{}{} + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + }, ctx) + cli := NewDRPCServiceClient(conn) + + const N = 10 + + cancels := make([]context.CancelFunc, N) + streams := make([]DRPCService_Method4Client, N) + for i := 0; i < N; i++ { + sctx, cancel := context.WithCancel(ctx) + s, err := cli.Method4(sctx) + assert.NoError(t, err) + cancels[i] = cancel + streams[i] = s + } + + for i := 0; i < N; i++ { + <-started + } + + // Start sending on all streams. + for i, s := range streams { + i, s := i, s + ctx.Run(func(_ context.Context) { + for { + if err := s.Send(&In{In: int64(i)}); err != nil { + return + } + if _, err := s.Recv(); err != nil { + return + } + } + }) + } + + time.Sleep(10 * time.Millisecond) + + // Fire all three shutdown mechanisms nearly simultaneously. + var wg sync.WaitGroup + wg.Add(3) + + go func() { + defer wg.Done() + for _, c := range cancels { + c() + } + }() + + go func() { + defer wg.Done() + _ = conn.Close() + }() + + go func() { + defer wg.Done() + _ = conn.Transport().Close() + }() + + // Must complete within timeout — no deadlock. + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + timer := time.NewTimer(5 * time.Second) + defer timer.Stop() + select { + case <-done: + case <-timer.C: + t.Fatal("triple-race shutdown deadlocked") + } +} + +// TestStress_ConcurrentUnary runs 100 goroutines each making 50 unary RPCs +// on one connection. Each unary call creates a stream, does the invoke +// handshake, sends request, receives response, and tears down, so this +// produces 5000 rapid stream lifecycles. Complements TestStress_BurstUnary +// below, which tests instantaneous burst contention rather than sustained +// throughput. +func TestStress_ConcurrentUnary(t *testing.T) { + defer goleak.VerifyNone(t) + + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + conn := createRawConnection(t, impl{ + Method1Fn: func(ctx context.Context, in *In) (*Out, error) { + return &Out{Out: in.In}, nil + }, + }, ctx) + defer func() { _ = conn.Close() }() + cli := NewDRPCServiceClient(conn) + + const N = 100 + const M = 50 + + errs := make(chan error, N) + for i := 0; i < N; i++ { + i := i + ctx.Run(func(ctx context.Context) { + select { + case <-ctx.Done(): + case errs <- func() error { + for j := 0; j < M; j++ { + val := int64(i*1000 + j) + out, err := cli.Method1(ctx, &In{In: val}) + if err != nil { + return fmt.Errorf("goroutine %d call %d: %w", i, j, err) + } + if out.Out != val { + return fmt.Errorf("goroutine %d call %d: got %d want %d", i, j, out.Out, val) + } + } + return nil + }(): + } + }) + } + + for i := 0; i < N; i++ { + assert.NoError(t, <-errs) + } +} + +// TestStress_BurstUnary fires 1000 goroutines each making a single unary RPC +// simultaneously. All 1000 invokes hit the pdone channel at once, testing +// burst contention on the invoke handshake and the MuxWriter's ability to +// batch a thundering herd of WriteFrame calls. +func TestStress_BurstUnary(t *testing.T) { + defer goleak.VerifyNone(t) + + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + conn := createRawConnection(t, impl{ + Method1Fn: func(ctx context.Context, in *In) (*Out, error) { + return &Out{Out: in.In}, nil + }, + }, ctx) + defer func() { _ = conn.Close() }() + cli := NewDRPCServiceClient(conn) + + const N = 1000 + + errs := make(chan error, N) + for i := 0; i < N; i++ { + i := i + ctx.Run(func(ctx context.Context) { + select { + case <-ctx.Done(): + case errs <- func() error { + val := int64(i) + out, err := cli.Method1(ctx, &In{In: val}) + if err != nil { + return fmt.Errorf("goroutine %d: %w", i, err) + } + if out.Out != val { + return fmt.Errorf("goroutine %d: got %d want %d", i, out.Out, val) + } + return nil + }(): + } + }) + } + + for i := 0; i < N; i++ { + assert.NoError(t, <-errs) + } +}