diff --git a/drpcstream/pktbuf.go b/drpcstream/pktbuf.go deleted file mode 100644 index db688649..00000000 --- a/drpcstream/pktbuf.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (C) 2019 Storj Labs, Inc. -// See LICENSE for copying information. - -package drpcstream - -import ( - "sync" -) - -type packetBuffer struct { - mu sync.Mutex - cond sync.Cond - err error - data []byte - set bool - held bool -} - -func (pb *packetBuffer) init() { - pb.cond.L = &pb.mu -} - -func (pb *packetBuffer) Close(err error) { - pb.mu.Lock() - defer pb.mu.Unlock() - - for pb.held { - pb.cond.Wait() - } - - if pb.err == nil { - pb.data = nil - pb.set = false - pb.err = err - pb.cond.Broadcast() - } -} - -func (pb *packetBuffer) Put(data []byte) { - pb.mu.Lock() - defer pb.mu.Unlock() - - for pb.set && pb.err == nil { - pb.cond.Wait() - } - if pb.err != nil { - return - } - - pb.data = data - pb.set = true - pb.held = false - pb.cond.Broadcast() - - for pb.set || pb.held { - pb.cond.Wait() - } -} - -func (pb *packetBuffer) Get() ([]byte, error) { - pb.mu.Lock() - defer pb.mu.Unlock() - - for !pb.set && pb.err == nil { - pb.cond.Wait() - } - if pb.err != nil { - return nil, pb.err - } - - pb.held = true - pb.cond.Broadcast() - - return pb.data, nil -} - -func (pb *packetBuffer) Done() { - pb.mu.Lock() - defer pb.mu.Unlock() - - pb.data = nil - pb.set = false - pb.held = false - pb.cond.Broadcast() -} diff --git a/drpcstream/spsc_queue.go b/drpcstream/spsc_queue.go new file mode 100644 index 00000000..c8e6b173 --- /dev/null +++ b/drpcstream/spsc_queue.go @@ -0,0 +1,147 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +package drpcstream + +import ( + "sync" +) + +// defaultPacketBufferSize is the number of messages the packet buffer can +// hold before the producer blocks. This decouples the transport reader +// from the consumer (RPC handler), preventing deadlocks when the handler +// is delayed before calling Recv. +const defaultPacketBufferSize = 10 + +// spscQueue is a bounded single-producer / single-consumer queue for byte +// slices. It is implemented as a ring buffer with mutex+cond synchronization. +// +// The producer calls Enqueue to copy data into the next write slot. If the +// queue is full, Enqueue blocks until a slot is freed or the queue is closed. +// +// The consumer calls Dequeue to get a reference to the next read slot. The +// returned slice is valid until Done is called. Done advances the read +// pointer and recycles the slot for reuse by the producer. +// +// Close sets an error and wakes all blocked waiters. After Close, Enqueue +// is a no-op and Dequeue returns the close error. +// +// Slots are pre-allocated and reused. Each slot's backing array grows via +// append to fit incoming data, then stays at its high-water mark, avoiding +// per-message allocation in steady state. +type spscQueue struct { + mu sync.Mutex + cond sync.Cond + + slots [][]byte // ring buffer of byte slices + mask int // len(slots) - 1, for fast modulo (capacity is power of 2) + + head int // next write position (producer) + tail int // next read position (consumer) + len int // number of items in the queue + + held bool // true between Dequeue and Done + err error // set by Close +} + +// newSPSCQueue creates a new SPSC queue. The capacity is rounded up to the +// next power of 2 (minimum 2). +func newSPSCQueue(capacity int) *spscQueue { + cap := roundUpPow2(capacity) + q := &spscQueue{ + slots: make([][]byte, cap), + mask: cap - 1, + } + q.cond.L = &q.mu + return q +} + +// Enqueue copies data into the next write slot. If the queue is full, it +// blocks until a slot is freed or the queue is closed. If the queue is +// closed, Enqueue returns silently without enqueuing. +func (q *spscQueue) Enqueue(data []byte) { + q.mu.Lock() + defer q.mu.Unlock() + + for q.len > q.mask && q.err == nil { + q.cond.Wait() + } + if q.err != nil { + return + } + + // Copy data into the slot, reusing the existing backing array if + // it has enough capacity. This avoids allocation in steady state. + q.slots[q.head&q.mask] = append(q.slots[q.head&q.mask][:0], data...) + q.head++ + q.len++ + q.cond.Broadcast() +} + +// Dequeue returns the data from the next read slot. If the queue is empty, +// it blocks until data is available or the queue is closed. The returned +// slice is valid until Done is called. +func (q *spscQueue) Dequeue() ([]byte, error) { + q.mu.Lock() + defer q.mu.Unlock() + + for q.len == 0 && q.err == nil { + q.cond.Wait() + } + if q.len == 0 { + // Queue is empty and closed — return the close error. + return nil, q.err + } + // Return data even if closed, draining pending items first. + + data := q.slots[q.tail&q.mask] + q.held = true + return data, nil +} + +// Done advances the read pointer, making the slot available for reuse. +// It must be called exactly once after each successful Dequeue. +func (q *spscQueue) Done() { + q.mu.Lock() + defer q.mu.Unlock() + + q.tail++ + q.len-- + q.held = false + q.cond.Broadcast() +} + +// Close marks the queue as closed with the given error. All blocked +// Enqueue and Dequeue calls are woken and will return. If Close has +// already been called, subsequent calls are no-ops. +func (q *spscQueue) Close(err error) { + q.mu.Lock() + defer q.mu.Unlock() + + // Wait for any in-progress Dequeue/Done pair to complete so that + // we don't race with the consumer's use of the dequeued data. + for q.held { + q.cond.Wait() + } + + if q.err == nil { + q.err = err + q.cond.Broadcast() + } +} + +// roundUpPow2 rounds n up to the next power of 2. Minimum is 2. +func roundUpPow2(n int) int { + if n <= 2 { + return 2 + } + n-- + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n++ + return n +} diff --git a/drpcstream/spsc_queue_test.go b/drpcstream/spsc_queue_test.go new file mode 100644 index 00000000..1b923654 --- /dev/null +++ b/drpcstream/spsc_queue_test.go @@ -0,0 +1,294 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +package drpcstream + +import ( + "errors" + "sync" + "testing" + "time" + + "github.com/zeebo/assert" +) + +func TestSPSCQueue_BasicEnqueueDequeue(t *testing.T) { + q := newSPSCQueue(4) + q.Enqueue([]byte("hello")) + data, err := q.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("hello")) + q.Done() +} + +func TestSPSCQueue_FIFO(t *testing.T) { + q := newSPSCQueue(4) + q.Enqueue([]byte("first")) + q.Enqueue([]byte("second")) + q.Enqueue([]byte("third")) + + for _, want := range []string{"first", "second", "third"} { + data, err := q.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte(want)) + q.Done() + } +} + +func TestSPSCQueue_SlotReuse(t *testing.T) { + // Capacity 2: verify slots are recycled after Done. + q := newSPSCQueue(2) + + for i := 0; i < 10; i++ { + q.Enqueue([]byte("data")) + data, err := q.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("data")) + q.Done() + } +} + +func TestSPSCQueue_BlockOnEmpty(t *testing.T) { + q := newSPSCQueue(4) + done := make(chan []byte, 1) + + go func() { + data, err := q.Dequeue() + assert.NoError(t, err) + done <- append([]byte(nil), data...) + q.Done() + }() + + // Consumer should be blocked. + select { + case <-done: + t.Fatal("Dequeue returned before Enqueue") + case <-time.After(50 * time.Millisecond): + } + + q.Enqueue([]byte("arrived")) + select { + case data := <-done: + assert.DeepEqual(t, data, []byte("arrived")) + case <-time.After(5 * time.Second): + t.Fatal("Dequeue did not unblock after Enqueue") + } +} + +func TestSPSCQueue_BlockOnFull(t *testing.T) { + q := newSPSCQueue(2) + q.Enqueue([]byte("a")) + q.Enqueue([]byte("b")) + + enqueued := make(chan struct{}) + go func() { + q.Enqueue([]byte("c")) // should block — queue is full + close(enqueued) + }() + + // Producer should be blocked. + select { + case <-enqueued: + t.Fatal("Enqueue returned on full queue") + case <-time.After(50 * time.Millisecond): + } + + // Consume one item to free a slot. + data, err := q.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("a")) + q.Done() + + select { + case <-enqueued: + case <-time.After(5 * time.Second): + t.Fatal("Enqueue did not unblock after Done") + } +} + +func TestSPSCQueue_CloseUnblocksProducer(t *testing.T) { + q := newSPSCQueue(2) + q.Enqueue([]byte("a")) + q.Enqueue([]byte("b")) + + returned := make(chan struct{}) + go func() { + q.Enqueue([]byte("c")) // blocks — full + close(returned) + }() + + select { + case <-returned: + t.Fatal("Enqueue returned before Close") + case <-time.After(50 * time.Millisecond): + } + + q.Close(errors.New("closed")) + + select { + case <-returned: + case <-time.After(5 * time.Second): + t.Fatal("Enqueue did not unblock after Close") + } +} + +func TestSPSCQueue_CloseUnblocksConsumer(t *testing.T) { + q := newSPSCQueue(4) + + returned := make(chan error, 1) + go func() { + _, err := q.Dequeue() + returned <- err + }() + + select { + case <-returned: + t.Fatal("Dequeue returned before Close") + case <-time.After(50 * time.Millisecond): + } + + closeErr := errors.New("done") + q.Close(closeErr) + + select { + case err := <-returned: + assert.Equal(t, err, closeErr) + case <-time.After(5 * time.Second): + t.Fatal("Dequeue did not unblock after Close") + } +} + +func TestSPSCQueue_CloseDrainsPendingItems(t *testing.T) { + q := newSPSCQueue(4) + q.Enqueue([]byte("pending")) + + closeErr := errors.New("closed") + q.Close(closeErr) + + // Pending items are drained before the close error is returned. + data, err := q.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("pending")) + q.Done() + + // Now the queue is empty and closed. + _, err = q.Dequeue() + assert.Equal(t, err, closeErr) +} + +func TestSPSCQueue_EnqueueAfterClose(t *testing.T) { + q := newSPSCQueue(4) + q.Close(errors.New("closed")) + + // Should be a no-op, not panic or block. + q.Enqueue([]byte("ignored")) +} + +func TestSPSCQueue_DoubleClose(t *testing.T) { + q := newSPSCQueue(4) + q.Close(errors.New("first")) + q.Close(errors.New("second")) // no-op + + _, err := q.Dequeue() + assert.Equal(t, err.Error(), "first") +} + +func TestSPSCQueue_CloseWaitsForHeld(t *testing.T) { + q := newSPSCQueue(4) + q.Enqueue([]byte("data")) + + data, err := q.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("data")) + // data is held — Done not yet called. + + closed := make(chan struct{}) + go func() { + q.Close(errors.New("closed")) + close(closed) + }() + + // Close should block because data is held. + select { + case <-closed: + t.Fatal("Close returned while data is held") + case <-time.After(50 * time.Millisecond): + } + + q.Done() + + select { + case <-closed: + case <-time.After(5 * time.Second): + t.Fatal("Close did not return after Done") + } +} + +func TestSPSCQueue_ConcurrentStress(t *testing.T) { + const numMessages = 10000 + q := newSPSCQueue(8) + + var wg sync.WaitGroup + wg.Add(2) + + // Producer. + go func() { + defer wg.Done() + for i := 0; i < numMessages; i++ { + q.Enqueue([]byte{byte(i), byte(i >> 8)}) + } + }() + + // Consumer: dequeue exactly numMessages items, then signal done. + received := 0 + go func() { + defer wg.Done() + for i := 0; i < numMessages; i++ { + _, err := q.Dequeue() + assert.NoError(t, err) + received++ + q.Done() + } + }() + + wg.Wait() + assert.Equal(t, received, numMessages) + q.Close(errors.New("done")) +} + +func TestSPSCQueue_DataIsolation(t *testing.T) { + // Verify that Enqueue copies data — modifying the source after + // Enqueue must not affect the queued data. + q := newSPSCQueue(4) + src := []byte("original") + q.Enqueue(src) + src[0] = 'X' // mutate source + + data, err := q.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("original")) + q.Done() +} + +func TestRoundUpPow2(t *testing.T) { + tests := []struct { + in, want int + }{ + {0, 2}, + {1, 2}, + {2, 2}, + {3, 4}, + {4, 4}, + {5, 8}, + {7, 8}, + {8, 8}, + {9, 16}, + {10, 16}, + {16, 16}, + {17, 32}, + } + for _, tt := range tests { + got := roundUpPow2(tt.in) + assert.Equal(t, got, tt.want) + } +} diff --git a/drpcstream/stream.go b/drpcstream/stream.go index 29ccd636..3319478a 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -54,7 +54,7 @@ type Stream struct { id drpcwire.ID wr *drpcwire.Writer - pbuf packetBuffer + pbuf *spscQueue wbuf []byte mu sync.Mutex // protects state transitions @@ -103,7 +103,7 @@ func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts O } // initialize the packet buffer - s.pbuf.init() + s.pbuf = newSPSCQueue(defaultPacketBufferSize) return s } @@ -232,7 +232,7 @@ func (s *Stream) HandlePacket(pkt drpcwire.Packet) (err error) { s.log("HANDLE", pkt.String) if pkt.Kind == drpcwire.KindMessage { - s.pbuf.Put(pkt.Data) + s.pbuf.Enqueue(pkt.Data) return nil } @@ -457,7 +457,7 @@ func (s *Stream) RawRecv() (data []byte, err error) { s.read.Lock() defer s.read.Unlock() - data, err = s.pbuf.Get() + data, err = s.pbuf.Dequeue() if err != nil { return nil, err } @@ -509,7 +509,7 @@ func (s *Stream) MsgRecv(msg drpc.Message, enc drpc.Encoding) (err error) { s.read.Lock() defer s.read.Unlock() - data, err := s.pbuf.Get() + data, err := s.pbuf.Dequeue() if err != nil { return err } diff --git a/internal/integration/transport_test.go b/internal/integration/transport_test.go index 03d68fa6..86e1dc3a 100644 --- a/internal/integration/transport_test.go +++ b/internal/integration/transport_test.go @@ -7,13 +7,18 @@ import ( "context" "errors" "io" + "net" "testing" + "time" "github.com/zeebo/assert" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "storj.io/drpc/drpcconn" + "storj.io/drpc/drpcmanager" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcserver" "storj.io/drpc/drpctest" ) @@ -137,3 +142,111 @@ func TestTransport_ErrorCausesCancel(t *testing.T) { assert.That(t, isExpectedError) } } + +// TestTransport_ClosedWhileHandlerBlockedBeforeRecv reproduces a deadlock +// where the server handler is doing work before calling Recv() while +// manageReader has already read a message from the transport and is blocked +// in packetBuffer.Put(). When the client closes the transport, manageReader +// cannot detect the closure because it is stuck in Put(), so the server +// stream's context is never canceled. +// +// This reproduces the issue seen in CockroachDB's TestReceiveSnapshotLogging +// "cancel during receive" subtest, where the snapshot receiver handler is +// blocked in BeforeRecvAcceptedSnapshot (a test knob) before calling +// MsgRecv(), and the delegate's cancellation closes the transport but the +// server never detects it. +func TestTransport_ClosedWhileHandlerBlockedBeforeRecv(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + // Test knobs — channels for synchronization, mirroring CockroachDB's + // receiveStartedCh and svrContextDone pattern. + handlerStarted := make(chan struct{}) + svrCtxDone := make(chan struct{}) + + // Set up server and client manually to use SoftCancel: false on + // the client, matching CockroachDB's configuration. + c1, c2 := net.Pipe() + defer func() { _ = c1.Close() }() + defer func() { _ = c2.Close() }() + + mux := drpcmux.New() + assert.NoError(t, DRPCRegisterService(mux, impl{ + Method2Fn: func(stream DRPCService_Method2Stream) error { + // The handler has started but has work to do before + // reading messages. In CockroachDB, this corresponds to + // the snapshot receiver sending the ACCEPTED response + // and hitting the BeforeRecvAcceptedSnapshot test knob + // before calling MsgRecv(). + close(handlerStarted) + + // Block until the stream context is canceled. With the + // deadlock bug, this never fires because manageReader + // is stuck in packetBuffer.Put() and cannot detect the + // transport closure. + select { + case <-stream.Context().Done(): + close(svrCtxDone) + return stream.Context().Err() + } + }, + })) + srv := drpcserver.New(mux) + ctx.Run(func(ctx context.Context) { _ = srv.ServeOne(ctx, c1) }) + + // Client connection with SoftCancel: false. When the client + // context is canceled, manageStream calls stream.Cancel() and + // then m.terminate() which closes the transport — the same + // code path as CockroachDB's delegate cancellation. + conn := drpcconn.NewWithOptions(c2, drpcconn.Options{ + Manager: drpcmanager.Options{SoftCancel: false}, + }) + defer func() { _ = conn.Close() }() + + // Create a cancelable context for the client RPC, simulating + // the delegate's context that gets canceled when the test + // calls cancel(). + rpcCtx, cancel := context.WithCancel(ctx) + defer cancel() + + cli := NewDRPCServiceClient(conn) + + // Start a client-streaming RPC. NewStream buffers the invoke + // packet but does not flush it — the flush happens on the first + // Send() call. + stream, err := cli.Method2(rpcCtx) + assert.NoError(t, err) + + // Send a message. This flushes both the invoke and the message + // in a single write. The server's manageReader reads the invoke + // (which triggers NewServerStream → handleRPC → handler start) + // and then the KindMessage (which enters packetBuffer.Put() and + // blocks because the handler hasn't called Recv() yet). + assert.NoError(t, stream.Send(in(1))) + + // Wait for the handler to start. + <-handlerStarted + + // Allow manageReader time to enter packetBuffer.Put() after + // delivering the invoke packet. + time.Sleep(100 * time.Millisecond) + + // Cancel the client RPC context. This triggers: + // manageStream detects ctx.Done() + // → stream.Cancel(ctx.Err()) returns false (not finished) + // → m.terminate(ctx.Err()) + // → m.tr.Close() closes the transport + // This is the same code path as CockroachDB's delegate + // cancellation closing the TCP connection to the receiver. + cancel() + + // The server handler's stream context should be canceled. + select { + case <-svrCtxDone: + // Transport closure propagated to the handler. + case <-time.After(5 * time.Second): + t.Fatal("deadlock: server handler's stream context was not " + + "canceled after client transport closed; manageReader is " + + "stuck in packetBuffer.Put()") + } +}