From 95375d1e96eb8620e0dcfd1f3d0e999ea52a9d4d Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Mon, 23 Mar 2026 12:46:57 +0000 Subject: [PATCH] drpc: enable stream multiplexing Squashed result of the upstream stream-multiplexing branch. See the merged PRs for granular history: - #39 drpcmanager: fix race between manageReader and stream creation - #42 *: move frame assembly from reader to stream - #43 *: extract PacketAssembler for frame-to-packet assembly - #44 drpcmanager: replace manageStreams loop with per-stream goroutines - #45 *: use per-stream Finished signal instead of shared sfin channel - #46 drpcmanager: use atomic counter for client stream ID generation - #47 drpcmanager: replace streamBuffer with a streams registry - #51 drpc: enable stream multiplexing A connection now runs multiple concurrent client and server streams over a single transport. Frames carry stream IDs and are interleaved on the wire by a shared MuxWriter. Each stream owns its own packet ring buffer, Finished signal, and goroutine, and the manager tracks live streams in an activeStreams registry. New: drpcwire.MuxWriter, drpcwire.PacketAssembler, drpcstream.ringBuffer, drpcmanager.activeStreams. Removed: the drpccache package, drpcwire/writer (now MuxWriter), drpcstream/pktbuf (now ringBuffer), drpcmanager/streambuf (now activeStreams), drpcmanager.Options.InactivityTimeout, and the drpcconn shared write buffer plus stats infrastructure (CollectStats, Stats, drpcstats wiring). --- drpc.go | 8 - drpccache/README.md | 74 ---- drpccache/cache.go | 92 ----- drpccache/cache_test.go | 71 ---- drpccache/doc.go | 5 - drpcclient/dialoptions.go | 1 - drpcconn/conn.go | 91 +---- drpcconn/conn_test.go | 52 +-- drpcmanager/active_streams.go | 90 +++++ drpcmanager/active_streams_test.go | 113 ++++++ drpcmanager/manager.go | 484 +++++++++---------------- drpcmanager/manager_test.go | 489 +++++++++++++++++++------- drpcmanager/random_test.go | 27 +- drpcmanager/streambuf.go | 57 --- drpcserver/server.go | 27 +- drpcstream/pktbuf.go | 85 ----- drpcstream/ring_buffer.go | 116 ++++++ drpcstream/ring_buffer_test.go | 228 ++++++++++++ drpcstream/stream.go | 249 +++++-------- drpcstream/stream_test.go | 237 +++++++------ drpcwire/mux_writer.go | 94 +++++ drpcwire/mux_writer_test.go | 330 +++++++++++++++++ drpcwire/packet_assembler.go | 89 +++++ drpcwire/packet_assembler_test.go | 235 +++++++++++++ drpcwire/reader.go | 79 +---- drpcwire/reader_test.go | 322 +++++++---------- drpcwire/writer.go | 107 ------ drpcwire/writer_test.go | 32 -- internal/drpcopts/stream.go | 6 - internal/grpccompat/benchmark_test.go | 7 +- internal/grpccompat/common_test.go | 1 - internal/integration/cache_test.go | 63 ---- internal/integration/cancel_test.go | 10 - internal/integration/common_test.go | 7 +- internal/integration/simple_test.go | 174 ++++++--- 35 files changed, 2394 insertions(+), 1758 deletions(-) delete mode 100644 drpccache/README.md delete mode 100644 drpccache/cache.go delete mode 100644 drpccache/cache_test.go delete mode 100644 drpccache/doc.go create mode 100644 drpcmanager/active_streams.go create mode 100644 drpcmanager/active_streams_test.go delete mode 100644 drpcmanager/streambuf.go delete mode 100644 drpcstream/pktbuf.go create mode 100644 drpcstream/ring_buffer.go create mode 100644 drpcstream/ring_buffer_test.go create mode 100644 drpcwire/mux_writer.go create mode 100644 drpcwire/mux_writer_test.go create mode 100644 drpcwire/packet_assembler.go create mode 100644 drpcwire/packet_assembler_test.go delete mode 100644 drpcwire/writer.go delete mode 100644 drpcwire/writer_test.go delete mode 100644 internal/integration/cache_test.go diff --git a/drpc.go b/drpc.go index f6f92084..ed037037 100644 --- a/drpc.go +++ b/drpc.go @@ -76,14 +76,6 @@ type Stream interface { // received on it. Context() context.Context - // Kind returns the type of the stream ("unknown", "cli", or "srv"). Client - // and server streams must be treated differently for error handling and - // logging purposes. - // - // Client streams return Unavailable errors when the remote closes the - // connection, while server streams return Canceled errors. - Kind() StreamKind - // MsgSend sends the Message to the remote. MsgSend(msg Message, enc Encoding) error diff --git a/drpccache/README.md b/drpccache/README.md deleted file mode 100644 index 5739abcc..00000000 --- a/drpccache/README.md +++ /dev/null @@ -1,74 +0,0 @@ -# package drpccache - -`import "storj.io/drpc/drpccache"` - -Package drpccache implements per stream cache for drpc. - -## Usage - -#### func WithContext - -```go -func WithContext(parent context.Context, cache *Cache) context.Context -``` -WithContext returns a context with the value cache associated with the context. - -#### type Cache - -```go -type Cache struct { -} -``` - -Cache is a per stream cache. - -#### func FromContext - -```go -func FromContext(ctx context.Context) *Cache -``` -FromContext returns a cache from a context. - -Example usage: - - cache := drpccache.FromContext(stream.Context()) - if cache != nil { - value := cache.LoadOrCreate("initialized", func() (interface{}) { - return 42 - }) - } - -#### func New - -```go -func New() *Cache -``` -New returns a new cache. - -#### func (*Cache) Clear - -```go -func (cache *Cache) Clear() -``` -Clear clears the cache. - -#### func (*Cache) Load - -```go -func (cache *Cache) Load(key interface{}) interface{} -``` -Load returns the value with the given key. - -#### func (*Cache) LoadOrCreate - -```go -func (cache *Cache) LoadOrCreate(key interface{}, fn func() interface{}) interface{} -``` -LoadOrCreate returns the value with the given key. - -#### func (*Cache) Store - -```go -func (cache *Cache) Store(key, value interface{}) -``` -Store sets the value at a key. diff --git a/drpccache/cache.go b/drpccache/cache.go deleted file mode 100644 index 107b7775..00000000 --- a/drpccache/cache.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (C) 2020 Storj Labs, Inc. -// See LICENSE for copying information. - -package drpccache - -import ( - "context" - "sync" -) - -type cacheKey struct{} - -// Cache is a per stream cache. -type Cache struct { - mu sync.Mutex - values map[interface{}]interface{} -} - -// New returns a new cache. -func New() *Cache { return &Cache{} } - -// FromContext returns a cache from a context. -// -// Example usage: -// -// cache := drpccache.FromContext(stream.Context()) -// if cache != nil { -// value := cache.LoadOrCreate("initialized", func() (interface{}) { -// return 42 -// }) -// } -func FromContext(ctx context.Context) *Cache { - cache, _ := ctx.Value(cacheKey{}).(*Cache) - return cache -} - -// WithContext returns a context with the value cache associated with the context. -func WithContext(parent context.Context, cache *Cache) context.Context { - return context.WithValue(parent, cacheKey{}, cache) -} - -// init ensures that the values map exist. -func (cache *Cache) init() { - if cache.values == nil { - cache.values = map[interface{}]interface{}{} - } -} - -// Clear clears the cache. -func (cache *Cache) Clear() { - cache.mu.Lock() - defer cache.mu.Unlock() - - cache.values = nil -} - -// Store sets the value at a key. -func (cache *Cache) Store(key, value interface{}) { - cache.mu.Lock() - defer cache.mu.Unlock() - - cache.init() - cache.values[key] = value -} - -// Load returns the value with the given key. -func (cache *Cache) Load(key interface{}) interface{} { - cache.mu.Lock() - defer cache.mu.Unlock() - - if cache.values == nil { - return nil - } - - return cache.values[key] -} - -// LoadOrCreate returns the value with the given key. -func (cache *Cache) LoadOrCreate(key interface{}, fn func() interface{}) interface{} { - cache.mu.Lock() - defer cache.mu.Unlock() - - cache.init() - - value, ok := cache.values[key] - if !ok { - value = fn() - cache.values[key] = value - } - - return value -} diff --git a/drpccache/cache_test.go b/drpccache/cache_test.go deleted file mode 100644 index bd98dbf9..00000000 --- a/drpccache/cache_test.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (C) 2021 Storj Labs, Inc. -// See LICENSE for copying information. - -package drpccache - -import ( - "context" - "testing" - - "github.com/zeebo/assert" -) - -func TestWithContext(t *testing.T) { - ctx := context.Background() - assert.Nil(t, FromContext(ctx)) - - cache := New() - ctx = WithContext(ctx, cache) - assert.Equal(t, cache, FromContext(ctx)) -} - -func TestLoad(t *testing.T) { - cache := New() - - assert.Nil(t, cache.Load("key1")) - assert.Nil(t, cache.Load("key2")) - - cache.Store("key1", "val1") - - assert.Equal(t, cache.Load("key1"), "val1") - assert.Nil(t, cache.Load("key2")) - - cache.Store("key2", "val2") - - assert.Equal(t, cache.Load("key1"), "val1") - assert.Equal(t, cache.Load("key2"), "val2") -} - -func TestClear(t *testing.T) { - cache := New() - - cache.Store("key1", "val1") - cache.Store("key2", "val2") - - assert.Equal(t, cache.Load("key1"), "val1") - assert.Equal(t, cache.Load("key2"), "val2") - - cache.Clear() - - assert.Nil(t, cache.Load("key1")) - assert.Nil(t, cache.Load("key2")) -} - -func TestLoadOrCreate(t *testing.T) { - f := func(val interface{}) func() interface{} { - return func() interface{} { return val } - } - - cache := New() - - assert.Nil(t, cache.Load("key1")) - assert.Nil(t, cache.Load("key2")) - - assert.Equal(t, cache.LoadOrCreate("key1", f("key1")), "key1") - assert.Equal(t, cache.LoadOrCreate("key1", f("key2")), "key1") - - assert.Equal(t, cache.Load("key1"), "key1") - assert.Nil(t, cache.Load("key2")) - - cache.LoadOrCreate("key1", func() interface{} { panic("called") }) -} diff --git a/drpccache/doc.go b/drpccache/doc.go deleted file mode 100644 index 6926e5d4..00000000 --- a/drpccache/doc.go +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (C) 2019 Storj Labs, Inc. -// See LICENSE for copying information. - -// Package drpccache implements per stream cache for drpc. -package drpccache diff --git a/drpcclient/dialoptions.go b/drpcclient/dialoptions.go index 5054a83e..3c87b400 100644 --- a/drpcclient/dialoptions.go +++ b/drpcclient/dialoptions.go @@ -158,7 +158,6 @@ func DialContext( Stream: drpcstream.Options{ MaximumBufferSize: 0, // unlimited }, - SoftCancel: false, }, ShouldRecord: options.shouldRecord, Metrics: *options.metrics, diff --git a/drpcconn/conn.go b/drpcconn/conn.go index 8e10a714..23fbb546 100644 --- a/drpcconn/conn.go +++ b/drpcconn/conn.go @@ -5,7 +5,6 @@ package drpcconn import ( "context" - "sync" "github.com/zeebo/errs" grpcmetadata "google.golang.org/grpc/metadata" @@ -14,10 +13,8 @@ import ( "storj.io/drpc/drpcmanager" "storj.io/drpc/drpcmetadata" "storj.io/drpc/drpcmetrics" - "storj.io/drpc/drpcstats" "storj.io/drpc/drpcstream" "storj.io/drpc/drpcwire" - "storj.io/drpc/internal/drpcopts" ) // Options controls configuration settings for a conn. @@ -25,11 +22,6 @@ type Options struct { // Manager controls the options we pass to the manager of this conn. Manager drpcmanager.Options - // TODO: (server): deprecate this - // CollectStats controls whether the client should collect stats on the - // rpcs it creates. - CollectStats bool - // ShouldRecord, if non-nil, controls whether metrics are recorded. // When it returns true, the transport is wrapped to track bytes // sent and received. @@ -41,12 +33,8 @@ type Options struct { // Conn is a drpc client connection. type Conn struct { - tr drpc.Transport - man *drpcmanager.Manager - mu sync.Mutex - wbuf []byte - - stats map[string]*drpcstats.Stats // TODO (server): deprecate + tr drpc.Transport + man *drpcmanager.Manager } var _ drpc.Conn = (*Conn)(nil) @@ -71,42 +59,11 @@ func NewWithOptions(tr drpc.Transport, opts Options) *Conn { c.tr = mt } - // TODO: (server): deprecate - if opts.CollectStats { - drpcopts.SetManagerStatsCB(&opts.Manager.Internal, c.getStats) - c.stats = make(map[string]*drpcstats.Stats) - } - - c.man = drpcmanager.NewWithOptions(c.tr, opts.Manager) + c.man = drpcmanager.NewWithOptions(c.tr, drpcmanager.Client, opts.Manager) return c } -// Stats returns the collected stats grouped by rpc. -func (c *Conn) Stats() map[string]drpcstats.Stats { - c.mu.Lock() - defer c.mu.Unlock() - - stats := make(map[string]drpcstats.Stats, len(c.stats)) - for k, v := range c.stats { - stats[k] = v.AtomicClone() - } - return stats -} - -// getStats returns the drpcopts.Stats struct for the given rpc. -func (c *Conn) getStats(rpc string) *drpcstats.Stats { - c.mu.Lock() - defer c.mu.Unlock() - - stats := c.stats[rpc] - if stats == nil { - stats = new(drpcstats.Stats) - c.stats[rpc] = stats - } - return stats -} - // Transport returns the transport the conn is using. func (c *Conn) Transport() drpc.Transport { return c.tr } @@ -114,15 +71,15 @@ func (c *Conn) Transport() drpc.Transport { return c.tr } func (c *Conn) Closed() <-chan struct{} { return c.man.Closed() } // Unblocked returns a channel that is closed once the connection is no longer -// blocked by a previously canceled Invoke or NewStream call. It should not -// be called concurrently with Invoke or NewStream. +// blocked. With multiplexing, multiple streams run concurrently and this +// channel is always closed immediately. func (c *Conn) Unblocked() <-chan struct{} { return c.man.Unblocked() } // Close closes the connection. func (c *Conn) Close() (err error) { return c.man.Close() } // Invoke issues the rpc on the transport serializing in, waits for a response, and -// deserializes it into out. Only one Invoke or Stream may be open at a time. +// deserializes it into out. func (c *Conn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) (err error) { defer func() { err = drpc.ToRPCErr(err) }() @@ -138,30 +95,21 @@ func (c *Conn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, ou } defer func() { err = errs.Combine(err, stream.Close()) }() - // we have to protect c.wbuf here even though the manager only allows one - // stream at a time because the stream may async close allowing another - // concurrent call to Invoke to proceed. - c.mu.Lock() - defer c.mu.Unlock() - - c.wbuf, err = drpcenc.MarshalAppend(in, enc, c.wbuf[:0]) + // TODO: use buffer pool to reduce allocations + data, err := drpcenc.MarshalAppend(in, enc, nil) if err != nil { return err } - if err := c.doInvoke(stream, enc, rpc, c.wbuf, metadata, out); err != nil { + if err := c.doInvoke(stream, enc, rpc, data, metadata, out); err != nil { return err } return nil } func (c *Conn) doInvoke(stream *drpcstream.Stream, enc drpc.Encoding, rpc string, data []byte, metadata []byte, out drpc.Message) (err error) { - if len(metadata) > 0 { - if err := stream.RawWrite(drpcwire.KindInvokeMetadata, metadata); err != nil { - return err - } - } - if err := stream.RawWrite(drpcwire.KindInvoke, []byte(rpc)); err != nil { + defer func() { err = stream.CheckCancelError(err) }() + if err := stream.WriteInvoke(rpc, metadata); err != nil { return err } if err := stream.RawWrite(drpcwire.KindMessage, data); err != nil { @@ -176,8 +124,7 @@ func (c *Conn) doInvoke(stream *drpcstream.Stream, enc drpc.Encoding, rpc string return nil } -// NewStream begins a streaming rpc on the connection. Only one Invoke or Stream may -// be open at a time. +// NewStream begins a streaming rpc on the connection. func (c *Conn) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (_ drpc.Stream, err error) { defer func() { err = drpc.ToRPCErr(err) }() @@ -192,25 +139,13 @@ func (c *Conn) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (_ return nil, err } - if err := c.doNewStream(stream, rpc, metadata); err != nil { + if err := stream.WriteInvoke(rpc, metadata); err != nil { return nil, errs.Combine(err, stream.Close()) } return stream, nil } -func (c *Conn) doNewStream(stream *drpcstream.Stream, rpc string, metadata []byte) error { - if len(metadata) > 0 { - if err := stream.RawWrite(drpcwire.KindInvokeMetadata, metadata); err != nil { - return err - } - } - if err := stream.RawWrite(drpcwire.KindInvoke, []byte(rpc)); err != nil { - return err - } - return nil -} - // encodeMetadata retrieves and encodes metadata from the provided // (outgoing/client) context. func (c *Conn) encodeMetadata(ctx context.Context) (metadata []byte, err error) { diff --git a/drpcconn/conn_test.go b/drpcconn/conn_test.go index 9f74516e..aadda12b 100644 --- a/drpcconn/conn_test.go +++ b/drpcconn/conn_test.go @@ -40,31 +40,32 @@ func TestConn_InvokeFlushesSendClose(t *testing.T) { invokeDone := make(chan struct{}) ctx.Run(func(ctx context.Context) { - wr := drpcwire.NewWriter(ps, 64) + wr := drpcwire.NewMuxWriter(ps, nil) + defer func() { wr.Stop(nil); <-wr.Done() }() rd := drpcwire.NewReader(ps) - _, _ = rd.ReadPacket() // Invoke - _, _ = rd.ReadPacket() // Message - pkt, _ := rd.ReadPacket() // CloseSend + _, _ = rd.ReadFrame() // Invoke + _, _ = rd.ReadFrame() // Message + pkt, _ := rd.ReadFrame() // CloseSend - _ = wr.WritePacket(drpcwire.Packet{ + _ = wr.WriteFrame(drpcwire.Frame{ Data: []byte("qux"), ID: drpcwire.ID{Stream: pkt.ID.Stream, Message: 1}, Kind: drpcwire.KindMessage, + Done: true, }) - _ = wr.Flush() - _, _ = rd.ReadPacket() // Close - <-invokeDone // wait for invoke to return + _, _ = rd.ReadFrame() // Close + <-invokeDone // wait for invoke to return // ensure that any later packets are dropped by writing one // before closing the transport. - for i := 0; i < 5; i++ { - _ = wr.WritePacket(drpcwire.Packet{ + for range 5 { + _ = wr.WriteFrame(drpcwire.Frame{ ID: drpcwire.ID{Stream: pkt.ID.Stream, Message: 2}, Kind: drpcwire.KindCloseSend, + Done: true, }) - _ = wr.Flush() } _ = ps.Close() @@ -78,7 +79,7 @@ func TestConn_InvokeFlushesSendClose(t *testing.T) { invokeDone <- struct{}{} // signal invoke has returned - // we should eventually notice the transport is closed + // we should eventually notice the transport is closed due to ps.Close() select { case <-conn.Closed(): case <-time.After(1 * time.Second): @@ -95,10 +96,11 @@ func TestConn_InvokeSendsGrpcAndDrpcMetadata(t *testing.T) { defer func() { assert.NoError(t, ps.Close()) }() ctx.Run(func(ctx context.Context) { - wr := drpcwire.NewWriter(ps, 64) + wr := drpcwire.NewMuxWriter(ps, nil) + defer func() { wr.Stop(nil); <-wr.Done() }() rd := drpcwire.NewReader(ps) - md, err := rd.ReadPacket() // Metadata + md, err := rd.ReadFrame() // Metadata assert.NoError(t, err) assert.Equal(t, md.Kind, drpcwire.KindInvokeMetadata) metadata, err := drpcmetadata.Decode(md.Data) @@ -110,18 +112,18 @@ func TestConn_InvokeSendsGrpcAndDrpcMetadata(t *testing.T) { "common-key": "common-value2", }) - _, _ = rd.ReadPacket() // Invoke - _, _ = rd.ReadPacket() // Message - pkt, _ := rd.ReadPacket() // CloseSend + _, _ = rd.ReadFrame() // Invoke + _, _ = rd.ReadFrame() // Message + pkt, _ := rd.ReadFrame() // CloseSend - _ = wr.WritePacket(drpcwire.Packet{ + _ = wr.WriteFrame(drpcwire.Frame{ Data: []byte("qux"), ID: drpcwire.ID{Stream: pkt.ID.Stream, Message: 1}, Kind: drpcwire.KindMessage, + Done: true, }) - _ = wr.Flush() - _, _ = rd.ReadPacket() // Close + _, _ = rd.ReadFrame() // Close }) conn := New(pc) @@ -154,7 +156,7 @@ func TestConn_NewStreamSendsGrpcAndDrpcMetadata(t *testing.T) { ctx.Run(func(ctx context.Context) { rd := drpcwire.NewReader(ps) - md, err := rd.ReadPacket() // Metadata + md, err := rd.ReadFrame() // Metadata assert.NoError(t, err) assert.Equal(t, md.Kind, drpcwire.KindInvokeMetadata) metadata, err := drpcmetadata.Decode(md.Data) @@ -164,8 +166,8 @@ func TestConn_NewStreamSendsGrpcAndDrpcMetadata(t *testing.T) { "drpc-key": "drpc-value", }) - _, _ = rd.ReadPacket() // Invoke - _, _ = rd.ReadPacket() // CloseSend + _, _ = rd.ReadFrame() // Invoke + _, _ = rd.ReadFrame() // CloseSend }) conn := New(pc) @@ -181,6 +183,10 @@ func TestConn_NewStreamSendsGrpcAndDrpcMetadata(t *testing.T) { s, err := conn.NewStream(ctx, "/com.example.Foo/Bar", testEncoding{}) assert.NoError(t, err) _ = s.CloseSend() + + // Wait for the server goroutine to read all frames before defers + // close the pipe. With MuxWriter, writes are asynchronous. + ctx.Wait() } func TestConn_encodeMetadata(t *testing.T) { diff --git a/drpcmanager/active_streams.go b/drpcmanager/active_streams.go new file mode 100644 index 00000000..4ee7dc17 --- /dev/null +++ b/drpcmanager/active_streams.go @@ -0,0 +1,90 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "sync" + + "storj.io/drpc/drpcstream" +) + +// activeStreams is a thread-safe map of stream IDs to stream objects. +// It is used by the Manager to track active streams for lifecycle management. +type activeStreams struct { + mu sync.RWMutex + streams map[uint64]*drpcstream.Stream + closed bool + closeErr error +} + +func newActiveStreams() *activeStreams { + return &activeStreams{ + streams: make(map[uint64]*drpcstream.Stream), + } +} + +// Add adds a stream. It returns an error if the collection is closed or if a +// stream with the same ID already exists. +func (r *activeStreams) Add(id uint64, stream *drpcstream.Stream) error { + if stream == nil { + return managerClosed.New("stream can't be nil") + } + + r.mu.Lock() + defer r.mu.Unlock() + + if r.closed { + return r.closeErr + } + if _, ok := r.streams[id]; ok { + return managerClosed.New("duplicate stream id") + } + r.streams[id] = stream + return nil +} + +// Remove removes a stream. It is a no-op if the stream is not present or if +// the collection has been closed. +func (r *activeStreams) Remove(id uint64) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.streams != nil { + delete(r.streams, id) + } +} + +// Get returns the stream for the given ID and whether it was found. +func (r *activeStreams) Get(id uint64) (*drpcstream.Stream, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + if r.closed { + return nil, false + } + s, ok := r.streams[id] + return s, ok +} + +// Close cancels all active streams with the given error, clears the +// collection, and marks it as closed to prevent future Add calls. +func (r *activeStreams) Close(err error) { + r.mu.Lock() + defer r.mu.Unlock() + + r.closed = true + r.closeErr = err + for id, s := range r.streams { + s.Cancel(err) + delete(r.streams, id) + } +} + +// Len returns the number of active streams. +func (r *activeStreams) Len() int { + r.mu.RLock() + defer r.mu.RUnlock() + + return len(r.streams) +} diff --git a/drpcmanager/active_streams_test.go b/drpcmanager/active_streams_test.go new file mode 100644 index 00000000..f463b188 --- /dev/null +++ b/drpcmanager/active_streams_test.go @@ -0,0 +1,113 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "context" + "errors" + "io" + "testing" + + "github.com/zeebo/assert" + + "storj.io/drpc/drpcstream" + "storj.io/drpc/drpcwire" +) + +func testMuxWriter(t *testing.T) *drpcwire.MuxWriter { + mw := drpcwire.NewMuxWriter(io.Discard, func(error) {}) + t.Cleanup(func() { mw.Stop(nil); <-mw.Done() }) + return mw +} + +func testStream(t *testing.T, id uint64) *drpcstream.Stream { + return drpcstream.New(context.Background(), id, testMuxWriter(t)) +} + +func TestActiveStreams_AddAndGet(t *testing.T) { + streams := newActiveStreams() + s := testStream(t, 1) + + assert.NoError(t, streams.Add(1, s)) + + got, ok := streams.Get(1) + assert.That(t, ok) + assert.Equal(t, got, s) +} + +func TestActiveStreams_GetMissing(t *testing.T) { + streams := newActiveStreams() + + got, ok := streams.Get(42) + assert.That(t, !ok) + assert.Nil(t, got) +} + +func TestActiveStreams_Remove(t *testing.T) { + streams := newActiveStreams() + s := testStream(t, 1) + + assert.NoError(t, streams.Add(1, s)) + assert.Equal(t, streams.Len(), 1) + + streams.Remove(1) + + _, ok := streams.Get(1) + assert.That(t, !ok) + assert.Equal(t, streams.Len(), 0) +} + +func TestActiveStreams_RemoveIdempotent(t *testing.T) { + streams := newActiveStreams() + + // must not panic when removing a non-existent ID + streams.Remove(99) +} + +func TestActiveStreams_DuplicateAdd(t *testing.T) { + streams := newActiveStreams() + s1 := testStream(t, 1) + s2 := testStream(t, 1) + + assert.NoError(t, streams.Add(1, s1)) + assert.Error(t, streams.Add(1, s2)) + + // original stream is still present + got, ok := streams.Get(1) + assert.That(t, ok) + assert.Equal(t, got, s1) +} + +func TestActiveStreams_AddAfterClose(t *testing.T) { + streams := newActiveStreams() + streams.Close(errors.New("closed")) + + err := streams.Add(1, testStream(t, 1)) + assert.Error(t, err) +} + +func TestActiveStreams_RemoveAfterClose(t *testing.T) { + streams := newActiveStreams() + s := testStream(t, 1) + assert.NoError(t, streams.Add(1, s)) + + streams.Close(errors.New("closed")) + + // must not panic + streams.Remove(1) +} + +func TestActiveStreams_Len(t *testing.T) { + streams := newActiveStreams() + assert.Equal(t, streams.Len(), 0) + + assert.NoError(t, streams.Add(1, testStream(t, 1))) + assert.Equal(t, streams.Len(), 1) + + assert.NoError(t, streams.Add(2, testStream(t, 2))) + assert.Equal(t, streams.Len(), 2) + + streams.Remove(1) + assert.Equal(t, streams.Len(), 1) +} diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index d7308366..6c77e174 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -10,8 +10,9 @@ import ( "io" "net" "strings" + "sync" + "sync/atomic" "syscall" - "time" "github.com/zeebo/errs" grpcmetadata "google.golang.org/grpc/metadata" @@ -29,30 +30,12 @@ var managerClosed = errs.Class("manager closed") // Options controls configuration settings for a manager. type Options struct { - // WriterBufferSize controls the size of the buffer that we will fill before - // flushing. Normal writes to streams typically issue a flush explicitly. - WriterBufferSize int - // Reader are passed to any readers the manager creates. Reader drpcwire.ReaderOptions // Stream are passed to any streams the manager creates. Stream drpcstream.Options - // SoftCancel controls if a context cancel will cause the transport to be - // closed or, if true, a soft cancel message will be attempted if possible. - // A soft cancel can reduce the amount of closed and dialed connections at - // the potential cost of higher latencies if there is latent data still - // being flushed when the cancel happens. - SoftCancel bool - - // InactivityTimeout is the amount of time the manager will wait when - // creating a NewServerStream. It only includes the time it is reading - // packets from the remote client. In other words, it only includes the time - // that the client could delay before invoking an RPC. If zero or negative, - // no timeout is used. - InactivityTimeout time.Duration - // Internal contains options that are for internal use only. Internal drpcopts.Manager @@ -68,65 +51,90 @@ type Options struct { // to the appropriate stream. type Manager struct { tr drpc.Transport - wr *drpcwire.Writer + wr *drpcwire.MuxWriter rd *drpcwire.Reader opts Options - sem drpcsignal.Chan // held by the active stream - sbuf streamBuffer // largest stream id created - pkts chan drpcwire.Packet // channel for invoke packets - pdone drpcsignal.Chan // signals when a packets buffers can be reused - sfin chan struct{} // shared signal for stream finished - streams chan streamInfo // channel to signal that a stream should start + // next client stream ID, incremented atomically + lastStreamID atomic.Uint64 + + wg sync.WaitGroup // tracks active manageStream goroutines + + // streams tracks active streams. + streams *activeStreams + + pdone drpcsignal.Chan // signals when NewServerStream has registered the new stream + invokes chan invokeInfo // completed invoke info from manageReader to NewServerStream + + // pendingStreams is owned by the manageReader goroutine, used in + // handleInvokeFrame. It tracks streams that are being assembled from + // invoke/metadata frames but haven't been fully created yet. + pendingStreams map[uint64]*pendingStream sigs struct { - term drpcsignal.Signal // set when the manager should start terminating - stream drpcsignal.Signal // set when the manage streams goroutine is done - read drpcsignal.Signal // set after the goroutine reading from the transport is done - tport drpcsignal.Signal // set after the transport has been closed + term drpcsignal.Signal // set when the manager should start terminating + read drpcsignal.Signal // set after the goroutine reading from the transport is done + tport drpcsignal.Signal // set after the transport has been closed } + + kind ManagerKind } -type streamInfo struct { - ctx context.Context - stream *drpcstream.Stream +type ManagerKind uint8 + +const ( + _ ManagerKind = iota + Client + Server +) + +// pendingStream accumulates invoke and metadata frames for a stream that is +// being set up but hasn't been fully created yet. Once the invoke packet +// arrives, the pending stream is forwarded to NewServerStream. +type pendingStream struct { + metadata map[string]string // accumulated invoke metadata + pa drpcwire.PacketAssembler // assembles invoke/metadata frames into packets +} + +// invokeInfo carries the assembled invoke data from manageReader to +// NewServerStream. It is reused across invocations; call Reset between uses. +type invokeInfo struct { + sid uint64 + metadata map[string]string + data []byte // RPC name bytes from the KindInvoke packet } // New returns a new Manager for the transport. -func New(tr drpc.Transport) *Manager { - return NewWithOptions(tr, Options{}) +func New(tr drpc.Transport, kind ManagerKind) *Manager { + return NewWithOptions(tr, kind, Options{}) } // NewWithOptions returns a new manager for the transport. It uses the provided // options to manage details of how it uses it. -func NewWithOptions(tr drpc.Transport, opts Options) *Manager { +func NewWithOptions(tr drpc.Transport, kind ManagerKind, opts Options) *Manager { m := &Manager{ tr: tr, - wr: drpcwire.NewWriter(tr, opts.WriterBufferSize), rd: drpcwire.NewReaderWithOptions(tr, opts.Reader), opts: opts, - pkts: make(chan drpcwire.Packet), - sfin: make(chan struct{}, 1), - streams: make(chan streamInfo), + invokes: make(chan invokeInfo), + kind: kind, } - // initialize the stream buffer - m.sbuf.init() - - // this semaphore controls the number of concurrent streams. it MUST be 1. - m.sem.Make(1) + m.wr = drpcwire.NewMuxWriter(tr, m.terminate) - // a buffer of size 1 allows the consumer of the packet to signal it is done - // without having to coordinate with the sender of the packet. + // a buffer of size 1 allows NewServerStream to signal it is done creating a + // new server stream without having to coordinate with manageReader. m.pdone.Make(1) + m.pendingStreams = make(map[uint64]*pendingStream) + + m.streams = newActiveStreams() + // set the internal stream options drpcopts.SetStreamTransport(&m.opts.Stream.Internal, m.tr) - drpcopts.SetStreamFin(&m.opts.Stream.Internal, m.sfin) go m.manageReader() - go m.manageStreams() return m } @@ -140,72 +148,23 @@ func (m *Manager) log(what string, cb func() string) { } } -// -// helpers -// - -// acquireSemaphore attempts to acquire the semaphore protecting streams. If the -// context is canceled or the manager is terminated, it returns an error. -func (m *Manager) acquireSemaphore(ctx context.Context) error { - if err, ok := m.sigs.term.Get(); ok { - return err - } else if err := ctx.Err(); err != nil { - return err - } - - select { - case <-ctx.Done(): - return ctx.Err() - - case <-m.sigs.term.Signal(): - return m.sigs.term.Err() - - case m.sem.Get() <- struct{}{}: - if err := m.waitForPreviousStream(ctx); err != nil { - m.sem.Recv() - return err - } - return nil - } -} - -// waitForPreviousStream will, if there was a previous stream, ensure it is -// Closed and then wait until it is in the Finished state, where it will no -// longer make any reads or writes on the transport. It exits early if the -// context is canceled or the manager is terminated. -func (m *Manager) waitForPreviousStream(ctx context.Context) (err error) { - prev := m.sbuf.Get() - if prev == nil { - return nil - } - - // if the stream is not finished yet, we need to wait for it to be - // finished before letting the next stream to start. - if prev.IsFinished() { - return nil - } - - m.log("WAIT", prev.String) - - select { - case <-ctx.Done(): - return ctx.Err() - - case <-m.sigs.term.Signal(): - return m.sigs.term.Err() - - case <-prev.Finished(): - return nil - } -} - // terminate puts the Manager into a terminal state and closes any resources -// that need to be closed to signal the state change. +// that need to be closed to signal the state change. The mux writer is stopped +// before closing the transport so that WriteFrame immediately rejects new +// writes; the subsequent transport close unblocks any in-flight Write in the +// drain goroutine. func (m *Manager) terminate(err error) { if m.sigs.term.Set(err) { m.log("TERM", func() string { return fmt.Sprint(err) }) + if errors.Is(err, io.EOF) { + err = context.Canceled + if m.kind == Client { + err = drpc.ClosedError.New("connection closed") + } + } + m.wr.Stop(err) m.sigs.tport.Set(m.tr.Close()) - m.sbuf.Close() + m.streams.Close(err) } } @@ -213,27 +172,15 @@ func (m *Manager) terminate(err error) { // manage reader // -// manageReader is always reading a packet and dispatching it to the appropriate -// stream or queue. It sets the read signal when it exits so that one can wait -// to ensure that no one is reading on the reader. It sets the term signal if -// there is any error reading packets. +// manageReader reads the frame and dispatches them to the appropriate stream or +// queue. It sets the read signal when it exits so that one can wait to ensure +// that no one is reading on the reader. It sets the term signal if there is any +// error reading frames. func (m *Manager) manageReader() { defer m.sigs.read.Set(nil) - var pkt drpcwire.Packet - var err error - var run int - for !m.sigs.term.IsSet() { - // if we have a run of "small" packets, drop the buffer to release - // memory so that a burst of large packets does not cause eternally - // large heap usage. - if run > 10 { - pkt.Data = nil - run = 0 - } - - pkt, err = m.rd.ReadPacketUsing(pkt.Data[:0]) + incomingFrame, err := m.rd.ReadFrame() if err != nil { if isConnectionReset(err) { err = drpc.ClosedError.Wrap(err) @@ -242,55 +189,70 @@ func (m *Manager) manageReader() { return } - if len(pkt.Data) < cap(pkt.Data)/4 { - run++ - } else { - run = 0 - } + m.log("READ", incomingFrame.String) - m.log("READ", pkt.String) + stream, ok := m.streams.Get(incomingFrame.ID.Stream) - again: - switch curr := m.sbuf.Get(); { - // if the packet is for the current stream, deliver it. - case curr != nil && pkt.ID.Stream == curr.ID(): - if err := curr.HandlePacket(pkt); err != nil { + switch { + // if the packet is for an active stream, deliver it. + case ok: + if err := stream.HandleFrame(incomingFrame); err != nil { m.terminate(managerClosed.Wrap(err)) return } - // if an old message has been sent, just ignore it. - case curr != nil && pkt.ID.Stream < curr.ID(): - - // if any invoke sequence is being sent, close any old unterminated - // stream and forward it to be handled. - case pkt.Kind == drpcwire.KindInvoke || pkt.Kind == drpcwire.KindInvokeMetadata: - if curr != nil && !curr.IsTerminated() { - curr.Cancel(context.Canceled) - } - - select { - case m.pkts <- pkt: - m.pdone.Recv() - - case <-m.sigs.term.Signal(): + case incomingFrame.Kind == drpcwire.KindInvoke || incomingFrame.Kind == drpcwire.KindInvokeMetadata: + if err := m.handleInvokeFrame(incomingFrame); err != nil { + m.terminate(managerClosed.Wrap(err)) return } - // a non-invoke packet should be delivered to some stream so we wait for - // a new stream to be created and try again. like an invoke, we - // implicitly close any previous stream. + // silently drop packet for an unknown stream default: - if curr != nil && !curr.IsTerminated() { - curr.Cancel(context.Canceled) - } + m.log("DROP", incomingFrame.String) + } + } +} - if !m.sbuf.Wait(curr.ID()) { - return - } - goto again +// handleInvokeFrame assembles invoke/metadata frames into complete packets and +// forwards the finished invoke info to NewServerStream. Metadata packets are +// accumulated; the invoke packet triggers the send. +func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error { + ps, ok := m.pendingStreams[fr.ID.Stream] + if !ok { + ps = &pendingStream{pa: drpcwire.NewPacketAssembler()} + m.pendingStreams[fr.ID.Stream] = ps + } + pkt, packetReady, err := ps.pa.AppendFrame(fr) + if err != nil { + return err + } + if !packetReady { + return nil + } + + // Metadata arrives before invoke; accumulate it and wait for the invoke. + if pkt.Kind == drpcwire.KindInvokeMetadata { + meta, err := drpcmetadata.Decode(pkt.Data) + if err != nil { + return err } + ps.metadata = meta + return nil + } + + // Invoke packet completes the sequence. Send to NewServerStream. + select { + case m.invokes <- invokeInfo{sid: pkt.ID.Stream, data: pkt.Data, metadata: ps.metadata}: + // Wait for NewServerStream to finish stream creation before reading the + // next frame. This guarantees curr is set for subsequent non-invoke + // packets. + m.pdone.Recv() + // TODO: reuse pending stream + delete(m.pendingStreams, fr.ID.Stream) + case <-m.sigs.term.Signal(): } + return nil } // @@ -307,88 +269,36 @@ func (m *Manager) newStream(ctx context.Context, sid uint64, kind drpc.StreamKin } stream := drpcstream.NewWithOptions(ctx, sid, m.wr, opts) - select { - case m.streams <- streamInfo{ctx: ctx, stream: stream}: - m.sbuf.Set(stream) - m.log("STREAM", stream.String) - return stream, nil - case <-m.sigs.term.Signal(): - return nil, m.sigs.term.Err() + if err := m.streams.Add(sid, stream); err != nil { + return nil, err } -} -// manageStreams reads from the streams channel for stream infos and runs the -// manageStream function on them. -func (m *Manager) manageStreams() { - defer m.sigs.stream.Set(nil) + m.wg.Add(1) + go m.manageStream(ctx, stream) - for { - select { - case si := <-m.streams: - m.manageStream(si.ctx, si.stream) + m.log("STREAM", stream.String) - case <-m.sigs.term.Signal(): - return - } - } + return stream, nil } // manageStream watches the context and the stream and returns when the stream // is finished, canceling the stream if the context is canceled. func (m *Manager) manageStream(ctx context.Context, stream *drpcstream.Stream) { + defer m.wg.Done() + defer m.streams.Remove(stream.ID()) select { - case <-m.sigs.term.Signal(): - err := m.sigs.term.Err() - if errors.Is(err, io.EOF) { - err = context.Canceled - if stream.Kind() == drpc.StreamKindClient { - err = drpc.ClosedError.New("connection closed") - } - } - stream.Cancel(err) - <-m.sfin - m.sem.Recv() - - case <-m.sfin: - m.sem.Recv() + case <-stream.Finished(): case <-ctx.Done(): m.log("CANCEL", stream.String) - if m.opts.SoftCancel { - // allow a new stream to begin. - m.sem.Recv() - - // attempt to send the soft cancel. if it fails or if the stream is - // busy sending something else, then we have to hard cancel. - if busy, err := stream.SendCancel(ctx.Err()); err != nil { - m.terminate(err) - } else if busy { - m.log("BUSY", stream.String) - m.terminate(ctx.Err()) - } - stream.Cancel(ctx.Err()) - - // wait for the stream to signal that it is finished. - <-m.sfin - } else { - // If the stream isn't already finished, we have to terminate the - // transport to do an active cancel. If it is already finished, - // there is no need. - if !stream.Cancel(ctx.Err()) { - m.log("UNFIN", stream.String) - m.terminate(ctx.Err()) - } else { - m.log("CLEAN", stream.String) - } - - // wait for the stream to signal that it is finished. - <-m.sfin - - // allow a new stream to begin. - m.sem.Recv() + if err := stream.SendCancel(ctx.Err()); err != nil { + // SendCancel can fail if it's an IO error which reader will catch. + m.log("SendCancel", func() string { return fmt.Sprintf("%s: %s", stream.String(), err) }) } + stream.Cancel(ctx.Err()) + <-stream.Finished() } } @@ -402,14 +312,12 @@ func (m *Manager) Closed() <-chan struct{} { } // Unblocked returns a channel that is closed when the manager is no longer -// blocked from creating a new stream due to a previous stream's soft cancel. It -// should not be called concurrently with NewClientStream or NewServerStream and -// the return result is only valid until the next call to NewClientStream or -// NewServerStream. +// blocked. With multiplexing, multiple streams run concurrently and this +// channel is always closed immediately. +// +// TODO(shubham): audit whether this is still useful in a multiplexing world. +// The only meaningful caller is Pool.Take. func (m *Manager) Unblocked() <-chan struct{} { - if prev := m.sbuf.Get(); prev != nil { - return prev.Context().Done() - } return closedCh } @@ -417,8 +325,9 @@ func (m *Manager) Unblocked() <-chan struct{} { func (m *Manager) Close() error { m.terminate(managerClosed.New("Close called")) - m.sigs.stream.Wait() - m.sigs.read.Wait() + <-m.wr.Done() // wait for writer goroutine to exit + m.wg.Wait() // wait for all stream goroutines + m.sigs.read.Wait() // wait for reader goroutine to exit m.sigs.tport.Wait() return m.sigs.tport.Err() @@ -426,89 +335,46 @@ func (m *Manager) Close() error { // NewClientStream starts a stream on the managed transport for use by a client. func (m *Manager) NewClientStream(ctx context.Context, rpc string) (stream *drpcstream.Stream, err error) { - if err := m.acquireSemaphore(ctx); err != nil { + if err := ctx.Err(); err != nil { return nil, err } - - return m.newStream(ctx, m.sbuf.Get().ID()+1, drpc.StreamKindClient, rpc) + return m.newStream(ctx, m.lastStreamID.Add(1), drpc.StreamKindClient, rpc) } // NewServerStream starts a stream on the managed transport for use by a server. // It does this by waiting for the client to issue an invoke message and // returning the details. func (m *Manager) NewServerStream(ctx context.Context) (stream *drpcstream.Stream, rpc string, err error) { - if err := m.acquireSemaphore(ctx); err != nil { - return nil, "", err - } - defer func() { - if err != nil { - m.sem.Recv() - } - }() - - var meta map[string]string - var metaID uint64 - var timeoutCh <-chan time.Time - - // set up the timeout on the context if necessary. - if timeout := m.opts.InactivityTimeout; timeout > 0 { - timer := time.NewTimer(timeout) - defer timer.Stop() - timeoutCh = timer.C - } - - for { - select { - case <-timeoutCh: - return nil, "", context.DeadlineExceeded - - case <-ctx.Done(): - return nil, "", ctx.Err() - - case <-m.sigs.term.Signal(): - return nil, "", m.sigs.term.Err() - - case pkt := <-m.pkts: - switch pkt.Kind { - // keep track of any metadata being sent before an invoke so that we - // can include it if the stream id matches the eventual invoke. - case drpcwire.KindInvokeMetadata: - meta, err = drpcmetadata.Decode(pkt.Data) - m.pdone.Send() + select { + case <-ctx.Done(): + return nil, "", ctx.Err() - if err != nil { - return nil, "", err - } - metaID = pkt.ID.Stream - - case drpcwire.KindInvoke: - rpc = string(pkt.Data) - m.pdone.Send() - - if metaID == pkt.ID.Stream { - if m.opts.GRPCMetadataCompatMode { - // Populate incoming metadata as grpc metadata in the - // context. This is a short-term fix that will enable us - // to send and receive grpc metadata when DRPC is enabled, - // without any changes in the calling code. - grpcMeta := make(map[string][]string, len(meta)) - for k, v := range meta { - grpcMeta[k] = []string{v} - } - ctx = grpcmetadata.NewIncomingContext(ctx, grpcMeta) - } else { - // Add metadata to the incoming context. - ctx = drpcmetadata.NewIncomingContext(ctx, meta) - } + case <-m.sigs.term.Signal(): + return nil, "", m.sigs.term.Err() + + case pkt := <-m.invokes: + rpc = string(pkt.data) + if pkt.metadata != nil { + if m.opts.GRPCMetadataCompatMode { + // Populate incoming metadata as grpc metadata in the + // context. This is a short-term fix that will enable us + // to send and receive grpc metadata when DRPC is enabled, + // without any changes in the calling code. + grpcMeta := make(map[string][]string, len(pkt.metadata)) + for k, v := range pkt.metadata { + grpcMeta[k] = []string{v} } - stream, err := m.newStream(ctx, pkt.ID.Stream, drpc.StreamKindServer, rpc) - return stream, rpc, err - - default: - // this should never happen, but defensive. - m.pdone.Send() + ctx = grpcmetadata.NewIncomingContext(ctx, grpcMeta) + } else { + // Add metadata to the incoming context. + ctx = drpcmetadata.NewIncomingContext(ctx, pkt.metadata) } } + stream, err := m.newStream(ctx, pkt.sid, drpc.StreamKindServer, rpc) + // Signal pdone only after adding the stream so that manageReader sees + // the new stream in activeStreams when it reads the next frame. + m.pdone.Send() + return stream, rpc, err } } diff --git a/drpcmanager/manager_test.go b/drpcmanager/manager_test.go index 5918113d..00ccd866 100644 --- a/drpcmanager/manager_test.go +++ b/drpcmanager/manager_test.go @@ -8,38 +8,18 @@ import ( "errors" "io" "net" - "sync" "testing" "time" "github.com/zeebo/assert" grpcmetadata "google.golang.org/grpc/metadata" + "storj.io/drpc" "storj.io/drpc/drpcmetadata" "storj.io/drpc/drpctest" "storj.io/drpc/drpcwire" ) -func closed(ch <-chan struct{}) bool { - select { - case <-ch: - return true - default: - return false - } -} - -func TestTimeout(t *testing.T) { - tr := make(blockingTransport) - man := NewWithOptions(tr, Options{ - InactivityTimeout: time.Millisecond, - }) - defer func() { _ = man.Close() }() - - _, _, err := man.NewServerStream(context.Background()) - assert.That(t, errors.Is(err, context.DeadlineExceeded)) -} - func TestDrpcMetadata(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() @@ -48,10 +28,10 @@ func TestDrpcMetadata(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - cman := New(cconn) + cman := New(cconn, Client) defer func() { _ = cman.Close() }() - sman := NewWithOptions(sconn, Options{ + sman := NewWithOptions(sconn, Server, Options{ GRPCMetadataCompatMode: false, }) defer func() { _ = sman.Close() }() @@ -68,11 +48,7 @@ func TestDrpcMetadata(t *testing.T) { assert.NoError(t, stream.RawWrite(drpcwire.KindInvokeMetadata, buf)) assert.NoError(t, stream.RawWrite(drpcwire.KindInvoke, []byte("invoke"))) assert.NoError(t, stream.RawWrite(drpcwire.KindMessage, []byte("message"))) - assert.NoError(t, stream.RawFlush()) - assert.That(t, !closed(cman.Unblocked())) - assert.NoError(t, stream.Close()) - assert.That(t, closed(cman.Unblocked())) }) ctx.Run(func(ctx context.Context) { @@ -108,10 +84,10 @@ func TestDrpcMetadataWithGRPCMetadataCompatMode(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - cman := New(cconn) + cman := New(cconn, Client) defer func() { _ = cman.Close() }() - sman := NewWithOptions(sconn, Options{ + sman := NewWithOptions(sconn, Server, Options{ GRPCMetadataCompatMode: true, }) defer func() { _ = sman.Close() }() @@ -128,11 +104,7 @@ func TestDrpcMetadataWithGRPCMetadataCompatMode(t *testing.T) { assert.NoError(t, stream.RawWrite(drpcwire.KindInvokeMetadata, buf)) assert.NoError(t, stream.RawWrite(drpcwire.KindInvoke, []byte("invoke"))) assert.NoError(t, stream.RawWrite(drpcwire.KindMessage, []byte("message"))) - assert.NoError(t, stream.RawFlush()) - assert.That(t, !closed(cman.Unblocked())) - assert.NoError(t, stream.Close()) - assert.That(t, closed(cman.Unblocked())) }) ctx.Run(func(ctx context.Context) { @@ -161,13 +133,44 @@ func TestDrpcMetadataWithGRPCMetadataCompatMode(t *testing.T) { ctx.Wait() } -type blockingTransport chan struct{} +// writeFrames serializes the given frames and writes them to w. +func writeFrames(t *testing.T, w io.Writer, frames ...drpcwire.Frame) { + t.Helper() + var buf []byte + for _, fr := range frames { + buf = drpcwire.AppendFrame(buf, fr) + } + _, err := w.Write(buf) + assert.NoError(t, err) +} + +// createFrame is a shorthand for constructing a Frame. +func createFrame(kind drpcwire.Kind, sid, mid uint64, data string, done bool) drpcwire.Frame { + return drpcwire.Frame{ + ID: drpcwire.ID{Stream: sid, Message: mid}, + Kind: kind, + Data: []byte(data), + Done: done, + } +} + +// waitForClosed blocks until the manager terminates or the timeout expires. +func waitForClosed(t *testing.T, man *Manager) { + t.Helper() + select { + case <-man.Closed(): + case <-time.After(5 * time.Second): + t.Fatal("manager did not terminate in time") + } +} -func (b blockingTransport) Read(p []byte) (n int, err error) { <-b; return 0, io.EOF } -func (b blockingTransport) Write(p []byte) (n int, err error) { <-b; return 0, io.EOF } -func (b blockingTransport) Close() error { close(b); return nil } +// +// manageReader tests +// -func TestUnblocked_NoCancel(t *testing.T) { +// Global frame monotonicity: a frame with an ID lower than the last seen +// frame causes the manager to terminate with a protocol error. +func TestManageReader_GlobalMonotonicity_SameStream(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() @@ -175,123 +178,355 @@ func TestUnblocked_NoCancel(t *testing.T) { defer func() { _ = cconn.Close() }() defer func() { _ = sconn.Close() }() - cman := New(cconn) - defer func() { _ = cman.Close() }() - - sman := New(sconn) - defer func() { _ = sman.Close() }() + man := New(sconn, Server) + defer func() { _ = man.Close() }() + // Consume the invoke and drain messages so HandleFrame doesn't block. ctx.Run(func(ctx context.Context) { - stream, err := cman.NewClientStream(ctx, "rpc") + stream, _, err := man.NewServerStream(ctx) assert.NoError(t, err) - defer func() { _ = stream.Close() }() + for { + if _, err := stream.RawRecv(); err != nil { + return + } + } + }) - assert.NoError(t, stream.RawWrite(drpcwire.KindInvoke, []byte("invoke"))) - assert.NoError(t, stream.RawWrite(drpcwire.KindMessage, []byte("message"))) - assert.NoError(t, stream.RawFlush()) - assert.That(t, !closed(cman.Unblocked())) + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindMessage, 1, 5, "ok", true), + createFrame(drpcwire.KindMessage, 1, 4, "bad", true), + ) - assert.NoError(t, stream.Close()) - assert.That(t, closed(cman.Unblocked())) + waitForClosed(t, man) +} + +// Invoke replay: after [s1,m1,invoke,done=true], lastFrameID is bumped to +// {1,2}. A replayed [s1,m1,invoke] is caught by the monotonicity check. +func TestManageReader_InvokeReplayBlocked(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn, Server) + defer func() { _ = man.Close() }() + + ctx.Run(func(ctx context.Context) { + _, _, _ = man.NewServerStream(ctx) }) + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + ) + + waitForClosed(t, man) +} + +// Non-done frames don't bump the message ID, so continuation frames with +// the same ID are accepted. +func TestManageReader_ContinuationFramesAccepted(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn, Server) + defer func() { _ = man.Close() }() + + recv := make(chan []byte, 1) ctx.Run(func(ctx context.Context) { - stream, _, err := sman.NewServerStream(ctx) + stream, _, err := man.NewServerStream(ctx) assert.NoError(t, err) - defer func() { _ = stream.Close() }() - - _, err = stream.RawRecv() + data, err := stream.RawRecv() assert.NoError(t, err) - - _, err = stream.RawRecv() - assert.That(t, errors.Is(err, io.EOF)) + recv <- data }) - ctx.Wait() + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindMessage, 1, 2, "hel", false), + createFrame(drpcwire.KindMessage, 1, 2, "lo", true), + ) + + assert.DeepEqual(t, <-recv, []byte("hello")) } -func TestUnblocked_SoftCancel(t *testing.T) { - run := func(t *testing.T, softCancel bool) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - tr := newBlockedTransport() - man := NewWithOptions(tr, Options{SoftCancel: softCancel}) - defer func() { _ = man.Close() }() - defer tr.setReadOpen(true) - defer tr.setWriteOpen(true) - - for i := 0; i < 10; i++ { - func() { - subctx, cancel := context.WithCancel(ctx) - defer cancel() - - stream, err := man.NewClientStream(subctx, "rpc") - if softCancel { - assert.NoError(t, err) - } else if i > 0 { - assert.Error(t, err) - return - } - defer func() { _ = stream.Close() }() - - assert.That(t, !closed(man.Unblocked())) - cancel() - - // temporary unblock writing to allow the stream to finish soft cancel - tr.setWriteOpen(true) - <-man.Unblocked() - tr.setWriteOpen(false) - }() +// Old-stream frames are silently ignored when the stream has been cancelled +// and removed from the registry. +func TestManageReader_OldStreamFramesIgnored(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + cman := New(cconn, Client) + defer func() { _ = cman.Close() }() + + // Drain all client writes so nothing blocks. + ctx.Run(func(ctx context.Context) { + buf := make([]byte, 4096) + for { + _, err := sconn.Read(buf) + if err != nil { + return + } } - } + }) - t.Run("Enabled", func(t *testing.T) { run(t, true) }) - t.Run("Disabled", func(t *testing.T) { run(t, false) }) + // Create stream 1 on the client, then cancel it so it's removed + // from the registry. + subctx, cancel := context.WithCancel(ctx) + stream1, err := cman.NewClientStream(subctx, "rpc1") + assert.NoError(t, err) + cancel() + <-stream1.Finished() + + stream2, err := cman.NewClientStream(ctx, "rpc2") + assert.NoError(t, err) + + // Write an old-stream frame (s1) then the real response for s2. + // The s1 frame should be silently ignored by the client manager. + writeFrames(t, sconn, + createFrame(drpcwire.KindMessage, 1, 1, "old", true), + createFrame(drpcwire.KindMessage, 2, 1, "new", true), + ) + + data, err := stream2.RawRecv() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("new")) + + _ = stream2.Close() } -type blockedTransport struct { - mu *sync.Mutex - co *sync.Cond - ro bool - wo bool +// A valid invoke sequence: Invoke → Message. +// Metadata encoding is covered separately by TestDrpcMetadata. +func TestManageReader_ValidInvokeSequence(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn, Server) + defer func() { _ = man.Close() }() + + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + stream, rpc, err := man.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, rpc, "myrpc") + + data, err := stream.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "myrpc", true), + createFrame(drpcwire.KindMessage, 1, 2, "payload", true), + ) + + assert.DeepEqual(t, <-recv, []byte("payload")) } -func newBlockedTransport() *blockedTransport { - mu := new(sync.Mutex) - co := sync.NewCond(mu) - return &blockedTransport{ - mu: mu, - co: co, - } +// Multi-frame message delivered through manager to stream: frames are +// assembled by the stream into a single packet. +func TestManageReader_MultiFrameDelivery(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn, Server) + defer func() { _ = man.Close() }() + + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + stream, _, err := man.NewServerStream(ctx) + assert.NoError(t, err) + + data, err := stream.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindMessage, 1, 2, "hel", false), + createFrame(drpcwire.KindMessage, 1, 2, "lo ", false), + createFrame(drpcwire.KindMessage, 1, 2, "world", true), + ) + + assert.DeepEqual(t, <-recv, []byte("hello world")) } -func (b *blockedTransport) setWriteOpen(open bool) { - b.mu.Lock() - defer b.mu.Unlock() +// When a higher message ID arrives mid-assembly, the partial data is +// discarded and only the new message is delivered. +func TestManageReader_HigherMsgDiscardsInProgress(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn, Server) + defer func() { _ = man.Close() }() - b.wo = open - b.co.Broadcast() + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + stream, _, err := man.NewServerStream(ctx) + assert.NoError(t, err) + data, err := stream.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindMessage, 1, 2, "discard", false), + createFrame(drpcwire.KindMessage, 1, 3, "kept", true), + ) + + assert.DeepEqual(t, <-recv, []byte("kept")) } -func (b *blockedTransport) setReadOpen(open bool) { - b.mu.Lock() - defer b.mu.Unlock() +// A continuation frame with a different kind than the first frame of the +// packet causes the manager to terminate with a protocol error. +func TestManageReader_KindChangeWithinPacket(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() - b.ro = open - b.co.Broadcast() + man := New(sconn, Server) + defer func() { _ = man.Close() }() + + ctx.Run(func(ctx context.Context) { + stream, _, err := man.NewServerStream(ctx) + assert.NoError(t, err) + for { + if _, err := stream.RawRecv(); err != nil { + return + } + } + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindMessage, 1, 2, "data", false), + createFrame(drpcwire.KindClose, 1, 2, "", true), + ) + + waitForClosed(t, man) } -func (b *blockedTransport) wait(p int, rw *bool) (int, error) { - b.mu.Lock() - defer b.mu.Unlock() +// Multi-frame assembly works correctly when the message ID is greater than +// the previous message (e.g., on the server side where invoke consumed +// earlier IDs). +func TestManageReader_MultiFrameWithSkippedMessageID(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() - for !*rw { - b.co.Wait() - } - return p, nil + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn, Server) + defer func() { _ = man.Close() }() + + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + stream, _, err := man.NewServerStream(ctx) + assert.NoError(t, err) + data, err := stream.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindMessage, 1, 3, "hel", false), + createFrame(drpcwire.KindMessage, 1, 3, "lo", true), + ) + + assert.DeepEqual(t, <-recv, []byte("hello")) +} + +// A second invoke for the same stream ID is rejected — the stream treats +// it as a protocol error, terminating the manager. +func TestManageReader_InvokeOnExistingStream(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn, Server) + defer func() { _ = man.Close() }() + + ctx.Run(func(ctx context.Context) { + stream, _, err := man.NewServerStream(ctx) + assert.NoError(t, err) + _ = stream + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc1", true), + createFrame(drpcwire.KindInvoke, 1, 2, "rpc2", true), + ) + + waitForClosed(t, man) + assert.That(t, drpc.ProtocolError.Has(man.sigs.term.Err())) } -func (b *blockedTransport) Read(p []byte) (n int, err error) { return b.wait(len(p), &b.ro) } -func (b *blockedTransport) Write(p []byte) (n int, err error) { return b.wait(len(p), &b.wo) } -func (b *blockedTransport) Close() error { return nil } +// When a non-invoke frame arrives before the stream is created (e.g., +// NewServerStream hasn't returned yet), manageReader waits for the stream +// and retries. +func TestManageReader_WaitsForStreamCreation(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn, Server) + defer func() { _ = man.Close() }() + + // Write invoke + message immediately. The message arrives before + // NewServerStream creates the stream, exercising the default/wait path. + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindMessage, 1, 2, "data", true), + ) + + // Small delay to let manageReader process both frames. + time.Sleep(10 * time.Millisecond) + + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + stream, _, err := man.NewServerStream(ctx) + assert.NoError(t, err) + + data, err := stream.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + assert.DeepEqual(t, <-recv, []byte("data")) +} diff --git a/drpcmanager/random_test.go b/drpcmanager/random_test.go index f0e140c4..af63bf4d 100644 --- a/drpcmanager/random_test.go +++ b/drpcmanager/random_test.go @@ -22,10 +22,12 @@ import ( ) func TestRandomized_Client(t *testing.T) { + t.Skip("disabled as the generated random workload violates the wire protocol") runRandomized(t, randomBytes(time.Now().UnixNano(), 1024), new(randClient)) } func TestRandomized_Server(t *testing.T) { + t.Skip("disabled as the generated random workload violates the wire protocol") runRandomized(t, randomBytes(time.Now().UnixNano(), 1024), new(randServer)) } @@ -43,14 +45,15 @@ func (rc *randClient) newSteam(ctx context.Context, man *Manager) (*drpcstream.S return stream, err } -func (rc *randClient) execute(t *testing.T, wr *drpcwire.Writer, op byte) { +func (rc *randClient) execute(t *testing.T, wr *drpcwire.MuxWriter, op byte) { cmd, arg, done := parseOp(op) if !rc.active { - assert.NoError(t, wr.WritePacket(drpcwire.Packet{ + assert.NoError(t, wr.WriteFrame(drpcwire.Frame{ Data: make([]byte, arg), ID: rc.id.incMessage(), Kind: drpcwire.KindInvoke, + Done: true, })) rc.active = true } @@ -58,9 +61,10 @@ func (rc *randClient) execute(t *testing.T, wr *drpcwire.Writer, op byte) { switch cmd { case 0: // new invoke if rc.active { - assert.NoError(t, wr.WritePacket(drpcwire.Packet{ + assert.NoError(t, wr.WriteFrame(drpcwire.Frame{ ID: rc.id.incMessage(), Kind: drpcwire.KindClose, + Done: true, })) } @@ -97,10 +101,11 @@ func (rc *randClient) execute(t *testing.T, wr *drpcwire.Writer, op byte) { })) case 2: // cause the remote side to close - assert.NoError(t, wr.WritePacket(drpcwire.Packet{ + assert.NoError(t, wr.WriteFrame(drpcwire.Frame{ Data: []byte("remote-close"), ID: rc.id.incMessage(), Kind: drpcwire.KindMessage, + Done: true, })) case 3, 4, 5, 6, 7: // send normal message @@ -128,7 +133,7 @@ func (rs *randServer) newSteam(ctx context.Context, man *Manager) (*drpcstream.S return man.NewClientStream(ctx, "rpc") } -func (rs *randServer) execute(t *testing.T, wr *drpcwire.Writer, op byte) { +func (rs *randServer) execute(t *testing.T, wr *drpcwire.MuxWriter, op byte) { cmd, arg, done := parseOp(op) switch cmd { @@ -158,10 +163,11 @@ func (rs *randServer) execute(t *testing.T, wr *drpcwire.Writer, op byte) { })) case 2: // cause the remote side to close - assert.NoError(t, wr.WritePacket(drpcwire.Packet{ + assert.NoError(t, wr.WriteFrame(drpcwire.Frame{ Data: []byte("remote-close"), ID: rs.id.incMessage(), Kind: drpcwire.KindMessage, + Done: true, })) case 3, 4, 5, 6, 7: // send random message @@ -183,7 +189,7 @@ func (rs *randServer) execute(t *testing.T, wr *drpcwire.Writer, op byte) { type runner interface { newSteam(ctx context.Context, man *Manager) (*drpcstream.Stream, error) - execute(t *testing.T, wr *drpcwire.Writer, op byte) + execute(t *testing.T, wr *drpcwire.MuxWriter, op byte) } func runRandomized(t *testing.T, prog []byte, r runner) { @@ -194,8 +200,10 @@ func runRandomized(t *testing.T, prog []byte, r runner) { defer func() { _ = pc.Close() }() defer func() { _ = ps.Close() }() - wr := drpcwire.NewWriter(pc, 0) - man := New(ps) + wr := drpcwire.NewMuxWriter(pc, func(error) {}) + defer func() { wr.Stop(nil); <-wr.Done() }() + + man := New(ps, Server) defer func() { _ = man.Close() }() errch := make(chan error, 1) @@ -221,7 +229,6 @@ func runRandomized(t *testing.T, prog []byte, r runner) { for _, op := range prog { r.execute(t, wr, op) - assert.NoError(t, wr.Flush()) } assert.NoError(t, man.Close()) diff --git a/drpcmanager/streambuf.go b/drpcmanager/streambuf.go deleted file mode 100644 index fd0329b6..00000000 --- a/drpcmanager/streambuf.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright (C) 2019 Storj Labs, Inc. -// See LICENSE for copying information. - -package drpcmanager - -import ( - "sync" - "sync/atomic" - - "storj.io/drpc/drpcstream" -) - -type streamBuffer struct { - mu sync.Mutex - cond sync.Cond - stream atomic.Pointer[drpcstream.Stream] - closed bool -} - -func (sb *streamBuffer) init() { - sb.cond.L = &sb.mu -} - -func (sb *streamBuffer) Close() { - sb.mu.Lock() - defer sb.mu.Unlock() - - sb.closed = true - sb.cond.Broadcast() -} - -func (sb *streamBuffer) Get() *drpcstream.Stream { - return sb.stream.Load() -} - -func (sb *streamBuffer) Set(stream *drpcstream.Stream) { - sb.mu.Lock() - defer sb.mu.Unlock() - - if sb.closed { - return - } - - sb.stream.Store(stream) - sb.cond.Broadcast() -} - -func (sb *streamBuffer) Wait(sid uint64) bool { - sb.mu.Lock() - defer sb.mu.Unlock() - - for !sb.closed && sb.Get().ID() == sid { - sb.cond.Wait() - } - - return !sb.closed -} diff --git a/drpcserver/server.go b/drpcserver/server.go index 75e034ff..83c11525 100644 --- a/drpcserver/server.go +++ b/drpcserver/server.go @@ -12,7 +12,6 @@ import ( "github.com/zeebo/errs" "storj.io/drpc" - "storj.io/drpc/drpccache" "storj.io/drpc/drpcctx" "storj.io/drpc/drpcmanager" "storj.io/drpc/drpcmetrics" @@ -161,22 +160,28 @@ func (s *Server) ServeOne(ctx context.Context, tr drpc.Transport) (err error) { } } - man := drpcmanager.NewWithOptions(tr, s.opts.Manager) - defer func() { err = errs.Combine(err, man.Close()) }() - - cache := drpccache.New() - defer cache.Clear() - - ctx = drpccache.WithContext(ctx, cache) + man := drpcmanager.NewWithOptions(tr, drpcmanager.Server, s.opts.Manager) + var wg sync.WaitGroup + defer func() { + wg.Wait() + err = errs.Combine(err, man.Close()) + }() for { stream, rpc, err := man.NewServerStream(ctx) if err != nil { return errs.Wrap(err) } - if err := s.handleRPC(stream, rpc); err != nil { - return errs.Wrap(err) - } + // TODO: add worker pool + wg.Add(1) + go func() { + defer wg.Done() + if err := s.handleRPC(stream, rpc); err != nil { + if s.opts.Log != nil { + s.opts.Log(err) + } + } + }() } } diff --git a/drpcstream/pktbuf.go b/drpcstream/pktbuf.go deleted file mode 100644 index db688649..00000000 --- a/drpcstream/pktbuf.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (C) 2019 Storj Labs, Inc. -// See LICENSE for copying information. - -package drpcstream - -import ( - "sync" -) - -type packetBuffer struct { - mu sync.Mutex - cond sync.Cond - err error - data []byte - set bool - held bool -} - -func (pb *packetBuffer) init() { - pb.cond.L = &pb.mu -} - -func (pb *packetBuffer) Close(err error) { - pb.mu.Lock() - defer pb.mu.Unlock() - - for pb.held { - pb.cond.Wait() - } - - if pb.err == nil { - pb.data = nil - pb.set = false - pb.err = err - pb.cond.Broadcast() - } -} - -func (pb *packetBuffer) Put(data []byte) { - pb.mu.Lock() - defer pb.mu.Unlock() - - for pb.set && pb.err == nil { - pb.cond.Wait() - } - if pb.err != nil { - return - } - - pb.data = data - pb.set = true - pb.held = false - pb.cond.Broadcast() - - for pb.set || pb.held { - pb.cond.Wait() - } -} - -func (pb *packetBuffer) Get() ([]byte, error) { - pb.mu.Lock() - defer pb.mu.Unlock() - - for !pb.set && pb.err == nil { - pb.cond.Wait() - } - if pb.err != nil { - return nil, pb.err - } - - pb.held = true - pb.cond.Broadcast() - - return pb.data, nil -} - -func (pb *packetBuffer) Done() { - pb.mu.Lock() - defer pb.mu.Unlock() - - pb.data = nil - pb.set = false - pb.held = false - pb.cond.Broadcast() -} diff --git a/drpcstream/ring_buffer.go b/drpcstream/ring_buffer.go new file mode 100644 index 00000000..5cb620ab --- /dev/null +++ b/drpcstream/ring_buffer.go @@ -0,0 +1,116 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcstream + +import "sync" + +// defaultRingBufferCapacity is the number of messages the ring buffer can +// hold before the producer blocks. This decouples the transport reader +// (manageReader) from the consumer (RPC handler), preventing a slow handler +// from blocking frame delivery to other streams. +// +// TODO: benchmark whether power-of-2 masking improves performance over modulo. +const defaultRingBufferCapacity = 256 + +// ringBuffer is a bounded single-producer / single-consumer FIFO queue for +// assembled packet data. It sits between manageReader (producer, calls +// Enqueue) and the application goroutine (consumer, calls Dequeue/Done). +// +// Slots are pre-allocated and reused: each slot's backing array grows via +// append to fit incoming data, then stays at its high-water mark, avoiding +// per-message allocation in steady state. +// +// After Close, Dequeue drains any queued messages before returning the close +// error. This ensures graceful shutdown (KindClose/KindCloseSend) delivers +// all buffered data to the consumer. +type ringBuffer struct { + mu sync.Mutex + cond sync.Cond + + buf [][]byte // ring of byte slices + head int // next write position (producer) + tail int // next read position (consumer) + count int // number of occupied slots + + held bool // true between Dequeue and Done + err error // terminal error, set by Close +} + +func (rb *ringBuffer) init() { + rb.cond.L = &rb.mu + rb.buf = make([][]byte, defaultRingBufferCapacity) +} + +// Enqueue copies data into the next write slot. If the buffer is full, it +// blocks until a slot is freed or the buffer is closed. If the buffer is +// closed, Enqueue returns silently without enqueuing. +func (rb *ringBuffer) Enqueue(data []byte) { + rb.mu.Lock() + defer rb.mu.Unlock() + + for rb.count == len(rb.buf) && rb.err == nil { + rb.cond.Wait() + } + if rb.err != nil { + return + } + + rb.buf[rb.head] = append(rb.buf[rb.head][:0], data...) + rb.head = (rb.head + 1) % len(rb.buf) + rb.count++ + rb.cond.Broadcast() +} + +// Dequeue returns the data from the next read slot. If the buffer is empty, +// it blocks until data is available or the buffer is closed. The returned +// slice is valid until Done is called. +func (rb *ringBuffer) Dequeue() ([]byte, error) { + rb.mu.Lock() + defer rb.mu.Unlock() + + for rb.count == 0 && rb.err == nil { + rb.cond.Wait() + } + if rb.count == 0 && rb.err != nil { + return nil, rb.err + } + + rb.held = true + return rb.buf[rb.tail], nil +} + +// Done advances the read pointer, making the slot available for reuse. +// It must be called exactly once after each successful Dequeue. +// +// TODO(shubham): remove this method once a shared buffer pool is introduced. +// With a pool, Dequeue will advance the tail immediately and the caller will +// return the buffer to the pool directly. +func (rb *ringBuffer) Done() { + rb.mu.Lock() + defer rb.mu.Unlock() + + rb.tail = (rb.tail + 1) % len(rb.buf) + rb.count-- + rb.held = false + rb.cond.Broadcast() +} + +// Close marks the buffer as closed with the given error. All blocked Enqueue +// and Dequeue calls are woken and will return. Close waits for any in-progress +// Dequeue/Done pair to complete before setting the error. Subsequent calls are +// no-ops. +func (rb *ringBuffer) Close(err error) { + rb.mu.Lock() + defer rb.mu.Unlock() + + for rb.held { + rb.cond.Wait() + } + if rb.err != nil { + return + } + + rb.err = err + rb.cond.Broadcast() +} diff --git a/drpcstream/ring_buffer_test.go b/drpcstream/ring_buffer_test.go new file mode 100644 index 00000000..8be9c587 --- /dev/null +++ b/drpcstream/ring_buffer_test.go @@ -0,0 +1,228 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcstream + +import ( + "io" + "sync" + "testing" + + "github.com/zeebo/assert" +) + +func TestRingBuffer_EnqueueDequeue(t *testing.T) { + var rb ringBuffer + rb.init() + + rb.Enqueue([]byte("hello")) + + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("hello")) + rb.Done() +} + +func TestRingBuffer_FIFO(t *testing.T) { + var rb ringBuffer + rb.init() + + rb.Enqueue([]byte("first")) + rb.Enqueue([]byte("second")) + rb.Enqueue([]byte("third")) + + for _, want := range []string{"first", "second", "third"} { + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte(want)) + rb.Done() + } +} + +func TestRingBuffer_DequeueBlocksUntilEnqueue(t *testing.T) { + var rb ringBuffer + rb.init() + + got := make(chan []byte, 1) + go func() { + data, err := rb.Dequeue() + assert.NoError(t, err) + got <- data + }() + + rb.Enqueue([]byte("delayed")) + assert.DeepEqual(t, <-got, []byte("delayed")) + rb.Done() +} + +func TestRingBuffer_EnqueueBlocksWhenFull(t *testing.T) { + var rb ringBuffer + rb.cond.L = &rb.mu + rb.buf = make([][]byte, 2) // capacity 2 + + rb.Enqueue([]byte("a")) + rb.Enqueue([]byte("b")) + + // Third enqueue should block until we drain one. + done := make(chan struct{}) + go func() { + rb.Enqueue([]byte("c")) + close(done) + }() + + // Drain one slot. + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("a")) + rb.Done() + + // Now the blocked Enqueue should complete. + <-done + + // Verify remaining items. + data, err = rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("b")) + rb.Done() + + data, err = rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("c")) + rb.Done() +} + +func TestRingBuffer_CloseUnblocksDequeue(t *testing.T) { + var rb ringBuffer + rb.init() + + errch := make(chan error, 1) + go func() { + _, err := rb.Dequeue() + errch <- err + }() + + rb.Close(io.EOF) + assert.Equal(t, <-errch, io.EOF) +} + +func TestRingBuffer_CloseUnblocksEnqueue(t *testing.T) { + var rb ringBuffer + rb.cond.L = &rb.mu + rb.buf = make([][]byte, 1) // capacity 1 + + rb.Enqueue([]byte("fill")) + + done := make(chan struct{}) + go func() { + rb.Enqueue([]byte("blocked")) + close(done) + }() + + rb.Close(io.EOF) + <-done +} + +func TestRingBuffer_CloseDrainsQueued(t *testing.T) { + var rb ringBuffer + rb.init() + + rb.Enqueue([]byte("queued")) + rb.Close(io.EOF) + + // Dequeue returns the queued data first. + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("queued")) + rb.Done() + + // Next Dequeue returns the close error. + data, err = rb.Dequeue() + assert.Nil(t, data) + assert.Equal(t, err, io.EOF) +} + +func TestRingBuffer_CloseIdempotent(t *testing.T) { + var rb ringBuffer + rb.init() + + rb.Close(io.EOF) + rb.Close(io.ErrUnexpectedEOF) // should not overwrite + + _, err := rb.Dequeue() + assert.Equal(t, err, io.EOF) // original error preserved +} + +func TestRingBuffer_EnqueueAfterClose(t *testing.T) { + var rb ringBuffer + rb.init() + + rb.Close(io.EOF) + rb.Enqueue([]byte("dropped")) // should not panic or block +} + +func TestRingBuffer_SlotReuse(t *testing.T) { + var rb ringBuffer + rb.cond.L = &rb.mu + rb.buf = make([][]byte, 2) + + // Fill and drain a few rounds to exercise slot reuse. + for round := 0; round < 5; round++ { + rb.Enqueue([]byte("data")) + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("data")) + rb.Done() + } +} + +func TestRingBuffer_CloseWaitsForHeld(t *testing.T) { + var rb ringBuffer + rb.init() + + rb.Enqueue([]byte("msg")) + + // Dequeue the data but don't call Done yet. + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("msg")) + + closed := make(chan struct{}) + go func() { + rb.Close(io.EOF) + close(closed) + }() + + // Close should be blocked because held is true. + // Call Done to release it. + rb.Done() + <-closed +} + +func TestRingBuffer_ConcurrentProducerConsumer(t *testing.T) { + var rb ringBuffer + rb.init() + + const n = 1000 + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + for i := 0; i < n; i++ { + rb.Enqueue([]byte{byte(i)}) + } + }() + + go func() { + defer wg.Done() + for i := 0; i < n; i++ { + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.Equal(t, data[0], byte(i)) + rb.Done() + } + }() + + wg.Wait() + rb.Close(io.EOF) +} diff --git a/drpcstream/stream.go b/drpcstream/stream.go index 29ccd636..1fa0460f 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -23,15 +23,9 @@ import ( // Options controls configuration settings for a stream. type Options struct { - // SplitSize controls the default size we split packets into frames. + // SplitSize controls the default size we split data packets into frames. SplitSize int - // ManualFlush controls if the stream will automatically flush after every - // message send. Note that flushing is not part of the drpc.Stream - // interface, so if you use this you must be ready to type assert and call - // RawFlush dynamically. - ManualFlush bool - // MaximumBufferSize causes the Stream to drop any internal buffers that are // larger than this amount to control maximum memory usage at the expense of // more allocations. 0 is unlimited. @@ -45,22 +39,34 @@ type Options struct { type Stream struct { ctx streamCtx opts Options - fin chan<- struct{} task *trace.Task + // write and read serialize operations within a stream. The data path + // (MsgSend/MsgRecv) and the control path (SendCancel/Close/SendError) + // genuinely race because cancellation arrives from manageStream while the + // application may be mid-send. These are inspectMutex (not sync.Mutex) so + // that checkFinished can test whether ops are in flight without blocking. write inspectMutex read inspectMutex - flush sync.Once - id drpcwire.ID - wr *drpcwire.Writer - pbuf packetBuffer - wbuf []byte + pa drpcwire.PacketAssembler + + id drpcwire.ID + wr *drpcwire.MuxWriter + recvQueue ringBuffer + wbuf []byte mu sync.Mutex // protects state transitions sigs struct { - send drpcsignal.Signal // set when done sending messages - recv drpcsignal.Signal // set when done receiving messages + send drpcsignal.Signal // set when done sending messages + recv drpcsignal.Signal // set when done receiving messages + // Stream shutdown is two-phase: term then fin. When termination arrives + // (remote error, local cancel, close), there may be an in-flight write + // on the transport that is past the term check and inside WriteFrame. + // term tells new operations to bail out; fin signals that all in-flight + // operations have actually completed. Consumers (manageStream) wait on + // fin before cleaning up, guaranteeing no goroutine is touching the + // stream afterward. term drpcsignal.Signal // set when the stream is terminating and no new ops should begin fin drpcsignal.Signal // set when the stream is finished and all ops are complete cancel drpcsignal.Signal // set when externally canceled @@ -72,7 +78,7 @@ var _ drpc.Stream = (*Stream)(nil) // New returns a new stream bound to the context with the given stream id and // will use the writer to write messages on. It is important use monotonically // increasing stream ids within a single transport. -func New(ctx context.Context, sid uint64, wr *drpcwire.Writer) *Stream { +func New(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter) *Stream { return NewWithOptions(ctx, sid, wr, Options{}) } @@ -80,7 +86,7 @@ func New(ctx context.Context, sid uint64, wr *drpcwire.Writer) *Stream { // stream id and will use the writer to write messages on. It is important use // monotonically increasing stream ids within a single transport. The options // are used to control details of how the Stream operates. -func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts Options) *Stream { +func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, opts Options) *Stream { var task *trace.Task if trace.IsEnabled() { kind, rpc := drpcopts.GetStreamKind(&opts.Internal), drpcopts.GetStreamRPC(&opts.Internal) @@ -89,21 +95,25 @@ func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts O } } + pa := drpcwire.NewPacketAssembler() + pa.SetStreamID(sid) + s := &Stream{ ctx: streamCtx{ Context: ctx, tr: drpcopts.GetStreamTransport(&opts.Internal), }, opts: opts, - fin: drpcopts.GetStreamFin(&opts.Internal), task: task, + pa: pa, + id: drpcwire.ID{Stream: sid}, - wr: wr.Reset(), + wr: wr, } // initialize the packet buffer - s.pbuf.init() + s.recvQueue.init() return s } @@ -186,53 +196,37 @@ func (s *Stream) Finished() <-chan struct{} { return s.sigs.fin.Signal() } // issue any writes or reads. func (s *Stream) IsFinished() bool { return s.sigs.fin.IsSet() } -// SetManualFlush sets the ManualFlush option. It cannot be called concurrently -// with any sends or receives on the stream. Example use case: -// -// flusher := stream.(interface{ -// GetStream() drpc.Stream -// }).GetStream().(interface{ -// SetManualFlush(bool) -// }) // -// flusher.SetManualFlush(true) -// err = stream.Send(&pb.Message{Request: "hello, "}) -// flusher.SetManualFlush(false) -// if err != nil { -// return err -// } +// frame handler // -// // the next send will send both messages in the same write -// // to the underlying connection. -// err = stream.Send(&pb.Message{Request: "world!"}) -// if err != nil { -// return err -// } -func (s *Stream) SetManualFlush(mf bool) { s.opts.ManualFlush = mf } -// -// packet handler -// - -// HandlePacket advances the stream state machine by inspecting the packet. It -// returns any major errors that should terminate the transport the stream is -// operating on as well as a boolean indicating if the stream expects more -// packets. -func (s *Stream) HandlePacket(pkt drpcwire.Packet) (err error) { - if pkt.ID.Stream != s.id.Stream { +// HandleFrame processes an incoming frame, assembling multi-frame packets +// and dispatching complete packets to the stream state machine. +func (s *Stream) HandleFrame(fr drpcwire.Frame) (err error) { + if s.sigs.term.IsSet() { return nil } - drpcopts.GetStreamStats(&s.opts.Internal).AddRead(uint64(len(pkt.Data))) - - if s.sigs.term.IsSet() { + packet, packetReady, err := s.pa.AppendFrame(fr) + if err != nil { + return err + } + if !packetReady { return nil } + return s.handlePacket(packet) +} + +// handlePacket advances the stream state machine by inspecting the packet. It +// returns any major errors that should terminate the transport the stream is +// operating on. +func (s *Stream) handlePacket(pkt drpcwire.Packet) (err error) { + drpcopts.GetStreamStats(&s.opts.Internal).AddRead(uint64(len(pkt.Data))) s.log("HANDLE", pkt.String) if pkt.Kind == drpcwire.KindMessage { - s.pbuf.Put(pkt.Data) + s.recvQueue.Enqueue(pkt.Data) return nil } @@ -240,7 +234,7 @@ func (s *Stream) HandlePacket(pkt drpcwire.Packet) (err error) { defer s.mu.Unlock() switch pkt.Kind { - case drpcwire.KindInvoke: + case drpcwire.KindInvoke, drpcwire.KindInvokeMetadata: err := drpc.ProtocolError.New("invoke on existing stream") s.terminate(err) return err @@ -260,13 +254,13 @@ func (s *Stream) HandlePacket(pkt drpcwire.Packet) (err error) { case drpcwire.KindClose: s.sigs.recv.Set(io.EOF) - s.pbuf.Close(io.EOF) + s.recvQueue.Close(io.EOF) s.terminate(drpc.ClosedError.New("remote closed the stream")) return nil case drpcwire.KindCloseSend: s.sigs.recv.Set(io.EOF) - s.pbuf.Close(io.EOF) + s.recvQueue.Close(io.EOF) s.terminateIfBothClosed() return nil @@ -286,17 +280,16 @@ func (s *Stream) HandlePacket(pkt drpcwire.Packet) (err error) { // helpers // -// checkFinished checks to see if the stream is terminated, and if so, sets the -// finished flag. This must be called after every read or write is complete, as -// well as when the stream becomes terminated. +// checkFinished bridges the two-phase shutdown. It is called in two places: +// inside terminate() for when no I/O is in flight (fin fires immediately), +// and deferred after every read/write unlock for when an operation was in +// flight at termination time (fin fires once the last operation completes). +// Whichever call site runs last sees term set and both locks free, and sets fin. func (s *Stream) checkFinished() { if s.sigs.term.IsSet() && s.write.Unlocked() && s.read.Unlocked() { if s.sigs.fin.Set(nil) { s.log("FIN", func() string { return "" }) s.ctx.sig.Set(context.Canceled) - if s.fin != nil { - s.fin <- struct{}{} - } if s.task != nil { s.task.End() } @@ -304,10 +297,10 @@ func (s *Stream) checkFinished() { } } -// checkCancelError will replace the error with one from the cancel signal if it +// CheckCancelError will replace the error with one from the cancel signal if it // is set. This is to prevent errors from reads/writes to a transport after it // has been asynchronously closed due to context cancelation. -func (s *Stream) checkCancelError(err error) error { +func (s *Stream) CheckCancelError(err error) error { if s.sigs.cancel.IsSet() { return s.sigs.cancel.Err() } @@ -321,9 +314,9 @@ func (s *Stream) newFrameLocked(kind drpcwire.Kind) drpcwire.Frame { return drpcwire.Frame{ID: s.id, Kind: kind} } -// sendPacketLocked sends the packet in a single write and flushes. It does not -// check for any conditions to stop it from writing and is meant for internal -// stream use to do things like signal errors or closes to the remote side. +// sendPacketLocked sends the packet in a single write. It does not check for +// any conditions to stop it from writing and is meant for internal stream use +// to do things like signal errors or closes to the remote side. func (s *Stream) sendPacketLocked(kind drpcwire.Kind, control bool, data []byte) (err error) { fr := s.newFrameLocked(kind) fr.Data = data @@ -336,9 +329,6 @@ func (s *Stream) sendPacketLocked(kind drpcwire.Kind, control bool, data []byte) if err := s.wr.WriteFrame(fr); err != nil { return errs.Wrap(err) } - if err := s.wr.Flush(); err != nil { - return errs.Wrap(err) - } return nil } @@ -356,10 +346,26 @@ func (s *Stream) terminate(err error) { s.sigs.send.Set(err) s.sigs.recv.Set(err) s.sigs.term.Set(err) - s.pbuf.Close(err) + s.recvQueue.Close(err) s.checkFinished() } +// WriteInvoke writes the invoke metadata (if any) and invoke frame +// atomically under the write lock. This prevents SendCancel from +// interrupting the invoke sequence. +func (s *Stream) WriteInvoke(rpc string, metadata []byte) error { + defer s.checkFinished() + s.write.Lock() + defer s.write.Unlock() + + if len(metadata) > 0 { + if err := s.rawWriteLocked(drpcwire.KindInvokeMetadata, metadata); err != nil { + return err + } + } + return s.rawWriteLocked(drpcwire.KindInvoke, []byte(rpc)) +} + // // raw read/write // @@ -375,6 +381,7 @@ func (s *Stream) RawWrite(kind drpcwire.Kind, data []byte) (err error) { // rawWriteLocked does the body of RawWrite assuming the caller is holding the // appropriate locks. +// TODO(shubham): can we merge this with sendPacketLocked? func (s *Stream) rawWriteLocked(kind drpcwire.Kind, data []byte) (err error) { fr := s.newFrameLocked(kind) n := s.opts.SplitSize @@ -394,75 +401,25 @@ func (s *Stream) rawWriteLocked(kind drpcwire.Kind, data []byte) (err error) { s.log("SEND", fr.String) if err := s.wr.WriteFrame(fr); err != nil { - return s.checkCancelError(errs.Wrap(err)) + return s.CheckCancelError(errs.Wrap(err)) } else if fr.Done { return nil } } } -// RawFlush flushes any buffers of data. -func (s *Stream) RawFlush() (err error) { - defer s.checkFinished() - s.write.Lock() - defer s.write.Unlock() - - return s.rawFlushLocked() -} - -// rawFlushLocked checks for any conditions that should cause a flush to not -// happen and then issues the flush. It assumes the caller is holding the -// appropriate locks. -func (s *Stream) rawFlushLocked() (err error) { - if s.wr.Empty() { - return nil - } - - switch { - case s.sigs.cancel.IsSet(): - return s.sigs.cancel.Err() - case s.sigs.send.IsSet(): - return s.sigs.send.Err() - case s.sigs.term.IsSet(): - return s.sigs.term.Err() - } - - s.log("FLUSH", func() string { return "" }) - - return s.checkCancelError(errs.Wrap(s.wr.Flush())) -} - -func (s *Stream) checkRecvFlush() (err error) { - s.flush.Do(func() { err = s.RawFlush() }) - if err != nil { - return err - } - - if s.opts.ManualFlush && !s.wr.Empty() { - if err := s.RawFlush(); err != nil { - return err - } - } - - return nil -} - // RawRecv returns the raw bytes received for a message. func (s *Stream) RawRecv() (data []byte, err error) { - if err := s.checkRecvFlush(); err != nil { - return nil, err - } - defer s.checkFinished() s.read.Lock() defer s.read.Unlock() - data, err = s.pbuf.Get() + data, err = s.recvQueue.Dequeue() if err != nil { return nil, err } data = append([]byte(nil), data...) - s.pbuf.Done() + s.recvQueue.Done() return data, nil } @@ -471,12 +428,9 @@ func (s *Stream) RawRecv() (data []byte, err error) { // msg read/write // -// MsgSend marshals the message with the encoding, writes it, and flushes. +// MsgSend marshals the message with the encoding and writes it. func (s *Stream) MsgSend(msg drpc.Message, enc drpc.Encoding) (err error) { defer func() { err = drpc.ToRPCErr(err) }() - - s.flush.Do(func() {}) - defer s.checkFinished() s.write.Lock() defer s.write.Unlock() @@ -491,9 +445,6 @@ func (s *Stream) MsgSend(msg drpc.Message, enc drpc.Encoding) (err error) { if err := s.rawWriteLocked(drpcwire.KindMessage, wbuf); err != nil { return err } - if !s.opts.ManualFlush { - return s.rawFlushLocked() - } return nil } @@ -501,20 +452,16 @@ func (s *Stream) MsgSend(msg drpc.Message, enc drpc.Encoding) (err error) { func (s *Stream) MsgRecv(msg drpc.Message, enc drpc.Encoding) (err error) { defer func() { err = drpc.ToRPCErr(err) }() - if err := s.checkRecvFlush(); err != nil { - return err - } - defer s.checkFinished() s.read.Lock() defer s.read.Unlock() - data, err := s.pbuf.Get() + data, err := s.recvQueue.Dequeue() if err != nil { return err } err = enc.Unmarshal(data, msg) - s.pbuf.Done() + s.recvQueue.Done() return err } @@ -549,25 +496,19 @@ func (s *Stream) SendError(serr error) (err error) { s.terminate(termError) s.mu.Unlock() - return s.checkCancelError(s.sendPacketLocked(drpcwire.KindError, false, drpcwire.MarshalError(serr))) + return s.CheckCancelError(s.sendPacketLocked(drpcwire.KindError, false, drpcwire.MarshalError(serr))) } -// SendCancel transitions the stream into the canceled state with -// context.Canceled and sends a cancel error to the remote side for a soft -// cancel. It is a no-op if the stream is already terminated. It returns true -// for busy if writes are already blocked and a hard cancel is required. -func (s *Stream) SendCancel(err error) (busy bool, _ error) { +// SendCancel terminates the stream and sends a cancel to the remote side. It +// blocks until any in-progress write completes. It is a no-op if the stream is +// already terminated. +func (s *Stream) SendCancel(err error) error { s.log("CALL", func() string { return "SendCancel()" }) s.mu.Lock() - if !s.write.Unlocked() { // if writes are happening, then we have to do a hard cancel. - s.mu.Unlock() - return true, nil - } - if s.sigs.term.IsSet() { s.mu.Unlock() - return false, nil + return nil } defer s.checkFinished() @@ -578,7 +519,7 @@ func (s *Stream) SendCancel(err error) (busy bool, _ error) { s.terminate(err) s.mu.Unlock() - return false, s.checkCancelError(s.sendPacketLocked(drpcwire.KindCancel, true, nil)) + return s.CheckCancelError(s.sendPacketLocked(drpcwire.KindCancel, true, nil)) } // Close terminates the stream and sends that the stream has been closed to the @@ -599,7 +540,7 @@ func (s *Stream) Close() (err error) { s.terminate(termClosed) s.mu.Unlock() - return s.checkCancelError(s.sendPacketLocked(drpcwire.KindClose, false, nil)) + return s.CheckCancelError(s.sendPacketLocked(drpcwire.KindClose, false, nil)) } // CloseSend informs the remote that no more messages will be sent. If the remote has @@ -622,7 +563,7 @@ func (s *Stream) CloseSend() (err error) { s.terminateIfBothClosed() s.mu.Unlock() - return s.checkCancelError(s.sendPacketLocked(drpcwire.KindCloseSend, false, nil)) + return s.CheckCancelError(s.sendPacketLocked(drpcwire.KindCloseSend, false, nil)) } // Cancel transitions the stream into a state where all writes to the transport will return diff --git a/drpcstream/stream_test.go b/drpcstream/stream_test.go index 3cf4ca38..9cee8ddf 100644 --- a/drpcstream/stream_test.go +++ b/drpcstream/stream_test.go @@ -4,10 +4,10 @@ package drpcstream import ( - "bytes" "context" "errors" "io" + "strings" "testing" "github.com/zeebo/assert" @@ -18,16 +18,33 @@ import ( "storj.io/drpc/drpcwire" ) +// testMuxWriter creates a MuxWriter that writes to io.Discard with a no-op +// error handler. The writer goroutine is stopped when the test finishes. +func testMuxWriter(t *testing.T) *drpcwire.MuxWriter { + t.Helper() + mw := drpcwire.NewMuxWriter(io.Discard, func(error) {}) + t.Cleanup(func() { mw.Stop(nil); <-mw.Done() }) + return mw +} + +// handleFrame is a helper that sends a single-frame packet to the stream. +// It constructs a frame with the given kind, matching the stream's ID, +// using the provided message ID, done=true. +func handleFrame(st *Stream, kind drpcwire.Kind, mid uint64) error { + return st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: st.ID(), Message: mid}, + Kind: kind, + Done: true, + }) +} + func TestStream_StateTransitions(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() + mw := testMuxWriter(t) any := errors.New("any sentinel error") - handlePacket := func(st *Stream, kind drpcwire.Kind) error { - return st.HandlePacket(drpcwire.Packet{Kind: kind}) - } - checkErrs := func(t *testing.T, exp interface{}, got error) { t.Helper() @@ -81,32 +98,32 @@ func TestStream_StateTransitions(t *testing.T) { }, { // recv close - Op: func(st *Stream) error { return handlePacket(st, drpcwire.KindClose) }, + Op: func(st *Stream) error { return handleFrame(st, drpcwire.KindClose, 1) }, Send: &drpc.ClosedError, Recv: io.EOF, }, { // recv error - Op: func(st *Stream) error { return handlePacket(st, drpcwire.KindError) }, + Op: func(st *Stream) error { return handleFrame(st, drpcwire.KindError, 1) }, Send: io.EOF, Recv: any, }, { // recv closesend - Op: func(st *Stream) error { return handlePacket(st, drpcwire.KindCloseSend) }, + Op: func(st *Stream) error { return handleFrame(st, drpcwire.KindCloseSend, 1) }, Send: nil, Recv: io.EOF, }, } for _, test := range cases { - st := New(ctx, 0, drpcwire.NewWriter(io.Discard, 0)) + st := New(ctx, 1, mw) assert.NoError(t, test.Op(st)) checkErrs(t, test.Send, st.RawWrite(drpcwire.KindMessage, nil)) if test.Recv == nil { - ctx.Run(func(ctx context.Context) { _ = handlePacket(st, drpcwire.KindMessage) }) + ctx.Run(func(ctx context.Context) { _ = handleFrame(st, drpcwire.KindMessage, 2) }) } _, err := st.RawRecv() checkErrs(t, test.Recv, err) @@ -117,9 +134,7 @@ func TestStream_Unblocks(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() - handlePacket := func(st *Stream, kind drpcwire.Kind) error { - return st.HandlePacket(drpcwire.Packet{Kind: kind}) - } + mw := testMuxWriter(t) cases := []struct { Op func(st *Stream) error @@ -141,20 +156,20 @@ func TestStream_Unblocks(t *testing.T) { }, { // recv close - Op: func(st *Stream) error { return handlePacket(st, drpcwire.KindClose) }, + Op: func(st *Stream) error { return handleFrame(st, drpcwire.KindClose, 1) }, }, { // recv error - Op: func(st *Stream) error { return handlePacket(st, drpcwire.KindError) }, + Op: func(st *Stream) error { return handleFrame(st, drpcwire.KindError, 1) }, }, { // recv closesend - Op: func(st *Stream) error { return handlePacket(st, drpcwire.KindCloseSend) }, + Op: func(st *Stream) error { return handleFrame(st, drpcwire.KindCloseSend, 1) }, }, } for _, test := range cases { - st := New(ctx, 0, drpcwire.NewWriter(io.Discard, 0)) + st := New(ctx, 1, mw) ctx.Run(func(ctx context.Context) { _, _ = st.RawRecv() }) assert.NoError(t, test.Op(st)) @@ -164,7 +179,8 @@ func TestStream_Unblocks(t *testing.T) { func TestStream_ContextCancel(t *testing.T) { ctx := context.Background() - st := New(ctx, 0, drpcwire.NewWriter(io.Discard, 0)) + mw := testMuxWriter(t) + st := New(ctx, 0, mw) child, cancel := context.WithCancel(st.Context()) defer cancel() @@ -178,108 +194,47 @@ func TestStream_ConcurrentCloseCancel(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() - pr, pw := io.Pipe() - defer func() { _ = pr.Close() }() - defer func() { _ = pw.Close() }() + mw := testMuxWriter(t) + st := New(ctx, 0, mw) - st := New(ctx, 0, drpcwire.NewWriter(pw, 0)) - - // start the Close call + // Close and Cancel concurrently should not panic or deadlock. errch := make(chan error, 1) go func() { errch <- st.Close() }() - // wait for the close to begin writing - _, err := pr.Read(make([]byte, 1)) - assert.NoError(t, err) - - // cancel the context and close the transport st.Cancel(context.Canceled) - assert.NoError(t, pw.Close()) - - // we should always receive the canceled error - assert.That(t, errors.Is(<-errch, context.Canceled)) -} - -func TestStream_Control(t *testing.T) { - st := New(context.Background(), 0, drpcwire.NewWriter(io.Discard, 0)) - - // N.B. the stream will return nil on any HandlePacket calls after the - // stream has been terminated for any reason, including if an invalid - // packet has been sent. the order of these two assertions is important! - // an invalid packet is not an error if the control bit is set - assert.NoError(t, st.HandlePacket(drpcwire.Packet{Control: true})) - - // an invalid packet is an error if the control bit it not set - assert.That(t, drpc.InternalError.Has(st.HandlePacket(drpcwire.Packet{}))) -} - -func TestStream_CorkUntilFirstRead(t *testing.T) { - run := func() { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - var buf bytes.Buffer - st := New(ctx, 0, drpcwire.NewWriter(&buf, 50)) - - // concurrently read and write at the same time. - // we should always see the write happen. - - errch := make(chan error, 3) - ctx.Run(func(ctx context.Context) { - errch <- st.MsgSend([]byte("write"), byteEncoding{}) - }) - ctx.Run(func(ctx context.Context) { - _, err := st.RawRecv() - errch <- err - }) - ctx.Run(func(ctx context.Context) { - errch <- st.HandlePacket(drpcwire.Packet{ - Data: []byte("read"), - ID: drpcwire.ID{Message: 1}, - Kind: drpcwire.KindMessage, - }) - }) - - assert.NoError(t, <-errch) - assert.NoError(t, <-errch) - assert.NoError(t, <-errch) - - assert.Equal(t, buf.String(), "\x05\x00\x01\x05write") - } - for i := 0; i < 100; i++ { - run() + // Close returns nil or context.Canceled depending on timing. + err := <-errch + if err != nil { + assert.That(t, errors.Is(err, context.Canceled)) } } -type byteEncoding struct{} - -func (byteEncoding) Marshal(msg drpc.Message) ([]byte, error) { return msg.([]byte), nil } -func (byteEncoding) Unmarshal(buf []byte, msg drpc.Message) error { - *msg.(*[]byte) = append(*msg.(*[]byte), buf...) - return nil -} - func TestStream_PacketBufferReuse(t *testing.T) { run := func() { ctx := drpctest.NewTracker(t) defer ctx.Close() defer ctx.Wait() - buf := make([]byte, 20) - st := New(ctx, 0, drpcwire.NewWriter(io.Discard, 0)) + mw := testMuxWriter(t) + data := make([]byte, 20) + mid := uint64(1) + st := New(ctx, 1, mw) ctx.Run(func(ctx context.Context) { for !st.IsTerminated() { - err := st.HandlePacket(drpcwire.Packet{ - Data: buf, + err := st.HandleFrame(drpcwire.Frame{ + Data: data, + ID: drpcwire.ID{Stream: 1, Message: mid}, Kind: drpcwire.KindMessage, + Done: true, }) if err != nil { return } - for i := range buf { - buf[i]++ + mid++ + for i := range data { + data[i]++ } } }) @@ -303,27 +258,81 @@ func TestStream_PacketBufferReuse(t *testing.T) { } } -func TestStream_SendCancelBusyDuringBlockedClose(t *testing.T) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() +// +// HandleFrame tests +// + +func TestHandleFrame_FirstFrameOnFreshStream(t *testing.T) { + mw := testMuxWriter(t) + for _, messageID := range []uint64{1, 2} { + st := New(context.Background(), 1, mw) + // Close the ring buffer so KindMessage Enqueue doesn't block. + st.recvQueue.Close(io.EOF) + err := st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: messageID}, Kind: drpcwire.KindMessage, Done: true, + }) + assert.NoError(t, err) + } +} - pr, pw := io.Pipe() - defer func() { _ = pr.Close() }() - defer func() { _ = pw.Close() }() +// Invoke and InvokeMetadata frames are rejected on an already-created stream. +func TestHandleFrame_InvokeOnExistingStream(t *testing.T) { + mw := testMuxWriter(t) + st := New(context.Background(), 1, mw) - st := New(ctx, 0, drpcwire.NewWriter(pw, 0)) + err := handleFrame(st, drpcwire.KindInvoke, 1) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "invoke on existing stream")) +} - // launch a goroutine to close the stream - ctx.Run(func(ctx context.Context) { _ = st.Close() }) +func TestHandleFrame_InvokeMetadataOnExistingStream(t *testing.T) { + mw := testMuxWriter(t) + st := New(context.Background(), 1, mw) - // read just 1 byte from the pipe to ensure that the Close has started - _, err := pr.Read(make([]byte, 1)) - assert.NoError(t, err) - assert.That(t, st.IsTerminated()) + err := handleFrame(st, drpcwire.KindInvokeMetadata, 1) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "invoke on existing stream")) +} + +// Frames arriving after the stream is terminated are silently ignored. +func TestHandleFrame_AfterTerminated(t *testing.T) { + mw := testMuxWriter(t) + st := New(context.Background(), 1, mw) - // even though the stream is terminated, soft cancel should report that - // the stream is still busy because the close is being sent. - busy, err := st.SendCancel(context.Canceled) + // Terminate the stream via cancel. + st.Cancel(context.Canceled) + + // Frames after termination are silently ignored. + err := st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Done: true, + }) assert.NoError(t, err) - assert.That(t, busy) +} + +// A completed KindMessage frame delivers its data through RawRecv. +func TestHandleFrame_MessageDeliveredViaRecv(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + mw := testMuxWriter(t) + st := New(ctx, 1, mw) + + // Launch receiver before sending to avoid Put blocking. + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + data, err := st.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, + Kind: drpcwire.KindMessage, + Data: []byte("payload"), + Done: true, + })) + + assert.DeepEqual(t, <-recv, []byte("payload")) } diff --git a/drpcwire/mux_writer.go b/drpcwire/mux_writer.go new file mode 100644 index 00000000..972ed320 --- /dev/null +++ b/drpcwire/mux_writer.go @@ -0,0 +1,94 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcwire + +import ( + "io" + "sync" +) + +type MuxWriter struct { + w io.Writer + buf []byte + mu sync.Mutex + cond *sync.Cond + closed bool + closeErr error + onError func(error) + done chan struct{} +} + +var defaultBufferCapacity = 4096 + +func NewMuxWriter(w io.Writer, onError func(error)) *MuxWriter { + mw := &MuxWriter{ + w: w, + buf: make([]byte, 0, defaultBufferCapacity), + onError: onError, + done: make(chan struct{}), + } + mw.cond = sync.NewCond(&mw.mu) + go mw.run() + return mw +} + +func (mw *MuxWriter) run() { + defer close(mw.done) + spare := make([]byte, 0, defaultBufferCapacity) + for { + mw.mu.Lock() + for len(mw.buf) == 0 && !mw.closed { + mw.cond.Wait() + } + + if mw.closed { + mw.mu.Unlock() + return + } + + mw.buf, spare = spare, mw.buf + mw.mu.Unlock() + if _, err := mw.w.Write(spare); err != nil { + mw.mu.Lock() + if mw.closed { + mw.mu.Unlock() + return + } + mw.closed = true + mw.closeErr = err + mw.mu.Unlock() + if mw.onError != nil { + mw.onError(err) + } + return + } + + spare = spare[:0] + } +} + +func (mw *MuxWriter) WriteFrame(fr Frame) (err error) { + mw.mu.Lock() + defer mw.mu.Unlock() + if mw.closed { + return mw.closeErr + } + mw.buf = AppendFrame(mw.buf, fr) + mw.cond.Signal() + return nil +} + +func (mw *MuxWriter) Stop(err error) { + mw.mu.Lock() + if !mw.closed { + mw.closed = true + mw.closeErr = err + mw.cond.Broadcast() + } + mw.mu.Unlock() +} + +func (mw *MuxWriter) Done() <-chan struct{} { + return mw.done +} diff --git a/drpcwire/mux_writer_test.go b/drpcwire/mux_writer_test.go new file mode 100644 index 00000000..4ee75c51 --- /dev/null +++ b/drpcwire/mux_writer_test.go @@ -0,0 +1,330 @@ +// Copyright (C) 2021 Storj Labs, Inc. +// See LICENSE for copying information. + +package drpcwire + +import ( + "bytes" + "errors" + "io" + "sync" + "testing" + "time" + + "github.com/zeebo/assert" +) + +// blockingWriter blocks in Write until unblock is closed, then returns err. +type blockingWriter struct { + unblock chan struct{} + err error // error to return once unblocked + wrote chan []byte // sends a copy of data on each Write entry +} + +func newBlockingWriter() *blockingWriter { + return &blockingWriter{ + unblock: make(chan struct{}), + wrote: make(chan []byte, 10), + } +} + +func (w *blockingWriter) Write(p []byte) (int, error) { + cp := make([]byte, len(p)) + copy(cp, p) + w.wrote <- cp + <-w.unblock + if w.err != nil { + return 0, w.err + } + return len(p), nil +} + +// failWriter returns err on the nth call to Write (1-indexed). Calls before +// that succeed normally. +type failWriter struct { + n int + count int + err error + buf bytes.Buffer +} + +func newFailWriter(n int, err error) *failWriter { + return &failWriter{n: n, err: err} +} + +func (w *failWriter) Write(p []byte) (int, error) { + w.count++ + if w.count >= w.n { + return 0, w.err + } + return w.buf.Write(p) +} + +func TestMuxWriter(t *testing.T) { + var exp []byte + pr, pw := io.Pipe() + mw := NewMuxWriter(pw, func(error) {}) + + for range 1000 { + fr := RandFrame() + exp = AppendFrame(exp, fr) + assert.NoError(t, mw.WriteFrame(fr)) + } + + // Read exactly len(exp) bytes: this blocks until MuxWriter has drained + // all frames through the pipe. + got := make([]byte, len(exp)) + _, err := io.ReadFull(pr, got) + assert.NoError(t, err) + + // Now stop the writer and close the pipe. + mw.Stop(errors.New("stopped")) + <-mw.Done() + pw.Close() + pr.Close() + + assert.That(t, bytes.Equal(exp, got)) +} + +func TestMuxWriter_WriteFrameAfterStop(t *testing.T) { + mw := NewMuxWriter(io.Discard, func(error) {}) + mw.Stop(errors.New("stopped")) + <-mw.Done() + + err := mw.WriteFrame(RandFrame()) + assert.Error(t, err) + assert.Equal(t, err.Error(), "stopped") +} + +func TestMuxWriter_ConcurrentWriteFrame(t *testing.T) { + pr, pw := io.Pipe() + mw := NewMuxWriter(pw, func(error) {}) + + const numWriters = 10 + const framesPerWriter = 100 + + // Pre-generate frames and compute total expected bytes so we can use + // io.ReadFull to block until everything has drained (Stop has abort + // semantics, so we can't rely on it to drain). + allFrames := make([][]Frame, numWriters) + var expSize int + for i := range numWriters { + allFrames[i] = make([]Frame, framesPerWriter) + for j := range framesPerWriter { + fr := Frame{ + Data: []byte{byte(j)}, + ID: ID{Stream: uint64(i + 1), Message: uint64(j + 1)}, + Kind: KindMessage, + Done: true, + } + allFrames[i][j] = fr + expSize += len(AppendFrame(nil, fr)) + } + } + + var wg sync.WaitGroup + for i := range numWriters { + wg.Add(1) + go func() { + defer wg.Done() + for j := range framesPerWriter { + assert.NoError(t, mw.WriteFrame(allFrames[i][j])) + } + }() + } + + wg.Wait() + + // Block until all bytes have been drained through the pipe. + got := make([]byte, expSize) + _, err := io.ReadFull(pr, got) + assert.NoError(t, err) + mw.Stop(errors.New("stopped")) + <-mw.Done() + pw.Close() + pr.Close() + + // Parse received bytes and count frames. + count := 0 + for len(got) > 0 { + rem, _, ok, err := ParseFrame(got) + assert.NoError(t, err) + assert.That(t, ok) + got = rem + count++ + } + assert.Equal(t, count, numWriters*framesPerWriter) +} + +func TestMuxWriter_WriteErrorCallsOnError(t *testing.T) { + writeErr := errors.New("disk full") + fw := newFailWriter(1, writeErr) + + gotErr := make(chan error, 1) + mw := NewMuxWriter(fw, func(err error) { gotErr <- err }) + + assert.NoError(t, mw.WriteFrame(RandFrame())) + + select { + case err := <-gotErr: + assert.Equal(t, err, writeErr) + case <-time.After(5 * time.Second): + t.Fatal("onError not called") + } + + select { + case <-mw.Done(): + case <-time.After(5 * time.Second): + t.Fatal("Done did not return") + } +} + +// Tests the critical deadlock path from the design doc: +// run() → Write fails → sets closed → onError → Stop() → noop → run() returns. +func TestMuxWriter_OnErrorCallingStopDoesNotDeadlock(t *testing.T) { + writeErr := errors.New("broken pipe") + fw := newFailWriter(1, writeErr) + + var mw *MuxWriter + mw = NewMuxWriter(fw, func(err error) { + // Simulate manager.terminate calling Stop. + mw.Stop(errors.New("stopped")) + }) + + assert.NoError(t, mw.WriteFrame(RandFrame())) + + select { + case <-mw.Done(): + case <-time.After(5 * time.Second): + t.Fatal("deadlock: Done did not return") + } +} + +// Tests the manager's two-phase shutdown: close transport to unblock a blocked +// Write, then Stop signals the goroutine to exit. +func TestMuxWriter_BlockedWriteUnblockedByClose(t *testing.T) { + bw := newBlockingWriter() + mw := NewMuxWriter(bw, func(error) {}) + + assert.NoError(t, mw.WriteFrame(RandFrame())) + + // Wait for run() to enter Write. + select { + case <-bw.wrote: + case <-time.After(5 * time.Second): + t.Fatal("run() did not enter Write") + } + + // Simulate terminate: Stop, then unblock the writer (like tr.Close()). + mw.Stop(errors.New("stopped")) + bw.err = errors.New("closed") + close(bw.unblock) + + select { + case <-mw.Done(): + case <-time.After(5 * time.Second): + t.Fatal("deadlock: Done did not return") + } +} + +func TestMuxWriter_ConcurrentStop(t *testing.T) { + mw := NewMuxWriter(io.Discard, func(error) {}) + + // Write a frame so the goroutine has work. + assert.NoError(t, mw.WriteFrame(RandFrame())) + + const n = 20 + var wg sync.WaitGroup + wg.Add(n) + for range n { + go func() { + defer wg.Done() + mw.Stop(errors.New("stopped")) + }() + } + wg.Wait() + + select { + case <-mw.Done(): + case <-time.After(5 * time.Second): + t.Fatal("Done did not return") + } +} + +// Stop has abort semantics: buffered data is discarded, not drained. +func TestMuxWriter_StopDiscardsBufferedData(t *testing.T) { + bw := newBlockingWriter() + mw := NewMuxWriter(bw, func(error) {}) + + // Write several frames while the writer is blocked on the first Write. + for range 10 { + assert.NoError(t, mw.WriteFrame(RandFrame())) + } + + // Wait for run() to enter Write with the first batch. + select { + case <-bw.wrote: + case <-time.After(5 * time.Second): + t.Fatal("run() did not enter Write") + } + + // More frames accumulate in buf while Write is blocked. + for range 10 { + assert.NoError(t, mw.WriteFrame(RandFrame())) + } + + // Stop without letting the blocked Write complete. + mw.Stop(errors.New("stopped")) + bw.err = errors.New("closed") + close(bw.unblock) + + select { + case <-mw.Done(): + case <-time.After(5 * time.Second): + t.Fatal("Done did not return") + } + + // Only the first batch was written; the rest were discarded by Stop. + assert.Equal(t, len(bw.wrote), 0) // no more writes after the first +} + +func TestMuxWriter_WriteFrameDuringActiveDrain(t *testing.T) { + // gatedWriter lets us control when each Write completes. + type gate struct{ ch chan struct{} } + gates := make(chan gate, 10) + + gw := writerFunc(func(p []byte) (int, error) { + g := gate{ch: make(chan struct{})} + gates <- g + <-g.ch + return len(p), nil + }) + + mw := NewMuxWriter(gw, func(error) {}) + + // Batch 1: write a frame, wait for run() to pick it up and block in Write. + fr1 := Frame{Data: []byte("batch1"), ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Done: true} + assert.NoError(t, mw.WriteFrame(fr1)) + + g1 := <-gates // run() is now blocked in Write for batch 1 + + // Batch 2: write another frame while batch 1 is still draining. + fr2 := Frame{Data: []byte("batch2"), ID: ID{Stream: 1, Message: 2}, Kind: KindMessage, Done: true} + assert.NoError(t, mw.WriteFrame(fr2)) + + // Complete batch 1 write. + close(g1.ch) + + // run() loops, picks up batch 2, enters Write again. + g2 := <-gates + close(g2.ch) + + // Both batches were written. Stop and verify. + mw.Stop(errors.New("stopped")) + <-mw.Done() +} + +// writerFunc adapts a function to io.Writer. +type writerFunc func([]byte) (int, error) + +func (f writerFunc) Write(p []byte) (int, error) { return f(p) } diff --git a/drpcwire/packet_assembler.go b/drpcwire/packet_assembler.go new file mode 100644 index 00000000..2cf7fae3 --- /dev/null +++ b/drpcwire/packet_assembler.go @@ -0,0 +1,89 @@ +package drpcwire + +import ( + "storj.io/drpc" +) + +// PacketAssembler assembles frames into complete packets, enforcing wire +// protocol invariants: +// - All frames must belong to the same stream ID (set explicitly via +// SetStreamID, or inferred from the first frame). +// - Message IDs must be monotonically increasing. +// - Frame kind must not change within a single packet (multi-frame). +// +// It is not safe for concurrent use. +type PacketAssembler struct { + pk Packet + assembling bool + streamInitialized bool +} + +// NewPacketAssembler returns a new PacketAssembler ready to assemble frames. +func NewPacketAssembler() PacketAssembler { + return PacketAssembler{ + pk: Packet{ + ID: ID{Stream: 0, Message: 1}, + }, + } +} + +// SetStreamID sets the expected stream ID. Frames for a different stream will +// be rejected. If not called, the stream ID is inferred from the first frame. +func (pa *PacketAssembler) SetStreamID(streamID uint64) { + pa.pk.ID.Stream = streamID + pa.streamInitialized = true +} + +// Reset clears all assembly state, preparing the assembler for a new stream. +func (pa *PacketAssembler) Reset() { + pa.pk = Packet{ + ID: ID{Stream: 0, Message: 1}, + } + pa.assembling = false + pa.streamInitialized = false +} + +// AppendFrame adds a frame to the in-progress packet. It returns the completed +// packet and true when a frame with Done=true is received. It returns false +// when more frames are needed to complete the packet. +func (pa *PacketAssembler) AppendFrame(fr Frame) (packet Packet, packetReady bool, err error) { + // Enforce stream ID consistency: infer from first frame or reject mismatches. + if !pa.streamInitialized { + pa.pk.ID.Stream = fr.ID.Stream + pa.streamInitialized = true + } else if fr.ID.Stream != pa.pk.ID.Stream { + return Packet{}, false, drpc.ProtocolError.New( + "frame stream mismatch: got stream %d, expected %d", fr.ID.Stream, pa.pk.ID.Stream) + } + + if fr.ID.Message < pa.pk.ID.Message { + return Packet{}, false, drpc.ProtocolError.New( + "message id monotonicity violation: got %v, expected >= %v", fr.ID.Message, pa.pk.ID.Message) + } else if fr.ID.Message > pa.pk.ID.Message || !pa.assembling { + // New message: reset the buffer and start assembling. + pa.pk.Data = pa.pk.Data[:0] + pa.assembling = true + pa.pk.ID.Message = fr.ID.Message + } else if fr.Kind != pa.pk.Kind { + return Packet{}, false, drpc.ProtocolError.New( + "frame kind changed mid-packet: got %v, expected %v", fr.Kind, pa.pk.Kind) + } + + // TODO(shubham): add buf reuse + pa.pk.Data = append(pa.pk.Data, fr.Data...) + pa.pk.Kind = fr.Kind + pa.pk.Control = fr.Control + + if !fr.Done { + return Packet{}, false, nil + } + + packet = pa.pk + + pa.assembling = false + pa.pk.ID.Message = fr.ID.Message + 1 + // Reuse the backing array: the caller must consume packet.Data before the + // next AppendFrame call, as it will be overwritten. + pa.pk.Data = pa.pk.Data[:0] + return packet, true, nil +} diff --git a/drpcwire/packet_assembler_test.go b/drpcwire/packet_assembler_test.go new file mode 100644 index 00000000..41cf70da --- /dev/null +++ b/drpcwire/packet_assembler_test.go @@ -0,0 +1,235 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcwire + +import ( + "strings" + "testing" + + "github.com/zeebo/assert" + + "storj.io/drpc" +) + +func TestPacketAssembler_WrongStreamID(t *testing.T) { + pa := NewPacketAssembler() + pa.SetStreamID(1) + + _, _, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 2, Message: 1}, + Kind: KindMessage, + Done: true, + }) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "frame stream mismatch")) +} + +func TestPacketAssembler_StreamIDInferredFromFirstFrame(t *testing.T) { + pa := NewPacketAssembler() + + // First frame sets the stream ID implicitly. + _, _, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 5, Message: 1}, + Kind: KindMessage, + Done: true, + }) + assert.NoError(t, err) + + // Second frame for a different stream is rejected. + _, _, err = pa.AppendFrame(Frame{ + ID: ID{Stream: 6, Message: 2}, + Kind: KindMessage, + Done: true, + }) + assert.Error(t, err) + assert.That(t, strings.Contains(err.Error(), "frame stream mismatch")) +} + +// A frame with a message ID lower than a previously completed message is rejected. +func TestPacketAssembler_MessageMonotonicity(t *testing.T) { + pa := NewPacketAssembler() + pa.SetStreamID(1) + + // m3 completes, next expected becomes 4. + _, _, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 3}, Kind: KindMessage, Done: true, + }) + assert.NoError(t, err) + + // m2 < 4 → error. + _, _, err = pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 2}, Kind: KindMessage, Done: true, + }) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "monotonicity")) +} + +// When a higher message ID arrives mid-assembly, the in-progress data is +// silently discarded and a new packet begins. +func TestPacketAssembler_HigherMsgDiscardsInProgress(t *testing.T) { + pa := NewPacketAssembler() + pa.SetStreamID(1) + + // Start accumulating m1. + _, ready, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Data: []byte("discard"), Done: false, + }) + assert.NoError(t, err) + assert.That(t, !ready) + + // m2 arrives, m1 data should be silently discarded. + pkt, ready, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 2}, Kind: KindMessage, Data: []byte("kept"), Done: true, + }) + assert.NoError(t, err) + assert.That(t, ready) + assert.DeepEqual(t, pkt.Data, []byte("kept")) +} + +// Continuation frames (same message ID, mid-assembly) must carry the same +// kind as the first frame. A kind change mid-packet is a protocol error. +func TestPacketAssembler_KindChangeWithinPacket(t *testing.T) { + pa := NewPacketAssembler() + pa.SetStreamID(1) + + _, _, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Done: false, + }) + assert.NoError(t, err) + + _, _, err = pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindError, Done: true, + }) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "kind change")) +} + +// Multiple continuation frames for the same message accumulate data correctly. +func TestPacketAssembler_MultiFrameDataAccumulation(t *testing.T) { + pa := NewPacketAssembler() + pa.SetStreamID(1) + + _, ready, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Data: []byte("hel"), Done: false, + }) + assert.NoError(t, err) + assert.That(t, !ready) + + _, ready, err = pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Data: []byte("lo "), Done: false, + }) + assert.NoError(t, err) + assert.That(t, !ready) + + pkt, ready, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Data: []byte("world"), Done: true, + }) + assert.NoError(t, err) + assert.That(t, ready) + assert.DeepEqual(t, pkt.Data, []byte("hello world")) +} + +// Multi-frame assembly works when the message ID is greater than the initial +// expected ID (e.g., on the server side where invoke consumed earlier message +// IDs). Continuation frames must accumulate data, not reset on each frame. +func TestPacketAssembler_MultiFrameWithSkippedMessageID(t *testing.T) { + pa := NewPacketAssembler() + pa.SetStreamID(1) + + // msg=3 is greater than initial expected message ID=1. + _, ready, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 3}, Kind: KindMessage, Data: []byte("hel"), Done: false, + }) + assert.NoError(t, err) + assert.That(t, !ready) + + _, ready, err = pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 3}, Kind: KindMessage, Data: []byte("lo"), Done: false, + }) + assert.NoError(t, err) + assert.That(t, !ready) + + pkt, ready, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 3}, Kind: KindMessage, Data: []byte(" world"), Done: true, + }) + assert.NoError(t, err) + assert.That(t, ready) + assert.DeepEqual(t, pkt.Data, []byte("hello world")) +} + +// Once a message completes (done=true), the same message ID is rejected. +func TestPacketAssembler_DonePreventsReplay(t *testing.T) { + pa := NewPacketAssembler() + pa.SetStreamID(1) + + // m1 completes → next expected becomes 2. + _, _, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Done: true, + }) + assert.NoError(t, err) + + // Same message ID again → error. + _, _, err = pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Done: true, + }) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "monotonicity")) +} + +// Kind consistency is only enforced within a packet (continuation frames), not +// across messages. A KindMessage followed by a KindClose for the next message +// should be accepted without error. +func TestPacketAssembler_KindChangeAcrossMessages(t *testing.T) { + pa := NewPacketAssembler() + pa.SetStreamID(1) + + // Multi-frame message 1 with KindMessage. + _, _, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Data: []byte("ab"), Done: false, + }) + assert.NoError(t, err) + + pkt, ready, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Data: []byte("cd"), Done: true, + }) + assert.NoError(t, err) + assert.That(t, ready) + assert.DeepEqual(t, pkt.Data, []byte("abcd")) + + // Message 2 with a different kind — should not trigger kind check. + pkt, ready, err = pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 2}, Kind: KindClose, Done: true, + }) + assert.NoError(t, err) + assert.That(t, ready) + assert.Equal(t, pkt.Kind, KindClose) +} + +// Reset clears all state so the assembler can be reused for a new stream. +func TestPacketAssembler_Reset(t *testing.T) { + pa := NewPacketAssembler() + pa.SetStreamID(1) + + // Complete a packet on stream 1. + _, _, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Done: true, + }) + assert.NoError(t, err) + + // After reset, stream ID is cleared and must be re-inferred. + pa.Reset() + + // A frame for stream 2 should now be accepted. + pkt, ready, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 2, Message: 1}, Kind: KindMessage, Data: []byte("new"), Done: true, + }) + assert.NoError(t, err) + assert.That(t, ready) + assert.DeepEqual(t, pkt.Data, []byte("new")) + assert.Equal(t, pkt.ID.Stream, uint64(2)) +} diff --git a/drpcwire/reader.go b/drpcwire/reader.go index c9ac3979..d5ab5800 100644 --- a/drpcwire/reader.go +++ b/drpcwire/reader.go @@ -16,13 +16,12 @@ type ReaderOptions struct { MaximumBufferSize int } -// Reader reconstructs packets from frames read from an io.Reader. +// Reader reads frames from an io.Reader. type Reader struct { opts ReaderOptions r io.Reader curr []byte buf []byte - id ID rerr error } @@ -35,12 +34,12 @@ type Reader struct { // 9: maximum varint data length const maxFrameOverhead = 1 + 9 + 9 + 9 -// NewReader constructs a Reader to read Packets from the io.Reader. +// NewReader constructs a Reader to read Frames from the io.Reader. func NewReader(r io.Reader) *Reader { return NewReaderWithOptions(r, ReaderOptions{}) } -// NewReaderWithOptions constructs a Reader to read Packets from +// NewReaderWithOptions constructs a Reader to read Frames from // the io.Reader. It uses the provided options to manage buffering. func NewReaderWithOptions(r io.Reader, opts ReaderOptions) *Reader { if opts.MaximumBufferSize == 0 { @@ -50,10 +49,9 @@ func NewReaderWithOptions(r io.Reader, opts ReaderOptions) *Reader { return &Reader{ opts: opts, r: r, - // Err on the side of a smaller buffer since ReadPacket will lazily + // Err on the side of a smaller buffer since ReadFrame will lazily // grow this buffer. curr: make([]byte, 0, 4096), - id: ID{Stream: 1, Message: 1}, } } @@ -76,29 +74,14 @@ func (r *Reader) read(p []byte) (n int, err error) { return 0, drpc.InternalError.Wrap(io.ErrNoProgress) } -// ReadPacket reads a packet from the io.Reader. It is equivalent to -// calling ReadPacketUsing(nil). -func (r *Reader) ReadPacket() (pkt Packet, err error) { - return r.ReadPacketUsing(nil) -} - -// ReadPacketUsing reads a packet from the io.Reader. IDs read from -// frames must be monotonically increasing. When a new ID is read, the -// old data is discarded. This allows for easier asynchronous interrupts. -// If the amount of data in the Packet becomes too large, an error is -// returned. The returned packet's Data field is constructed by appending -// to the provided buf after it has been resliced to be zero length. -func (r *Reader) ReadPacketUsing(buf []byte) (pkt Packet, err error) { - pkt.Data = buf[:0] - - var fr Frame - var ok bool - +// ReadFrame reads a single frame from the io.Reader. +func (r *Reader) ReadFrame() (fr Frame, err error) { for { + var ok bool r.curr, fr, ok, err = ParseFrame(r.curr) switch { case err != nil: - return Packet{}, drpc.ProtocolError.Wrap(err) + return Frame{}, drpc.ProtocolError.Wrap(err) case !ok: // r.curr doesn't have enough data for a full frame, so prepend @@ -115,62 +98,28 @@ func (r *Reader) ReadPacketUsing(buf []byte) (pkt Packet, err error) { n, err := r.read(r.buf[len(r.buf):cap(r.buf)]) if err != nil { - return Packet{}, err + return Frame{}, err } ncap := uint(len(r.buf) + n) if ncap > uint(cap(r.buf)) { - return Packet{}, drpc.ProtocolError.New("data overflow") + return Frame{}, drpc.ProtocolError.New("data overflow") } r.buf = r.buf[:ncap] if len(r.buf)-maxFrameOverhead > r.opts.MaximumBufferSize { - return Packet{}, drpc.ProtocolError.New("data overflow") + return Frame{}, drpc.ProtocolError.New("data overflow") } r.curr = r.buf continue } - // since we got a packet, signal that we need to restore buf with - // whatever remains in r.curr the next time we don't have a packet. + // since we got a frame, signal that we need to restore buf with + // whatever remains in r.curr the next time we don't have a frame. if len(r.buf) > 0 { r.buf = r.buf[:0] } - - // If any frames are set to control, then the whole packet is - // considered to be control. - pkt.Control = pkt.Control || fr.Control - - switch { - case fr.ID.Less(r.id): - return Packet{}, drpc.ProtocolError.New("id monotonicity violation (fr:%v r:%v)", fr.ID, r.id) - - case r.id != fr.ID || pkt.ID == ID{}: - r.id = fr.ID - - pkt = Packet{ - Data: pkt.Data[:0], - ID: fr.ID, - Kind: fr.Kind, - Control: fr.Control, - } - - case fr.Kind != pkt.Kind: - return Packet{}, drpc.ProtocolError.New("packet kind change (fr:%v pkt:%v)", fr.Kind, pkt.Kind) - } - - pkt.Data = append(pkt.Data, fr.Data...) - - switch { - case len(pkt.Data) > r.opts.MaximumBufferSize: - return Packet{}, drpc.ProtocolError.New("data overflow (len:%v)", len(pkt.Data)) - - case fr.Done: - // increment the message id so that we do not accept any frames - // with the same id. - r.id.Message++ - return pkt, nil - } + return fr, nil } } diff --git a/drpcwire/reader_test.go b/drpcwire/reader_test.go index d57145b8..4a551ff2 100644 --- a/drpcwire/reader_test.go +++ b/drpcwire/reader_test.go @@ -15,176 +15,145 @@ import ( "github.com/zeebo/assert" ) -func TestReader(t *testing.T) { - type testCase struct { - Packets []Packet - Frames []Frame - Error string - Options ReaderOptions - } - - p := func(kind Kind, id uint64, control bool, data string) Packet { - return Packet{ - Data: []byte(data), - ID: ID{Stream: 1, Message: id}, - Kind: kind, - Control: control, - } - } - - f := func(kind Kind, id uint64, data string, done, control bool) Frame { +func TestReadFrame(t *testing.T) { + f := func(kind Kind, sid, mid uint64, data string, done, control bool) Frame { return Frame{ Data: []byte(data), - ID: ID{Stream: 1, Message: id}, + ID: ID{Stream: sid, Message: mid}, Kind: kind, Done: done, Control: control, } } - m := func(pkt Packet, frames ...Frame) testCase { - return testCase{ - Packets: []Packet{pkt}, - Frames: frames, + t.Run("SingleFrame", func(t *testing.T) { + fr := f(KindMessage, 1, 1, "hello", true, false) + + buf := AppendFrame(nil, fr) + rd := NewReader(bytes.NewReader(buf)) + + got, err := rd.ReadFrame() + assert.NoError(t, err) + assert.DeepEqual(t, got, fr) + + _, err = rd.ReadFrame() + assert.That(t, errors.Is(err, io.EOF)) + }) + + t.Run("MultipleFrames", func(t *testing.T) { + // Frames are returned individually even when they share a message + // ID and have done=false. Reader does no assembly — that's the + // stream's job. + frames := []Frame{ + f(KindMessage, 1, 1, "hello", false, false), + f(KindMessage, 1, 1, " ", false, false), + f(KindMessage, 1, 1, "world", true, false), + f(KindClose, 1, 2, "", true, false), } - } - megaFrames := make([]Frame, 0, 10*1024) - for i := 0; i < 10*1024; i++ { - megaFrames = append(megaFrames, f(KindMessage, 1, strings.Repeat("X", 1024), false, false)) - } - megaFrames = append(megaFrames, f(KindMessage, 1, "", true, false)) - - // 1 more than the maximum frame overhead is the minimum required to overflow - const overFrame = maxFrameOverhead + 1 - - cases := []testCase{ - m(p(KindMessage, 1, false, "hello world"), - f(KindMessage, 1, "hello", false, false), - f(KindMessage, 1, " ", false, false), - f(KindMessage, 1, "world", true, false)), - - m(p(KindMessage, 1, true, "hello world"), - f(KindMessage, 1, "hello", false, false), - f(KindMessage, 1, " ", false, true), - f(KindMessage, 1, "world", true, false)), - - m(p(KindClose, 2, false, ""), - f(KindMessage, 1, "hello", false, false), - f(KindMessage, 1, " ", false, false), - f(KindClose, 2, "", true, false)), - - { - Packets: []Packet{ - p(KindClose, 2, false, ""), - }, - Frames: []Frame{ - f(KindMessage, 1, "1", false, false), - f(KindClose, 2, "", true, false), - f(KindMessage, 1, "1", true, false), - }, - Error: "id monotonicity violation", - }, - - { // a single frame that's too large - Frames: []Frame{f(KindMessage, 1, strings.Repeat("X", 4<<20+overFrame), true, false)}, - Error: "data overflow", - }, - - { // a single frame that's too large with limited size - Frames: []Frame{f(KindMessage, 1, strings.Repeat("X", 1000+overFrame), true, false)}, - Error: "data overflow", - Options: ReaderOptions{MaximumBufferSize: 1000}, - }, - - { // multiple frames that make too large a packet - Frames: megaFrames, - Error: "data overflow", - }, - - { // multiple frames that make too large a packet with limited size - Frames: []Frame{ - f(KindMessage, 1, strings.Repeat("X", 500), false, false), - f(KindMessage, 1, strings.Repeat("X", 400), false, false), - f(KindMessage, 1, strings.Repeat("X", 100), false, false), - f(KindMessage, 1, strings.Repeat("X", overFrame), true, false), - }, - Error: "data overflow", - Options: ReaderOptions{MaximumBufferSize: 1000}, - }, - - { // Control bit is preserved - Packets: []Packet{ - p(KindClose, 2, false, ""), - p(KindMessage, 3, true, "ab"), - }, - Frames: []Frame{ - f(KindMessage, 1, "1", false, false), - f(KindClose, 2, "", true, false), - f(KindMessage, 3, "a", false, true), - f(KindMessage, 3, "b", true, false), - }, - }, - - { // packet kind changes - Frames: []Frame{ - f(KindMessage, 1, "", false, false), - f(KindClose, 1, "", false, false), - }, - Error: "packet kind change", - }, - - { // id monotonicity from id reuse - Packets: []Packet{ - p(KindMessage, 1, false, "1"), - }, - Frames: []Frame{ - f(KindMessage, 1, "1", true, false), - f(KindMessage, 1, "2", true, false), - }, - Error: "id monotonicity violation", - }, - - { // message id zero is not allowed - Frames: []Frame{{ID: ID{Stream: 1, Message: 0}}}, - Error: "id monotonicity violation", - }, - - { // stream id zero is not allowed - Frames: []Frame{{ID: ID{Stream: 0, Message: 1}}}, - Error: "id monotonicity violation", - }, - } + var buf []byte + for _, fr := range frames { + buf = AppendFrame(buf, fr) + } + + rd := NewReader(bytes.NewReader(buf)) + for _, exp := range frames { + got, err := rd.ReadFrame() + assert.NoError(t, err) + assert.DeepEqual(t, got, exp) + } + + _, err := rd.ReadFrame() + assert.That(t, errors.Is(err, io.EOF)) + }) + + t.Run("NoMonotonicity", func(t *testing.T) { + // Reader no longer enforces monotonicity. Frames with decreasing + // IDs should be returned without error. + frames := []Frame{ + f(KindMessage, 1, 5, "a", true, false), + f(KindMessage, 1, 3, "b", true, false), + } - for _, tc := range cases { var buf []byte - for _, fr := range tc.Frames { + for _, fr := range frames { buf = AppendFrame(buf, fr) } - rd := NewReaderWithOptions(bytes.NewReader(buf), tc.Options) - for _, expPkt := range tc.Packets { - pkt, err := rd.ReadPacket() + rd := NewReader(bytes.NewReader(buf)) + for _, exp := range frames { + got, err := rd.ReadFrame() assert.NoError(t, err) - assert.DeepEqual(t, expPkt, pkt) + assert.DeepEqual(t, got, exp) } + }) + + t.Run("BufferOverflow_SingleLargeFrame", func(t *testing.T) { + // 1 more than the maximum frame overhead is the minimum required to overflow. + const overFrame = maxFrameOverhead + 1 + fr := f(KindMessage, 1, 1, strings.Repeat("X", 4<<20+overFrame), true, false) - _, err := rd.ReadPacket() + buf := AppendFrame(nil, fr) + rd := NewReader(bytes.NewReader(buf)) + + _, err := rd.ReadFrame() assert.Error(t, err) - if tc.Error != "" { - assert.That(t, strings.Contains(err.Error(), tc.Error)) - } else { - assert.Equal(t, err, io.EOF) - } - } + assert.That(t, strings.Contains(err.Error(), "data overflow")) + }) + + t.Run("BufferOverflow_CustomLimit", func(t *testing.T) { + const overFrame = maxFrameOverhead + 1 + fr := f(KindMessage, 1, 1, strings.Repeat("X", 1000+overFrame), true, false) + + buf := AppendFrame(nil, fr) + rd := NewReaderWithOptions(bytes.NewReader(buf), ReaderOptions{MaximumBufferSize: 1000}) + + _, err := rd.ReadFrame() + assert.Error(t, err) + assert.That(t, strings.Contains(err.Error(), "data overflow")) + }) + + t.Run("ErrorWithData", func(t *testing.T) { + // If the underlying reader returns data and an error together, + // the frame should still be parsed from the data. + rd := NewReader(readerFunc(func(b []byte) (int, error) { + out := AppendFrame(b[:0:8], Frame{ + Data: []byte("test"), + ID: ID{1, 1}, + Kind: KindMessage, + Done: true, + }) + return len(out), io.EOF + })) + + got, err := rd.ReadFrame() + assert.NoError(t, err) + assert.DeepEqual(t, got, Frame{ + Data: []byte("test"), + ID: ID{1, 1}, + Kind: KindMessage, + Done: true, + }) + + _, err = rd.ReadFrame() + assert.That(t, errors.Is(err, io.EOF)) + }) + + t.Run("ErrorNoProgress", func(t *testing.T) { + rd := NewReader(readerFunc(func(b []byte) (int, error) { + return 0, nil + })) + + _, err := rd.ReadFrame() + assert.That(t, errors.Is(err, io.ErrNoProgress)) + }) } -func TestReaderRandomized(t *testing.T) { +func TestReadFrame_Randomized(t *testing.T) { seed := time.Now().UnixNano() t.Log("seed:", seed) rng := rand.New(rand.NewSource(seed)) - // create a function to get a predefined sequence of bytes bid := 0 get := func(n int) []byte { out := make([]byte, n) @@ -195,75 +164,40 @@ func TestReaderRandomized(t *testing.T) { return out } - // construct a random sequence of frames of different sizes - // to attempt to capture any bugs from buffer management + // Build a random sequence of frames with varying sizes. + var frames []Frame var buf []byte mid := uint64(1) done := false for i := 0; i < 1000; i++ { - buf = AppendFrame(buf, Frame{ + data := get(rng.Intn(8192)) + fr := Frame{ ID: ID{Stream: 1, Message: mid}, - Data: get(rng.Intn(8192)), + Data: data, Done: done, - }) + } + frames = append(frames, fr) + buf = AppendFrame(buf, fr) if done { mid++ } - done = rng.Intn(10) == 0 } - // read all of the packets back which should have the - // exact sequence of bytes, so we reset bid to generate - // the sequence again. + // ReadFrame should return each frame individually. bid = 0 r := NewReader(bytes.NewBuffer(buf)) - for i := 1; ; i++ { - pkt, err := r.ReadPacket() - if errors.Is(err, io.EOF) { - break - } + for _, exp := range frames { + got, err := r.ReadFrame() assert.NoError(t, err) - assert.Equal(t, pkt.ID.Message, i) - assert.Equal(t, pkt.Data, get(len(pkt.Data))) + assert.Equal(t, got.ID, exp.ID) + assert.Equal(t, got.Done, exp.Done) + assert.Equal(t, got.Data, get(len(exp.Data))) } } type readerFunc func([]byte) (int, error) func (fn readerFunc) Read(p []byte) (int, error) { return fn(p) } - -func TestReaderErrorWithData(t *testing.T) { - r := NewReader(readerFunc(func(b []byte) (int, error) { - out := AppendFrame(b[:0:8], Frame{ - Data: []byte("test"), - ID: ID{1, 1}, - Kind: KindMessage, - Done: true, - }) - return len(out), io.EOF - })) - - pkt, err := r.ReadPacket() - assert.NoError(t, err) - assert.Equal(t, pkt, Packet{ - Data: []byte("test"), - ID: ID{1, 1}, - Kind: KindMessage, - Control: false, - }) - - _, err = r.ReadPacket() - assert.Equal(t, err, io.EOF) -} - -func TestReaderErrorNoProgress(t *testing.T) { - r := NewReader(readerFunc(func(b []byte) (int, error) { - return 0, nil - })) - - _, err := r.ReadPacket() - assert.That(t, errors.Is(err, io.ErrNoProgress)) -} diff --git a/drpcwire/writer.go b/drpcwire/writer.go deleted file mode 100644 index fe909cff..00000000 --- a/drpcwire/writer.go +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright (C) 2019 Storj Labs, Inc. -// See LICENSE for copying information. - -package drpcwire - -import ( - "fmt" - "io" - "sync" - "sync/atomic" - - "storj.io/drpc/drpcdebug" -) - -// -// Writer -// - -// Writer is a helper to buffer and write packets and frames to an io.Writer. -type Writer struct { - empty uint32 - w io.Writer - size int - mu sync.Mutex - buf []byte -} - -// NewWriter returns a Writer that will attempt to buffer size data before -// sending it to the io.Writer. -func NewWriter(w io.Writer, size int) *Writer { - if size == 0 { - size = 4 * 1024 - } - - return &Writer{ - w: w, - size: size, - buf: make([]byte, 0, size), - } -} - -func (b *Writer) log(what string, cb func() string) { - if drpcdebug.Enabled { - drpcdebug.Log(func() (_, _, _ string) { return fmt.Sprintf("", b), what, cb() }) - } -} - -// WritePacket writes the packet as a single frame, ignoring any size -// constraints. -func (b *Writer) WritePacket(pkt Packet) (err error) { - return b.WriteFrame(Frame{ - Data: pkt.Data, - ID: pkt.ID, - Kind: pkt.Kind, - Control: pkt.Control, - Done: true, - }) -} - -// Empty returns true if there are no bytes buffered in the writer. -func (b *Writer) Empty() bool { - return atomic.LoadUint32(&b.empty) == 0 -} - -// Reset clears any pending data in the buffer. -func (b *Writer) Reset() *Writer { - b.mu.Lock() - defer b.mu.Unlock() - - b.buf = b.buf[:0] - atomic.StoreUint32(&b.empty, 0) - return b -} - -// WriteFrame appends the frame into the buffer, and if the buffer is larger -// than the configured size, flushes it. -func (b *Writer) WriteFrame(fr Frame) (err error) { - b.mu.Lock() - defer b.mu.Unlock() - - if len(b.buf) == 0 { - atomic.StoreUint32(&b.empty, 1) - } - b.buf = AppendFrame(b.buf, fr) - if len(b.buf) >= b.size { - b.log("FLUSH", func() string { return fmt.Sprintf("buffer: %d > %d", len(b.buf), b.size) }) - _, err = b.w.Write(b.buf) - b.buf = b.buf[:0] - atomic.StoreUint32(&b.empty, 0) - } - return err -} - -// Flush forces a flush of any buffered data to the io.Writer. It is a no-op if -// there is no data in the buffer. -func (b *Writer) Flush() (err error) { - b.mu.Lock() - defer b.mu.Unlock() - - if len(b.buf) > 0 { - _, err = b.w.Write(b.buf) - b.log("FLUSH", func() string { return fmt.Sprintf("explicit: %d", len(b.buf)) }) - b.buf = b.buf[:0] - atomic.StoreUint32(&b.empty, 0) - } - return err -} diff --git a/drpcwire/writer_test.go b/drpcwire/writer_test.go deleted file mode 100644 index 840a6164..00000000 --- a/drpcwire/writer_test.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (C) 2021 Storj Labs, Inc. -// See LICENSE for copying information. - -package drpcwire - -import ( - "bytes" - "testing" - - "github.com/zeebo/assert" -) - -func TestWriter(t *testing.T) { - run := func(size int) func(t *testing.T) { - return func(t *testing.T) { - var exp []byte - var got bytes.Buffer - - wr := NewWriter(&got, size) - for i := 0; i < 1000; i++ { - fr := RandFrame() - exp = AppendFrame(exp, fr) - assert.NoError(t, wr.WriteFrame(fr)) - } - assert.NoError(t, wr.Flush()) - assert.That(t, bytes.Equal(exp, got.Bytes())) - } - } - - t.Run("Size 0B", run(0)) - t.Run("Size 1MB", run(1024*1024)) -} diff --git a/internal/drpcopts/stream.go b/internal/drpcopts/stream.go index 6ab0511a..15fadac6 100644 --- a/internal/drpcopts/stream.go +++ b/internal/drpcopts/stream.go @@ -23,12 +23,6 @@ func GetStreamTransport(opts *Stream) drpc.Transport { return opts.transport } // SetStreamTransport sets the drpc.Transport stored in the options. func SetStreamTransport(opts *Stream, tr drpc.Transport) { opts.transport = tr } -// GetStreamFin returns the chan<- struct{} stored in the options. -func GetStreamFin(opts *Stream) chan<- struct{} { return opts.fin } - -// SetStreamFin sets the chan<- struct{} stored in the options. -func SetStreamFin(opts *Stream, fin chan<- struct{}) { opts.fin = fin } - // GetStreamKind returns the StreamKind stored in the options. func GetStreamKind(opts *Stream) drpc.StreamKind { return opts.kind } diff --git a/internal/grpccompat/benchmark_test.go b/internal/grpccompat/benchmark_test.go index 3bd0af71..c59f59a8 100644 --- a/internal/grpccompat/benchmark_test.go +++ b/internal/grpccompat/benchmark_test.go @@ -13,7 +13,6 @@ import ( "google.golang.org/protobuf/proto" "storj.io/drpc/drpcmanager" - "storj.io/drpc/drpcstream" ) var benchmarkImpl = &serviceImpl{ @@ -65,11 +64,7 @@ var benchmarkImpl = &serviceImpl{ } func benchmarkBoth(b *testing.B, fn func(b *testing.B, in *In, client Client)) { - options := drpcmanager.Options{ - Stream: drpcstream.Options{ - ManualFlush: true, - }, - } + options := drpcmanager.Options{} for _, size := range []struct { Name string diff --git a/internal/grpccompat/common_test.go b/internal/grpccompat/common_test.go index 8cc5c777..65bb2c9b 100644 --- a/internal/grpccompat/common_test.go +++ b/internal/grpccompat/common_test.go @@ -16,7 +16,6 @@ import ( "testing" "time" - "github.com/zeebo/assert" "github.com/zeebo/errs" "google.golang.org/grpc" "google.golang.org/grpc/codes" diff --git a/internal/integration/cache_test.go b/internal/integration/cache_test.go deleted file mode 100644 index d4e1c959..00000000 --- a/internal/integration/cache_test.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (C) 2020 Storj Labs, Inc. -// See LICENSE for copying information. - -package integration - -import ( - "context" - "testing" - - "github.com/zeebo/assert" - "github.com/zeebo/errs" - - "storj.io/drpc/drpccache" - "storj.io/drpc/drpctest" -) - -func TestCache(t *testing.T) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - // create a server that signals then waits for the context to die - cli, close := createConnection(t, impl{ - Method1Fn: func(ctx context.Context, _ *In) (*Out, error) { - cache := drpccache.FromContext(ctx) - if cache == nil { - return nil, errs.New("no cache associated with context") - } - cache.LoadOrCreate("value", func() interface{} { return 42 }) - return &Out{Out: 123}, nil - }, - Method2Fn: func(stream DRPCService_Method2Stream) error { - cache := drpccache.FromContext(stream.Context()) - if cache == nil { - return errs.New("no cache associated with context") - } - value, _ := cache.Load("value").(int) - return stream.SendAndClose(&Out{Out: int64(value)}) - }, - }) - defer close() - - { // value not yet cached - stream, err := cli.Method2(ctx) - assert.NoError(t, err) - out, err := stream.CloseAndRecv() - assert.NoError(t, err) - assert.True(t, Equal(out, &Out{Out: 0})) - } - - { // store value in the cache - out, err := cli.Method1(ctx, in(1)) - assert.NoError(t, err) - assert.True(t, Equal(out, &Out{Out: 123})) - } - - { // expected value in the cache - stream, err := cli.Method2(ctx) - assert.NoError(t, err) - out, err := stream.CloseAndRecv() - assert.NoError(t, err) - assert.True(t, Equal(out, &Out{Out: 42})) - } -} diff --git a/internal/integration/cancel_test.go b/internal/integration/cancel_test.go index 03a38604..dec012da 100644 --- a/internal/integration/cancel_test.go +++ b/internal/integration/cancel_test.go @@ -15,7 +15,6 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "storj.io/drpc" "storj.io/drpc/drpcconn" "storj.io/drpc/drpcpool" "storj.io/drpc/drpctest" @@ -150,15 +149,6 @@ func TestCancellationPropagation_Stream(t *testing.T) { clientctx.Run(func(ctx context.Context) { stream, _ := cli.Method4(ctx) - // this is a weird case where the rpc does not send or receive or even - // close the stream, and neither does the other side, and so we have to - // explicitly flush the invoke. - type ( - getStreamer interface{ GetStream() drpc.Stream } - rawFlusher interface{ RawFlush() error } - ) - _ = stream.(getStreamer).GetStream().(rawFlusher).RawFlush() - called <- struct{}{} select { case <-stream.Context().Done(): diff --git a/internal/integration/common_test.go b/internal/integration/common_test.go index 5acb61a1..1fd555ef 100644 --- a/internal/integration/common_test.go +++ b/internal/integration/common_test.go @@ -16,7 +16,6 @@ import ( "google.golang.org/grpc/status" "storj.io/drpc/drpcconn" - "storj.io/drpc/drpcmanager" "storj.io/drpc/drpcmetadata" "storj.io/drpc/drpcmux" "storj.io/drpc/drpcserver" @@ -45,11 +44,7 @@ func createRawConnection(t testing.TB, server DRPCServiceServer, ctx *drpctest.T assert.NoError(t, DRPCRegisterService(mux, server)) srv := drpcserver.New(mux) ctx.Run(func(ctx context.Context) { _ = srv.ServeOne(ctx, c1) }) - return drpcconn.NewWithOptions(c2, drpcconn.Options{ - Manager: drpcmanager.Options{ - SoftCancel: true, - }, - }) + return drpcconn.NewWithOptions(c2, drpcconn.Options{}) } func createConnection(t testing.TB, server DRPCServiceServer) (DRPCServiceClient, func()) { diff --git a/internal/integration/simple_test.go b/internal/integration/simple_test.go index 26a2e8c9..8dbf47a5 100644 --- a/internal/integration/simple_test.go +++ b/internal/integration/simple_test.go @@ -85,6 +85,131 @@ func TestSimple(t *testing.T) { } } +func TestMultiplexedStreams(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + // Echo server: sends back each received message immediately. + echoServer := impl{ + Method1Fn: standardImpl.Method1Fn, + Method2Fn: standardImpl.Method2Fn, + Method3Fn: standardImpl.Method3Fn, + Method4Fn: func(stream DRPCService_Method4Stream) error { + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + } + + cli, close := createConnection(t, echoServer) + defer close() + + // Open two bidi streams on the same connection. + s1, err := cli.Method4(ctx) + assert.NoError(t, err) + + s2, err := cli.Method4(ctx) + assert.NoError(t, err) + + // Send on both streams interleaved. + assert.NoError(t, s1.Send(&In{In: 1})) + assert.NoError(t, s2.Send(&In{In: 2})) + + // Receive from both: each stream gets its own response. + out1, err := s1.Recv() + assert.NoError(t, err) + assert.Equal(t, out1.Out, int64(1)) + + out2, err := s2.Recv() + assert.NoError(t, err) + assert.Equal(t, out2.Out, int64(2)) + + // Close both streams. + assert.NoError(t, s1.CloseSend()) + assert.NoError(t, s2.CloseSend()) + + _, err = s1.Recv() + assert.That(t, errors.Is(err, io.EOF)) + + _, err = s2.Recv() + assert.That(t, errors.Is(err, io.EOF)) +} + +func TestConcurrentStreams(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + echoServer := impl{ + Method1Fn: standardImpl.Method1Fn, + Method2Fn: standardImpl.Method2Fn, + Method3Fn: standardImpl.Method3Fn, + Method4Fn: func(stream DRPCService_Method4Stream) error { + for { + msg, err := stream.Recv() + if err != nil { + return nil + } + if err := stream.Send(&Out{Out: msg.In}); err != nil { + return err + } + } + }, + } + + cli, close := createConnection(t, echoServer) + defer close() + + const numStreams = 10 + const numMessages = 20 + + errs := make(chan error, numStreams) + for i := 0; i < numStreams; i++ { + i := i + ctx.Run(func(ctx context.Context) { + select { + case <-ctx.Done(): + case errs <- func() error { + stream, err := cli.Method4(ctx) + if err != nil { + return fmt.Errorf("stream %d: open: %w", i, err) + } + for j := 0; j < numMessages; j++ { + val := int64(i*1000 + j) + if err := stream.Send(&In{In: val}); err != nil { + return fmt.Errorf("stream %d: send %d: %w", i, j, err) + } + out, err := stream.Recv() + if err != nil { + return fmt.Errorf("stream %d: recv %d: %w", i, j, err) + } + if out.Out != val { + return fmt.Errorf("stream %d: msg %d: got %d, want %d", i, j, out.Out, val) + } + } + if err := stream.CloseSend(); err != nil { + return fmt.Errorf("stream %d: close send: %w", i, err) + } + _, err = stream.Recv() + if !errors.Is(err, io.EOF) { + return fmt.Errorf("stream %d: final recv: got %v, want EOF", i, err) + } + return nil + }(): + } + }) + } + + for i := 0; i < numStreams; i++ { + assert.NoError(t, <-errs) + } +} + func TestConcurrent(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() @@ -164,52 +289,3 @@ func TestServerStats(t *testing.T) { "/service.Service/Method3": {Read: 2, Written: 6}, }) } - -func TestClientStats(t *testing.T) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - c1, c2 := net.Pipe() - mux := drpcmux.New() - _ = DRPCRegisterService(mux, standardImpl) - - srv := drpcserver.New(mux) - ctx.Run(func(ctx context.Context) { _ = srv.ServeOne(ctx, c1) }) - - conn := drpcconn.NewWithOptions(c2, drpcconn.Options{ - CollectStats: true, - }) - defer func() { _ = conn.Close() }() - cli := NewDRPCServiceClient(conn) - - assert.Equal(t, srv.Stats(), map[string]drpcstats.Stats{}) - - _, err := cli.Method1(ctx, in(5)) - assert.Error(t, err) - - assert.Equal(t, conn.Stats(), map[string]drpcstats.Stats{ - "/service.Service/Method1": {Read: 9, Written: 26}, - }) - - _, err = cli.Method1(ctx, in(1)) - assert.NoError(t, err) - - assert.Equal(t, conn.Stats(), map[string]drpcstats.Stats{ - "/service.Service/Method1": {Read: 9 + 2, Written: 26 + 26}, - }) - - stream, err := cli.Method3(ctx, in(3)) - assert.NoError(t, err) - for i := 0; i < 3; i++ { - _, err := stream.Recv() - assert.NoError(t, err) - } - _, err = stream.Recv() - assert.That(t, errors.Is(err, io.EOF)) - assert.NoError(t, stream.Close()) - - assert.Equal(t, conn.Stats(), map[string]drpcstats.Stats{ - "/service.Service/Method1": {Read: 9 + 2, Written: 26 + 26}, - "/service.Service/Method3": {Read: 6, Written: 26}, - }) -}